Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| def get_fourier_embeds_from_boundingbox(embed_dim, box): | |
| """ | |
| Args: | |
| embed_dim: int | |
| box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline | |
| Returns: | |
| [B x N x embed_dim] tensor of positional embeddings | |
| """ | |
| batch_size, num_boxes = box.shape[:2] | |
| emb = 100 ** (torch.arange(embed_dim) / embed_dim) | |
| emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) | |
| emb = emb * box.unsqueeze(-1) | |
| emb = torch.stack((emb.sin(), emb.cos()), dim=-1) | |
| emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) | |
| return emb | |
| class PixArtAlphaTextProjection(nn.Module): | |
| """ | |
| Projects caption embeddings. Also handles dropout for classifier-free guidance. | |
| Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
| """ | |
| def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): | |
| super().__init__() | |
| if out_features is None: | |
| out_features = hidden_size | |
| self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) | |
| if act_fn == "gelu_tanh": | |
| self.act_1 = nn.GELU(approximate="tanh") | |
| elif act_fn == "silu": | |
| self.act_1 = nn.SiLU() | |
| elif act_fn == "silu_fp32": | |
| self.act_1 = FP32SiLU() | |
| else: | |
| raise ValueError(f"Unknown activation function: {act_fn}") | |
| self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) | |
| def forward(self, caption): | |
| hidden_states = self.linear_1(caption) | |
| hidden_states = self.act_1(hidden_states) | |
| hidden_states = self.linear_2(hidden_states) | |
| return hidden_states | |
| class ObjectLayoutEncoder(nn.Module): | |
| def __init__(self, positive_len, out_dim, fourier_freqs=8 ,max_boxes_token_length=30): | |
| super().__init__() | |
| self.positive_len = positive_len | |
| self.out_dim = out_dim | |
| self.fourier_embedder_dim = fourier_freqs | |
| self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy #64 | |
| if isinstance(out_dim, tuple): | |
| out_dim = out_dim[0] | |
| self.null_positive_feature = torch.nn.Parameter(torch.zeros([max_boxes_token_length, self.positive_len])) | |
| self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) | |
| self.linears = PixArtAlphaTextProjection(in_features=self.positive_len + self.position_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu") | |
| def forward( | |
| self, | |
| boxes, # [B,10,4] | |
| masks, # [B,10] | |
| positive_embeddings, # [B,10,30,3072] | |
| ): | |
| B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072 | |
| positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072] | |
| masks = masks.reshape(B*N, 1, 1) # [B*10,1,1] | |
| # Process positional encoding | |
| xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # [B,10,64] | |
| xyxy_embedding = xyxy_embedding.reshape(B*N, -1) # [B*10,64] | |
| xyxy_null = self.null_position_feature.view(1, -1) # [1,64] | |
| # Expand positional encoding to match sequence dimension | |
| xyxy_embedding = xyxy_embedding.unsqueeze(1).expand(-1, S, -1) # [B*10,30,64] | |
| xyxy_null = xyxy_null.unsqueeze(0).expand(B*N, S, -1) # [B*10,30,64] | |
| # Apply mask | |
| xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # [B*10,30,64] | |
| # Process feature encoding | |
| positive_null = self.null_positive_feature.view(1, S, -1).expand(B*N, -1, -1) # [B*10,30,3072] | |
| positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null # [B*10,30,3072] | |
| # Concatenate positional encoding and feature encoding | |
| combined = torch.cat([positive_embeddings, xyxy_embedding], dim=-1) # [B*10,30,3072+64] | |
| # Process each box's features independently | |
| objs = self.linears(combined) # [B*10,30,3072] | |
| # Restore original shape | |
| objs = objs.reshape(B, N, S, -1) # [B,10,30,3072] | |
| return objs | |
| class ObjectLayoutEncoder_noFourier(nn.Module): | |
| def __init__(self, in_dim, out_dim): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| self.linears = PixArtAlphaTextProjection(in_features=self.in_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu") | |
| def forward( | |
| self, | |
| positive_embeddings, # [B,10,30,3072] | |
| ): | |
| B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072 | |
| positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072] | |
| # Process each box's features independently | |
| objs = self.linears(positive_embeddings) # [B*10,30,3072] | |
| # Restore original shape | |
| objs = objs.reshape(B, N, S, -1) # [B,10,30,3072] | |
| return objs |