nicka360 commited on
Commit
08a5a82
·
1 Parent(s): 72a4044

Initial A360 migration

Browse files
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from warp.gradio_app.app import demo
2
+
3
+ if __name__ == "__main__":
4
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ supabase
2
+ gradio
3
+ pandas
4
+ pillow
5
+ exifread
6
+ requests
7
+ pydantic
8
+ boto3
9
+ python-dotenv
10
+
11
+ # Phase 2: Local background removal models
12
+ rembg==2.0.67
13
+ onnxruntime==1.23.2
14
+ scikit-image==0.25.2
15
+
16
+ # ML/AI dependencies for image upscaling
17
+ torch>=2.0.0
18
+ torchvision>=0.15.0
19
+ transformers>=4.30.0
20
+ diffusers>=0.21.0
21
+ accelerate>=0.20.0
warp/data/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Data management for A360 WARP."""
2
+
3
+ from .image_loader import ImageLoader, list_practices
4
+
5
+ __all__ = ["ImageLoader", "list_practices"]
warp/data/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (314 Bytes). View file
 
warp/data/__pycache__/image_loader.cpython-313.pyc ADDED
Binary file (9.17 kB). View file
 
warp/data/image_loader.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image loader for scraped medical practice images."""
2
+
3
+ from pathlib import Path
4
+
5
+ from PIL import Image
6
+
7
+ # Default scraped images directory (now under top-level data/scrapedimages)
8
+ DEFAULT_SCRAPED_IMAGES_DIR = Path(__file__).parent.parent.parent / "data" / "scrapedimages"
9
+
10
+
11
+ class ImageLoader:
12
+ """Load and manage scraped medical practice images.
13
+
14
+ Attributes:
15
+ base_path: Root directory containing scraped images organized by practice
16
+ practices: List of available practice directories
17
+ """
18
+
19
+ def __init__(self, base_path: Path | str | None = None):
20
+ """Initialize the ImageLoader.
21
+
22
+ Args:
23
+ base_path: Root directory containing scraped images.
24
+ Defaults to project's scrapedimages folder.
25
+ """
26
+ if base_path is None:
27
+ self.base_path = DEFAULT_SCRAPED_IMAGES_DIR
28
+ else:
29
+ self.base_path = Path(base_path)
30
+
31
+ if not self.base_path.exists():
32
+ raise ValueError(f"Image directory does not exist: {self.base_path}")
33
+
34
+ self._practices: list[str] | None = None
35
+
36
+ @property
37
+ def practices(self) -> list[str]:
38
+ """Get list of available practice directories.
39
+
40
+ Returns:
41
+ List of practice directory names (e.g., ['drleedy.com', 'drbirely.com'])
42
+ """
43
+ if self._practices is None:
44
+ self._practices = sorted(
45
+ [
46
+ d.name
47
+ for d in self.base_path.iterdir()
48
+ if d.is_dir() and not d.name.startswith(".")
49
+ ]
50
+ )
51
+ return self._practices
52
+
53
+ def get_practice_path(self, practice_name: str) -> Path:
54
+ """Get the full path to a practice directory.
55
+
56
+ Args:
57
+ practice_name: Name of the practice (e.g., 'drleedy.com')
58
+
59
+ Returns:
60
+ Path object pointing to the practice directory
61
+
62
+ Raises:
63
+ ValueError: If practice does not exist
64
+ """
65
+ practice_path = self.base_path / practice_name
66
+ if not practice_path.exists():
67
+ raise ValueError(
68
+ f"Practice '{practice_name}' not found. "
69
+ f"Available practices: {', '.join(self.practices)}"
70
+ )
71
+ return practice_path
72
+
73
+ def list_images(self, practice_name: str, extensions: list[str] | None = None) -> list[Path]:
74
+ """List all images for a given practice.
75
+
76
+ Args:
77
+ practice_name: Name of the practice
78
+ extensions: List of file extensions to filter (e.g., ['.jpg', '.png'])
79
+ If None, includes common image formats
80
+
81
+ Returns:
82
+ List of Path objects for all matching images
83
+ """
84
+ if extensions is None:
85
+ extensions = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"]
86
+
87
+ practice_path = self.get_practice_path(practice_name)
88
+ images: list[Path] = []
89
+
90
+ for ext in extensions:
91
+ images.extend(practice_path.glob(f"**/*{ext}"))
92
+ images.extend(practice_path.glob(f"**/*{ext.upper()}"))
93
+
94
+ return sorted(images)
95
+
96
+ def count_images(self, practice_name: str) -> int:
97
+ """Count total images for a practice.
98
+
99
+ Args:
100
+ practice_name: Name of the practice
101
+
102
+ Returns:
103
+ Number of images
104
+ """
105
+ return len(self.list_images(practice_name))
106
+
107
+ def load_image(self, image_path: Path | str) -> Image.Image:
108
+ """Load a single image.
109
+
110
+ Args:
111
+ image_path: Path to the image file
112
+
113
+ Returns:
114
+ PIL Image object
115
+
116
+ Raises:
117
+ FileNotFoundError: If image does not exist
118
+ """
119
+ image_path = Path(image_path)
120
+ if not image_path.exists():
121
+ raise FileNotFoundError(f"Image not found: {image_path}")
122
+
123
+ return Image.open(image_path)
124
+
125
+ def get_image_info(self, image_path: Path | str) -> dict:
126
+ """Get metadata about an image.
127
+
128
+ Args:
129
+ image_path: Path to the image file
130
+
131
+ Returns:
132
+ Dictionary with image metadata (size, format, mode, etc.)
133
+ """
134
+ image_path = Path(image_path)
135
+ img = self.load_image(image_path)
136
+
137
+ return {
138
+ "path": str(image_path),
139
+ "filename": image_path.name,
140
+ "practice": (
141
+ image_path.parent.name if image_path.is_relative_to(self.base_path) else None
142
+ ),
143
+ "size": img.size,
144
+ "width": img.width,
145
+ "height": img.height,
146
+ "format": img.format,
147
+ "mode": img.mode,
148
+ "file_size_bytes": image_path.stat().st_size,
149
+ }
150
+
151
+ def get_random_images(
152
+ self, practice_name: str, n: int = 5, seed: int | None = None
153
+ ) -> list[Path]:
154
+ """Get random sample of images from a practice.
155
+
156
+ Args:
157
+ practice_name: Name of the practice
158
+ n: Number of images to return
159
+ seed: Random seed for reproducibility
160
+
161
+ Returns:
162
+ List of n random image paths
163
+ """
164
+ import random
165
+
166
+ images = self.list_images(practice_name)
167
+
168
+ if seed is not None:
169
+ random.seed(seed)
170
+
171
+ return random.sample(images, min(n, len(images)))
172
+
173
+ def get_practice_stats(self, practice_name: str) -> dict:
174
+ """Get statistics for a practice's images.
175
+
176
+ Args:
177
+ practice_name: Name of the practice
178
+
179
+ Returns:
180
+ Dictionary with practice statistics
181
+ """
182
+ images = self.list_images(practice_name)
183
+ total_size = sum(img.stat().st_size for img in images)
184
+
185
+ # Get format distribution
186
+ formats: dict[str, int] = {}
187
+ for img_path in images:
188
+ ext = img_path.suffix.lower()
189
+ formats[ext] = formats.get(ext, 0) + 1
190
+
191
+ return {
192
+ "practice": practice_name,
193
+ "total_images": len(images),
194
+ "total_size_mb": total_size / (1024 * 1024),
195
+ "formats": formats,
196
+ "practice_path": str(self.get_practice_path(practice_name)),
197
+ }
198
+
199
+ def get_all_stats(self) -> dict:
200
+ """Get statistics for all practices.
201
+
202
+ Returns:
203
+ Dictionary with overall statistics
204
+ """
205
+ all_stats: dict = {"practices": {}, "total_images": 0, "total_size_mb": 0.0}
206
+
207
+ for practice in self.practices:
208
+ practice_stats = self.get_practice_stats(practice)
209
+ all_stats["practices"][practice] = practice_stats
210
+ all_stats["total_images"] += practice_stats["total_images"]
211
+ all_stats["total_size_mb"] += practice_stats["total_size_mb"]
212
+
213
+ return all_stats
214
+
215
+
216
+ def list_practices(base_path: Path | str | None = None) -> list[str]:
217
+ """Convenience function to list all available practices.
218
+
219
+ Args:
220
+ base_path: Root directory containing scraped images.
221
+ Defaults to project's scrapedimages folder.
222
+
223
+ Returns:
224
+ List of practice directory names
225
+ """
226
+ loader = ImageLoader(base_path)
227
+ return loader.practices
warp/gradio_app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Gradio application for A360 WARP experimentation UI."""
warp/gradio_app/app.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING
5
+
6
+ import gradio as gr
7
+ from dotenv import load_dotenv
8
+
9
+ if TYPE_CHECKING:
10
+ from supabase import Client as ClientType
11
+
12
+ from warp.data import ImageLoader as ImageLoaderType
13
+ from warp.gradio_app.models.upscaler import ImageUpscaler
14
+ else:
15
+ ClientType = object
16
+ ImageLoaderType = object
17
+ ImageUpscaler = object
18
+
19
+ try:
20
+ from supabase import Client, create_client
21
+ except Exception:
22
+ create_client = None # type: ignore
23
+ Client = None # type: ignore
24
+
25
+ try:
26
+ from warp.data import ImageLoader
27
+ except ImportError:
28
+ ImageLoader = None # type: ignore
29
+
30
+ try:
31
+ from warp.gradio_app.models.upscaler import create_upscaler
32
+
33
+ UPSCALER_AVAILABLE = True
34
+ print("✓ Upscaler module loaded successfully")
35
+ except ImportError as e:
36
+ create_upscaler = None # type: ignore
37
+ UPSCALER_AVAILABLE = False
38
+ print(f"✗ Upscaler import failed: {e}")
39
+
40
+ # Temporarily disable Advanced Upscaling tab due to import issues
41
+ # Will re-enable after fixing module resolution
42
+ COMPARE_TAB_AVAILABLE = False
43
+ build_upscale_compare = None
44
+
45
+ load_dotenv()
46
+ SUPABASE_URL: str = os.getenv("SUPABASE_URL", "")
47
+ SUPABASE_ANON_KEY: str = os.getenv("SUPABASE_ANON_KEY", "")
48
+
49
+ supabase: ClientType | None = None
50
+ if callable(create_client) and SUPABASE_URL and SUPABASE_ANON_KEY:
51
+ supabase = create_client(SUPABASE_URL, SUPABASE_ANON_KEY)
52
+
53
+ # Initialize image loader
54
+ image_loader: ImageLoaderType | None = None
55
+ try:
56
+ if callable(ImageLoader):
57
+ image_loader = ImageLoader()
58
+ print(f"✓ Loaded {len(image_loader.practices)} practices with scraped images")
59
+ except Exception as e:
60
+ # If initialization fails for any reason (including mocked import errors),
61
+ # fall back to no image loader so the rest of the app can still import.
62
+ image_loader = None
63
+ print(f"Warning: Could not initialize ImageLoader: {e}")
64
+
65
+ # Initialize upscaler (lazy load)
66
+ upscaler: ImageUpscaler | None = None
67
+
68
+
69
+ def load_practice_images(practice_name: str) -> tuple[list, str]:
70
+ """Load sample images from a practice.
71
+
72
+ Args:
73
+ practice_name: Name of the practice to load images from
74
+
75
+ Returns:
76
+ Tuple of (list of image paths, status message)
77
+ """
78
+ if not image_loader:
79
+ return [], "Image loader not available"
80
+
81
+ if not practice_name:
82
+ return [], "Please select a practice"
83
+
84
+ try:
85
+ # Get random sample of images
86
+ image_paths = image_loader.get_random_images(practice_name, n=10)
87
+ stats = image_loader.get_practice_stats(practice_name)
88
+ msg = (
89
+ f"Loaded {len(image_paths)} sample images from {practice_name} "
90
+ f"(Total: {stats['total_images']} images)"
91
+ )
92
+ return [str(p) for p in image_paths], msg
93
+ except Exception as e:
94
+ return [], f"Error loading images: {e}"
95
+
96
+
97
+ def run_model(procedure: str | None, notes: str | None) -> str:
98
+ """Run a placeholder model execution.
99
+
100
+ Args:
101
+ procedure: The selected procedure type
102
+ notes: Additional context or parameters
103
+
104
+ Returns:
105
+ A formatted string with procedure and notes information
106
+ """
107
+ return f"Procedure={procedure or 'n/a'} | Notes={notes or 'n/a'}"
108
+
109
+
110
+ def upscale_images(
111
+ before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
112
+ ):
113
+ """Upscale before/after image pair (synchronous helper).
114
+
115
+ This version is a simple function (not a generator) so tests can call it
116
+ and make assertions about the returned tuple. The Gradio UI wraps this in
117
+ a streaming function that updates the progress bar and status text.
118
+
119
+ Returns:
120
+ Tuple of (upscaled_before, upscaled_after, status_message)
121
+ """
122
+ global upscaler
123
+
124
+ # Handle missing upscaler dependency
125
+ if not UPSCALER_AVAILABLE:
126
+ return (
127
+ None,
128
+ None,
129
+ "Upscaler not available. Install: pip install torch diffusers transformers",
130
+ )
131
+
132
+ # Validate inputs
133
+ if before_img is None or after_img is None:
134
+ return None, None, "Please upload both before and after images"
135
+
136
+ try:
137
+ # Lazy load upscaler on first use
138
+ if upscaler is None and callable(create_upscaler):
139
+ upscaler = create_upscaler(model_type="sd-x4")
140
+
141
+ # Import PIL here to handle the images
142
+ from PIL import Image
143
+
144
+ # Convert numpy arrays to PIL Images if needed
145
+ if not isinstance(before_img, Image.Image):
146
+ before_img = Image.fromarray(before_img)
147
+ if not isinstance(after_img, Image.Image):
148
+ after_img = Image.fromarray(after_img)
149
+
150
+ orig_size = before_img.size
151
+
152
+ # Use the pair upscaling helper with a callback that updates the
153
+ # Gradio progress bar more granularly during diffusion steps.
154
+ callback_state = {"phase": "before", "last_step": -1}
155
+
156
+ def progress_callback(step, timestep, latents): # type: ignore[unused-argument]
157
+ """Update progress bar for each diffusion step.
158
+
159
+ We see steps 0..num_steps-1 for the "before" image first, then
160
+ again for the "after" image. When the step counter resets, we
161
+ switch to the "after" phase and map progress into [0.5, 0.9].
162
+ """
163
+ try:
164
+ # Detect phase change when step counter resets
165
+ if step < callback_state["last_step"]:
166
+ callback_state["phase"] = "after"
167
+ callback_state["last_step"] = step
168
+
169
+ frac = step / max(num_steps, 1)
170
+ if callback_state["phase"] == "before":
171
+ # Map to [0.1, 0.5]
172
+ pct = 0.1 + 0.4 * frac
173
+ desc = f"Upscaling BEFORE image ({step}/{num_steps})"
174
+ else:
175
+ # Map to [0.5, 0.9]
176
+ pct = 0.5 + 0.4 * frac
177
+ desc = f"Upscaling AFTER image ({step}/{num_steps})"
178
+
179
+ try:
180
+ progress(pct, desc=desc)
181
+ except Exception:
182
+ # In tests or non-Gradio contexts, progress may be a no-op
183
+ pass
184
+ except Exception:
185
+ # Never allow progress UI issues to break the core upscaling
186
+ pass
187
+
188
+ before_upscaled, after_upscaled = upscaler.upscale_pair(
189
+ before_img,
190
+ after_img,
191
+ prompt=prompt,
192
+ num_inference_steps=num_steps,
193
+ guidance_scale=guidance,
194
+ callback=progress_callback,
195
+ callback_steps=1,
196
+ )
197
+
198
+ status = (
199
+ "Successfully upscaled both images 4x\n"
200
+ f"Original: {orig_size[0]}×{orig_size[1]} → "
201
+ f"Upscaled: {before_upscaled.size[0]}×{before_upscaled.size[1]}"
202
+ )
203
+ return before_upscaled, after_upscaled, status
204
+
205
+ except Exception as e:
206
+ # Graceful error handling for tests and UI
207
+ return None, None, f"Error during upscaling: {str(e)}"
208
+
209
+
210
+ def upscale_images_stream(
211
+ before_img, after_img, prompt: str, num_steps: int, guidance: float, progress=gr.Progress()
212
+ ):
213
+ """Streaming wrapper for ``upscale_images`` used by the Gradio UI.
214
+
215
+ Yields intermediate status updates so the user sees a live progress bar
216
+ and status text while the heavy model runs.
217
+ """
218
+ # Handle missing upscaler dependency
219
+ if not UPSCALER_AVAILABLE:
220
+ yield (
221
+ None,
222
+ None,
223
+ "Upscaler not available. Install: pip install torch diffusers transformers",
224
+ )
225
+ return
226
+
227
+ # Validate inputs
228
+ if before_img is None or after_img is None:
229
+ yield None, None, "Please upload both before and after images"
230
+ return
231
+
232
+ try:
233
+ # Initial progress
234
+ try:
235
+ progress(0.0, desc="Initializing upscaler...")
236
+ except Exception:
237
+ pass
238
+ yield None, None, "Initializing upscaler..."
239
+
240
+ # Coarse progress while running the model
241
+ try:
242
+ progress(0.3, desc="Upscaling images...")
243
+ except Exception:
244
+ pass
245
+
246
+ before_upscaled, after_upscaled, status = upscale_images(
247
+ before_img, after_img, prompt, num_steps, guidance, progress
248
+ )
249
+
250
+ try:
251
+ progress(1.0, desc="Complete")
252
+ except Exception:
253
+ pass
254
+
255
+ yield before_upscaled, after_upscaled, status
256
+
257
+ except Exception as e:
258
+ yield None, None, f"Error during upscaling: {str(e)}"
259
+
260
+
261
+ # Build the UI
262
+ with gr.Blocks(title="A360 WARP — Gradio") as demo:
263
+ gr.Markdown("# A360 WARP — Experimentation UI (MVP)")
264
+ gr.Markdown("Load and experiment with before/after images from scraped medical practices.")
265
+
266
+ # Practice selection and image loading
267
+ with gr.Tab("Image Browser"):
268
+ with gr.Row():
269
+ practice_dropdown = gr.Dropdown(
270
+ label="Select Practice",
271
+ choices=image_loader.practices if image_loader else [],
272
+ value=None,
273
+ )
274
+ load_btn = gr.Button("Load Sample Images", variant="primary")
275
+
276
+ status_text = gr.Textbox(label="Status", interactive=False)
277
+ image_gallery = gr.Gallery(label="Sample Images", show_label=True, columns=5, height="auto")
278
+
279
+ load_btn.click(
280
+ fn=load_practice_images,
281
+ inputs=[practice_dropdown],
282
+ outputs=[image_gallery, status_text],
283
+ )
284
+
285
+ # Image Enhancement (Upscaling)
286
+ with gr.Tab("Image Enhancement"):
287
+ gr.Markdown(
288
+ "### Upscale Before/After Images\n"
289
+ "Upload medical before/after photos to upscale them 4x using AI. "
290
+ "This improves image quality and detail for better comparison."
291
+ )
292
+
293
+ with gr.Row():
294
+ with gr.Column():
295
+ gr.Markdown("#### Original Images")
296
+ before_input = gr.Image(label="Before Image", type="numpy")
297
+ after_input = gr.Image(label="After Image", type="numpy")
298
+
299
+ with gr.Column():
300
+ gr.Markdown("#### Upscaled Images (4x)")
301
+ before_output = gr.Image(label="Upscaled Before")
302
+ after_output = gr.Image(label="Upscaled After")
303
+
304
+ with gr.Row():
305
+ with gr.Column():
306
+ prompt_input = gr.Textbox(
307
+ label="Quality Prompt",
308
+ value="high quality medical photography, sharp details, professional lighting",
309
+ placeholder="Describe desired image quality...",
310
+ )
311
+ with gr.Column():
312
+ num_steps = gr.Slider(
313
+ minimum=20,
314
+ maximum=100,
315
+ value=50,
316
+ step=5,
317
+ label="Inference Steps (higher = better quality, slower)",
318
+ )
319
+ guidance_scale = gr.Slider(
320
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
321
+ )
322
+
323
+ upscale_btn = gr.Button("Upscale Images", variant="primary", size="lg")
324
+ upscale_status = gr.Textbox(label="Status", interactive=False)
325
+
326
+ # Use the streaming wrapper so users see live progress/status updates
327
+ upscale_btn.click(
328
+ fn=upscale_images_stream,
329
+ inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
330
+ outputs=[before_output, after_output, upscale_status],
331
+ )
332
+
333
+ # Advanced Upscaling with Comparison
334
+ if COMPARE_TAB_AVAILABLE and build_upscale_compare:
335
+ with gr.Tab("Advanced Upscaling"):
336
+ build_upscale_compare()
337
+
338
+ # Model experimentation
339
+ with gr.Tab("Model Experiments"):
340
+ with gr.Row():
341
+ procedure = gr.Dropdown(
342
+ label="Procedure",
343
+ choices=[
344
+ "breast-augmentation",
345
+ "liposuction",
346
+ "rhinoplasty",
347
+ "ftm-top-surgery",
348
+ "coolsculpting",
349
+ ],
350
+ value=None,
351
+ )
352
+ notes = gr.Textbox(label="Notes", placeholder="Run context / params…")
353
+ run = gr.Button("Run")
354
+ out = gr.Textbox(label="Output")
355
+
356
+ run.click(run_model, inputs=[procedure, notes], outputs=out)
357
+
358
+ if __name__ == "__main__":
359
+ demo.launch()
warp/gradio_app/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Model registry for A360 WARP."""
2
+
3
+ from .registry import MODELS
4
+
5
+ __all__ = ["MODELS"]
warp/gradio_app/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (319 Bytes). View file
 
warp/gradio_app/models/__pycache__/registry.cpython-313.pyc ADDED
Binary file (341 Bytes). View file
 
warp/gradio_app/models/__pycache__/upscaler.cpython-313.pyc ADDED
Binary file (7.53 kB). View file
 
warp/gradio_app/models/registry-Nick.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ MODELS = {
2
+ "CLIP": "openai/clip-vit-base-patch32",
3
+ "BLIP-2": "Salesforce/blip2-flan-t5-xl",
4
+ "DINOv2": "facebook/dinov2-base",
5
+ }
6
+
7
+ UPSCALER_MODELS = {
8
+ "SD-X4-Upscaler": "stabilityai/stable-diffusion-x4-upscaler",
9
+ }
warp/gradio_app/models/registry.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ MODELS = {
2
+ "CLIP": "openai/clip-vit-base-patch32",
3
+ "BLIP-2": "Salesforce/blip2-flan-t5-xl",
4
+ "DINOv2": "facebook/dinov2-base",
5
+ }
warp/gradio_app/models/upscaler.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image upscaling using HuggingFace models."""
2
+
3
+ from pathlib import Path
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ from PIL import Image
7
+
8
+ if TYPE_CHECKING:
9
+ from diffusers import StableDiffusionUpscalePipeline as PipelineType
10
+ else:
11
+ PipelineType = object
12
+
13
+ try:
14
+ import torch
15
+ from diffusers import StableDiffusionUpscalePipeline
16
+
17
+ TORCH_AVAILABLE = True
18
+ except ImportError:
19
+ # If either torch or diffusers is missing, mark them unavailable. The
20
+ # tests patch these symbols as needed, and the runtime gracefully degrades
21
+ # by raising a clear ImportError from ImageUpscaler.__init__.
22
+ torch = None # type: ignore[assignment]
23
+ StableDiffusionUpscalePipeline = None # type: ignore[assignment]
24
+ TORCH_AVAILABLE = False
25
+
26
+
27
+ class ImageUpscaler:
28
+ """Handle image upscaling using HuggingFace models."""
29
+
30
+ def __init__(
31
+ self, model_id: str = "stabilityai/stable-diffusion-x4-upscaler", device: str | None = None
32
+ ):
33
+ """Initialize the upscaler.
34
+
35
+ Args:
36
+ model_id: HuggingFace model identifier
37
+ device: Device to run model on ('cuda', 'cpu', or None for auto)
38
+ """
39
+ if not TORCH_AVAILABLE:
40
+ raise ImportError(
41
+ "torch and diffusers are required for upscaling. "
42
+ "Install with: pip install torch diffusers transformers"
43
+ )
44
+
45
+ self.model_id = model_id
46
+
47
+ # Auto-detect device
48
+ if device is None:
49
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
50
+ else:
51
+ self.device = device
52
+
53
+ self.pipeline: PipelineType | None = None
54
+ self._load_model()
55
+
56
+ def _load_model(self) -> None:
57
+ """Load the upscaling model."""
58
+ print(f"Loading upscaler model: {self.model_id} on {self.device}...")
59
+
60
+ # Determine torch dtype based on device
61
+ torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
62
+
63
+ self.pipeline = StableDiffusionUpscalePipeline.from_pretrained(
64
+ self.model_id, torch_dtype=torch_dtype
65
+ )
66
+ self.pipeline = self.pipeline.to(self.device)
67
+
68
+ # Enable memory optimizations if on CUDA
69
+ if self.device == "cuda":
70
+ self.pipeline.enable_attention_slicing()
71
+
72
+ print(f"✓ Model loaded successfully on {self.device}")
73
+
74
+ def upscale(
75
+ self,
76
+ image: Image.Image | str | Path,
77
+ prompt: str = "high quality, detailed, sharp",
78
+ num_inference_steps: int = 50,
79
+ guidance_scale: float = 7.5,
80
+ callback=None,
81
+ callback_steps: int = 1,
82
+ ) -> Image.Image:
83
+ """Upscale an image 4x.
84
+
85
+ Args:
86
+ image: PIL Image or path to image file
87
+ prompt: Text prompt to guide upscaling (helps with quality)
88
+ num_inference_steps: Number of denoising steps (higher = better quality, slower)
89
+ guidance_scale: How closely to follow the prompt (7.5 is good default)
90
+ callback: Optional callback function(step, timestep, latents) called each step
91
+ callback_steps: How often to call the callback (default: every step)
92
+
93
+ Returns:
94
+ Upscaled PIL Image
95
+ """
96
+ # Load image if path is provided
97
+ if isinstance(image, (str, Path)):
98
+ image = Image.open(image).convert("RGB")
99
+
100
+ # Ensure RGB mode
101
+ if image.mode != "RGB":
102
+ image = image.convert("RGB")
103
+
104
+ # Run upscaling
105
+ if self.pipeline is None:
106
+ raise RuntimeError("Pipeline not initialized")
107
+ result = self.pipeline(
108
+ prompt=prompt,
109
+ image=image,
110
+ num_inference_steps=num_inference_steps,
111
+ guidance_scale=guidance_scale,
112
+ callback=callback,
113
+ callback_steps=callback_steps,
114
+ )
115
+ upscaled: Image.Image = result.images[0]
116
+
117
+ return upscaled
118
+
119
+ def upscale_pair(
120
+ self,
121
+ before_image: Image.Image | str | Path,
122
+ after_image: Image.Image | str | Path,
123
+ prompt: str = "high quality medical photography, sharp details, professional lighting",
124
+ **kwargs,
125
+ ) -> tuple[Image.Image, Image.Image]:
126
+ """Upscale a before/after image pair.
127
+
128
+ Args:
129
+ before_image: Before image (PIL Image or path)
130
+ after_image: After image (PIL Image or path)
131
+ prompt: Text prompt for upscaling quality
132
+ **kwargs: Additional arguments for upscale()
133
+
134
+ Returns:
135
+ Tuple of (upscaled_before, upscaled_after)
136
+ """
137
+ print("Upscaling before image...")
138
+ before_upscaled = self.upscale(before_image, prompt=prompt, **kwargs)
139
+
140
+ print("Upscaling after image...")
141
+ after_upscaled = self.upscale(after_image, prompt=prompt, **kwargs)
142
+
143
+ return before_upscaled, after_upscaled
144
+
145
+ def batch_upscale(
146
+ self,
147
+ images: list[Image.Image | str | Path],
148
+ prompt: str = "high quality, detailed, sharp",
149
+ **kwargs,
150
+ ) -> list[Image.Image]:
151
+ """Upscale multiple images.
152
+
153
+ Args:
154
+ images: List of PIL Images or paths
155
+ prompt: Text prompt for upscaling
156
+ **kwargs: Additional arguments for upscale()
157
+
158
+ Returns:
159
+ List of upscaled PIL Images
160
+ """
161
+ results = []
162
+ for i, img in enumerate(images, 1):
163
+ print(f"Upscaling image {i}/{len(images)}...")
164
+ upscaled = self.upscale(img, prompt=prompt, **kwargs)
165
+ results.append(upscaled)
166
+ return results
167
+
168
+
169
+ def create_upscaler(
170
+ model_type: Literal["sd-x4", "fast"] = "sd-x4", device: str | None = None
171
+ ) -> ImageUpscaler:
172
+ """Factory function to create an upscaler.
173
+
174
+ Args:
175
+ model_type: Type of upscaler model
176
+ - "sd-x4": Stable Diffusion 4x upscaler (high quality, slower)
177
+ - "fast": Faster alternative (to be implemented)
178
+ device: Device to run on ('cuda', 'cpu', or None for auto)
179
+
180
+ Returns:
181
+ Initialized ImageUpscaler
182
+ """
183
+ model_map = {
184
+ "sd-x4": "stabilityai/stable-diffusion-x4-upscaler",
185
+ # Can add more models here later
186
+ }
187
+
188
+ model_id = model_map.get(model_type, model_map["sd-x4"])
189
+ return ImageUpscaler(model_id=model_id, device=device)
190
+
191
+
192
+ # NOTE:
193
+ # -----
194
+ # When this module is imported as a submodule of ``warp.gradio_app.models``
195
+ # (e.g. via ``from warp.gradio_app.models import upscaler``), Python normally
196
+ # caches it as an attribute on the parent package. That caching can interfere
197
+ # with tests that manipulate ``sys.modules`` to simulate import failures
198
+ # (like removing ``torch``/``diffusers`` and re-importing this module).
199
+ #
200
+ # To ensure those tests can reliably exercise the fallback path, we avoid
201
+ # permanently caching this submodule on the parent package by removing the
202
+ # attribute if it exists. The module itself remains available via
203
+ # ``sys.modules['warp.gradio_app.models.upscaler']``.
204
+ try: # Best-effort; never fail import because of this cleanup.
205
+ import sys as _sys
206
+
207
+ _parent_pkg = _sys.modules.get("warp.gradio_app.models")
208
+ if _parent_pkg is not None and hasattr(_parent_pkg, "upscaler"):
209
+ delattr(_parent_pkg, "upscaler")
210
+ except Exception:
211
+ pass
warp/gradio_app/upscale_compare_tab.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Advanced upscaling tab with before/after comparison and detailed metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image, ImageDraw
10
+
11
+ if TYPE_CHECKING:
12
+ from .models.upscaler import ImageUpscaler
13
+
14
+ try:
15
+ from .models.upscaler import create_upscaler
16
+
17
+ UPSCALER_AVAILABLE = True
18
+ except (ImportError, ModuleNotFoundError):
19
+ create_upscaler = None # type: ignore[assignment]
20
+ UPSCALER_AVAILABLE = False
21
+
22
+
23
+ # Global upscaler instance (lazy load)
24
+ _upscaler: ImageUpscaler | None = None
25
+
26
+ # Configuration
27
+ MAX_INPUT_WIDTH = 1024
28
+ MAX_INPUT_HEIGHT = 1024
29
+ UPSCALE_FACTOR = 4
30
+ QUALITY_PROMPT = "ultra realistic, natural contrast, high clarity, clean skin texture, professional medical photography"
31
+
32
+
33
+ def _get_upscaler() -> ImageUpscaler | None:
34
+ """Lazy load upscaler on first use."""
35
+ global _upscaler
36
+ if _upscaler is None and UPSCALER_AVAILABLE:
37
+ try:
38
+ _upscaler = create_upscaler(model_type="sd-x4")
39
+ except Exception as e:
40
+ raise RuntimeError(f"Failed to load upscaler: {e}") from e
41
+ return _upscaler
42
+
43
+
44
+ def _validate_and_resize_image(
45
+ img: np.ndarray | Image.Image, max_w: int = MAX_INPUT_WIDTH, max_h: int = MAX_INPUT_HEIGHT
46
+ ) -> Image.Image:
47
+ """Convert and validate image, resize if needed."""
48
+ # Convert numpy to PIL if needed
49
+ if isinstance(img, np.ndarray):
50
+ img = Image.fromarray(img.astype("uint8"))
51
+ elif not isinstance(img, Image.Image):
52
+ raise ValueError(f"Invalid image type: {type(img)}")
53
+
54
+ # Convert to RGB
55
+ if img.mode != "RGB":
56
+ img = img.convert("RGB")
57
+
58
+ # Check and resize if too large
59
+ w, h = img.size
60
+ if w > max_w or h > max_h:
61
+ img.thumbnail((max_w, max_h), Image.Resampling.LANCZOS)
62
+
63
+ return img
64
+
65
+
66
+ def _create_comparison_grid(
67
+ before_orig: Image.Image, after_orig: Image.Image, before_up: Image.Image, after_up: Image.Image
68
+ ) -> Image.Image:
69
+ """Create a 2x2 grid showing before/after, original/upscaled."""
70
+
71
+ # All upscaled images should be the same size (4x original)
72
+ # For display, we'll scale them down to fit alongside originals
73
+
74
+ orig_w, orig_h = before_orig.size
75
+ up_display_w, up_display_h = before_up.size # Should be 4x larger
76
+
77
+ # Create display versions of upscaled (scaled down slightly for display)
78
+ display_scale = 0.5 # Show upscaled at 2x (half of 4x)
79
+ display_w = int(up_display_w * display_scale)
80
+ display_h = int(up_display_h * display_scale)
81
+
82
+ before_up_display = before_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
83
+ after_up_display = after_up.resize((display_w, display_h), Image.Resampling.LANCZOS)
84
+
85
+ # Create grid background
86
+ grid_w = display_w * 2 + 40 # padding
87
+ grid_h = display_h * 2 + 80 # padding + title space
88
+
89
+ grid = Image.new("RGB", (grid_w, grid_h), color=(30, 30, 30))
90
+ draw = ImageDraw.Draw(grid)
91
+
92
+ # Add labels (simple text, no font to avoid system dependencies)
93
+ label_y = 10
94
+ draw.text((10, label_y), "BEFORE (Orig → Upscaled 2x)", fill=(255, 150, 0))
95
+ draw.text((display_w + 20, label_y), "AFTER (Orig → Upscaled 2x)", fill=(255, 150, 0))
96
+
97
+ # Paste images
98
+ paste_y = 40
99
+ grid.paste(before_orig, (10, paste_y))
100
+ grid.paste(before_up_display, (10, paste_y + orig_h + 10))
101
+
102
+ grid.paste(after_orig, (display_w + 20, paste_y))
103
+ grid.paste(after_up_display, (display_w + 20, paste_y + orig_h + 10))
104
+
105
+ return grid
106
+
107
+
108
+ def upscale_and_compare(
109
+ before_img, after_img, prompt: str = QUALITY_PROMPT, num_steps: int = 50, guidance: float = 7.5
110
+ ) -> tuple[Image.Image, Image.Image, Image.Image, str]:
111
+ """Upscale before/after pair with detailed comparison.
112
+
113
+ Args:
114
+ before_img: Before image (numpy or PIL)
115
+ after_img: After image (numpy or PIL)
116
+ prompt: Quality prompt for upscaling
117
+ num_steps: Inference steps (20-100)
118
+ guidance: Guidance scale (1.0-15.0)
119
+
120
+ Returns:
121
+ Tuple of (before_upscaled, after_upscaled, comparison_grid, status_message)
122
+ """
123
+
124
+ if not UPSCALER_AVAILABLE:
125
+ return (
126
+ None,
127
+ None,
128
+ None,
129
+ "❌ Upscaler not available. Install: pip install torch diffusers transformers",
130
+ ) # type: ignore[return-value]
131
+
132
+ if before_img is None or after_img is None:
133
+ return None, None, None, "❌ Please upload both before and after images" # type: ignore[return-value]
134
+
135
+ try:
136
+ # Get upscaler instance
137
+ upscaler = _get_upscaler()
138
+ if upscaler is None:
139
+ return None, None, None, "❌ Upscaler not available" # type: ignore[return-value]
140
+
141
+ # Validate and resize inputs
142
+ before_pil = _validate_and_resize_image(before_img)
143
+ after_pil = _validate_and_resize_image(after_img)
144
+
145
+ orig_before_size = before_pil.size
146
+ orig_after_size = after_pil.size
147
+
148
+ # Upscale both images
149
+ print(f"Upscaling before image ({orig_before_size})...")
150
+ before_upscaled = upscaler.upscale(
151
+ before_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
152
+ )
153
+
154
+ print(f"Upscaling after image ({orig_after_size})...")
155
+ after_upscaled = upscaler.upscale(
156
+ after_pil, prompt=prompt, num_inference_steps=num_steps, guidance_scale=guidance
157
+ )
158
+
159
+ # Create comparison grid
160
+ comparison = _create_comparison_grid(before_pil, after_pil, before_upscaled, after_upscaled)
161
+
162
+ # Build status message
163
+ status = (
164
+ f"✅ Successfully upscaled both images!\n\n"
165
+ f"Before: {orig_before_size} → {before_upscaled.size}\n"
166
+ f"After: {orig_after_size} → {after_upscaled.size}\n"
167
+ f"Upscale Factor: {UPSCALE_FACTOR}x\n"
168
+ f"Steps: {num_steps} | Guidance: {guidance}\n"
169
+ f"\nNote: Comparison shows upscaled images at 2x (50% of 4x for display)"
170
+ )
171
+
172
+ return before_upscaled, after_upscaled, comparison, status
173
+
174
+ except Exception as e:
175
+ error_msg = f"❌ Error during upscaling: {str(e)}"
176
+ print(error_msg)
177
+ return None, None, None, error_msg # type: ignore[return-value]
178
+
179
+
180
+ def build_ui() -> None:
181
+ """Build the advanced upscaling UI tab."""
182
+
183
+ gr.Markdown("### Upscale Before/After Images (Max Detail & Clarity)")
184
+ gr.Markdown(
185
+ "Upload medical before/after photos to upscale them 4x using Stable Diffusion x4 Upscaler. "
186
+ "Both images are processed with identical parameters for fair comparison.\n\n"
187
+ f"⚠️ **Note:** Processing takes 30-60 seconds per image (CPU) or 5-10 seconds (GPU). "
188
+ f"Maximum input size: {MAX_INPUT_WIDTH}x{MAX_INPUT_HEIGHT}px (automatically resized if larger)."
189
+ )
190
+
191
+ with gr.Row():
192
+ with gr.Column():
193
+ gr.Markdown("#### Original Images")
194
+ before_input = gr.Image(label="Before Image", type="numpy")
195
+ after_input = gr.Image(label="After Image", type="numpy")
196
+
197
+ # Parameters
198
+ prompt_input = gr.Textbox(
199
+ label="Quality Prompt",
200
+ value=QUALITY_PROMPT,
201
+ placeholder="Describe desired image quality...",
202
+ lines=3,
203
+ )
204
+
205
+ with gr.Column():
206
+ gr.Markdown("#### Upscaled Results (4x)")
207
+ before_output = gr.Image(label="Upscaled Before", type="pil")
208
+ after_output = gr.Image(label="Upscaled After", type="pil")
209
+
210
+ with gr.Row():
211
+ with gr.Column(scale=1):
212
+ num_steps = gr.Slider(
213
+ minimum=20, maximum=100, value=50, step=5, label="Inference Steps"
214
+ )
215
+ with gr.Column(scale=1):
216
+ guidance_scale = gr.Slider(
217
+ minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="Guidance Scale"
218
+ )
219
+
220
+ upscale_btn = gr.Button("🚀 Upscale Both", variant="primary", size="lg")
221
+ upscale_status = gr.Textbox(label="Status", interactive=False, lines=4)
222
+
223
+ gr.Markdown("#### Side-by-Side Comparison")
224
+ comparison_output = gr.Image(label="Comparison Grid", type="pil")
225
+
226
+ # Button click handler
227
+ upscale_btn.click(
228
+ fn=upscale_and_compare,
229
+ inputs=[before_input, after_input, prompt_input, num_steps, guidance_scale],
230
+ outputs=[before_output, after_output, comparison_output, upscale_status],
231
+ )