Spaces:
Runtime error
Runtime error
| from config import SHARE, MODELS, TRAINING_PARAMS, LORA_TRAINING_PARAMS, GENERATION_PARAMS | |
| import os | |
| import gradio as gr | |
| import random | |
| from trainer import Trainer | |
| LORA_DIR = 'lora' | |
| def random_name(): | |
| fruits = [ | |
| "dragonfruit", "kiwano", "rambutan", "durian", "mangosteen", | |
| "jabuticaba", "pitaya", "persimmon", "acai", "starfruit" | |
| ] | |
| return '-'.join(random.sample(fruits, 3)) | |
| class UI(): | |
| def __init__(self): | |
| self.trainer = Trainer() | |
| def load_loras(self): | |
| loaded_model_name = self.trainer.model_name | |
| if os.path.exists(LORA_DIR) and loaded_model_name is not None: | |
| loras = [f for f in os.listdir(LORA_DIR)] | |
| sanitized_model_name = loaded_model_name.replace('/', '_').replace('.', '_') | |
| loras = [f for f in loras if f.startswith(sanitized_model_name)] | |
| loras.insert(0, 'None') | |
| return gr.Dropdown.update(choices=loras) | |
| else: | |
| return gr.Dropdown.update(choices=['None'], value='None') | |
| def training_params_block(self): | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.max_seq_length = gr.Slider( | |
| interactive=True, | |
| minimum=1, maximum=4096, value=TRAINING_PARAMS['max_seq_length'], | |
| label="Max Sequence Length", | |
| ) | |
| self.micro_batch_size = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=TRAINING_PARAMS['micro_batch_size'], | |
| label="Micro Batch Size", | |
| ) | |
| self.gradient_accumulation_steps = gr.Slider( | |
| minimum=1, maximum=128, step=1, value=TRAINING_PARAMS['gradient_accumulation_steps'], | |
| label="Gradient Accumulation Steps", | |
| ) | |
| self.epochs = gr.Slider( | |
| minimum=1, maximum=100, step=1, value=TRAINING_PARAMS['epochs'], | |
| label="Epochs", | |
| ) | |
| self.learning_rate = gr.Slider( | |
| minimum=0.00001, maximum=0.01, value=TRAINING_PARAMS['learning_rate'], | |
| label="Learning Rate", | |
| ) | |
| with gr.Column(): | |
| self.lora_r = gr.Slider( | |
| minimum=1, maximum=64, step=1, value=LORA_TRAINING_PARAMS['lora_r'], | |
| label="LoRA R", | |
| ) | |
| self.lora_alpha = gr.Slider( | |
| minimum=1, maximum=128, step=1, value=LORA_TRAINING_PARAMS['lora_alpha'], | |
| label="LoRA Alpha", | |
| ) | |
| self.lora_dropout = gr.Slider( | |
| minimum=0, maximum=1, step=0.01, value=LORA_TRAINING_PARAMS['lora_dropout'], | |
| label="LoRA Dropout", | |
| ) | |
| def load_model(self, model_name, progress=gr.Progress(track_tqdm=True)): | |
| if model_name == '': return '' | |
| if model_name is None: return self.trainer.model_name | |
| progress(0, desc=f'Loading {model_name}...') | |
| self.trainer.load_model(model_name) | |
| return self.trainer.model_name | |
| def base_model_block(self): | |
| self.model_name = gr.Dropdown(label='Base Model', choices=MODELS) | |
| def training_data_block(self): | |
| training_text = gr.TextArea( | |
| lines=20, | |
| label="Training Data", | |
| info='Paste training data text here. Sequences must be separated with 2 blank lines' | |
| ) | |
| examples_dir = os.path.join(os.getcwd(), 'example-datasets') | |
| def load_example(filename): | |
| with open(os.path.join(examples_dir, filename) , 'r', encoding='utf-8') as f: | |
| return f.read() | |
| example_filename = gr.Textbox(visible=False) | |
| example_filename.change(fn=load_example, inputs=example_filename, outputs=training_text) | |
| gr.Examples("./example-datasets", inputs=example_filename) | |
| self.training_text = training_text | |
| def training_launch_block(self): | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.new_lora_name = gr.Textbox(label='New PEFT Adapter Name', value=random_name()) | |
| with gr.Column(): | |
| train_button = gr.Button('Train', variant='primary') | |
| def train( | |
| training_text, | |
| new_lora_name, | |
| max_seq_length, | |
| micro_batch_size, | |
| gradient_accumulation_steps, | |
| epochs, | |
| learning_rate, | |
| lora_r, | |
| lora_alpha, | |
| lora_dropout, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| self.trainer.unload_lora() | |
| self.trainer.train( | |
| training_text, | |
| new_lora_name, | |
| max_seq_length=max_seq_length, | |
| micro_batch_size=micro_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| epochs=epochs, | |
| learning_rate=learning_rate, | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout | |
| ) | |
| return new_lora_name | |
| train_button.click( | |
| fn=train, | |
| inputs=[ | |
| self.training_text, | |
| self.new_lora_name, | |
| self.max_seq_length, | |
| self.micro_batch_size, | |
| self.gradient_accumulation_steps, | |
| self.epochs, | |
| self.learning_rate, | |
| self.lora_r, | |
| self.lora_alpha, | |
| self.lora_dropout, | |
| ], | |
| outputs=[self.new_lora_name] | |
| ).then( | |
| fn=lambda x: self.trainer.load_model(x, force=True), | |
| inputs=[self.model_name], | |
| outputs=[] | |
| ) | |
| def inference_block(self): | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.lora_name = gr.Dropdown( | |
| interactive=True, | |
| choices=['None'], | |
| value='None', | |
| label='LoRA', | |
| ) | |
| def load_lora(lora_name, progress=gr.Progress(track_tqdm=True)): | |
| if lora_name == 'None': | |
| self.trainer.unload_lora() | |
| else: | |
| self.trainer.load_lora(f'{LORA_DIR}/{lora_name}') | |
| return lora_name | |
| self.lora_name.change( | |
| fn=load_lora, | |
| inputs=self.lora_name, | |
| outputs=self.lora_name | |
| ) | |
| self.prompt = gr.Textbox( | |
| interactive=True, | |
| lines=5, | |
| label="Prompt", | |
| value="Human: How is cheese made?\nAssistant:" | |
| ) | |
| self.generate_btn = gr.Button('Generate', variant='primary') | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.max_new_tokens = gr.Slider( | |
| minimum=0, maximum=4096, step=1, value=GENERATION_PARAMS['max_new_tokens'], | |
| label="Max New Tokens", | |
| ) | |
| with gr.Column(): | |
| self.do_sample = gr.Checkbox( | |
| interactive=True, | |
| label="Enable Sampling (leave off for greedy search)", | |
| value=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.num_beams = gr.Slider( | |
| minimum=1, maximum=10, step=1, value=GENERATION_PARAMS['num_beams'], | |
| label="Num Beams", | |
| ) | |
| with gr.Column(): | |
| self.repeat_penalty = gr.Slider( | |
| minimum=0, maximum=4.5, step=0.01, value=GENERATION_PARAMS['repetition_penalty'], | |
| label="Repetition Penalty", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.temperature = gr.Slider( | |
| minimum=0.01, maximum=1.99, step=0.01, value=GENERATION_PARAMS['temperature'], | |
| label="Temperature", | |
| ) | |
| self.top_p = gr.Slider( | |
| minimum=0, maximum=1, step=0.01, value=GENERATION_PARAMS['top_p'], | |
| label="Top P", | |
| ) | |
| self.top_k = gr.Slider( | |
| minimum=0, maximum=200, step=1, value=GENERATION_PARAMS['top_k'], | |
| label="Top K", | |
| ) | |
| with gr.Column(): | |
| self.output = gr.Textbox( | |
| interactive=True, | |
| lines=20, | |
| label="Output" | |
| ) | |
| def generate( | |
| prompt, | |
| do_sample, | |
| max_new_tokens, | |
| num_beams, | |
| repeat_penalty, | |
| temperature, | |
| top_p, | |
| top_k, | |
| progress=gr.Progress(track_tqdm=True) | |
| ): | |
| return self.trainer.generate( | |
| prompt, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| repetition_penalty=repeat_penalty, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k | |
| ) | |
| self.generate_btn.click( | |
| fn=generate, | |
| inputs=[ | |
| self.prompt, | |
| self.do_sample, | |
| self.max_new_tokens, | |
| self.num_beams, | |
| self.repeat_penalty, | |
| self.temperature, | |
| self.top_p, | |
| self.top_k | |
| ], | |
| outputs=[self.output] | |
| ) | |
| def layout(self): | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML("""<h2> | |
| <a style="text-decoration: none;" href="https://github.com/lxe/simple-llama-finetuner">🦙 Simple LLM Finetuner</a> <a href="https://huggingface.co/spaces/lxe/simple-llama-finetuner?duplicate=true"><img | |
| src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" style="display:inline"> | |
| </a></h2><p>Finetune an LLM on your own text. Duplicate this space onto a GPU-enabled space to run.</p>""") | |
| with gr.Column(): | |
| self.base_model_block() | |
| with gr.Tab('Finetuning'): | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.training_data_block() | |
| with gr.Column(): | |
| self.training_params_block() | |
| self.training_launch_block() | |
| with gr.Tab('Inference') as inference_tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| self.inference_block() | |
| inference_tab.select( | |
| fn=self.load_loras, | |
| inputs=[], | |
| outputs=[self.lora_name] | |
| ) | |
| self.model_name.change( | |
| fn=self.load_model, | |
| inputs=[self.model_name], | |
| outputs=[self.model_name] | |
| ).then( | |
| fn=self.load_loras, | |
| inputs=[], | |
| outputs=[self.lora_name] | |
| ) | |
| return demo | |
| def run(self): | |
| self.ui = self.layout() | |
| self.ui.queue().launch(show_error=True, share=SHARE) | |
| if (__name__ == '__main__'): | |
| ui = UI() | |
| ui.run() | |