| """
|
| VLM Soft Biometrics - Gradio Interface
|
| A web application for analyzing facial soft biometrics (age, gender, emotion) using Vision-Language Models.
|
| """
|
| import os
|
| import gradio as gr
|
| import torch
|
| import cv2
|
| import numpy as np
|
| from PIL import Image, ImageDraw, ImageFont
|
| import base64
|
| from io import BytesIO
|
| import traceback
|
| from huggingface_hub import snapshot_download
|
| from utils.face_detector import FaceDetector
|
|
|
|
|
| from src.model import MTLModel
|
| from utils.commons import get_backbone_pe
|
| from utils.task_config import Task
|
|
|
|
|
| TASKS = [
|
| Task(name='Age', class_labels=["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"], criterion=None),
|
| Task(name='Gender', class_labels=["Male", "Female"], criterion=None),
|
| Task(name='Emotion', class_labels=["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"], criterion=None)
|
| ]
|
| CLASSES = [
|
| ["0-2", "3-9", "10-19", "20-29", "30-39", "40-49", "50-59", "60-69", "70+"],
|
| ["M", "F"],
|
| ["Surprise", "Fear", "Disgust", "Happy", "Sad", "Angry", "Neutral"]
|
| ]
|
|
|
|
|
| model = None
|
| transform = None
|
| detector = None
|
| device = None
|
| current_ckpt_dir = None
|
| CHECKPOINTS_DIR = './checkpoints/'
|
| MODEL_REPO_ID = "Antuke/FaR-FT-PE"
|
|
|
| def scan_checkpoints(ckpt_dir):
|
| """Scans a directory for .pt or .pth files."""
|
| if not os.path.exists(ckpt_dir):
|
| print(f"Warning: Checkpoint directory not found: {ckpt_dir}")
|
| return [], None
|
|
|
| try:
|
| ckpt_files = [
|
| os.path.join(ckpt_dir, f)
|
| for f in sorted(os.listdir(ckpt_dir))
|
| if f.endswith(('.pt', '.pth'))
|
| ]
|
| except Exception as e:
|
| print(f"Error scanning checkpoint directory {ckpt_dir}: {e}")
|
| return [], None
|
|
|
|
|
| choices_list = [(os.path.basename(f), f) for f in ckpt_files]
|
|
|
| default_ckpt_path = os.path.join(ckpt_dir, 'mtlora.pt')
|
|
|
| if default_ckpt_path in ckpt_files:
|
| return choices_list, default_ckpt_path
|
| elif ckpt_files:
|
| return choices_list, ckpt_files[0]
|
| else:
|
| print(f"No checkpoints found in {ckpt_dir}")
|
| return [], None
|
|
|
| def load_model(device,ckpt_dir='./checkpoints/mtlora.pt', pe_vision_config="PE-Core-L14-336"):
|
| """Load and configure model."""
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| backbone, transform, _ = get_backbone_pe(version='PE-Core-L14-336', apply_migration_flag=True, pretrained=False)
|
| model = MTLModel(backbone,device=device,tasks=TASKS,use_lora=True,use_deep_head=True,
|
| use_mtl_lora=('mtlora' in ckpt_dir),
|
| )
|
| print(f'loading from {ckpt_dir}')
|
| model.load_model(filepath=ckpt_dir,map_location=device)
|
| return model,transform
|
|
|
| def load_model_and_update_status(model_filepath):
|
| """Wrapper function to load a model """
|
| global model, current_ckpt_dir
|
|
|
| if model_filepath is None or model_filepath == "":
|
| return "No checkpoint selected."
|
|
|
|
|
| if model is not None and model_filepath == current_ckpt_dir:
|
| status = f"Model already loaded: {os.path.basename(model_filepath)}"
|
| print(status)
|
| return status
|
|
|
| gr.Info(f"Loading model: {os.path.basename(model_filepath)}...")
|
| try:
|
|
|
| init_model(ckpt_dir=model_filepath, detection_confidence=0.5)
|
|
|
| current_ckpt_dir = model_filepath
|
| status = f"Successfully loaded: {os.path.basename(model_filepath)}"
|
| gr.Info("Model loaded successfully!")
|
| print(status)
|
| return status
|
|
|
| except Exception as e:
|
| traceback.print_exc()
|
| status = f"Failed to load {os.path.basename(model_filepath)}: {e}"
|
| gr.Info(f"Error: {status}")
|
| print(f"ERROR: {status}")
|
| return status
|
|
|
| def predict(model, image):
|
| """Make predictions for age, gender, and emotion."""
|
| with torch.no_grad():
|
| results = model(image)
|
|
|
| age_logits, gender_logits, emotion_logits = results['Age'], results['Gender'], results['Emotion']
|
|
|
| age_probs = torch.softmax(age_logits, dim=-1)
|
| gender_probs = torch.softmax(gender_logits, dim=-1)
|
| emotion_probs = torch.softmax(emotion_logits, dim=-1)
|
|
|
| ages = torch.argmax(age_logits, dim=-1).cpu().tolist()
|
| genders = torch.argmax(gender_logits, dim=-1).cpu().tolist()
|
| emotions = torch.argmax(emotion_logits, dim=-1).cpu().tolist()
|
|
|
| results = []
|
| for i in range(len(ages)):
|
|
|
| age_all_probs = {
|
| CLASSES[0][j]: float(age_probs[i][j].cpu().detach())
|
| for j in range(len(CLASSES[0]))
|
| }
|
| gender_all_probs = {
|
| CLASSES[1][j]: float(gender_probs[i][j].cpu().detach())
|
| for j in range(len(CLASSES[1]))
|
| }
|
| emotion_all_probs = {
|
| CLASSES[2][j]: float(emotion_probs[i][j].cpu().detach())
|
| for j in range(len(CLASSES[2]))
|
| }
|
|
|
| results.append({
|
| 'age': {
|
| 'predicted_class': CLASSES[0][ages[i]],
|
| 'predicted_confidence': float(age_probs[i][ages[i]].cpu().detach()),
|
| 'all_probabilities': age_all_probs
|
| },
|
| 'gender': {
|
| 'predicted_class': CLASSES[1][genders[i]],
|
| 'predicted_confidence': float(gender_probs[i][genders[i]].cpu().detach()),
|
| 'all_probabilities': gender_all_probs
|
| },
|
| 'emotion': {
|
| 'predicted_class': CLASSES[2][emotions[i]],
|
| 'predicted_confidence': float(emotion_probs[i][emotions[i]].cpu().detach()),
|
| 'all_probabilities': emotion_all_probs
|
| }
|
| })
|
|
|
| return results
|
|
|
| def get_centroid_weighted_age(probs):
|
| """
|
| Using centroids of age group we calculate an age regression number
|
| using an average weight based on predicted probability distribution
|
| """
|
| probs = list(probs.values())
|
| centroids = [1, 4.5, 14.5, 24.5, 34.5, 44.5, 54.5, 64.5, 80]
|
| age = 0
|
|
|
| for i,p in enumerate(probs):
|
| age += p * centroids[i]
|
|
|
| return age
|
|
|
|
|
| def init_model(ckpt_dir="./checkpoints/mtlora.pt", detection_confidence=0.5):
|
| """Initialize model and detector."""
|
| global model, transform, detector, device
|
|
|
| print(f"\n{'='*60}")
|
| print(f"INITIALIZING MODEL: {ckpt_dir}")
|
| print(f"{'='*60}")
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| print(f"Using device: {device}")
|
|
|
| if not os.path.exists(ckpt_dir):
|
| error_msg = f"Model weights not found: {ckpt_dir}."
|
| print(f"ERROR: {error_msg}")
|
| raise FileNotFoundError(error_msg)
|
|
|
| print(f"Model weights found: {ckpt_dir}")
|
|
|
|
|
| model, transform = load_model(ckpt_dir= ckpt_dir,device= device)
|
| model.eval()
|
| print(device)
|
| model.to(device)
|
|
|
|
|
| detector = FaceDetector(confidence_threshold=detection_confidence)
|
|
|
| print("✓ Model and detector initialized successfully")
|
| print(f"{'='*60}\n")
|
|
|
| def process_image(image, selected_checkpoint_path):
|
| """
|
| Process an uploaded image and return predictions with annotated image.
|
|
|
| Args:
|
| image: PIL Image or numpy array
|
| selected_checkpoint_path: The path from the checkpoint dropdown
|
|
|
| Returns:
|
| tuple: (annotated_image, results_html)
|
| """
|
| if image is None:
|
| return None, "<p style='color: red;'>Please upload an image</p>"
|
|
|
|
|
|
|
| if model is None or selected_checkpoint_path != current_ckpt_dir:
|
| print(f"Model mismatch or not loaded. Selected: {selected_checkpoint_path}, Current: {current_ckpt_dir}")
|
| status = load_model_and_update_status(selected_checkpoint_path)
|
| if "Failed" in status or "Error" in status:
|
| return image, f"<p style'color: red;'>Model Error: {status}</p>"
|
|
|
|
|
| try:
|
|
|
|
|
| if isinstance(image, Image.Image):
|
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
| else:
|
| img_cv = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
| img_pil_annotated = image.copy()
|
| draw = ImageDraw.Draw(img_pil_annotated)
|
|
|
| faces = detector.detect(img_cv, pad_rect=True)
|
|
|
| if faces is None or len(faces) == 0:
|
| return image, "<p style='color: orange;'>No faces detected in the image</p>"
|
|
|
|
|
| crops_pil = []
|
| face_data = []
|
|
|
| for idx, (crop, confidence, bbox) in enumerate(faces):
|
| crop_rgb = cv2.cvtColor(crop, cv2.COLOR_BGR2RGB)
|
| crop_pil = Image.fromarray(crop_rgb)
|
| crops_pil.append(crop_pil)
|
|
|
|
|
| crop_resized = crop_pil.resize((336, 336), Image.Resampling.LANCZOS)
|
|
|
| face_data.append({
|
| 'bbox': bbox,
|
| 'detection_confidence': float(confidence),
|
| 'crop_image': crop_resized
|
| })
|
|
|
|
|
| crop_tensors = [transform(crop_pil) for crop_pil in crops_pil]
|
| batch_tensor = torch.stack(crop_tensors).to(device)
|
|
|
| predictions = predict(model, batch_tensor)
|
|
|
|
|
| for face, pred in zip(face_data, predictions):
|
| face['predictions'] = pred
|
|
|
|
|
| for idx, face in enumerate(face_data):
|
| bbox = face['bbox']
|
| pred = face['predictions']
|
| x, y, w, h = bbox
|
|
|
|
|
| font_size_ratio = 0.08
|
| min_font_size = 12
|
| max_font_size = 48
|
| adaptive_font_size = max(min_font_size, min(int(w * font_size_ratio), max_font_size))
|
| try:
|
| font = ImageFont.load_default(size=adaptive_font_size)
|
| except IOError:
|
| font = ImageFont.load_default()
|
|
|
|
|
| draw.rectangle([(x, y), (x + w, y + h)], outline="lime", width=2)
|
|
|
|
|
| lines_to_draw = []
|
|
|
|
|
| age_label = pred['age']['predicted_class']
|
| age_conf = pred['age']['predicted_confidence']
|
| lines_to_draw.append(f"Age: {age_label} ({age_conf*100:.0f}%)")
|
|
|
|
|
| gen_label = pred['gender']['predicted_class']
|
| gen_conf = pred['gender']['predicted_confidence']
|
| lines_to_draw.append(f"Gender: {gen_label} ({gen_conf*100:.0f}%)")
|
|
|
|
|
| emo_label = pred['emotion']['predicted_class']
|
| emo_conf = pred['emotion']['predicted_confidence']
|
| lines_to_draw.append(f"Emotion: {emo_label} ({emo_conf*100:.0f}%)")
|
|
|
|
|
|
|
| line_spacing = 10
|
| total_text_height = 0
|
| for line in lines_to_draw:
|
| _left, top, _right, bottom = draw.textbbox((0, 0), line, font=font)
|
| total_text_height += (bottom - top) + line_spacing
|
|
|
|
|
| if y - total_text_height > 0:
|
| text_y = y - line_spacing
|
| for line in reversed(lines_to_draw):
|
| left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="ls")
|
| draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
| draw.text((x, text_y), line, font=font, fill="white", anchor="ls")
|
| text_y = top - line_spacing
|
| else:
|
| text_y = y + h + line_spacing
|
| for line in lines_to_draw:
|
| left, top, right, bottom = draw.textbbox((x, text_y), line, font=font, anchor="lt")
|
| draw.rectangle([(left - 2, top - 2), (right + 2, bottom + 2)], fill="black")
|
| draw.text((x, text_y), line, font=font, fill="white", anchor="lt")
|
| text_y = bottom + line_spacing
|
|
|
|
|
|
|
| def pil_to_base64(img_pil):
|
| buffered = BytesIO()
|
| img_pil.save(buffered, format="JPEG")
|
| img_str = base64.b64encode(buffered.getvalue()).decode()
|
| return f"data:image/jpeg;base64,{img_str}"
|
|
|
| results_html = f"""
|
| <style>
|
| :root {{
|
| --primary-color: #4f46e5;
|
| --success-color: #10b981;
|
|
|
| --text-primary: var(--body-text-color);
|
| --text-secondary: var(--body-text-color-subdued);
|
| --background-dark: var(--background-fill-primary);
|
| --background-darker: var(--background-fill-secondary);
|
| --border-color: var(--border-color-primary);
|
| }}
|
| .results-container {{
|
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| background: var(--background-darker);
|
| padding: 20px;
|
| border-radius: 12px;
|
| color: var(--text-primary);
|
| }}
|
| .results-container h2 {{
|
| color: var(--text-primary);
|
| margin-bottom: 20px;
|
| }}
|
| .face-count {{
|
| display: inline-block;
|
| background: var(--primary-color);
|
| color: white;
|
| padding: 4px 12px;
|
| border-radius: 20px;
|
| font-size: 0.9rem;
|
| font-weight: 500;
|
| margin-left: 8px;
|
| }}
|
| .face-card {{
|
| background: var(--background-dark);
|
| border-radius: 8px;
|
| padding: 20px;
|
| margin-top: 15px;
|
| border: 1px solid var(--border-color);
|
| display: flex;
|
| gap: 20px;
|
| align-items: flex-start;
|
| }}
|
| .face-header {{
|
| font-size: 1rem;
|
| font-weight: 600;
|
| margin-bottom: 20px;
|
| color: var(--text-primary);
|
| }}
|
| .face-image-left {{
|
| flex-shrink: 0;
|
| width: 336px;
|
| height: 336px;
|
| background: var(--background-darker);
|
| border-radius: 8px;
|
| overflow: hidden;
|
| border: 1px solid var(--border-color);
|
| }}
|
| .face-image-left img {{
|
| width: 100%;
|
| height: 100%;
|
| object-fit: cover;
|
| }}
|
| .face-predictions-right {{
|
| flex: 1;
|
| display: flex;
|
| flex-direction: column;
|
| gap: 10px;
|
| }}
|
| .predictions-horizontal {{
|
| display: flex;
|
| flex-direction: row;
|
| gap: 30px;
|
| justify-content: space-between;
|
| }}
|
| .prediction-section {{
|
| flex: 1;
|
| min-width: 0;
|
| }}
|
| .prediction-category-label {{
|
| font-size: 0.8rem;
|
| font-weight: 700;
|
| text-transform: uppercase;
|
| letter-spacing: 0.5px;
|
| color: var(--primary-color);
|
| margin-bottom: 8px;
|
| border-bottom: 2px solid var(--primary-color);
|
| padding-bottom: 4px;
|
| }}
|
| .probabilities-list {{
|
| display: flex;
|
| flex-direction: column;
|
| gap: 6px;
|
| }}
|
| .probability-item {{
|
| display: grid;
|
| grid-template-columns: 70px 1fr 55px;
|
| align-items: center;
|
| gap: 8px;
|
| padding: 4px 6px;
|
| border-radius: 4px;
|
| }}
|
| .probability-item.predicted {{
|
| background: rgba(79, 70, 229, 0.2);
|
| border-left: 3px solid var(--primary-color);
|
| padding-left: 8px;
|
| }}
|
| .prob-class {{
|
| font-size: 0.8rem;
|
| font-weight: 600;
|
| color: var(--text-primary);
|
| word-wrap: break-word; /* Ensure long class names wrap */
|
| }}
|
| .probability-item.predicted .prob-class {{
|
| color: var(--primary-color);
|
| font-weight: 700;
|
| }}
|
| .prob-bar-container {{
|
| height: 6px;
|
| background: var(--border-color);
|
| border-radius: 3px;
|
| overflow: hidden;
|
| }}
|
| .prob-bar {{
|
| height: 100%;
|
| background: linear-gradient(90deg, var(--primary-color), var(--success-color));
|
| border-radius: 3px;
|
| transition: width 0.6s ease;
|
| }}
|
| .probability-item.predicted .prob-bar {{
|
| background: var(--primary-color);
|
| }}
|
| .prob-percentage {{
|
| font-size: 0.75rem;
|
| font-weight: 500;
|
| color: var(--text-secondary);
|
| text-align: right;
|
| }}
|
| .probability-item.predicted .prob-percentage {{
|
| color: var(--primary-color);
|
| font-weight: 700;
|
| }}
|
| @media (max-width: 1200px) {{
|
| .predictions-horizontal {{
|
| flex-direction: column;
|
| gap: 15px;
|
| }}
|
| }}
|
| @media (max-width: 900px) {{
|
| .face-card {{
|
| flex-direction: column;
|
| }}
|
| .face-image-left {{
|
| width: 100%;
|
| max-width: 336px;
|
| margin: 0 auto;
|
| }}
|
| .probability-item {{
|
| grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
|
| }}
|
| .prob-class {{
|
| font-size: 0.75rem;
|
| }}
|
| }}
|
| </style>
|
|
|
| <div class='results-container'>
|
| <h2 style='margin-top: 0;'>Classification Results <span class='face-count'>{len(face_data)} face(s)</span></h2>
|
| """
|
|
|
| for idx, face in enumerate(face_data):
|
| pred = face['predictions']
|
| face_img_base64 = pil_to_base64(face['crop_image'])
|
| age = get_centroid_weighted_age(pred['age']['all_probabilities'])
|
| results_html += f"""
|
| <div class='face-card'>
|
| <div class='face-image-left'>
|
| <img src='{face_img_base64}' alt='Face {idx+1}'>
|
| </div>
|
| <div class='face-predictions-right'>
|
| <div class='face-header'>Face {idx+1} - Detection Confidence: {face['detection_confidence']:.1%} - Centroid Age: {int(age)}</div>
|
| <div class='predictions-horizontal'>
|
| <div class='prediction-section'>
|
| <div class='prediction-category-label'>Age</div>
|
| <div class='probabilities-list'>
|
| """
|
| for age_class in CLASSES[0]:
|
| prob = pred['age']['all_probabilities'][age_class]
|
| is_predicted = (age_class == pred['age']['predicted_class'])
|
| predicted_class = 'predicted' if is_predicted else ''
|
| results_html += f"""
|
| <div class='probability-item {predicted_class}'>
|
| <span class='prob-class'>{age_class}</span>
|
| <div class='prob-bar-container'>
|
| <div class='prob-bar' style='width: {prob*100}%'></div>
|
| </div>
|
| <span class='prob-percentage'>{prob*100:.1f}%</span>
|
| </div>
|
| """
|
| results_html += f"""
|
| </div>
|
| </div>
|
| <div class='prediction-section'>
|
| <div class='prediction-category-label'>Gender</div>
|
| <div class='probabilities-list'>
|
| """
|
| for gender_class in CLASSES[1]:
|
| prob = pred['gender']['all_probabilities'][gender_class]
|
| is_predicted = (gender_class == pred['gender']['predicted_class'])
|
| predicted_class = 'predicted' if is_predicted else ''
|
| results_html += f"""
|
| <div class='probability-item {predicted_class}'>
|
| <span class='prob-class'>{gender_class}</span>
|
| <div class='prob-bar-container'>
|
| <div class='prob-bar' style='width: {prob*100}%'></div>
|
| </div>
|
| <span class='prob-percentage'>{prob*100:.1f}%</span>
|
| </div>
|
| """
|
| results_html += """
|
| </div>
|
| </div>
|
| <div class='prediction-section'>
|
| <div class='prediction-category-label'>Emotion</div>
|
| <div class='probabilities-list'>
|
| """
|
| for emotion_class in CLASSES[2]:
|
| prob = pred['emotion']['all_probabilities'][emotion_class]
|
| is_predicted = (emotion_class == pred['emotion']['predicted_class'])
|
| predicted_class = 'predicted' if is_predicted else ''
|
| results_html += f"""
|
| <div class='probability-item {predicted_class}'>
|
| <span class='prob-class'>{emotion_class}</span>
|
| <div class='prob-bar-container'>
|
| <div class='prob-bar' style='width: {prob*100}%'></div>
|
| </div>
|
| <span class='prob-percentage'>{prob*100:.1f}%</span>
|
| </div>
|
| """
|
| results_html += """
|
| </div>
|
| </div>
|
| </div>
|
| </div>
|
| </div>
|
| """
|
| results_html += "</div>"
|
|
|
|
|
| return img_pil_annotated, results_html
|
|
|
| except Exception as e:
|
| traceback.print_exc()
|
| return image, f"<p style='color: red;'>Error processing image: {str(e)}</p>"
|
|
|
| def create_interface(checkpoint_list, default_checkpoint, initial_status):
|
| """Create and configure the Gradio interface."""
|
|
|
| custom_css = """
|
| .gradio-container {
|
| font-family: 'Arial', sans-serif;
|
| }
|
| .output-html {
|
| max-height: none !important;
|
| overflow-y: auto;
|
| }
|
| :root {
|
| --primary-color: #4f46e5;
|
| --success-color: #10b981;
|
|
|
| --text-primary: var(--body-text-color);
|
| --text-secondary: var(--body-text-color-subdued);
|
| --background-dark: var(--background-fill-primary);
|
| --background-darker: var(--background-fill-secondary);
|
| --border-color: var(--border-color-primary);
|
| }
|
| .results-container {
|
| font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
| background: var(--background-darker);
|
| padding: 20px;
|
| border-radius: 12px;
|
| color: var(--text-primary);
|
| }
|
| .results-container h2 {
|
| color: var(--text-primary);
|
| margin-bottom: 20px;
|
| }
|
| .face-count {
|
| display: inline-block;
|
| background: var(--primary-color);
|
| color: white;
|
| padding: 4px 12px;
|
| border-radius: 20px;
|
| font-size: 0.9rem;
|
| font-weight: 500;
|
| margin-left: 8px;
|
| }
|
| .face-card {
|
| background: var(--background-dark);
|
| border-radius: 8px;
|
| padding: 20px;
|
| margin-top: 15px;
|
| border: 1px solid var(--border-color);
|
| display: flex;
|
| gap: 20px;
|
| align-items: flex-start;
|
| }
|
| .face-header {
|
| font-size: 1rem;
|
| font-weight: 600;
|
| margin-bottom: 20px;
|
| color: var(--text-primary);
|
| }
|
| .face-image-left {
|
| flex-shrink: 0;
|
| width: 336px;
|
| height: 336px;
|
| background: var(--background-darker);
|
| border-radius: 8px;
|
| overflow: hidden;
|
| border: 1px solid var(--border-color);
|
| }
|
| .face-image-left img {
|
| width: 100%;
|
| height: 100%;
|
| object-fit: cover;
|
| }
|
| .face-predictions-right {
|
| flex: 1;
|
| display: flex;
|
| flex-direction: column;
|
| gap: 10px;
|
| }
|
| .predictions-horizontal {
|
| display: flex;
|
| flex-direction: row;
|
| gap: 30px;
|
| justify-content: space-between;
|
| }
|
| .prediction-section {
|
| flex: 1;
|
| min-width: 0;
|
| }
|
| .prediction-category-label {
|
| font-size: 0.8rem;
|
| font-weight: 700;
|
| text-transform: uppercase;
|
| letter-spacing: 0.5px;
|
| color: var(--primary-color);
|
| margin-bottom: 8px;
|
| border-bottom: 2px solid var(--primary-color);
|
| padding-bottom: 4px;
|
| }
|
| .probabilities-list {
|
| display: flex;
|
| flex-direction: column;
|
| gap: 6px;
|
| }
|
| .probability-item {
|
| display: grid;
|
| grid-template-columns: 70px 1fr 55px;
|
| align-items: center;
|
| gap: 8px;
|
| padding: 4px 6px;
|
| border-radius: 4px;
|
| }
|
| .probability-item.predicted {
|
| background: rgba(79, 70, 229, 0.2);
|
| border-left: 3px solid var(--primary-color);
|
| padding-left: 8px;
|
| }
|
| .prob-class {
|
| font-size: 0.8rem;
|
| font-weight: 600;
|
| color: var(--text-primary);
|
| word-wrap: break-word; /* Ensure long class names wrap */
|
| }
|
| .probability-item.predicted .prob-class {
|
| color: var(--primary-color);
|
| font-weight: 700;
|
| }
|
| .prob-bar-container {
|
| height: 6px;
|
| background: var(--border-color);
|
| border-radius: 3px;
|
| overflow: hidden;
|
| }
|
| .prob-bar {
|
| height: 100%;
|
| background: linear-gradient(90deg, var(--primary-color), var(--success-color));
|
| border-radius: 3px;
|
| transition: width 0.6s ease;
|
| }
|
| .probability-item.predicted .prob-bar {
|
| background: var(--primary-color);
|
| }
|
| .prob-percentage {
|
| font-size: 0.75rem;
|
| font-weight: 500;
|
| color: var(--text-secondary);
|
| text-align: right;
|
| }
|
| .probability-item.predicted .prob-percentage {
|
| color: var(--primary-color);
|
| font-weight: 700;
|
| }
|
| @media (max-width: 1200px) {
|
| .predictions-horizontal {
|
| flex-direction: column;
|
| gap: 15px;
|
| }
|
| }
|
| @media (max-width: 900px) {
|
| .face-card {
|
| flex-direction: column;
|
| }
|
| .face-image-left {
|
| width: 100%;
|
| max-width: 336px;
|
| margin: 0 auto;
|
| }
|
| .probability-item {
|
| grid-template-columns: 60px 1fr 50px; /* Adjust for smaller screens */
|
| }
|
| .prob-class {
|
| font-size: 0.75rem;
|
| }
|
| }
|
| """
|
|
|
|
|
| with gr.Blocks(css=custom_css, title="Face Classification System", theme=gr.themes.Default()) as demo:
|
|
|
| with gr.Row():
|
| gr.Markdown("# Face Classification System")
|
|
|
|
|
| with gr.Row():
|
| with gr.Column(scale=3):
|
| checkpoint_dropdown = gr.Dropdown(
|
| label="Select Model Checkpoint",
|
| choices=checkpoint_list,
|
| value=default_checkpoint,
|
| )
|
| with gr.Column(scale=2):
|
| model_status_text = gr.Textbox(
|
| label="Model Status",
|
| value=initial_status,
|
| interactive=False,
|
| )
|
|
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| gr.Markdown("""
|
| ### Features
|
| - **Age Classification**: 9 categories (0-2, 3-9, 10-19, 20-29, 30-39, 40-49, 50-59, 60-69, 70+) + Age estimation with weighted centroid average
|
| - **Gender Classification**: M/F
|
| - **Emotion Recognition**: 7 categories (Surprise, Fear, Disgust, Happy, Sad, Angry, Neutral)
|
| - **Automatic Face Detection**: Detects and analyzes multiple faces
|
| - **Detailed Probability Distributions**: View confidence for all classes
|
| """)
|
|
|
| with gr.Column(scale=1):
|
| gr.Markdown("""
|
| ### Instructions
|
| 1. (Optional) Select a model checkpoint from the dropdown.
|
| 2. Upload an image or capture from webcam (or select an example below)
|
| 3. Click "Classify Image"
|
| 4. View detected faces with age, gender, and emotion predictions below
|
| \n
|
| Demo video of usage of this space: https://youtu.be/V6-9QTf1xaQ
|
| """)
|
|
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| input_image = gr.Image(
|
| label="Upload Image",
|
| type="pil",
|
| sources=["upload", "webcam"],
|
| height=400
|
| )
|
|
|
| with gr.Column(scale=1):
|
| output_image = gr.Image(
|
| label="Annotated Image",
|
| type="pil",
|
| height=400
|
| )
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| analyze_btn = gr.Button(
|
| "Classify Image",
|
| variant="primary",
|
| size="lg"
|
| )
|
|
|
|
|
| example_dir = "example"
|
| example_images = []
|
| if os.path.exists(example_dir):
|
| try:
|
| example_images = [
|
| os.path.join(example_dir, f)
|
| for f in sorted(os.listdir(example_dir))
|
| if f.lower().endswith(('.jpg', '.jpeg', '.png', '.webp'))
|
| ]
|
| except Exception as e:
|
| print(f"Error reading example images from {example_dir}: {e}")
|
|
|
| if example_images:
|
| gr.Markdown("### 📸 Try with example images")
|
| gr.Examples(
|
| examples=example_images,
|
| inputs=input_image,
|
| cache_examples=False
|
| )
|
|
|
|
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| output_html = gr.HTML(
|
| label="Classification Results",
|
| elem_classes="output-html"
|
| )
|
|
|
|
|
| analyze_btn.click(
|
| fn=process_image,
|
| inputs=[input_image, checkpoint_dropdown],
|
| outputs=[output_image, output_html]
|
| )
|
|
|
| checkpoint_dropdown.change(
|
| fn=load_model_and_update_status,
|
| inputs=[checkpoint_dropdown],
|
| outputs=[model_status_text]
|
| )
|
|
|
|
|
| return demo
|
|
|
|
|
|
|
| print("="*60)
|
| print("VLM SOFT BIOMETRICS - GRADIO INTERFACE")
|
| print("="*60)
|
|
|
|
|
| print(f"Downloading model weights from {MODEL_REPO_ID} to {CHECKPOINTS_DIR}...")
|
| os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
|
| try:
|
| snapshot_download(
|
| repo_id=MODEL_REPO_ID,
|
| local_dir=CHECKPOINTS_DIR,
|
| allow_patterns=["*.pt", "*.pth"],
|
| local_dir_use_symlinks=False,
|
| )
|
| print("Model download complete.")
|
| except Exception as e:
|
| print(f"CRITICAL: Failed to download models from Hub. {e}")
|
| traceback.print_exc()
|
|
|
|
|
| checkpoint_list, default_checkpoint = scan_checkpoints(CHECKPOINTS_DIR)
|
|
|
| if not checkpoint_list:
|
| print(f"CRITICAL: No checkpoints found in {CHECKPOINTS_DIR}. App may not function.")
|
| else:
|
| print(f"Found checkpoints: {len(checkpoint_list)} file(s).")
|
| print(f"Default checkpoint: {default_checkpoint}")
|
|
|
|
|
| initial_status_msg = "No default model found. Please select one."
|
| if default_checkpoint:
|
| print(f"\nInitializing default model: {default_checkpoint}")
|
|
|
|
|
| initial_status_msg = load_model_and_update_status(default_checkpoint)
|
| print(initial_status_msg)
|
| else:
|
| print("Warning: No default model to load.")
|
|
|
|
|
|
|
| print("Creating Gradio interface...")
|
| demo = create_interface(checkpoint_list, default_checkpoint, initial_status_msg)
|
| print("✓ Interface created successfully!")
|
|
|
|
|
| if __name__ == "__main__":
|
| import argparse
|
|
|
| parser = argparse.ArgumentParser(description="VLM Soft Biometrics - Gradio Interface")
|
| parser.add_argument("--ckpt_dir", type=str, default="./checkpoints/",
|
| help="Path to the checkpoint directory (will be populated from HF Hub)")
|
| parser.add_argument("--detection_confidence", type=float, default=0.5,
|
| help="Confidence threshold for face detection")
|
| parser.add_argument("--port", type=int, default=7860,
|
| help="Port to run the Gradio app")
|
| parser.add_argument("--share", action="store_true",
|
| help="Create a public share link")
|
| parser.add_argument("--server_name", type=str, default="0.0.0.0",
|
| help="Server name/IP to bind to")
|
| args = parser.parse_args()
|
|
|
| CHECKPOINTS_DIR = args.ckpt_dir
|
|
|
| print(f"\nLaunching server on {args.server_name}:{args.port}")
|
| print(f"Monitoring checkpoint directory: {CHECKPOINTS_DIR}")
|
| print("="*60)
|
|
|
| demo.launch(
|
| share=args.share,
|
| server_name=args.server_name,
|
| server_port=args.port,
|
| show_error=True,
|
| ) |