Instructions to use nvidia/Cosmos-Embed1-336p with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Cosmos
How to use nvidia/Cosmos-Embed1-336p with Cosmos:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- NeMo
How to use nvidia/Cosmos-Embed1-336p with NeMo:
# tag did not correspond to a valid NeMo domain.
- Notebooks
- Google Colab
- Kaggle
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Misc functions and modules for Cosmos-Embed1.""" | |
| import functools | |
| from logging import getLogger | |
| from typing import Callable, Optional, Protocol | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn as nn | |
| logger = getLogger(__file__) | |
| def get_rank(group: Optional[dist.ProcessGroup] = None) -> int: | |
| """Get the rank (GPU device) of the worker. | |
| Returns: | |
| rank (int): The rank of the worker. | |
| """ | |
| rank = 0 | |
| if dist.is_available() and dist.is_initialized(): | |
| rank = dist.get_rank(group) | |
| return rank | |
| def barrier() -> None: | |
| """Barrier for all GPUs.""" | |
| if dist.is_available() and dist.is_initialized(): | |
| dist.barrier() | |
| def rank0_first(func: Callable) -> Callable: | |
| """Run the function on rank 0 first, then on other ranks.""" | |
| def wrapper(*args, **kwargs): # noqa: ANN202 | |
| if get_rank() == 0: | |
| result = func(*args, **kwargs) | |
| barrier() | |
| if get_rank() != 0: | |
| result = func(*args, **kwargs) | |
| return result | |
| return wrapper | |
| def add_docstring(docstring: str): | |
| def decorator(func): | |
| func.__doc__ = docstring | |
| return func | |
| return decorator | |
| INIT_DOCSTRING = """ | |
| Constructor for encoding module. | |
| Args: | |
| embed_dim: size of embedding vectors, e.g. x.shape[3]. | |
| max_len: maximum length of temporal sequence, e.g. x.shape[1]. | |
| """ | |
| FORWARD_DOCSTRING = """ | |
| Forward function. | |
| Args: | |
| x (`torch.Tensor`): rank 4 tensor to add spatio-temporal encodings to. | |
| Returns: | |
| `torch.Tensor` of rank 4. | |
| """ | |
| class EncodingProtocol(Protocol): | |
| def __init__(self, embed_dim: int, max_len: int) -> None: | |
| pass | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| pass | |
| def interpolate_temp_pos_embed(temp_embed: torch.Tensor, num_frames: int) -> torch.Tensor: | |
| """Linearly interpolates temporal encoding from `temp_embed.shape[0] to num_frames.""" | |
| temp_embed_resized = temp_embed.permute(1, 0).unsqueeze(0) | |
| temp_embed_resized = nn.functional.interpolate( | |
| temp_embed_resized, | |
| size=(num_frames), | |
| mode="linear", | |
| align_corners=False, | |
| ) | |
| return temp_embed_resized.squeeze(0).permute(1, 0) | |
| class TemporalParameterEncoding(nn.Module, EncodingProtocol): | |
| def __init__(self, embed_dim: int, max_len: int) -> None: | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| self.max_len = max_len | |
| self.temp_embed = nn.Parameter(torch.zeros(self.max_len, self.embed_dim)) | |
| nn.init.trunc_normal_(self.temp_embed, std=0.02) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, t, _, _ = x.shape | |
| if t != self.temp_embed.shape[0]: | |
| logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") | |
| temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) | |
| else: | |
| temp_embed = self.temp_embed | |
| temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) | |
| return x + temp_embed | |
| def create_neighbor_weight_matrix(num_tokens: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: | |
| indices = torch.arange(num_tokens, dtype=dtype, device=device) | |
| abs_diff = torch.abs(indices.unsqueeze(0) - indices.unsqueeze(1)) | |
| weights = 1.0 / (2.0**abs_diff) | |
| return weights | |
| def compute_t_adj(x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: | |
| return torch.einsum("bfnd,nk->bfkd", x, weights) | |
| def token_propagation(x: torch.Tensor, num_tokens: int) -> torch.Tensor: | |
| """Apply neighboring token propagation update.""" | |
| weights = create_neighbor_weight_matrix(num_tokens, x.device, x.dtype) | |
| t_adj = compute_t_adj(x, weights) | |
| return x + t_adj - t_adj.detach() | |
| class NeighboringTokenPropagationEncoding(TemporalParameterEncoding): | |
| """ | |
| Neighboring Token Propagation method inspired by Momentor (https://arxiv.org/abs/2402.11435) | |
| """ | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, t, q, _ = x.shape | |
| if t != self.temp_embed.shape[0]: | |
| logger.debug(f"Interpolating temporal encodings from {self.temp_embed.shape[0]} to {t}.") | |
| temp_embed = interpolate_temp_pos_embed(self.temp_embed, t) | |
| else: | |
| temp_embed = self.temp_embed | |
| temp_embed = temp_embed.unsqueeze(0).unsqueeze(2) | |
| if self.training: | |
| temp_embed = token_propagation(temp_embed, q) | |
| return x + temp_embed | |
| class EncodingFactory(nn.Module): | |
| def __init__(self, encoding_type: str, embed_dim: int, max_len: int) -> None: | |
| super().__init__() | |
| fn = { | |
| "temporal_parameter": TemporalParameterEncoding, | |
| "neighboring_token_propagation": NeighboringTokenPropagationEncoding, | |
| }[encoding_type] | |
| self.encoding = fn(embed_dim=embed_dim, max_len=max_len) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.encoding(x) | |