Instructions to use namednil/STEP with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use namednil/STEP with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="namednil/STEP", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("namednil/STEP", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| from transformers import AutoTokenizer, PretrainedConfig, T5Config, PreTrainedModel, T5ForConditionalGeneration, \ | |
| AutoModelForSeq2SeqLM, Adafactor | |
| from typing import Optional, List, Callable, Mapping, Any, Union | |
| import os | |
| class STEPFinetuningModelConfig(T5Config): | |
| model_type = "STEP_finetuning" | |
| def __init__(self, | |
| num_examples: int = 512, | |
| prefix_length: int = 10, | |
| random_selection: bool = True, | |
| # don't change these unless you change what the prefix of the model is initialized with: | |
| prefix_max_init_length: int = 20, | |
| num_precomputed_examples: int = 1000, | |
| **kwargs): | |
| # These are all about the initialization of the prefix. | |
| self.num_examples = num_examples | |
| self.prefix_length = prefix_length | |
| self.random_selection = random_selection | |
| self.prefix_max_init_length = prefix_max_init_length | |
| self.num_precomputed_examples = num_precomputed_examples | |
| super().__init__(**kwargs) | |
| class STEPFinetuningModel(PreTrainedModel): | |
| config_class = STEPFinetuningModelConfig | |
| def __init__(self, config: STEPFinetuningModelConfig): | |
| super().__init__(config) | |
| self.model = T5ForConditionalGeneration(config) | |
| # Initialize the prefix with NaNs. | |
| self.register_buffer("prefix_init_tensor", torch.zeros(config.num_precomputed_examples, config.prefix_max_init_length, config.d_model)) | |
| # There are two cases: (1) we initialize the model after STEP-pretraining, i.e. the tunable prefix is not set | |
| # and (2) the model has been fine-tuned on downstream data, and hence there is meaningful data in the tunable prefix | |
| # Initialize the prefix with NaNs. If we initialize from STEP-pretraining, this will be overwritten by a custom version of from_pretrained | |
| # if we initialize after fine-tuning, the NaNs will be overwritten anyway. | |
| self.prefix_embedding = torch.nn.Parameter(torch.nan + torch.zeros((1, self.config.prefix_length, self.config.d_model))) | |
| self.prefix_has_been_initialized = False | |
| def _initialize_prefix(self): | |
| prefix_init_tensor = self.prefix_init_tensor | |
| if self.config.random_selection: | |
| # randomize selection of edgewise tranformations to average for initialization the prefix. | |
| prefix_init_tensor = prefix_init_tensor[torch.randperm(prefix_init_tensor.shape[0]), :, :] | |
| prefix_init_tensor = prefix_init_tensor[:self.config.num_examples, :self.config.prefix_length, | |
| :] # shape (num ex, prefix length, d model) | |
| self.prefix_embedding.data.copy_(prefix_init_tensor.mean(dim=0, keepdims=True)) | |
| def from_pretrained( | |
| cls, | |
| pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], | |
| *model_args, | |
| **kwargs, | |
| ): | |
| model = super(STEPFinetuningModel, cls).from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) | |
| if torch.all(model.prefix_embedding.isnan()): | |
| model._initialize_prefix() | |
| return model | |
| def prepare_input(self, kwargs): | |
| """ | |
| Prepends the prefix to the given input. | |
| :param kwargs: | |
| :return: | |
| """ | |
| input_ids = kwargs["input_ids"] | |
| embedded_inputs = self.model.get_input_embeddings()(input_ids) | |
| batch_size = input_ids.shape[0] | |
| prefix = torch.repeat_interleave(self.prefix_embedding, batch_size, 0) #shape (batch, prefix length, embed dim) | |
| kwargs = dict(kwargs) | |
| embedded_inputs = torch.cat([prefix, embedded_inputs], dim=1) # shape (batch, prefix + seq length, embed dim) | |
| del kwargs["input_ids"] | |
| kwargs["inputs_embeds"] = embedded_inputs | |
| if "attention_mask" in kwargs: | |
| ones = torch.ones((batch_size, self.config.prefix_length), device=embedded_inputs.device, dtype=kwargs["attention_mask"].dtype) | |
| input_mask = torch.cat([ones, kwargs["attention_mask"]], dim=1) | |
| kwargs["attention_mask"] = input_mask | |
| return kwargs | |
| def forward(self, **kwargs): | |
| return self.model(**self.prepare_input(kwargs)) | |
| def generate(self, **kwargs): | |
| return self.model.generate(**self.prepare_input(kwargs)) | |
| def get_optimizer(self, optimizer: Callable[..., torch.optim.Optimizer] = None, prefix_lr:float = 10.0, **kwargs) -> torch.optim.Optimizer: | |
| """ | |
| Return an optimizer that uses a different learning rate (typically higher) for the prefix than for the rest of the model. | |
| """ | |
| prefix_params = [] | |
| other_params = [] | |
| for name, param in self.named_parameters(): | |
| if name == "prefix_embedding": | |
| prefix_params.append(param) | |
| else: | |
| other_params.append(param) | |
| if optimizer is None: | |
| # The optimizer used in the paper. | |
| hparams = {"scale_parameter": False, "relative_step": False, "warmup_init": False, "lr": 1e-4} | |
| return Adafactor(params=[{"params": prefix_params, "lr": prefix_lr}, {"params": other_params}], **(hparams | kwargs)) | |
| return optimizer(params=[{"params": prefix_params, "lr": prefix_lr}, {"params": other_params}], **kwargs) | |