sitatech commited on
Commit
19f8af1
·
1 Parent(s): ac43fd0

[vtry] Implement VirtualTryModel

Browse files
Files changed (2) hide show
  1. virtual_try/app.py +154 -0
  2. virtual_try/configs.py +21 -0
virtual_try/app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, cast
2
+
3
+ import modal
4
+ from configs import (
5
+ image,
6
+ hf_cache_vol,
7
+ API_KEY,
8
+ MINUTE,
9
+ PORT,
10
+ )
11
+
12
+ with image.imports():
13
+ import torch
14
+ from torchvision import transforms
15
+ from PIL import Image
16
+ import numpy as np
17
+ from diffusers import FluxFillPipeline
18
+ from nunchaku import NunchakuFluxTransformer2dModel
19
+ from nunchaku.utils import get_precision
20
+ from nunchaku.lora.flux.compose import compose_lora
21
+
22
+ from virtual_try.auto_masker import AutoInpaintMaskGenerator
23
+
24
+ TransformType = Callable[[Image.Image | np.ndarray], torch.Tensor]
25
+
26
+ app = modal.App("vibe-shopping")
27
+
28
+
29
+ @app.cls(
30
+ image=image,
31
+ gpu="A100-40GB",
32
+ cpu=4, # 8vCPUs
33
+ memory=16, # 16 GB RAM
34
+ volumes={
35
+ "/root/.cache/huggingface": hf_cache_vol,
36
+ },
37
+ secrets=[API_KEY],
38
+ scaledown_window=(
39
+ 1 * MINUTE
40
+ # how long should we stay up with no requests? Keep it low to minimize credit usage for now.
41
+ ),
42
+ timeout=10 * MINUTE, # how long should we wait for container start?
43
+ )
44
+ class VirtualTryModel:
45
+ @modal.enter()
46
+ def enter(self):
47
+ precision = get_precision() # auto-detect precision 'int4' or 'fp4' based GPU
48
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(
49
+ f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors"
50
+ )
51
+ transformer.set_attention_impl("nunchaku-fp16")
52
+ composed_lora = compose_lora(
53
+ [
54
+ ("xiaozaa/catvton-flux-lora-alpha/pytorch_lora_weights.safetensors", 1),
55
+ ("ByteDance/Hyper-SD/Hyper-FLUX.1-dev-8steps-lora.safetensors", 0.125),
56
+ ]
57
+ )
58
+ transformer.update_lora_params(composed_lora)
59
+
60
+ self.pipe = FluxFillPipeline.from_pretrained(
61
+ "black-forest-labs/FLUX.1-Fill-dev",
62
+ transformer=transformer,
63
+ torch_dtype=torch.bfloat16,
64
+ ).to("cuda")
65
+
66
+ self.auto_masker = AutoInpaintMaskGenerator()
67
+
68
+ def get_preprocessors(
69
+ self, input_size: tuple[int, int], target_megapixels: float = 1.0
70
+ ) -> tuple[TransformType, TransformType, tuple[int, int]]:
71
+ num_pixels = int(target_megapixels * 1024 * 1024)
72
+
73
+ input_width, input_height = input_size
74
+
75
+ # Resizes the input dimensions to the target number of megapixels while maintaining
76
+ # the aspect ratio and ensuring the new dimensions are multiples of 64.
77
+ scale_by = np.sqrt(num_pixels / (input_height * input_width))
78
+ new_height = int(np.ceil((input_height * scale_by) / 64)) * 64
79
+ new_width = int(np.ceil((input_width * scale_by) / 64)) * 64
80
+
81
+ transform = cast(
82
+ TransformType,
83
+ transforms.Compose(
84
+ [
85
+ transforms.ToTensor(),
86
+ transforms.Resize((new_height, new_width)),
87
+ transforms.Normalize([0.5], [0.5]),
88
+ ]
89
+ ),
90
+ )
91
+ mask_transform = cast(
92
+ TransformType,
93
+ transforms.Compose(
94
+ [
95
+ transforms.ToTensor(),
96
+ transforms.Resize((new_height, new_width)),
97
+ ]
98
+ ),
99
+ )
100
+ return transform, mask_transform, (new_width, new_height)
101
+
102
+ @modal.method()
103
+ def try_it(
104
+ self,
105
+ item_to_try: Image.Image,
106
+ image: Image.Image,
107
+ mask: Image.Image | np.ndarray | None = None,
108
+ prompt: str | None = None,
109
+ masking_prompt: str | None = None,
110
+ ) -> Image.Image:
111
+ assert mask or masking_prompt, "Either mask or masking_prompt must be provided."
112
+
113
+ preprocessor, mask_preprocessor, output_size = self.get_preprocessors(
114
+ input_size=image.size,
115
+ target_megapixels=0.7, # The image will be stacked which will double the pixel count
116
+ )
117
+
118
+ if mask is None:
119
+ # Generate mask using the auto-masker
120
+ mask = self.auto_masker.generate_mask(
121
+ image,
122
+ prompt=masking_prompt, # type: ignore
123
+ )
124
+
125
+ image_tensor = preprocessor(image.convert("RGB"))
126
+ item_to_try_tensor = preprocessor(item_to_try.convert("RGB"))
127
+ mask_tensor = mask_preprocessor(mask)
128
+
129
+ # Create concatenated images
130
+ inpaint_image = torch.cat(
131
+ [item_to_try_tensor, image_tensor], dim=2
132
+ ) # Concatenate along width
133
+ extended_mask = torch.cat([torch.zeros_like(mask_tensor), mask_tensor], dim=2)
134
+
135
+ prompt = prompt or (
136
+ "The pair of images highlights a product and its use in context, high resolution, 4K, 8K;"
137
+ "[IMAGE1] Detailed product shot of the item."
138
+ "[IMAGE2] The same item shown in a realistic lifestyle or usage setting."
139
+ )
140
+
141
+ width, height = output_size
142
+ result = self.pipe(
143
+ height=height,
144
+ width=width * 2,
145
+ image=inpaint_image,
146
+ mask_image=extended_mask,
147
+ num_inference_steps=10,
148
+ generator=torch.Generator("cuda").manual_seed(11),
149
+ max_sequence_length=512,
150
+ guidance_scale=30,
151
+ prompt=prompt,
152
+ ).images[0]
153
+
154
+ return result.crop((width, 0, width * 2, height))
virtual_try/configs.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+
3
+ image = (
4
+ modal.Image.debian_slim(python_version="3.12")
5
+ .pip_install(
6
+ "torch==2.7.0",
7
+ "torchvision",
8
+ "diffusers==0.33.1",
9
+ "transformers==4.52.4",
10
+ "accelerate==1.7.0",
11
+ "huggingface_hub[hf_transfer]==0.32.4",
12
+ "git+https://github.com/luca-medeiros/lang-segment-anything.git@e9af744d999d85eb4d0bd59a83342ecdc2bd2461",
13
+ "https://github.com/mit-han-lab/nunchaku/releases/download/v0.3.0/nunchaku-0.3.0+torch2.7-cp312-cp312-linux_x86_64.whl#sha256=ed28665515075050c8ef1bacd16845b85aa4335f6c760d6fa716d3b090909d8d7",
14
+ )
15
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
16
+ )
17
+
18
+ hf_cache_vol = modal.Volume.from_name("huggingface-cache", create_if_missing=True)
19
+ API_KEY = modal.Secret.from_name("vibe-shopping-secrets", required_keys=["VT_API_KEY"])
20
+ MINUTE = 60
21
+ PORT = 8000