Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from transformers import T5ForConditionalGeneration, T5Tokenizer | |
| import torch | |
| def set_seed(seed): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| tokenizer = T5Tokenizer.from_pretrained('Deep1994/t5-paraphrase-quora') | |
| def load_model(): | |
| model = T5ForConditionalGeneration.from_pretrained('Deep1994/t5-paraphrase-quora') | |
| return model | |
| model = load_model() | |
| st.sidebar.subheader('Select decoding strategy below.') | |
| decoding_strategy = st.sidebar.selectbox("decoding_strategy", ['Top k/p sampling', 'Beam Search']) | |
| st.title('Paraphrase a question in English.') | |
| st.write('This is a fine-tuned t5 model that will paraphrase\ | |
| your English input text into another English output\ | |
| by leveraging a pre-trained [Text-To-Text Transfer Tranformers](https://arxiv.org/abs/1910.10683) model.') | |
| st.subheader('Input Text') | |
| text = st.text_area(' ', height=100) | |
| if text != '': | |
| set_seed(1234) # for reproducibility | |
| prefix = 'paraphrase: ' | |
| encoding = tokenizer.encode_plus(prefix + text, padding=True, return_tensors="pt") | |
| input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"] | |
| if str(decoding_strategy) == 'Top k/p sampling': | |
| beam_outputs = model.generate( | |
| input_ids=input_ids, attention_mask=attention_masks, | |
| do_sample=True, | |
| max_length=20, | |
| top_k=50, | |
| top_p=0.95, | |
| early_stopping=True, | |
| num_return_sequences=10 | |
| ) | |
| elif str(decoding_strategy) == 'Beam Search': | |
| beam_outputs = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_masks, | |
| max_length=20, | |
| num_beams=10, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=10, | |
| early_stopping=True | |
| ) | |
| final_outputs =[] | |
| for beam_output in beam_outputs: | |
| sent = tokenizer.decode(beam_output, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| if sent.lower() != text.lower() and sent not in final_outputs: | |
| final_outputs.append(sent) | |
| if len(final_outputs) == 5: | |
| break | |
| # final_outputs.append(sent) | |
| st.subheader('Paraphrased Text') | |
| for i, final_output in enumerate(final_outputs): | |
| st.write(final_output + '\n') |