Matis Despujols commited on
Commit
c76abe2
Β·
verified Β·
1 Parent(s): 3fee73e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -0
app.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ from typing import Tuple, List
10
+ from rfdetr.detr import RFDETRMedium
11
+
12
+ # UI Element classes
13
+ CLASSES = ['button', 'field', 'heading', 'iframe', 'image', 'label', 'link', 'text']
14
+
15
+ # Color palette for different element types (BGR format for OpenCV)
16
+ CLASS_COLORS = {
17
+ 'button': (46, 204, 113), # Green
18
+ 'field': (52, 152, 219), # Blue
19
+ 'heading': (155, 89, 182), # Purple
20
+ 'iframe': (241, 196, 15), # Yellow
21
+ 'image': (230, 126, 34), # Orange
22
+ 'label': (26, 188, 156), # Turquoise
23
+ 'link': (231, 76, 60), # Red
24
+ 'text': (149, 165, 166) # Gray
25
+ }
26
+
27
+ # Global model variable
28
+ model = None
29
+
30
+ def load_model(model_path: str = "model/full_29.pth"):
31
+ """Load RF-DETR model"""
32
+ global model
33
+ if model is None:
34
+ print("Loading RF-DETR model...")
35
+ model = RFDETRMedium(pretrain_weights=model_path, resolution=1600)
36
+ model.eval()
37
+ print("Model loaded successfully!")
38
+ return model
39
+
40
+ def draw_detections(
41
+ image: np.ndarray,
42
+ boxes: List[Tuple[int, int, int, int]],
43
+ scores: List[float],
44
+ classes: List[int],
45
+ thickness: int = 3,
46
+ font_scale: float = 0.6
47
+ ) -> np.ndarray:
48
+ """Draw detection boxes and labels on image"""
49
+ img_with_boxes = image.copy()
50
+
51
+ for box, score, cls_id in zip(boxes, scores, classes):
52
+ x1, y1, x2, y2 = map(int, box)
53
+ class_name = CLASSES[cls_id]
54
+ color = CLASS_COLORS.get(class_name, (255, 255, 255))
55
+
56
+ # Draw rectangle
57
+ cv2.rectangle(img_with_boxes, (x1, y1), (x2, y2), color, thickness)
58
+
59
+ # Prepare label
60
+ label = f"{class_name} {score:.2f}"
61
+
62
+ # Calculate label size and position
63
+ (label_width, label_height), baseline = cv2.getTextSize(
64
+ label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness=2
65
+ )
66
+
67
+ # Draw label background
68
+ label_y = max(y1 - 10, label_height + 10)
69
+ cv2.rectangle(
70
+ img_with_boxes,
71
+ (x1, label_y - label_height - baseline - 5),
72
+ (x1 + label_width + 5, label_y + baseline - 5),
73
+ color,
74
+ -1
75
+ )
76
+
77
+ # Draw label text
78
+ cv2.putText(
79
+ img_with_boxes,
80
+ label,
81
+ (x1 + 2, label_y - baseline - 5),
82
+ cv2.FONT_HERSHEY_SIMPLEX,
83
+ font_scale,
84
+ (255, 255, 255),
85
+ thickness=2
86
+ )
87
+
88
+ return img_with_boxes
89
+
90
+ @torch.inference_mode()
91
+ def detect_ui_elements(
92
+ image: Image.Image,
93
+ confidence_threshold: float,
94
+ line_thickness: int
95
+ ) -> Tuple[Image.Image, str]:
96
+ """
97
+ Detect UI elements in the uploaded image
98
+
99
+ Args:
100
+ image: Input PIL Image
101
+ confidence_threshold: Minimum confidence score for detections
102
+ line_thickness: Thickness of bounding box lines
103
+
104
+ Returns:
105
+ Annotated image and detection summary text
106
+ """
107
+ if image is None:
108
+ return None, "Please upload an image first."
109
+
110
+ # Load model
111
+ model = load_model()
112
+
113
+ # Convert PIL to numpy array (RGB)
114
+ img_array = np.array(image)
115
+
116
+ # Convert RGB to BGR for OpenCV
117
+ img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
118
+
119
+ # Run detection (returns supervision Detections object)
120
+ detections = model.predict(img_array, threshold=confidence_threshold)
121
+
122
+ # Extract detection data
123
+ filtered_boxes = detections.xyxy # Bounding boxes in xyxy format
124
+ filtered_scores = detections.confidence # Confidence scores
125
+ filtered_classes = detections.class_id # Class IDs
126
+
127
+ # Draw detections
128
+ annotated_img = draw_detections(
129
+ img_bgr,
130
+ filtered_boxes.tolist(),
131
+ filtered_scores.tolist(),
132
+ filtered_classes.tolist(),
133
+ thickness=line_thickness
134
+ )
135
+
136
+ # Convert back to RGB for display
137
+ annotated_img_rgb = cv2.cvtColor(annotated_img, cv2.COLOR_BGR2RGB)
138
+ annotated_pil = Image.fromarray(annotated_img_rgb)
139
+
140
+ # Create summary text
141
+ summary_lines = [f"**Total detections:** {len(filtered_boxes)}\n"]
142
+
143
+ # Count by class
144
+ class_counts = {}
145
+ for cls_id in filtered_classes.tolist():
146
+ class_name = CLASSES[cls_id]
147
+ class_counts[class_name] = class_counts.get(class_name, 0) + 1
148
+
149
+ summary_lines.append("**Detected elements:**")
150
+ for class_name in sorted(class_counts.keys()):
151
+ count = class_counts[class_name]
152
+ summary_lines.append(f"- {class_name}: {count}")
153
+
154
+ summary_text = "\n".join(summary_lines)
155
+
156
+ return annotated_pil, summary_text
157
+
158
+ # Gradio interface
159
+ with gr.Blocks(title="RF-DETR UI Element Detector", theme=gr.themes.Soft()) as demo:
160
+
161
+ gr.Markdown("""
162
+ # 🎯 RF-DETR UI Element Detector
163
+
164
+ Upload a screenshot or UI mockup to automatically detect interactive elements.
165
+ This model identifies 8 types of UI components: buttons, fields, headings, iframes, images, labels, links, and text.
166
+ """)
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=1):
170
+ input_image = gr.Image(
171
+ type="pil",
172
+ label="πŸ“€ Upload Screenshot",
173
+ height=400
174
+ )
175
+
176
+ with gr.Accordion("βš™οΈ Detection Settings", open=True):
177
+ confidence_slider = gr.Slider(
178
+ minimum=0.1,
179
+ maximum=0.9,
180
+ value=0.35,
181
+ step=0.05,
182
+ label="Confidence Threshold",
183
+ info="Higher values = fewer but more confident detections"
184
+ )
185
+
186
+ thickness_slider = gr.Slider(
187
+ minimum=1,
188
+ maximum=6,
189
+ value=3,
190
+ step=1,
191
+ label="Box Line Thickness"
192
+ )
193
+
194
+ detect_button = gr.Button("πŸ” Detect Elements", variant="primary", size="lg")
195
+
196
+ gr.Markdown("""
197
+ ### πŸ“Š Detected Classes:
198
+ - 🟒 **button** - Interactive buttons
199
+ - πŸ”΅ **field** - Input fields
200
+ - 🟣 **heading** - Headers and titles
201
+ - 🟑 **iframe** - Embedded frames
202
+ - 🟠 **image** - Images and icons
203
+ - πŸ”· **label** - Text labels
204
+ - πŸ”΄ **link** - Hyperlinks
205
+ - βšͺ **text** - Plain text
206
+ """)
207
+
208
+ with gr.Column(scale=1):
209
+ output_image = gr.Image(
210
+ type="pil",
211
+ label="🎨 Detected Elements",
212
+ height=400
213
+ )
214
+
215
+ summary_output = gr.Markdown(label="πŸ“‹ Detection Summary")
216
+
217
+ # Examples
218
+ gr.Markdown("### πŸ’‘ Try with example images:")
219
+ gr.Examples(
220
+ examples=[
221
+ # Add example image paths here if available
222
+ ],
223
+ inputs=input_image,
224
+ label="Example Screenshots"
225
+ )
226
+
227
+ # Connect button
228
+ detect_button.click(
229
+ fn=detect_ui_elements,
230
+ inputs=[input_image, confidence_slider, thickness_slider],
231
+ outputs=[output_image, summary_output]
232
+ )
233
+
234
+ gr.Markdown("""
235
+ ---
236
+ **Model:** RF-DETR Medium (Resolution: 1600px) | **Framework:** PyTorch
237
+ """)
238
+
239
+ # Launch
240
+ if __name__ == "__main__":
241
+ demo.queue().launch(share=False)