File size: 1,646 Bytes
f7a076b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

# handler.py
import os
import io
import base64
import torch
from diffusers import DiffusionPipeline

class EndpointHandler:
    def __init__(self, path=""):
        # The default container mounts your repo at /repository
        model_dir = path or "/repository"
        # Load your SDXL pipeline in fp16, no device_map, no offloading
        self.pipe = DiffusionPipeline.from_pretrained(
            model_dir,
            torch_dtype=torch.float16,
            use_safetensors=True,
        ).to("cuda")
        self.pipe.set_progress_bar_config(disable=True)

    def __call__(self, data: dict):
        # Accept either {"inputs": "..."} or {"prompt": "..."} + optional "parameters"
        prompt = data.get("inputs") or data.get("prompt") or ""
        params = data.get("parameters") or {}

        width  = int(params.get("width", 768))
        height = int(params.get("height", 768))
        steps  = int(params.get("num_inference_steps", 25))
        guidance = float(params.get("guidance_scale", 7.0))
        negative = params.get("negative_prompt")
        seed = params.get("seed")
        generator = (torch.Generator(device="cuda").manual_seed(int(seed))
                     if seed is not None else None)

        image = self.pipe(
            prompt=prompt,
            negative_prompt=negative,
            width=width,
            height=height,
            num_inference_steps=steps,
            guidance_scale=guidance,
            generator=generator,
        ).images[0]

        buf = io.BytesIO()
        image.save(buf, format="PNG")
        return {"image_base64": base64.b64encode(buf.getvalue()).decode("utf-8")}