Spaces:
Runtime error
Runtime error
| from modules.module_customPllLabel import CustomPllLabel | |
| from modules.module_pllScore import PllScore | |
| from typing import List, Dict | |
| import torch | |
| class RankSents: | |
| def __init__( | |
| self, | |
| language_model, # LanguageModel class instance | |
| lang: str, | |
| errorManager # ErrorManager class instance | |
| ) -> None: | |
| self.tokenizer = language_model.initTokenizer() | |
| self.model = language_model.initModel() | |
| _ = self.model.eval() | |
| self.Label = CustomPllLabel() | |
| self.pllScore = PllScore( | |
| language_model=language_model | |
| ) | |
| self.softmax = torch.nn.Softmax(dim=-1) | |
| if lang == "es": | |
| self.articles = [ | |
| 'un','una','unos','unas','el','los','la','las','lo' | |
| ] | |
| self.prepositions = [ | |
| 'a','ante','bajo','cabe','con','contra','de','desde','en','entre','hacia','hasta','para','por','según','sin','so','sobre','tras','durante','mediante','vía','versus' | |
| ] | |
| self.conjunctions = [ | |
| 'y','o','ni','que','pero','si' | |
| ] | |
| elif lang == "en": | |
| self.articles = [ | |
| 'a','an', 'the' | |
| ] | |
| self.prepositions = [ | |
| 'above', 'across', 'against', 'along', 'among', 'around', 'at', 'before', 'behind', 'below', 'beneath', 'beside', 'between', 'by', 'down', 'from', 'in', 'into', 'near', 'of', 'off', 'on', 'to', 'toward', 'under', 'upon', 'with', 'within' | |
| ] | |
| self.conjunctions = [ | |
| 'and', 'or', 'but', 'that', 'if', 'whether' | |
| ] | |
| self.errorManager = errorManager | |
| def errorChecking( | |
| self, | |
| sent: str | |
| ) -> str: | |
| out_msj = "" | |
| if not sent: | |
| out_msj = ['RANKSENTS_NO_SENTENCE_PROVIDED'] | |
| elif sent.count("*") > 1: | |
| out_msj = ['RANKSENTS_TOO_MANY_MASKS_IN_SENTENCE'] | |
| elif sent.count("*") == 0: | |
| out_msj = ['RANKSENTS_NO_MASK_IN_SENTENCE'] | |
| else: | |
| sent_len = len(self.tokenizer.encode(sent.replace("*", self.tokenizer.mask_token))) | |
| max_len = self.tokenizer.max_len_single_sentence | |
| if sent_len > max_len: | |
| out_msj = ['RANKSENTS_TOKENIZER_MAX_TOKENS_REACHED', max_len] | |
| return self.errorManager.process(out_msj) | |
| def getTopPredictions( | |
| self, | |
| n: int, | |
| sent: str, | |
| banned_word_list: List[str], | |
| exclude_articles: bool, | |
| exclude_prepositions: bool, | |
| exclude_conjunctions: bool, | |
| ) -> List[str]: | |
| sent_masked = sent.replace("*", self.tokenizer.mask_token) | |
| inputs = self.tokenizer.encode_plus( | |
| sent_masked, | |
| add_special_tokens=True, | |
| return_tensors='pt', | |
| return_attention_mask=True, | |
| truncation=True | |
| ) | |
| tk_position_mask = torch.where(inputs['input_ids'][0] == self.tokenizer.mask_token_id)[0].item() | |
| with torch.no_grad(): | |
| out = self.model(**inputs) | |
| logits = out.logits | |
| outputs = self.softmax(logits) | |
| outputs = torch.squeeze(outputs, dim=0) | |
| probabilities = outputs[tk_position_mask] | |
| first_tk_id = torch.argsort(probabilities, descending=True) | |
| top_tks_pred = [] | |
| for tk_id in first_tk_id: | |
| tk_string = self.tokenizer.decode([tk_id]) | |
| tk_is_banned = tk_string in banned_word_list | |
| tk_is_punctuation = not tk_string.isalnum() | |
| tk_is_substring = tk_string.startswith("##") | |
| tk_is_special = (tk_string in self.tokenizer.all_special_tokens) | |
| if exclude_articles: | |
| tk_is_article = tk_string in self.articles | |
| else: | |
| tk_is_article = False | |
| if exclude_prepositions: | |
| tk_is_prepositions = tk_string in self.prepositions | |
| else: | |
| tk_is_prepositions = False | |
| if exclude_conjunctions: | |
| tk_is_conjunctions = tk_string in self.conjunctions | |
| else: | |
| tk_is_conjunctions = False | |
| predictions_is_dessire = not any([ | |
| tk_is_banned, | |
| tk_is_punctuation, | |
| tk_is_substring, | |
| tk_is_special, | |
| tk_is_article, | |
| tk_is_prepositions, | |
| tk_is_conjunctions | |
| ]) | |
| if predictions_is_dessire and len(top_tks_pred) < n: | |
| top_tks_pred.append(tk_string) | |
| elif len(top_tks_pred) >= n: | |
| break | |
| return top_tks_pred | |
| def rank(self, | |
| sent: str, | |
| interest_word_list: List[str]=[], | |
| banned_word_list: List[str]=[], | |
| exclude_articles: bool=False, | |
| exclude_prepositions: bool=False, | |
| exclude_conjunctions: bool=False, | |
| n_predictions: int=5 | |
| ) -> Dict[str, float]: | |
| err = self.errorChecking(sent) | |
| if err: | |
| raise Exception(err) | |
| if not interest_word_list: | |
| interest_word_list = self.getTopPredictions( | |
| n_predictions, | |
| sent, | |
| banned_word_list, | |
| exclude_articles, | |
| exclude_prepositions, | |
| exclude_conjunctions | |
| ) | |
| sent_list = [] | |
| sent_list2print = [] | |
| for word in interest_word_list: | |
| sent_list.append(sent.replace("*", "<"+word+">")) | |
| sent_list2print.append(sent.replace("*", "<"+word+">")) | |
| all_plls_scores = {} | |
| for sent, sent2print in zip(sent_list, sent_list2print): | |
| all_plls_scores[sent2print] = self.pllScore.compute(sent) | |
| return all_plls_scores |