cat_in_space / handler.py
freshcodestech's picture
Create handler.py
f7a076b verified
# handler.py
import os
import io
import base64
import torch
from diffusers import DiffusionPipeline
class EndpointHandler:
def __init__(self, path=""):
# The default container mounts your repo at /repository
model_dir = path or "/repository"
# Load your SDXL pipeline in fp16, no device_map, no offloading
self.pipe = DiffusionPipeline.from_pretrained(
model_dir,
torch_dtype=torch.float16,
use_safetensors=True,
).to("cuda")
self.pipe.set_progress_bar_config(disable=True)
def __call__(self, data: dict):
# Accept either {"inputs": "..."} or {"prompt": "..."} + optional "parameters"
prompt = data.get("inputs") or data.get("prompt") or ""
params = data.get("parameters") or {}
width = int(params.get("width", 768))
height = int(params.get("height", 768))
steps = int(params.get("num_inference_steps", 25))
guidance = float(params.get("guidance_scale", 7.0))
negative = params.get("negative_prompt")
seed = params.get("seed")
generator = (torch.Generator(device="cuda").manual_seed(int(seed))
if seed is not None else None)
image = self.pipe(
prompt=prompt,
negative_prompt=negative,
width=width,
height=height,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
).images[0]
buf = io.BytesIO()
image.save(buf, format="PNG")
return {"image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")}