SAM3-Demo / app.py
prithivMLmods's picture
update app
1fd4203 verified
raw
history blame
11.2 kB
import os
import gradio as gr
import numpy as np
import torch
from PIL import Image
from typing import Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from transformers import Sam3Processor, Sam3Model
# --- Handle optional 'spaces' import for local compatibility ---
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=60):
def decorator(func):
return func
return decorator
# --- Custom Theme Setup (Plum) ---
colors.plum = colors.Color(
name="plum",
c50="#FDF4FD",
c100="#F7E6F7",
c200="#ECD0EC",
c300="#DDA0DD", # Plum
c400="#C98BC9",
c500="#B060B0",
c600="#964B96",
c700="#7A3A7A",
c800="#602C60",
c900="#451E45",
c950="#2B122B",
)
class PlumTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.plum,
secondary_hue: colors.Color | str = colors.plum,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
self.set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_100, *primary_50)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *primary_500, *primary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *primary_600, *primary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *primary_800)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
button_secondary_text_color="black",
button_secondary_text_color_hover="white",
button_secondary_background_fill="linear-gradient(90deg, *primary_200, *primary_200)",
button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
slider_color="*primary_500",
slider_color_dark="*primary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
plum_theme = PlumTheme()
# --- Hardware Setup ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# --- Model Loading ---
try:
print("Loading SAM3 Model and Processor...")
model = Sam3Model.from_pretrained("facebook/sam3").to(device)
processor = Sam3Processor.from_pretrained("facebook/sam3")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None
processor = None
@spaces.GPU(duration=60)
def process_image(input_image, task_type, text_prompt, threshold=0.5):
if input_image is None:
raise gr.Error("Please upload an image.")
if model is None or processor is None:
raise gr.Error("Model not loaded correctly.")
# Convert image to RGB
image_pil = input_image.convert("RGB")
annotations = []
with torch.no_grad():
if task_type == "Instance Segmentation":
if not text_prompt:
raise gr.Error("Please enter a text prompt for Instance Segmentation.")
# 1. Instance Segmentation Flow (Text Prompt)
inputs = processor(images=image_pil, text=text_prompt, return_tensors="pt").to(device)
outputs = model(**inputs)
# Post-process instance masks
results = processor.post_process_instance_segmentation(
outputs,
threshold=threshold,
mask_threshold=0.5,
target_sizes=inputs.get("original_sizes").tolist()
)[0]
masks_np = results['masks'].cpu().numpy() # [N, H, W]
scores_np = results['scores'].cpu().numpy()
for i, mask in enumerate(masks_np):
score_val = scores_np[i]
label = f"{text_prompt} ({score_val:.2f})"
annotations.append((mask, label))
elif task_type == "Semantic Segmentation":
# 2. Semantic Segmentation Flow (No Prompt)
# Call processor without text
inputs = processor(images=image_pil, return_tensors="pt").to(device)
outputs = model(**inputs)
# Extract semantic segmentation map
# Shape: [batch, channels, height, width]
semantic_seg = outputs.semantic_seg
# Process for visualization:
# Assuming semantic_seg is a dense map (e.g., saliency or class probabilities).
# Since the snippet implies a single channel [batch, 1, H, W], we threshold it.
# Remove batch dim -> [1, H, W] or [C, H, W]
seg_map = semantic_seg.squeeze(0)
# If 1 channel, create binary mask based on threshold/sigmoid
if seg_map.shape[0] == 1:
# Apply sigmoid if logits, or just threshold if probs
# Assuming logits for general safety in torch models
mask_tensor = torch.sigmoid(seg_map[0]) > threshold
mask_np = mask_tensor.cpu().numpy()
# Resize mask to original image size if needed
# (Note: outputs.semantic_seg is usually feature map size, might need upscaling)
# For simplicity in this snippet, we assume processor/output aligns or AnnotatedImage handles resizing (it usually requires matching sizes).
# If size mismatch occurs, we convert mask to PIL, resize, then back to numpy.
if mask_np.shape != (image_pil.height, image_pil.width):
mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
mask_np = np.array(mask_img) > 128
annotations.append((mask_np, "Semantic Region"))
else:
# If multiple channels (classes), take argmax
# This logic depends on specific SAM3 output structure
mask_idx = torch.argmax(seg_map, dim=0).cpu().numpy()
# Just visualize non-background (assuming 0 is background)
mask_np = mask_idx > 0
if mask_np.shape != (image_pil.height, image_pil.width):
mask_img = Image.fromarray(mask_np.astype(np.uint8) * 255)
mask_img = mask_img.resize(image_pil.size, Image.NEAREST)
mask_np = np.array(mask_img) > 128
annotations.append((mask_np, "Segmented Objects"))
# Return tuple format for AnnotatedImage: (original_image, list_of_annotations)
return (image_pil, annotations)
# --- UI Logic ---
css="""
#col-container {
margin: 0 auto;
max-width: 1100px;
}
#main-title h1 {
font-size: 2.1em !important;
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
"""
def update_visibility(task):
if task == "Instance Segmentation":
return gr.update(visible=True)
else:
return gr.update(visible=False)
with gr.Blocks(css=css, theme=plum_theme) as demo:
with gr.Column(elem_id="col-container"):
# Header with Logo
gr.Markdown(
"# **SAM3 Image Segmentation** <img src='https://huggingface.co/spaces/prithivMLmods/Qwen-Image-Edit-2509-LoRAs-Fast-Fusion/resolve/main/Lora%20Huggy.png' alt='Logo' width='35' height='35' style='display: inline-block; vertical-align: text-bottom; margin-left: 5px;'>",
elem_id="main-title"
)
gr.Markdown("Segment objects using **SAM3** (Segment Anything Model 3). Choose **Instance** for specific text prompts or **Semantic** for automatic segmentation.")
with gr.Row():
# Left Column: Inputs
with gr.Column(scale=1):
input_image = gr.Image(label="Input Image", type="pil", height=350)
task_type = gr.Radio(
choices=["Instance Segmentation", "Semantic Segmentation"],
value="Instance Segmentation",
label="Task Type",
interactive=True
)
text_prompt = gr.Textbox(
label="Text Prompt",
placeholder="e.g., cat, ear, car wheel...",
info="Required for Instance Segmentation",
visible=True
)
threshold = gr.Slider(label="Confidence Threshold", minimum=0.0, maximum=1.0, value=0.4, step=0.05)
run_button = gr.Button("Run Segmentation", variant="primary")
# Right Column: Output
with gr.Column(scale=1.5):
output_image = gr.AnnotatedImage(label="Segmented Output", height=500)
# Event: Hide text prompt when Semantic Segmentation is selected
task_type.change(
fn=update_visibility,
inputs=[task_type],
outputs=[text_prompt]
)
# Examples
gr.Examples(
examples=[
["examples/cat.jpg", "Instance Segmentation", "cat", 0.5],
["examples/room.jpg", "Semantic Segmentation", "", 0.5],
["examples/car.jpg", "Instance Segmentation", "tire", 0.4],
],
inputs=[input_image, task_type, text_prompt, threshold],
outputs=[output_image],
fn=process_image,
cache_examples=False,
label="Examples"
)
run_button.click(
fn=process_image,
inputs=[input_image, task_type, text_prompt, threshold],
outputs=[output_image]
)
if __name__ == "__main__":
demo.launch(ssr_mode=False, show_error=True)