Spaces:
Sleeping
Sleeping
| import os | |
| # os.system("pip uninstall -y gradio") | |
| # os.system("pip install gradio==4.44.1") | |
| # os.system("pip install gradio_image_prompter") | |
| import gradio as gr | |
| import torch | |
| from PIL import ImageDraw, Image, ImageFont | |
| import numpy as np | |
| import requests | |
| from io import BytesIO | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from transformers import SamModel, SamProcessor | |
| from gradio_image_prompter import ImagePrompter | |
| import os | |
| # define variables | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # model_id = "facebook/sam-vit-huge" #60s | |
| model_id = 'Zigeng/SlimSAM-uniform-50' #50s | |
| # model_id = "facebook/sam-vit-base" #50s | |
| # model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device) | |
| model = SamModel.from_pretrained(model_id).to(device) | |
| processor = SamProcessor.from_pretrained(model_id) | |
| # Description | |
| title = "<center><strong><font size='8'> π Segment food with clicks π</font></strong></center>" | |
| instruction = """ # Instruction | |
| This segmentation tool is built with HuggingFace SAM model. To use to label true mask, please follow the following steps \n | |
| π₯ Step 1: Copy segmentation candidate image link and paste in 'Enter Image URL' and click 'Upload Image' \n | |
| π₯ Step 2: Add positive (right click), negative (middle click), and bounding box (click and drag - only ONE box at most) for the food \n | |
| π₯ Step 3: Click on 'Segment with prompts' to segment Image and see if there's a correct segmentation on the 3 options \n | |
| π₯ Step 4: If not, you can repeat the process of adding prompt and segment until a correct one is generated. Prompt history will be retained unless reloading the image \n | |
| π₯ Step 5: Download the satisfied segmentaion image through the icon on top right corner of the image, please name it with 'correct_seg_xxx' where xxx is the photo ID | |
| """ | |
| css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" | |
| # functions | |
| def read_image(url): | |
| response = requests.get(url) | |
| img = Image.open(BytesIO(response.content)) | |
| formatted_image = { | |
| "image": np.array(img), | |
| "points": [], | |
| } # Create the correct format | |
| return formatted_image | |
| def get_mask_image(raw_image, mask): | |
| tmp_mask = np.array(mask * 1) | |
| tmp_mask[tmp_mask == 1] = 255 | |
| tmp_mask2 = np.expand_dims(tmp_mask, axis=2) | |
| # | |
| tmp_img_arr = np.array(raw_image) | |
| tmp_img_arr = np.concatenate((tmp_img_arr, tmp_mask2), axis = 2) | |
| return tmp_img_arr | |
| def format_prompt_points(points): | |
| prompt_points = [] | |
| point_labels = [] | |
| prompt_boxes = [] | |
| for point in points: | |
| print(point) | |
| if point[2] == 2.0 and point[5] == 3.0: | |
| prompt_boxes.append([point[0], point[1], point[3], point[4]]) | |
| else: | |
| prompt_points.append([point[0], point[1]]) | |
| label = 1 if point[2] == 1.0 else 0 | |
| point_labels.append(label) | |
| prompt_points = [[prompt_points]] if len(prompt_points) > 0 else None | |
| point_labels = [point_labels] if len(point_labels) > 0 else None | |
| prompt_boxes = [prompt_boxes] if len(prompt_boxes) > 0 else None | |
| return prompt_points, point_labels, prompt_boxes | |
| def segment_with_points( | |
| prompts | |
| ): | |
| image = np.array(prompts["image"]) # Convert the image to a numpy array | |
| points = prompts["points"] # Get the points from prompts | |
| # | |
| prompt_points, point_labels, prompt_boxes = format_prompt_points(points) | |
| print(prompt_points, point_labels, prompt_boxes) | |
| # segment | |
| inputs = processor(image, | |
| input_boxes = prompt_boxes, | |
| input_points=prompt_points, | |
| input_labels=point_labels, | |
| return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| # | |
| masks = processor.image_processor.post_process_masks( | |
| outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()) | |
| scores = outputs.iou_scores | |
| # | |
| mask_images = [get_mask_image(image, m) for m in masks[0][0]] | |
| mask_img1, mask_img2, mask_img3 = mask_images | |
| # return fig, None | |
| return mask_img1, mask_img2, mask_img3 | |
| def clear(): | |
| return None, None, None, None | |
| with gr.Blocks(css=css, title='Segment Food with Prompts') as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown(title) | |
| gr.Markdown("") | |
| image_url = gr.Textbox(label="Enter Image URL", | |
| value = "https://img.cdn4dd.com/u/media/4da0fbcf-5e3d-45d4-8995-663fbcf3c3c8.jpg") | |
| run_with_url = gr.Button("Upload Image") | |
| segment_btn = gr.Button("Segment with prompts", variant='primary') | |
| clear_btn = gr.Button("Clear points", variant='secondary') | |
| with gr.Column(scale=1): | |
| gr.Markdown(instruction) | |
| # Images | |
| with gr.Row(variant="panel"): | |
| with gr.Column(scale=0): | |
| candidate_pic = ImagePrompter(show_label=False) | |
| segpic_output1 = gr.Image(format="png") | |
| with gr.Column(scale=0): | |
| segpic_output2 = gr.Image(format="png") | |
| segpic_output3 = gr.Image(format="png") | |
| # Define interaction relationship | |
| run_with_url.click(read_image, | |
| inputs=[image_url], | |
| # outputs=[segm_img_p, cond_img_p]) | |
| outputs=[candidate_pic]) | |
| segment_btn.click(segment_with_points, | |
| inputs=candidate_pic, | |
| # outputs=[segm_img_p, cond_img_p]) | |
| outputs=[segpic_output1, segpic_output2, segpic_output3]) | |
| clear_btn.click(clear, outputs=[candidate_pic, segpic_output1, segpic_output2, segpic_output3]) | |
| demo.queue() | |
| demo.launch() |