| |
|
| | |
| | import os |
| | import io |
| | import base64 |
| | import torch |
| | from diffusers import DiffusionPipeline |
| |
|
| | class EndpointHandler: |
| | def __init__(self, path=""): |
| | |
| | model_dir = path or "/repository" |
| | |
| | self.pipe = DiffusionPipeline.from_pretrained( |
| | model_dir, |
| | torch_dtype=torch.float16, |
| | use_safetensors=True, |
| | ).to("cuda") |
| | self.pipe.set_progress_bar_config(disable=True) |
| |
|
| | def __call__(self, data: dict): |
| | |
| | prompt = data.get("inputs") or data.get("prompt") or "" |
| | params = data.get("parameters") or {} |
| |
|
| | width = int(params.get("width", 768)) |
| | height = int(params.get("height", 768)) |
| | steps = int(params.get("num_inference_steps", 25)) |
| | guidance = float(params.get("guidance_scale", 7.0)) |
| | negative = params.get("negative_prompt") |
| | seed = params.get("seed") |
| | generator = (torch.Generator(device="cuda").manual_seed(int(seed)) |
| | if seed is not None else None) |
| |
|
| | image = self.pipe( |
| | prompt=prompt, |
| | negative_prompt=negative, |
| | width=width, |
| | height=height, |
| | num_inference_steps=steps, |
| | guidance_scale=guidance, |
| | generator=generator, |
| | ).images[0] |
| |
|
| | buf = io.BytesIO() |
| | image.save(buf, format="PNG") |
| | return {"image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")} |
| |
|