import torch import torch.nn as nn import math import random class Pooler(nn.Module): def __init__(self, dim_in, dim_out, pool_out_size): super().__init__() assert isinstance(pool_out_size, str) self.pool_out_size = pool_out_size.split(",") print("pool_out_size: {}".format(self.pool_out_size)) self.mlp = nn.Sequential( nn.Linear(dim_in, dim_out), nn.GELU(), nn.Linear(dim_out, dim_out) ) def forward(self, x): """ Args: x (torch.Tensor): image features shape (b, v, D) Returns: shape (b, n, D) where n is self.num_latents """ b, v, d = x.shape s = int(math.sqrt(v -1)) x = x[:, 1:, :] # remove cls_token x_in = x.reshape(b, s, s, d) pool_out_size = random.choice(self.pool_out_size) if '+' in pool_out_size: # "16+32" means ensemble the pool size of 16 and 32 pool_out_size_list = [int(p) for p in pool_out_size.split('+')] else: pool_out_size_list = [int(pool_out_size)] pool_out_size_list.sort(reverse=True) x_out = [] for pool_out_size in pool_out_size_list: assert s % pool_out_size == 0 x = x_in.reshape(b, pool_out_size, s//pool_out_size, pool_out_size, s//pool_out_size, d) x = x.permute([0, 1, 3, 5, 2, 4]).reshape(b, pool_out_size * pool_out_size, d, -1).mean(-1) x = self.mlp(x) # [b, h*w, d] x_out.append(x) x_out = torch.cat(x_out, dim=-2) return x_out