Plat
chore: better default value
f712c8b
import spaces
import json
import yaml
import os
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from model.pipeline import JiTModel, JiTConfig
from model.config import ClassContextConfig
MODEL_REPO = os.environ.get("MODEL_REPO", "p1atdev/JiT-AnimeFace-experiment")
MODEL_PATH = os.environ.get(
"MODEL_PATH", "jit-b256-p16-cls/12-jit-animeface_00043e_033368s.safetensors"
)
LABEL2ID_PATH = os.environ.get("LABEL2ID_PATH", "jit-b256-p16-cls/label2id.json")
CONFIG_PATH = os.environ.get("CONFIG_PATH", "jit-b256-p16-cls/config.yml")
DEVICE = (
torch.device("cuda")
if torch.cuda.is_available()
else torch.device("mps")
if torch.backends.mps.is_available()
else torch.device("cpu")
)
DTYPE = torch.bfloat16 if DEVICE.type in ["cuda"] else torch.float16
MAX_TOKEN_LENGTH = 32
model_map: dict[str, JiTModel] = {} # {model_path: model}
label2id_map: dict[str, dict] = {} # {label2id_path: label2id}
def get_file_path(repo: str, path: str) -> str:
"""Hugging Face Hub からファイルを取得"""
return hf_hub_download(repo, path)
def load_label2id(label2id_path: str) -> dict:
"""label2id.json を読み込む"""
with open(label2id_path, "r") as f:
return json.load(f)
def load_config(config_path: str) -> JiTConfig:
"""設定ファイルを読み込む"""
with open(config_path, "r") as f:
if config_path.endswith(".json"):
config_dict = json.load(f)
elif config_path.endswith((".yaml", ".yml")):
config_dict = yaml.safe_load(f)
else:
raise ValueError("Unsupported config file format. Use .json or .yaml/.yml")
return JiTConfig.model_validate(config_dict)
def load_model(
model_path: str,
label2id_path: str,
config_path: str,
device: torch.device,
dtype: torch.dtype = DTYPE,
) -> tuple[JiTModel, dict]:
"""モデルを読み込む"""
if model_path in model_map: # use cache
model = model_map[model_path]
label2id = label2id_map[label2id_path]
return model, label2id
config = load_config(get_file_path(MODEL_REPO, config_path))
if isinstance(config.context_encoder, ClassContextConfig):
config.context_encoder.label2id_map_path = get_file_path(
MODEL_REPO, label2id_path
)
model = JiTModel.from_pretrained(
config=config,
checkpoint_path=get_file_path(MODEL_REPO, model_path),
)
model.eval()
model.requires_grad_(False)
model.to(device=device, dtype=dtype)
model_map[model_path] = model # cache
label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path))
label2id_map[label2id_path] = label2id # cache
return model, label2id
@spaces.GPU(duration=6)
def generate_images(
prompt: str,
negative_prompt: str,
num_steps: int,
cfg_scale: float,
batch_size: int,
size: int,
seed: int,
#
model_path: str = MODEL_PATH,
label2id_path: str = LABEL2ID_PATH,
config_path: str = CONFIG_PATH,
progress=gr.Progress(track_tqdm=True),
):
model, _label2id = load_model(
model_path=model_path,
label2id_path=label2id_path,
config_path=config_path,
device=DEVICE,
dtype=DTYPE,
)
with torch.inference_mode(), torch.autocast(device_type=DEVICE.type, dtype=DTYPE):
images = model.generate(
prompt=[prompt] * batch_size,
negative_prompt=negative_prompt,
num_inference_steps=num_steps,
cfg_scale=cfg_scale,
height=size,
width=size,
max_token_length=MAX_TOKEN_LENGTH,
cfg_time_range=[0.1, 1.0],
seed=seed if seed >= 0 else None,
device=DEVICE,
execution_dtype=DTYPE,
)
return images
LABEL2ID_URL = f"https://huggingface.co/{MODEL_REPO}/blob/main/{LABEL2ID_PATH}"
def demo():
with gr.Blocks() as ui:
gr.Markdown(f"""
# JiT-AnimeFace Demo
Pixel-space x-prediction flow-matching 90M parameter model for anime face generation, trained from scratch.
- See full supported tags: [label2id.json]({LABEL2ID_URL}). 対応しているタグ一覧は [こちら]({LABEL2ID_URL}) から確認できます。ここに載っていないタグは反応しません。
- Current model: [{MODEL_PATH}](https://huggingface.co/{MODEL_REPO}/blob/main/{MODEL_PATH})
""")
with gr.Row():
with gr.Column():
prompt = gr.TextArea(
label="Prompt",
info=f"Space-separated tags. Not all of danbooru tags are supported. See [the full supported tags]({LABEL2ID_URL}). スペースで区切ってください。カンマ区切りは対応してません。",
value="general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt",
placeholder="e.g.: general 1girl solo portrait looking_at_viewer",
)
negative_prompt = gr.TextArea(
label="Negative Prompt",
info="Space-separated negative tags to avoid in generation. スペースで区切ってください。カンマ区切りは対応してません。",
value="retro_artstyle 1990s_(style) sketch",
lines=2,
placeholder="e.g.: retro_artstyle 1990s_(style) sketch",
)
num_steps = gr.Slider(
minimum=1,
maximum=100,
value=25,
step=1,
label="Number of Steps",
info="Recommended: more than 20 steps for better quality.",
)
cfg_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=5.0,
step=0.25,
label="CFG Scale",
info="Recommended: more than 2.0 for better adherence to the prompt.",
)
batch_size = gr.Slider(
minimum=1,
maximum=64,
value=25,
step=1,
label="Batch Size",
info="Number of images to generate in one batch.",
)
size = gr.Slider(
minimum=64,
maximum=320,
value=256,
step=64,
label="Image Size",
info="Only 256x256 is supported in the current model. Other sizes may cause quality degradation.",
)
seed = gr.Number(
value=-1,
label="Seed (-1 for random)",
)
with gr.Column(scale=2):
generate_button = gr.Button("Generate Images", variant="primary")
output_gallery = gr.Gallery(
label="Generated Images",
columns=5,
height="768px",
preview=False,
show_label=True,
)
gr.Examples(
examples=[
[
"general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl solo portrait looking_at_viewer",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl solo portrait looking_at_viewer blue_hair short_hair blush open_mouth cat_ears animal_ears red_eyes white_background",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl aqua_eyes baseball_cap blonde_hair closed_mouth earrings green_background hat jewelry looking_at_viewer shirt short_hair simple_background solo portrait yellow_shirt",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl solo portrait looking_at_viewer brown_hair ahoge long_hair :| expressionless closed_mouth swept_bangs pink_eyes pink_background simple_background dutch_angle",
"retro_artstyle 1990s_(style) sketch smile",
],
[
"general 1girl solo portrait looking_at_viewer hatsune_miku twintails long_hair blue_eyes one_eye_closed simple_background green_background",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl portrait looking_at_viewer sketch head_tilt white_background monochrome open_mouth long_hair",
"retro_artstyle 1990s_(style)",
],
[
"general 1girl solo from_behind short_hair simple_background black_background",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl portrait looking_to_the_side glasses",
"retro_artstyle 1990s_(style) sketch",
],
[
"general 1girl portrait looking_at_viewer cat_ears purple_theme ;d forehead animal_ears animal_ear_fluff cat_ears",
"retro_artstyle 1990s_(style) sketch",
],
],
inputs=[prompt, negative_prompt],
label="Examples",
examples_per_page=20,
)
gr.on(
triggers=[generate_button.click, prompt.submit],
fn=generate_images,
inputs=[
prompt,
negative_prompt,
num_steps,
cfg_scale,
batch_size,
size,
seed,
],
outputs=output_gallery,
)
return ui
if __name__ == "__main__":
load_model(
model_path=MODEL_PATH,
label2id_path=LABEL2ID_PATH,
config_path=CONFIG_PATH,
device=DEVICE,
dtype=DTYPE,
)
demo().launch()