Any-to-Any
Safetensors
Transformers
LongCat-Next
longcat_next
text-generation
multimodal
custom_code
Instructions to use meituan-longcat/LongCat-Next with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use meituan-longcat/LongCat-Next with Transformers:
# Load model directly from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("meituan-longcat/LongCat-Next", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from typing import Iterable, Optional, Tuple | |
| import numpy as np | |
| from safetensors.torch import load_file | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.amp import autocast | |
| from torch.nn import functional as F | |
| from einops import rearrange | |
| from flash_attn import flash_attn_varlen_func | |
| from transformers.activations import ACT2FN | |
| from transformers.modeling_outputs import BaseModelOutput | |
| from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( | |
| Qwen2RMSNorm, | |
| Qwen2_5_VisionTransformerPretrainedModel, | |
| ) | |
| from transformers.utils import logging | |
| from .image_refiner import ( | |
| ImageRefinerContainer, | |
| RefinerImageProcessor, | |
| RefinerPipeline, | |
| de_transform, | |
| tensor2pil, | |
| ) | |
| from .refiner_modules import FlowMatchEulerDiscreteScheduler | |
| logger = logging.get_logger(__name__) | |
| def uniform_init(*shape): | |
| t = torch.zeros(shape) | |
| nn.init.kaiming_uniform_(t) | |
| return t | |
| class VQEmbedding(nn.Module): | |
| """VQ embedding module with ema update.""" | |
| def __init__(self, n_embed, embed_dim, ema=True, decay=0.99, restart_unused_codes=True, eps=1e-5, init_std=0.02): | |
| super().__init__() | |
| self.ema = ema | |
| self.decay = decay | |
| self.eps = eps | |
| self.restart_unused_codes = restart_unused_codes | |
| self.n_embed = n_embed | |
| self.init_std = init_std | |
| assert self.ema | |
| embed = uniform_init(n_embed + 1, embed_dim).to(torch.float32) | |
| self.embed = nn.Parameter(embed) | |
| self.embed_ema = nn.Parameter(embed[:-1, :].clone()) | |
| self.cluster_size_ema = nn.Parameter(torch.ones(n_embed)) | |
| del embed | |
| _ = [p.requires_grad_(False) for p in self.parameters()] | |
| def compute_distances(self, inputs): | |
| codebook_t = self.embed[:-1, :].t() | |
| (embed_dim, _) = codebook_t.shape | |
| inputs_shape = inputs.shape | |
| assert inputs_shape[-1] == embed_dim | |
| inputs_flat = inputs.reshape(-1, embed_dim) | |
| inputs_norm_sq = inputs_flat.pow(2.).sum(dim=1, keepdim=True) | |
| codebook_t_norm_sq = codebook_t.pow(2.).sum(dim=0, keepdim=True) | |
| distances = torch.addmm( | |
| inputs_norm_sq + codebook_t_norm_sq, | |
| inputs_flat, | |
| codebook_t, | |
| alpha=-2.0, | |
| ) | |
| distances = distances.reshape(*inputs_shape[:-1], -1) # [B, h, w, n_embed or n_embed+1] | |
| return distances | |
| def find_nearest_embedding(self, inputs): | |
| distances = self.compute_distances(inputs) # [B, h, w, n_embed or n_embed+1] | |
| embed_idxs = distances.argmin(dim=-1) # use padding index or not | |
| return embed_idxs | |
| def forward(self, inputs): | |
| if inputs.dtype != torch.float32: | |
| inputs = inputs.to(torch.float32) | |
| embed_idxs = self.find_nearest_embedding(inputs) | |
| embeds = self.embed[embed_idxs] | |
| return embeds, embed_idxs | |
| class RQBottleneck(nn.Module): | |
| """ | |
| Quantization bottleneck via Residual Quantization. | |
| Arguments: | |
| latent_shape (Tuple[int, int, int]): the shape of latents, denoted (H, W, D) | |
| code_shape (Tuple[int, int, int]): the shape of codes, denoted (h, w, d) | |
| n_embed (int, List, or Tuple): the number of embeddings (i.e., the size of codebook) | |
| If isinstance(n_embed, int), the sizes of all codebooks are same. | |
| shared_codebook (bool): If True, codebooks are shared in all location. If False, | |
| uses separate codebooks along the ``depth'' dimension. (default: False) | |
| restart_unused_codes (bool): If True, it randomly assigns a feature vector in the curruent batch | |
| as the new embedding of unused codes in training. (default: True) | |
| """ | |
| def __init__(self, | |
| latent_shape, | |
| code_shape, | |
| n_embed, | |
| decay=0.99, | |
| shared_codebook=False, | |
| restart_unused_codes=True, | |
| commitment_loss='cumsum' | |
| ): | |
| super().__init__() | |
| if not len(code_shape) == len(latent_shape) == 3: | |
| raise ValueError("incompatible code shape or latent shape") | |
| if any([y % x != 0 for x, y in zip(code_shape[:2], latent_shape[:2])]): | |
| raise ValueError("incompatible code shape or latent shape") | |
| #residual quantization does not divide feature dims for quantization. | |
| embed_dim = np.prod(latent_shape[:2]) // np.prod(code_shape[:2]) * latent_shape[2] | |
| self.latent_shape = torch.Size(latent_shape) | |
| self.code_shape = torch.Size(code_shape) | |
| self.shape_divisor = torch.Size([latent_shape[i] // code_shape[i] for i in range(len(latent_shape))]) | |
| self.shared_codebook = shared_codebook | |
| if self.shared_codebook: | |
| if isinstance(n_embed, Iterable) or isinstance(decay, Iterable): | |
| raise ValueError("Shared codebooks are incompatible \ | |
| with list types of momentums or sizes: Change it into int") | |
| self.restart_unused_codes = restart_unused_codes | |
| self.n_embed = n_embed if isinstance(n_embed, Iterable) else [n_embed for _ in range(self.code_shape[-1])] | |
| self.decay = decay if isinstance(decay, Iterable) else [decay for _ in range(self.code_shape[-1])] | |
| assert len(self.n_embed) == self.code_shape[-1] | |
| assert len(self.decay) == self.code_shape[-1] | |
| if self.shared_codebook: | |
| codebook0 = VQEmbedding(self.n_embed[0], | |
| embed_dim, | |
| decay=self.decay[0], | |
| restart_unused_codes=restart_unused_codes, | |
| ).to(torch.float32) | |
| self.codebooks = nn.ModuleList([codebook0 for _ in range(self.code_shape[-1])]) | |
| else: | |
| codebooks = [VQEmbedding(self.n_embed[idx], | |
| embed_dim, | |
| decay=self.decay[idx], | |
| restart_unused_codes=restart_unused_codes, | |
| ).to(torch.float32) for idx in range(self.code_shape[-1])] | |
| self.codebooks = nn.ModuleList(codebooks) | |
| self.commitment_loss = commitment_loss | |
| def to_code_shape(self, x): | |
| (B, H, W, D) = x.shape | |
| (rH, rW, _) = self.shape_divisor | |
| x = x.reshape(B, H//rH, rH, W//rW, rW, D) | |
| x = x.permute(0, 1, 3, 2, 4, 5) | |
| x = x.reshape(B, H//rH, W//rW, -1) | |
| return x | |
| def to_latent_shape(self, x): | |
| (B, h, w, _) = x.shape | |
| (_, _, D) = self.latent_shape | |
| (rH, rW, _) = self.shape_divisor | |
| x = x.reshape(B, h, w, rH, rW, D) | |
| x = x.permute(0, 1, 3, 2, 4, 5) | |
| x = x.reshape(B, h*rH, w*rW, D) | |
| return x | |
| def quantize(self, x): | |
| r""" | |
| Return list of quantized features and the selected codewords by the residual quantization. | |
| The code is selected by the residuals between x and quantized features by the previous codebooks. | |
| Arguments: | |
| x (Tensor): bottleneck feature maps to quantize. | |
| Returns: | |
| quant_list (list): list of sequentially aggregated and quantized feature maps by codebooks. | |
| codes (LongTensor): codewords index, corresponding to quants. | |
| Shape: | |
| - x: (B, h, w, embed_dim) | |
| - quant_list[i]: (B, h, w, embed_dim) | |
| - codes: (B, h, w, d) | |
| """ | |
| B, h, w, embed_dim = x.shape | |
| ori_dtype = x.dtype | |
| x = x.to(torch.float32) | |
| self.codebooks = self.codebooks.to(torch.float32) | |
| residual_feature = x.detach().clone() | |
| quant_list = [] | |
| code_list = [] | |
| aggregated_quants = torch.zeros_like(x) | |
| for i in range(self.code_shape[-1]): | |
| quant, code = self.codebooks[i](residual_feature) | |
| residual_feature.sub_(quant) | |
| aggregated_quants.add_(quant) | |
| quant_list.append(aggregated_quants.clone().to(dtype=ori_dtype)) | |
| code_list.append(code.unsqueeze(-1)) | |
| codes = torch.cat(code_list, dim=-1) | |
| return quant_list, codes | |
| def forward(self, x): | |
| x_reshaped = self.to_code_shape(x) | |
| # 强制使用float32精度来执行 | |
| quant_list, codes = self.quantize(x_reshaped) | |
| # quant_list, codes = self.quantize(x_reshaped) | |
| commitment_loss = self.compute_commitment_loss(x_reshaped, quant_list) | |
| quants_trunc = self.to_latent_shape(quant_list[-1]) | |
| quants_trunc = x + (quants_trunc - x).detach() | |
| ''' | |
| if self.shared_codebook: | |
| cur_len = codes.view(-1).shape[0] | |
| self.codebook_used[:-cur_len] = self.codebook_used[cur_len:].clone() | |
| self.codebook_used[-cur_len:] = codes.view(-1) | |
| codebook_usage = len(torch.unique(self.codebook_used)) / self.n_embed[0] | |
| else: | |
| # info|code: torch.Size([10, 16, 16, 4]) | |
| codebook_usage = 0 | |
| for idx in range(self.code_shape[-1]): | |
| cur_len = codes[..., idx].view(-1).shape[0] | |
| self.codebook_used[idx, :-cur_len] = self.codebook_used[idx, cur_len:].clone() | |
| self.codebook_used[idx, -cur_len:] = codes[..., idx].view(-1) | |
| codebook_usage += len(torch.unique(self.codebook_used[idx])) | |
| codebook_usage /= (self.n_embed[0] * self.code_shape[-1]) | |
| ''' | |
| codebook_usage = 0 | |
| # (vq_loss, commit_loss, entropy_loss, codebook_usage) # 格式对齐 | |
| codebook_loss = [0, commitment_loss, 0, codebook_usage] | |
| return quants_trunc, codebook_loss, codes | |
| def compute_commitment_loss(self, x, quant_list): | |
| r""" | |
| Compute the commitment loss for the residual quantization. | |
| The loss is iteratively computed by aggregating quantized features. | |
| """ | |
| loss_list = [] | |
| for idx, quant in enumerate(quant_list): | |
| partial_loss = (x-quant.detach()).pow(2.0).mean() | |
| loss_list.append(partial_loss) | |
| commitment_loss = torch.mean(torch.stack(loss_list)) | |
| return commitment_loss | |
| class Qwen2_5_VisionRotaryEmbedding_Modified(nn.Module): | |
| def __init__(self, dim: int, theta: float = 10000.0) -> None: | |
| super().__init__() | |
| self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) | |
| # self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def forward(self, seqlen: int, device: torch.device) -> torch.Tensor: | |
| self.inv_freq = self.inv_freq.to(device) | |
| seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
| freqs = torch.outer(seq, self.inv_freq) | |
| return freqs | |
| class VisualEncoder(Qwen2_5_VisionTransformerPretrainedModel): | |
| def __init__(self, config): | |
| config._attn_implementation = 'flash_attention_2' | |
| super().__init__(config) | |
| self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding_Modified(config.hidden_size // config.num_heads // 2) | |
| self.gradient_checkpointing = False | |
| self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint | |
| self.merge_size = config.merge_size if hasattr(config, 'merge_size') else 2 | |
| del self.merger # register visual.merger in visual_bridge_model | |
| def get_dtype(self) -> torch.dtype: | |
| return self.blocks[0].mlp.down_proj.weight.dtype | |
| def get_device(self) -> torch.device: | |
| return self.blocks[0].mlp.down_proj.weight.device | |
| def rot_pos_emb(self, grid_thw): | |
| pos_ids = [] | |
| for t, h, w in grid_thw: | |
| hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) | |
| hpos_ids = hpos_ids.reshape( | |
| h // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| w // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| ) | |
| hpos_ids = hpos_ids.permute(0, 2, 1, 3) | |
| hpos_ids = hpos_ids.flatten() | |
| wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) | |
| wpos_ids = wpos_ids.reshape( | |
| h // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| w // self.spatial_merge_size, | |
| self.spatial_merge_size, | |
| ) | |
| wpos_ids = wpos_ids.permute(0, 2, 1, 3) | |
| wpos_ids = wpos_ids.flatten() | |
| pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) | |
| pos_ids = torch.cat(pos_ids, dim=0) | |
| max_grid_size = grid_thw[:, 1:].max() | |
| rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device) | |
| rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1) | |
| return rotary_pos_emb | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| grid_thw: torch.Tensor, | |
| require_window_index: bool = False, | |
| ): | |
| ''' | |
| pixel_values.shape=[NumOfPatches, 1176] | |
| grid_thw.shape=[NumOfSamples, 3]. [grid_t,grid_h,grid_w] | |
| ''' | |
| hidden_states = pixel_values.to(torch.bfloat16) | |
| grid_thw = grid_thw.to(pixel_values.device) | |
| hidden_states = self.patch_embed(hidden_states) | |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) | |
| window_index, cu_window_seqlens = self.get_window_index(grid_thw) | |
| cu_window_seqlens = torch.tensor( | |
| cu_window_seqlens, | |
| device=hidden_states.device, | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens) | |
| seq_len, _ = hidden_states.size() | |
| hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) | |
| hidden_states = hidden_states[window_index, :, :] | |
| hidden_states = hidden_states.reshape(seq_len, -1) | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1) | |
| rotary_pos_emb = rotary_pos_emb[window_index, :, :] | |
| rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) | |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) | |
| position_embeddings = (emb.cos(), emb.sin()) | |
| cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( | |
| dim=0, | |
| # Select dtype based on the following factors: | |
| # - FA2 requires that cu_seqlens_q must have dtype int32 | |
| # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw | |
| # See https://github.com/huggingface/transformers/pull/34852 for more information | |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, | |
| ) | |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) | |
| for layer_num, blk in enumerate(self.blocks): | |
| if layer_num in self.fullatt_block_indexes: | |
| cu_seqlens_now = cu_seqlens | |
| else: | |
| cu_seqlens_now = cu_window_seqlens | |
| if self.gradient_checkpointing and self.training: | |
| hidden_states = self._gradient_checkpointing_func(blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings) | |
| else: | |
| hidden_states = blk( | |
| hidden_states, | |
| cu_seqlens=cu_seqlens_now, | |
| position_embeddings=position_embeddings, | |
| ) | |
| if require_window_index: | |
| return hidden_states, window_index | |
| return hidden_states | |
| class OmniVisualBridge(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.merge_size = self.config.merge_size if hasattr(self.config, 'merge_size') else 2 | |
| self.hidden_size = self.config.hidden_size * (self.merge_size**2) | |
| self.window_index = self.config.window_size | |
| self.ln_q = Qwen2RMSNorm(self.config.hidden_size, eps=1e-6) | |
| self.mlp = nn.Sequential( | |
| nn.Linear(self.hidden_size, self.hidden_size), | |
| nn.GELU(), | |
| nn.Linear(self.hidden_size, self.config.out_hidden_size), | |
| ) | |
| def forward(self, x: torch.Tensor, window_index) -> torch.Tensor: | |
| x = self.mlp(self.ln_q(x).view(-1, self.hidden_size)) | |
| reverse_indices = torch.argsort(window_index) | |
| x = x[reverse_indices, :] | |
| return x | |
| class VisualQuantizer(nn.Module): | |
| def __init__(self, quantizer_config): | |
| super().__init__() | |
| self.config = quantizer_config | |
| self.depth = self.config.depth | |
| self.decay = self.config.decay | |
| self.codebook_size = self.config.codebook_size | |
| self.codebook_dim = self.config.codebook_dim | |
| self.shared_codebook = self.config.shared_codebook | |
| self.restart_unused_codes = self.config.restart_unused_codes | |
| self.in_channels = self.config.in_channels | |
| self.vq_loss_ratio = self.config.vq_loss_ratio | |
| self.entropy_loss_ratio = self.config.entropy_loss_ratio | |
| self.commit_loss_ratio = self.config.commit_loss_ratio | |
| code_h_w = int(448 / 14) | |
| latent_shape = [code_h_w, code_h_w, self.codebook_dim] | |
| code_shape = [code_h_w, code_h_w, self.depth] | |
| self.quantize = RQBottleneck( | |
| latent_shape=latent_shape, | |
| code_shape=code_shape, | |
| n_embed=self.codebook_size, | |
| decay=self.decay, | |
| shared_codebook=self.shared_codebook, | |
| restart_unused_codes=self.restart_unused_codes, | |
| ) | |
| if self.config.quant_conv: | |
| self.quant_conv = nn.Sequential( | |
| nn.LayerNorm(self.in_channels), | |
| nn.Linear(self.in_channels, self.in_channels), | |
| nn.GELU(), | |
| nn.Linear(self.in_channels, self.codebook_dim) | |
| ) | |
| else: | |
| self.quant_conv = None | |
| def encode(self, x): | |
| L, D = x.shape | |
| to_qnt_feat = x.clone() | |
| to_qnt_feat = to_qnt_feat.unsqueeze(0) # [L, D] -> [1, L, D] | |
| N = 1 | |
| if self.quant_conv is not None: | |
| to_qnt_feat = self.quant_conv(to_qnt_feat) | |
| # quantizer needs nchw format. N,L,d -> N,1,L,d -> N,d,1,L | |
| to_qnt_feat = to_qnt_feat.reshape(N, 1, L, self.codebook_dim).permute(0,3,1,2) | |
| if self.config.quantizer_type == "rq": | |
| to_qnt_feat = to_qnt_feat.permute(0, 2, 3, 1).contiguous() # N,d,1,L -> N,1,L,d | |
| quant, emb_loss, info = self.quantize(to_qnt_feat) | |
| info = info.reshape(-1, info.shape[-1]) # n,h,w,lv -> n*h*w,lv | |
| info = [None, None, info] | |
| quant = quant.permute(0, 3, 1, 2).contiguous() # N,1,L,d -> N,d,1,L | |
| else: | |
| quant, emb_loss, info = self.quantize(to_qnt_feat) | |
| return quant, emb_loss, info, x.detach() | |
| def forward(self, x): | |
| quant, (vq_loss, commit_loss, entropy_loss, codebook_usage), (perplexity, min_encodings, min_encoding_indices), align_feature = \ | |
| self.encode(x) | |
| return min_encoding_indices | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| hidden_act: str, | |
| ): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) | |
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) | |
| self.act_fn = ACT2FN[hidden_act] | |
| def forward(self, x): | |
| return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
| class DecoderLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| self.mlp = MLP( | |
| hidden_size=self.hidden_size, | |
| intermediate_size=config.visual_embedding_layer_intermediate_size, | |
| hidden_act=config.visual_embedding_layer_hidden_act, | |
| ) | |
| self.pre_layernorm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| ): | |
| residual = hidden_states | |
| hidden_states = self.pre_layernorm(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| return hidden_states | |
| class VisualEmbeddingBridge(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.pre_buffer = DecoderLayer(config) | |
| def forward(self, embeding): | |
| return self.pre_buffer(embeding) | |
| class VisualVQBridge(nn.Module): | |
| def __init__(self, visual_config): | |
| super().__init__() | |
| self.bridge = OmniVisualBridge(visual_config) | |
| self.quantizer = VisualQuantizer(visual_config.vq_config) | |
| def forward( | |
| self, | |
| visual_embed: torch.Tensor, | |
| window_index: torch.Tensor, | |
| ): | |
| visual_embed = self.bridge(visual_embed, window_index) | |
| indices = self.quantizer(visual_embed) | |
| return indices | |
| class LongcatNextVisualTokenizer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.visual_model = VisualEncoder(config.visual_config) | |
| self.visual_bridge_model = VisualVQBridge(config.visual_config) | |
| self.visual_embedding_layer = VisualEmbeddingBridge(config) | |
| self.image_decoder = None | |
| self._refiner_pipeline = None | |
| def encode(self, pixel_values: torch.Tensor, visual_grid_thw: torch.Tensor): | |
| visual_embed, window_index = self.visual_model(pixel_values, grid_thw=visual_grid_thw, require_window_index=True) | |
| indices = self.visual_bridge_model(visual_embed, window_index) | |
| return indices | |
| def lazy_decode_and_save(self, visual_ids, tokens_h, tokens_w, save_path): | |
| device = next(self.parameters()).device | |
| if self.image_decoder is None: | |
| print("lazy load image_decoder / image_refiner / _refiner_pipeline ...") | |
| vdc = self.config.visual_config.visual_decoder_config | |
| self.image_decoder = VisionTransformerDecoder.from_pretrained( | |
| vdc.image_decoder_config, | |
| vdc.weight_path, | |
| ).to(device=device, dtype=torch.bfloat16) | |
| image_refiner = ImageRefinerContainer.from_pretrained(vdc, vdc.weight_path).to(device=device, dtype=torch.bfloat16) | |
| sc = vdc.scheduler_config | |
| scheduler = FlowMatchEulerDiscreteScheduler( | |
| num_train_timesteps=sc.num_train_timesteps, | |
| dynamic_time_shift=sc.dynamic_time_shift) | |
| self._refiner_pipeline = RefinerPipeline( | |
| vae=image_refiner.vae, | |
| transformer=image_refiner.base_transformer, | |
| scheduler=scheduler, | |
| cond_proj=image_refiner.cond_proj, | |
| ) | |
| self._refiner_pipeline.set_progress_bar_config(disable=False) | |
| data = torch.as_tensor(visual_ids, dtype=torch.long) | |
| if data.ndim == 1: | |
| data = data.view(-1, len(self.config.visual_config.vq_config.codebook_sizes)) | |
| if data.ndim == 2: | |
| data = data.unsqueeze(0) | |
| batch_size = data.shape[0] | |
| quant_features = None | |
| for idx in range(len(self.config.visual_config.vq_config.codebook_sizes)): | |
| embed = self.visual_bridge_model.quantizer.quantize.codebooks[idx].embed | |
| feat = embed[data[..., idx].to(embed.device)] | |
| quant_features = feat if quant_features is None else quant_features + feat | |
| quant_features = quant_features.to(device) | |
| # tokens_h/tokens_w are the merged grid; expand to the full (unmerged) grid | |
| s = self.image_decoder.spatial_merge_size | |
| grid_thw_list = [(1, tokens_h * s, tokens_w * s)] | |
| grid_thw_batch = list(grid_thw_list) * batch_size | |
| image_mean = [0.48145466, 0.4578275, 0.40821073] | |
| image_std = [0.26862954, 0.26130258, 0.27577711] | |
| emb_2d = quant_features.reshape(-1, quant_features.shape[-1]).contiguous() | |
| device_type = "cuda" if str(device).startswith("cuda") else str(device) | |
| with torch.amp.autocast(device_type=device_type, enabled=True, dtype=torch.float32): | |
| decoder_out = self.image_decoder(emb_2d, grid_thw_batch, return_pixel_features=False) | |
| decoded_tensors = decoder_out.get("images") or [] | |
| decoded_images = [tensor2pil(t, image_mean, image_std) for t in decoded_tensors] | |
| decoded_path = save_path.replace(".png", "_decoded.png") | |
| # decoded_images[0].save(decoded_path) | |
| ref_input = [] | |
| for t in decoded_tensors: | |
| img_01 = de_transform(t, mean=image_mean, std=image_std, rescale_factor=1 / 255) | |
| img_norm = RefinerImageProcessor.normalize(img_01) | |
| ref_input.append(img_norm.squeeze(0).to(device)) | |
| generators = [torch.Generator(device=device).manual_seed(42 + b) for b in range(batch_size)] | |
| out = self._refiner_pipeline( | |
| encoder_hidden_states=quant_features, | |
| grid_thw_list=grid_thw_list, | |
| image=ref_input, | |
| generator=generators[0] if batch_size == 1 else generators, | |
| output_type="pil", | |
| return_dict=True, | |
| ) | |
| refined_images = out.images | |
| refined_path = save_path.replace(".png", "_refined.png") | |
| refined_images[0].save(refined_path) | |
| return [refined_path] | |
| # --------------------------------------------------------------------------- | |
| # Vision Transformer Decoder | |
| # --------------------------------------------------------------------------- | |
| def _rotate_half(x): | |
| x = rearrange(x, "... (d r) -> ... d r", r=2) | |
| x1, x2 = x.unbind(dim=-1) | |
| x = torch.stack((-x2, x1), dim=-1) | |
| return rearrange(x, "... d r -> ... (d r)") | |
| class VisionRoPE2D(nn.Module): | |
| """2D Rotary Position Embedding for Q/K in vision decoder attention.""" | |
| def __init__(self, theta: float = 10000.0): | |
| super().__init__() | |
| self.theta = theta | |
| def _rope_half(self, x_half, pos_1d, theta): | |
| BH, T, d_half = x_half.shape | |
| idx = torch.arange(0, d_half, 2, device=x_half.device, dtype=torch.float32) | |
| inv_freq = (1.0 / (theta ** (idx / d_half))).to(x_half.dtype) | |
| angles = pos_1d.to(x_half.dtype)[:, None] * inv_freq[None, :] | |
| cos = torch.repeat_interleave(torch.cos(angles), 2, dim=-1).unsqueeze(0) | |
| sin = torch.repeat_interleave(torch.sin(angles), 2, dim=-1).unsqueeze(0) | |
| return x_half * cos + _rotate_half(x_half) * sin | |
| def forward(self, x, positions_2d): | |
| d_half = x.shape[-1] // 2 | |
| x_y = self._rope_half(x[:, :, :d_half], positions_2d[:, 0], self.theta) | |
| x_x = self._rope_half(x[:, :, d_half:], positions_2d[:, 1], self.theta) | |
| return torch.cat([x_y, x_x], dim=-1) | |
| class VisionAttention(nn.Module): | |
| """Multi-headed attention with 2D RoPE + FlashAttention varlen.""" | |
| def __init__(self, config, rope=None, rope_shift=0): | |
| super().__init__() | |
| self.config = config | |
| self.embed_dim = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.embed_dim // self.num_heads | |
| if self.head_dim * self.num_heads != self.embed_dim: | |
| raise ValueError( | |
| f"embed_dim must be divisible by num_heads (got embed_dim={self.embed_dim}, num_heads={self.num_heads})" | |
| ) | |
| self.scale = self.head_dim ** -0.5 | |
| self.dropout = config.attention_dropout | |
| self.subln = config.subln | |
| self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "k_bias", True)) | |
| self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "v_bias", True)) | |
| self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=getattr(config, "q_bias", True)) | |
| self.inner_attn_ln = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) if config.subln else nn.Identity() | |
| self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | |
| self.rope = rope | |
| self.rope_shift = int(rope_shift) | |
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | |
| def _maybe_flash_attention(self, query_states, key_states, value_states, seq_lens, training): | |
| if not (query_states.is_cuda and (query_states.dtype in (torch.float16, torch.bfloat16, torch.float32))): | |
| return None | |
| if seq_lens is None: | |
| return None | |
| try: | |
| BxH, T, hd = query_states.shape | |
| H = self.num_heads | |
| assert BxH % H == 0 | |
| B = BxH // H | |
| if int(seq_lens.sum().item()) != T: | |
| return None | |
| q = query_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous() | |
| k = key_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous() | |
| v = value_states.view(B, H, T, hd).transpose(1, 2).reshape(-1, H, hd).contiguous() | |
| cu_q = torch.zeros(seq_lens.numel() + 1, dtype=torch.int32, device=seq_lens.device) | |
| cu_q[1:] = torch.cumsum(seq_lens.to(torch.int32), dim=0) | |
| cu_k = cu_q | |
| max_seqlen = int(seq_lens.max().item()) | |
| orig_dtype = q.dtype | |
| use_dtype = q.dtype if q.dtype in (torch.float16, torch.bfloat16) else torch.float16 | |
| if q.dtype != use_dtype: | |
| q = q.to(use_dtype) | |
| k = k.to(use_dtype) | |
| v = v.to(use_dtype) | |
| out = flash_attn_varlen_func( | |
| q, k, v, cu_q, cu_k, max_seqlen, max_seqlen, | |
| dropout_p=self.dropout if training else 0.0, | |
| softmax_scale=None, causal=False, return_attn_probs=False | |
| ) | |
| if out.dtype != orig_dtype: | |
| out = out.to(orig_dtype) | |
| return out.view(B, -1, H, hd).transpose(1, 2).contiguous().view(B * H, T, hd) | |
| except Exception: | |
| return None | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = False, | |
| positions_2d: Optional[torch.Tensor] = None, | |
| seq_lens: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: | |
| bsz, tgt_len, embed_dim = hidden_states.size() | |
| query_states = self.q_proj(hidden_states) * self.scale | |
| key_states = self.k_proj(hidden_states) | |
| value_states = self.v_proj(hidden_states) | |
| query_states = self._shape(query_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim) | |
| key_states = self._shape(key_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim) | |
| value_states = self._shape(value_states, tgt_len, bsz).view(bsz * self.num_heads, tgt_len, self.head_dim) | |
| if self.rope is not None and positions_2d is not None: | |
| if self.rope_shift > 0: | |
| q_pref = query_states[:, :self.rope_shift, :] | |
| k_pref = key_states[:, :self.rope_shift, :] | |
| q_rot = self.rope(query_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:]) | |
| k_rot = self.rope(key_states[:, self.rope_shift:, :], positions_2d[self.rope_shift:]) | |
| query_states = torch.cat([q_pref, q_rot], dim=1).type_as(value_states) | |
| key_states = torch.cat([k_pref, k_rot], dim=1).type_as(value_states) | |
| else: | |
| query_states = self.rope(query_states, positions_2d).type_as(value_states) | |
| key_states = self.rope(key_states, positions_2d).type_as(value_states) | |
| attn_output = self._maybe_flash_attention( | |
| query_states, key_states, value_states, seq_lens=seq_lens, training=self.training | |
| ) | |
| if attn_output is not None: | |
| attn_weights_reshaped = None | |
| else: | |
| src_len = key_states.size(1) | |
| attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) | |
| if causal_attention_mask is not None: | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| if attention_mask is not None: | |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask | |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| if output_attentions: | |
| attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) | |
| else: | |
| attn_weights_reshaped = None | |
| attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) | |
| attn_output = torch.bmm(attn_probs, value_states) | |
| attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) | |
| attn_output = attn_output.transpose(1, 2).reshape(bsz, tgt_len, embed_dim) | |
| attn_output = self.inner_attn_ln(attn_output) | |
| attn_output = self.out_proj(attn_output) | |
| return attn_output, attn_weights_reshaped | |
| class VisionSwiGLU(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.w1 = nn.Linear(self.hidden_size, self.intermediate_size) | |
| self.w2 = nn.Linear(self.hidden_size, self.intermediate_size) | |
| self.w3 = nn.Linear(self.intermediate_size, self.hidden_size) | |
| self.act_fn = nn.SiLU() | |
| self.ffn_ln = Qwen2RMSNorm(self.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity() | |
| def forward(self, x): | |
| x1 = self.w1(x) | |
| x2 = self.w2(x) | |
| hidden = self.act_fn(x1) * x2 | |
| x = self.ffn_ln(hidden) | |
| x = self.w3(x) | |
| return x | |
| class VisionMLP(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.activation_fn = ACT2FN[config.hidden_act] | |
| self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) | |
| self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) | |
| self.ffn_ln = Qwen2RMSNorm(config.intermediate_size, eps=config.layer_norm_eps) if config.subln else nn.Identity() | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states = self.ffn_ln(hidden_states) | |
| hidden_states = self.fc2(hidden_states) | |
| return hidden_states | |
| class VisionEncoderLayer(nn.Module): | |
| def __init__(self, config, rope=None, rope_shift=0): | |
| super().__init__() | |
| self.embed_dim = config.hidden_size | |
| self.self_attn = VisionAttention(config, rope=rope, rope_shift=rope_shift) | |
| self.layer_norm1 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.mlp = VisionSwiGLU(config) if config.swiglu else VisionMLP(config) | |
| self.layer_norm2 = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| causal_attention_mask: Optional[torch.Tensor], | |
| output_attentions: Optional[bool] = False, | |
| positions_2d: Optional[torch.Tensor] = None, | |
| seq_lens: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.FloatTensor, Optional[torch.Tensor]]: | |
| residual = hidden_states | |
| hidden_states = self.layer_norm1(hidden_states) | |
| hidden_states, attn_weights = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| causal_attention_mask=causal_attention_mask, | |
| output_attentions=output_attentions, | |
| positions_2d=positions_2d, | |
| seq_lens=seq_lens, | |
| ) | |
| hidden_states = residual + hidden_states | |
| residual = hidden_states | |
| hidden_states = self.layer_norm2(hidden_states) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (attn_weights,) | |
| return outputs | |
| class VisionEncoder(nn.Module): | |
| def __init__(self, config, rope=None, rope_shift=0): | |
| super().__init__() | |
| self.config = config | |
| self.layers = nn.ModuleList( | |
| [VisionEncoderLayer(config, rope=rope, rope_shift=rope_shift) for _ in range(config.num_hidden_layers)] | |
| ) | |
| self.gradient_checkpointing = False | |
| self._gradient_checkpointing_func = torch.utils.checkpoint.checkpoint | |
| def forward( | |
| self, | |
| inputs_embeds: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| causal_attention_mask: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| positions_2d: Optional[torch.Tensor] = None, | |
| seq_lens: Optional[torch.Tensor] = None, | |
| ): | |
| output_attentions = output_attentions if output_attentions is not None else False | |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else False | |
| return_dict = True if return_dict is None else return_dict | |
| encoder_states = () if output_hidden_states else None | |
| all_attentions = () if output_attentions else None | |
| hidden_states = inputs_embeds | |
| for layer in self.layers: | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| def custom_forward(hs, attn, causal, pos2d, seqlens): | |
| return layer( | |
| hs, | |
| attention_mask=attn, | |
| causal_attention_mask=causal, | |
| output_attentions=False, | |
| positions_2d=pos2d, | |
| seq_lens=seqlens, | |
| )[0] | |
| hidden_states = self._gradient_checkpointing_func( | |
| custom_forward, | |
| hidden_states, | |
| attention_mask if attention_mask is not None else torch.tensor(0., device=hidden_states.device), | |
| causal_attention_mask if causal_attention_mask is not None else torch.tensor(0., device=hidden_states.device), | |
| positions_2d, | |
| seq_lens if seq_lens is not None else torch.tensor([], device=hidden_states.device), | |
| use_reentrant=False, | |
| ) | |
| else: | |
| layer_outputs = layer( | |
| hidden_states, | |
| attention_mask, | |
| causal_attention_mask, | |
| output_attentions=output_attentions, | |
| positions_2d=positions_2d, | |
| seq_lens=seq_lens, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if output_attentions: | |
| all_attentions = all_attentions + (layer_outputs[1],) | |
| if output_hidden_states: | |
| encoder_states = encoder_states + (hidden_states,) | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) | |
| return BaseModelOutput( | |
| last_hidden_state=hidden_states, | |
| hidden_states=encoder_states, | |
| attentions=all_attentions, | |
| ) | |
| class PatchUnMerger(nn.Module): | |
| """Learnable inverse of Qwen2_5_VLPatchMerger.""" | |
| def __init__(self, dim, context_dim, spatial_merge_size=2): | |
| super().__init__() | |
| self.spatial_merge_size = spatial_merge_size | |
| self.context_dim = context_dim | |
| hidden = context_dim * (spatial_merge_size ** 2) | |
| self.ln_q = Qwen2RMSNorm(dim, eps=1e-6) | |
| self.mlp = nn.Sequential(nn.Linear(dim, hidden), nn.GELU(), nn.Linear(hidden, hidden)) | |
| def forward(self, x): | |
| x = self.mlp(self.ln_q(x)) | |
| return x.view(x.shape[0] * (self.spatial_merge_size ** 2), self.context_dim) | |
| def restore_spatial_structure_and_convert_to_images(patches, grid_thw_list, patch_size, | |
| channel_dim=3, temporal_patch_size=2, merge_size=2): | |
| """Convert decoder pixel features back to image tensors [3, H, W].""" | |
| if isinstance(patches, tuple): | |
| patches = patches[0] | |
| image_tensors = [] | |
| ptr = 0 | |
| for grid in grid_thw_list: | |
| gt, gh, gw = (int(x) for x in (grid if not isinstance(grid, torch.Tensor) else grid.tolist())) | |
| n = gt * gh * gw | |
| chunk = patches[ptr:ptr + n] | |
| ptr += n | |
| r = chunk.reshape(gt, gh // merge_size, gw // merge_size, merge_size, merge_size, | |
| channel_dim, temporal_patch_size, patch_size, patch_size) | |
| r = r.permute(0, 6, 5, 1, 3, 7, 2, 4, 8) | |
| image_tensors.append(r.reshape(gt * temporal_patch_size, channel_dim, gh * patch_size, gw * patch_size)[0]) | |
| return image_tensors | |
| class VisionTransformerDecoder(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.embed_dim = config.hidden_size | |
| self.patch_size = config.patch_size | |
| self.spatial_merge_size = config.spatial_merge_size | |
| self.codebook_dim = config.codebook_dim | |
| self.temporal_patch_size = config.temporal_patch_size | |
| self.rope2d = VisionRoPE2D(theta=10000.0) | |
| self.post_quant_conv = nn.Linear(self.codebook_dim, self.embed_dim) | |
| self.post_quant_norm = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.patch_unmerger = PatchUnMerger(self.embed_dim, self.embed_dim, self.spatial_merge_size) | |
| self.norm_in = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.encoder = VisionEncoder(config, rope=self.rope2d, rope_shift=0) | |
| self.norm_out = Qwen2RMSNorm(self.embed_dim, eps=config.layer_norm_eps) | |
| self.decoder_head = nn.Sequential( | |
| nn.Linear(self.embed_dim, config.intermediate_size), nn.GELU(), | |
| nn.Linear(config.intermediate_size, 3 * self.patch_size * self.patch_size * self.temporal_patch_size), | |
| ) | |
| def from_pretrained(cls, config, model_path: str): | |
| """Load a pretrained model from a checkpoint.""" | |
| model = cls(config) | |
| weight_dict = load_file(model_path, device="cpu") | |
| model.load_state_dict({k.removeprefix("image_decoder."): v for k, v in weight_dict.items() if k.startswith("image_decoder.")}, strict=True) | |
| model.eval() | |
| return model | |
| def _build_2d_positions(self, grid_thw_list): | |
| pos_list = [] | |
| for (t, gh, gw) in grid_thw_list: | |
| for _ in range(int(t)): | |
| for y in range(int(gh)): | |
| for x in range(int(gw)): | |
| pos_list.append([y, x]) | |
| return torch.tensor(pos_list, dtype=torch.long) | |
| def _build_attention_mask(self, grid_thw_list, device, dtype, B, num_heads): | |
| counts = [int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list] | |
| L = sum(counts) | |
| mask = torch.zeros((B, num_heads, L, L), device=device, dtype=dtype) | |
| s = 0 | |
| for c in counts: | |
| e = s + c | |
| if s > 0: | |
| mask[:, :, s:e, :s] = float("-inf") | |
| if e < L: | |
| mask[:, :, s:e, e:] = float("-inf") | |
| s = e | |
| return mask | |
| def forward(self, embeddings, grid_thw, return_pixel_features=False, return_last_latent=False): | |
| device = embeddings.device | |
| grid_thw_list = ([(int(t), int(h), int(w)) for t, h, w in grid_thw.detach().cpu().numpy()] | |
| if isinstance(grid_thw, torch.Tensor) else list(grid_thw)) | |
| if embeddings.shape[-1] == self.codebook_dim: | |
| embeddings = self.post_quant_conv(embeddings) | |
| embeddings = self.post_quant_norm(embeddings) | |
| unmerged = self.patch_unmerger(embeddings) | |
| if unmerged.dim() == 2: | |
| unmerged = unmerged.unsqueeze(0) | |
| B, L, D = unmerged.shape | |
| hidden_states = self.norm_in(unmerged) | |
| positions_2d = self._build_2d_positions(grid_thw_list).to(device) | |
| seq_lens = torch.tensor([int(t) * int(h) * int(w) for (t, h, w) in grid_thw_list], | |
| device=device, dtype=torch.int32) | |
| assert positions_2d.shape[0] == L, f"positions_2d {positions_2d.shape[0]} != L {L}" | |
| last_latent = hidden_states.detach().squeeze(0) if return_last_latent else None | |
| enc_out = self.encoder( | |
| inputs_embeds=hidden_states, | |
| attention_mask=None, | |
| causal_attention_mask=None, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=True, | |
| positions_2d=positions_2d, | |
| seq_lens=seq_lens, | |
| ) | |
| hidden_states = enc_out.last_hidden_state | |
| hidden_states = self.norm_out(hidden_states) | |
| pixel_features = self.decoder_head(hidden_states).squeeze(0) | |
| out_imgs = (None if return_pixel_features else | |
| restore_spatial_structure_and_convert_to_images( | |
| pixel_features, grid_thw_list, self.patch_size, | |
| temporal_patch_size=self.temporal_patch_size, merge_size=self.spatial_merge_size)) | |
| ret = {"images": out_imgs, "pixel_features": pixel_features} | |
| if last_latent is not None: | |
| ret["last_latent"] = last_latent | |
| return ret | |