Spaces:
Paused
Paused
| import transformers | |
| import numpy as np | |
| import re | |
| from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForCausalLM | |
| from vllm import LLM, SamplingParams | |
| import torch | |
| import gradio as gr | |
| import json | |
| import os | |
| import shutil | |
| import requests | |
| from pprint import pprint | |
| import chromadb | |
| import pandas as pd | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| pd.set_option('display.max_columns', None) | |
| sampling_params = SamplingParams(temperature=.7, top_p=.95, max_tokens=2000, presence_penalty = 1.5, stop = ["``", "### Fin ###", "<|eot_id|>"]) | |
| # Define the device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| #Define variables | |
| temperature=0.2 | |
| max_new_tokens=1000 | |
| top_p=0.92 | |
| repetition_penalty=1.7 | |
| model_name = "Inagua/code-model-2" | |
| llm = LLM(model_name, max_model_len=4096) | |
| #CSS for references formatting | |
| css = """ | |
| .generation { | |
| margin-left:2em; | |
| margin-right:2em; | |
| } | |
| :target { | |
| background-color: #CCF3DF; /* Change the text color to red */ | |
| } | |
| .source { | |
| float:left; | |
| max-width:17%; | |
| margin-left:2%; | |
| } | |
| .tooltip { | |
| position: relative; | |
| cursor: pointer; | |
| font-variant-position: super; | |
| color: #97999b; | |
| } | |
| .tooltip:hover::after { | |
| content: attr(data-text); | |
| position: absolute; | |
| left: 0; | |
| top: 120%; /* Adjust this value as needed to control the vertical spacing between the text and the tooltip */ | |
| white-space: pre-wrap; /* Allows the text to wrap */ | |
| width: 500px; /* Sets a fixed maximum width for the tooltip */ | |
| max-width: 500px; /* Ensures the tooltip does not exceed the maximum width */ | |
| z-index: 1; | |
| background-color: #f9f9f9; | |
| color: #000; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| padding: 5px; | |
| display: block; | |
| box-shadow: 0 4px 8px rgba(0,0,0,0.1); /* Optional: Adds a subtle shadow for better visibility */ | |
| }""" | |
| #Curtesy of chatgpt | |
| def format_references(text): | |
| # Define start and end markers for the reference | |
| ref_start_marker = '<ref text="' | |
| ref_end_marker = '</ref>' | |
| # Initialize an empty list to hold parts of the text | |
| parts = [] | |
| current_pos = 0 | |
| ref_number = 1 | |
| # Loop until no more reference start markers are found | |
| while True: | |
| start_pos = text.find(ref_start_marker, current_pos) | |
| if start_pos == -1: | |
| # No more references found, add the rest of the text | |
| parts.append(text[current_pos:]) | |
| break | |
| # Add text up to the start of the reference | |
| parts.append(text[current_pos:start_pos]) | |
| # Find the end of the reference text attribute | |
| end_pos = text.find('">', start_pos) | |
| if end_pos == -1: | |
| # Malformed reference, break to avoid infinite loop | |
| break | |
| # Extract the reference text | |
| ref_text = text[start_pos + len(ref_start_marker):end_pos].replace('\n', ' ').strip() | |
| ref_text_encoded = ref_text.replace("&", "&").replace("<", "<").replace(">", ">") | |
| # Find the end of the reference tag | |
| ref_end_pos = text.find(ref_end_marker, end_pos) | |
| if ref_end_pos == -1: | |
| # Malformed reference, break to avoid infinite loop | |
| break | |
| # Extract the reference ID | |
| ref_id = text[end_pos + 2:ref_end_pos].strip() | |
| # Create the HTML for the tooltip | |
| tooltip_html = f'<span class="tooltip" data-refid="{ref_id}" data-text="{ref_id}: {ref_text_encoded}"><a href="#{ref_id}">[' + str(ref_number) +']</a></span>' | |
| parts.append(tooltip_html) | |
| # Update current_pos to the end of the current reference | |
| current_pos = ref_end_pos + len(ref_end_marker) | |
| ref_number = ref_number + 1 | |
| # Join and return the parts | |
| parts = ''.join(parts) | |
| return parts | |
| # Class to encapsulate the Falcon chatbot | |
| class MistralChatBot: | |
| def __init__(self, system_prompt="Le dialogue suivant est une conversation"): | |
| self.system_prompt = system_prompt | |
| def predict_field(self, user_message): | |
| detailed_prompt = """### Question ###\n""" + user_message + "\n\n### Field ###\n" | |
| prompts = [detailed_prompt] | |
| outputs = llm.generate(prompts, sampling_params, use_tqdm = False) | |
| generated_text = outputs[0].outputs[0].text | |
| print(generated_text) | |
| fiches_html = "" | |
| return generated_text, fiches_html | |
| def predict_answer(self, user_message, context): | |
| detailed_prompt = """### Question ###\n""" + user_message + "\n\n### Contexte ###\n" + context + "\n\n### Formule ###\n" | |
| prompts = [detailed_prompt] | |
| outputs = llm.generate(prompts, sampling_params, use_tqdm = False) | |
| generated_text = outputs[0].outputs[0].text | |
| print(generated_text) | |
| fiches_html = "" | |
| return generated_text, fiches_html | |
| # Create the Falcon chatbot instance | |
| mistral_bot = MistralChatBot() | |
| # Define the Gradio interface | |
| title = "Inagua" | |
| description = "An experimental LLM to interact with DAMAaaS documentation" | |
| examples = [ | |
| [ | |
| "How to calculate a linear regression?", # user_message | |
| 0.7 # temperature | |
| ] | |
| ] | |
| additional_inputs=[ | |
| gr.Slider( | |
| label="Température", | |
| value=0.2, # Default value | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Des valeurs plus élevées donne plus de créativité, mais aussi d'étrangeté", | |
| ), | |
| ] | |
| demo = gr.Blocks() | |
| with gr.Blocks(theme='gradio/monochrome') as demo: | |
| gr.HTML("""<h1 style="text-align:center">InaguaLLM</h1>""") | |
| text_input = gr.Textbox(label="Your question", type="text", lines=1) | |
| context_input = gr.Textbox(label="Your context", type="text", lines=1) | |
| text_button = gr.Button("Query InaguaLLM") | |
| field_output = gr.HTML(label="Field") | |
| text_output = gr.HTML(label="Answer") | |
| text_button.click(mistral_bot.predict_field, inputs=[text_input], outputs=[field_output], api_name="convert-question") | |
| text_button.click(mistral_bot.predict_answer, inputs=[text_input, context_input], outputs=[text_output], api_name="convert-code") | |
| if __name__ == "__main__": | |
| demo.queue().launch() |