maddigit's picture
Upload 27 files
ddbdbca verified
import torch
import torch.nn as nn
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:
embed_dim: int
box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline
Returns:
[B x N x embed_dim] tensor of positional embeddings
"""
batch_size, num_boxes = box.shape[:2]
emb = 100 ** (torch.arange(embed_dim) / embed_dim)
emb = emb[None, None, None].to(device=box.device, dtype=box.dtype)
emb = emb * box.unsqueeze(-1)
emb = torch.stack((emb.sin(), emb.cos()), dim=-1)
emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4)
return emb
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class ObjectLayoutEncoder(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8 ,max_boxes_token_length=30):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim
self.fourier_embedder_dim = fourier_freqs
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy #64
if isinstance(out_dim, tuple):
out_dim = out_dim[0]
self.null_positive_feature = torch.nn.Parameter(torch.zeros([max_boxes_token_length, self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
self.linears = PixArtAlphaTextProjection(in_features=self.positive_len + self.position_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
def forward(
self,
boxes, # [B,10,4]
masks, # [B,10]
positive_embeddings, # [B,10,30,3072]
):
B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
masks = masks.reshape(B*N, 1, 1) # [B*10,1,1]
# Process positional encoding
xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # [B,10,64]
xyxy_embedding = xyxy_embedding.reshape(B*N, -1) # [B*10,64]
xyxy_null = self.null_position_feature.view(1, -1) # [1,64]
# Expand positional encoding to match sequence dimension
xyxy_embedding = xyxy_embedding.unsqueeze(1).expand(-1, S, -1) # [B*10,30,64]
xyxy_null = xyxy_null.unsqueeze(0).expand(B*N, S, -1) # [B*10,30,64]
# Apply mask
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # [B*10,30,64]
# Process feature encoding
positive_null = self.null_positive_feature.view(1, S, -1).expand(B*N, -1, -1) # [B*10,30,3072]
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null # [B*10,30,3072]
# Concatenate positional encoding and feature encoding
combined = torch.cat([positive_embeddings, xyxy_embedding], dim=-1) # [B*10,30,3072+64]
# Process each box's features independently
objs = self.linears(combined) # [B*10,30,3072]
# Restore original shape
objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
return objs
class ObjectLayoutEncoder_noFourier(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.linears = PixArtAlphaTextProjection(in_features=self.in_dim,hidden_size=out_dim//2,out_features=out_dim, act_fn="silu")
def forward(
self,
positive_embeddings, # [B,10,30,3072]
):
B, N, S, C = positive_embeddings.shape # B: batch_size, N: 10, S: 30, C: 3072
positive_embeddings = positive_embeddings.reshape(B*N, S, C) # [B*10,30,3072]
# Process each box's features independently
objs = self.linears(positive_embeddings) # [B*10,30,3072]
# Restore original shape
objs = objs.reshape(B, N, S, -1) # [B,10,30,3072]
return objs