# ============================================================================ # SD1.5-Flow-Sol Correct Inference (Colab Cell) # ============================================================================ # Matches trainer's sample() method exactly: # - DDPM scheduler timesteps # - Specifically aligned for the SOL training pipeline to ensure accurate inference. # - Model predicts velocity # - Convert velocity → epsilon for scheduler stepping # ============================================================================ !pip install -q diffusers transformers accelerate safetensors import torch import gc from huggingface_hub import hf_hub_download from diffusers import UNet2DConditionModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTextModel, CLIPTokenizer from PIL import Image import numpy as np torch.cuda.empty_cache() gc.collect() # ============================================================================ # CONFIG # ============================================================================ DEVICE = "cuda" DTYPE = torch.float16 SOL_REPO = "AbstractPhil/sd15-flow-matching" SOL_FILENAME = "sd15_flowmatch_david_weighted_efinal.pt" # ============================================================================ # LOAD MODELS # ============================================================================ print("Loading CLIP...") clip_tok = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") clip_enc = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=DTYPE).to(DEVICE).eval() print("Loading VAE...") vae = AutoencoderKL.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae", torch_dtype=DTYPE ).to(DEVICE).eval() print("Loading UNet...") unet = UNet2DConditionModel.from_pretrained( "stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet", torch_dtype=DTYPE, ).to(DEVICE).eval() print("Loading DDPM Scheduler...") sched = DDPMScheduler(num_train_timesteps=1000) # ============================================================================ # LOAD SOL WEIGHTS # ============================================================================ print(f"\nLoading Sol from {SOL_REPO}...") weights_path = hf_hub_download(repo_id=SOL_REPO, filename=SOL_FILENAME) checkpoint = torch.load(weights_path, map_location="cpu") state_dict = checkpoint["student"] print(f" gstep: {checkpoint.get('gstep', 'unknown')}") if any(k.startswith("unet.") for k in state_dict.keys()): state_dict = {k.replace("unet.", ""): v for k, v in state_dict.items() if k.startswith("unet.")} state_dict = {k: v for k, v in state_dict.items() if not k.startswith(("hooks.", "local_heads."))} missing, unexpected = unet.load_state_dict(state_dict, strict=False) print(f" Loaded: {len(state_dict)} keys, missing: {len(missing)}, unexpected: {len(unexpected)}") del checkpoint, state_dict gc.collect() for p in unet.parameters(): p.requires_grad = False print("✓ Sol ready!") # ============================================================================ # HELPER: Alpha/Sigma from DDPM schedule (matches trainer) # ============================================================================ def alpha_sigma(t: torch.LongTensor): """Get alpha and sigma from DDPM alphas_cumprod - matches trainer exactly.""" ac = sched.alphas_cumprod.to(DEVICE)[t] alpha = ac.sqrt().view(-1, 1, 1, 1).float() sigma = (1.0 - ac).sqrt().view(-1, 1, 1, 1).float() return alpha, sigma # ============================================================================ # CORRECT SAMPLER (matches trainer's sample() method) # ============================================================================ @torch.inference_mode() def generate_sol(prompt, negative_prompt="", seed=42, steps=30, cfg=7.5): """ Matches trainer's sample() method exactly: 1. Use DDPM scheduler timesteps 2. Model predicts velocity v 3. Convert v → x0_hat → eps_hat 4. Use sched.step(eps_hat, t, x_t) """ if seed is not None: torch.manual_seed(seed) # Encode prompts inputs = clip_tok(prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) cond = clip_enc(**inputs).last_hidden_state.to(DTYPE) inputs_neg = clip_tok(negative_prompt, return_tensors="pt", padding="max_length", max_length=77, truncation=True).to(DEVICE) uncond = clip_enc(**inputs_neg).last_hidden_state.to(DTYPE) # Set scheduler timesteps sched.set_timesteps(steps, device=DEVICE) # Start from noise x_t = torch.randn(1, 4, 64, 64, device=DEVICE, dtype=DTYPE) print(f"Sampling '{prompt[:40]}' | {steps} steps, cfg={cfg}") for i, t_scalar in enumerate(sched.timesteps): t = torch.full((1,), t_scalar, device=DEVICE, dtype=torch.long) # Model predicts VELOCITY (not epsilon!) v_cond = unet(x_t.to(DTYPE), t, encoder_hidden_states=cond).sample v_uncond = unet(x_t.to(DTYPE), t, encoder_hidden_states=uncond).sample # CFG on velocity v_hat = v_uncond + cfg * (v_cond - v_uncond) # Convert velocity to epsilon (EXACTLY as trainer does) alpha, sigma = alpha_sigma(t) # v = alpha * eps - sigma * x0 # x_t = alpha * x0 + sigma * eps # Solve for x0: x0 = (alpha * x_t - sigma * v) / (alpha^2 + sigma^2) # Then: eps = (x_t - alpha * x0) / sigma denom = alpha**2 + sigma**2 x0_hat = (alpha * x_t.float() - sigma * v_hat.float()) / (denom + 1e-8) eps_hat = (x_t.float() - alpha * x0_hat) / (sigma + 1e-8) # Step with epsilon step_out = sched.step(eps_hat, t_scalar, x_t.float()) x_t = step_out.prev_sample.to(DTYPE) if (i + 1) % max(1, steps // 5) == 0: print(f" Step {i+1}/{steps}, t={t_scalar}") # Decode x_t = x_t / 0.18215 img = vae.decode(x_t).sample img = (img / 2 + 0.5).clamp(0, 1)[0].permute(1, 2, 0).cpu().float().numpy() return Image.fromarray((img * 255).astype(np.uint8)) # ============================================================================ # TEST # ============================================================================ print("\n" + "="*60) print("Generating test images with Sol (correct sampler)") print("="*60) from IPython.display import display prompts = [ "a castle at sunset", "a portrait of a woman", "a city street at night", ] for prompt in prompts: print() img = generate_sol(prompt, negative_prompt="", seed=42, steps=4, cfg=5.0) display(img) print("\n✓ Bask in the beauty of the geometric expert!")