| import gradio as gr |
| from random import randint |
| from all_models import models |
| from externalmod import gr_Interface_load |
| import asyncio |
| import os |
| from threading import RLock |
|
|
| lock = RLock() |
| HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
| def load_fn(models): |
| global models_load |
| models_load = {} |
| |
| for model in models: |
| if model not in models_load.keys(): |
| try: |
| m = gr_Interface_load(f'models/{model}', hf_token=HF_TOKEN) |
| except Exception as error: |
| print(error) |
| m = gr.Interface(lambda: None, ['text'], ['image']) |
| models_load.update({model: m}) |
|
|
| load_fn(models) |
|
|
| num_models = 6 |
| MAX_SEED = 3999999999 |
| default_models = models[:num_models] |
| inference_timeout = 600 |
|
|
| async def infer(model_str, prompt, seed=1, timeout=inference_timeout): |
| kwargs = {"seed": seed} |
| task = asyncio.create_task(asyncio.to_thread(models_load[model_str].fn, prompt=prompt, **kwargs, token=HF_TOKEN)) |
| await asyncio.sleep(0) |
| try: |
| result = await asyncio.wait_for(task, timeout=timeout) |
| except (Exception, asyncio.TimeoutError) as e: |
| print(e) |
| print(f"Task timed out: {model_str}") |
| if not task.done(): |
| task.cancel() |
| result = None |
| if task.done() and result is not None: |
| with lock: |
| png_path = "image.png" |
| result.save(png_path) |
| return png_path |
| return None |
|
|
| |
| def generate_api(model_str, prompt, seed=1): |
| result = asyncio.run(infer(model_str, prompt, seed)) |
| if result: |
| return result |
| return None |
|
|
| |
| iface = gr.Interface(fn=generate_api, inputs=["text", "text", "number"], outputs="file") |
| iface.launch(show_api=True, share=True) |