Spaces:
Runtime error
Runtime error
| import argparse | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import kornia as K | |
| from kornia.contrib import FaceDetector, FaceDetectorResult | |
| import gradio as gr | |
| import face_detection | |
| def compare_detect_faces(img: np.ndarray, | |
| confidence_threshold, | |
| nms_threshold, | |
| kornia_toggle, | |
| retina_toggle, | |
| retina_mobile_toggle, | |
| dsfd_toggle | |
| ): | |
| detections = [] | |
| if kornia_toggle=="On": | |
| kornia_detections = kornia_detect(img, | |
| confidence_threshold=confidence_threshold, | |
| nms_threshold=nms_threshold) | |
| else: | |
| kornia_detections = None | |
| if retina_toggle=="On": | |
| retina_detections = retina_detect(img, | |
| confidence_threshold=confidence_threshold, | |
| nms_threshold=nms_threshold) | |
| detections.append(retina_detections) | |
| else: | |
| retina_detections = None | |
| if retina_mobile_toggle=="On": | |
| retina_mobile_detections = retina_mobilenet_detect(img, | |
| confidence_threshold=confidence_threshold, | |
| nms_threshold=nms_threshold) | |
| detections.append(retina_mobile_detections) | |
| else: | |
| retina_mobile_detections = None | |
| if dsfd_toggle=="On": | |
| dsfd_detections = dsfd_detect(img, | |
| confidence_threshold=confidence_threshold, | |
| nms_threshold=nms_threshold) | |
| detections.append(dsfd_detections) | |
| else: | |
| dsfd_detections = None | |
| return kornia_detections, retina_detections, retina_mobile_detections, dsfd_detections | |
| def scale_image(img: np.ndarray, size: int) -> np.ndarray: | |
| h, w = img.shape[:2] | |
| scale = 1.0 * size / w | |
| return cv2.resize(img, (int(w * scale), int(h * scale))) | |
| def base_detect(detector, img): | |
| img = scale_image(img, 640) | |
| detections = detector.detect(img) | |
| img_vis = img.copy() | |
| for box in detections: | |
| img_vis = cv2.rectangle(img_vis, | |
| box[:2].astype(int).tolist(), | |
| box[2:4].astype(int).tolist(), | |
| (0, 255, 0), 1) | |
| return img_vis | |
| def retina_detect(img, confidence_threshold, nms_threshold): | |
| detector = face_detection.build_detector( | |
| "RetinaNetResNet50", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold) | |
| img_vis = base_detect(detector, img) | |
| return img_vis | |
| def retina_mobilenet_detect(img, confidence_threshold, nms_threshold): | |
| detector = face_detection.build_detector( | |
| "RetinaNetMobileNetV1", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold) | |
| img_vis = base_detect(detector, img) | |
| return img_vis | |
| def dsfd_detect(img, confidence_threshold, nms_threshold): | |
| detector = face_detection.build_detector( | |
| "DSFDDetector", confidence_threshold=confidence_threshold, nms_iou_threshold=nms_threshold) | |
| img_vis = base_detect(detector, img) | |
| return img_vis | |
| def kornia_detect(img, confidence_threshold, nms_threshold): | |
| # select the device | |
| device = torch.device('cpu') | |
| # load the image and scale | |
| img_raw = scale_image(img, 400) | |
| # preprocess | |
| img = K.image_to_tensor(img_raw, keepdim=False).to(device) | |
| img = K.color.bgr_to_rgb(img.float()) | |
| # create the detector and find the faces ! | |
| face_detection = FaceDetector(confidence_threshold=confidence_threshold, | |
| nms_threshold=nms_threshold).to(device) | |
| with torch.no_grad(): | |
| dets = face_detection(img) | |
| dets = [FaceDetectorResult(o) for o in dets[0]] | |
| # show image | |
| img_vis = img_raw.copy() | |
| for b in dets: | |
| # draw face bounding box | |
| img_vis = cv2.rectangle(img_vis, | |
| b.top_left.int().tolist(), | |
| b.bottom_right.int().tolist(), | |
| (0, 255, 0), | |
| 1) | |
| return img_vis | |
| input_image = gr.components.Image() | |
| image_kornia = gr.components.Image(label="Kornia YuNet") | |
| image_retina = gr.components.Image(label="RetinaFace") | |
| image_retina_mobile = gr.components.Image(label="Retina Mobilenet") | |
| image_dsfd = gr.components.Image(label="DSFD") | |
| confidence_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.5, step=0.05, label="Confidence Threshold") | |
| nms_slider = gr.components.Slider(minimum=0.1, maximum=0.95, value=0.3, step=0.05, label="Non Maximum Supression (NMS) Threshold") | |
| kornia_radio = gr.Radio(["On", "Off"], value="On", label="Kornia YuNet") | |
| retinanet_radio = gr.Radio(["On", "Off"], value="On", label="RetinaFace") | |
| retina_mobile_radio = gr.Radio(["On", "Off"], value="On", label="Retina Mobilenets") | |
| dsfd_radio = gr.Radio(["On", "Off"], value="On", label="DSFD") | |
| #methods_dropdown = gr.components.Dropdown(["Kornia YuNet", "RetinaFace", "RetinaMobile", "DSFD"], value="Kornia YuNet", label="Choose a method") | |
| description = """This space let's you compare different face detection algorithms, based on Convolutional Neural Networks (CNNs). | |
| The models used here are: | |
| * Kornia YuNet: High Speed. Using the [Kornia Face Detection](https://kornia.readthedocs.io/en/latest/applications/face_detection.html) implementation | |
| * RetinaFace: High Accuracy. Using the [RetinaFace](https://arxiv.org/pdf/1905.00641.pdf) implementation with ResNet50 backbone from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference) | |
| * RetinaMobileNet: Mid Speed, Mid Accuracy. RetinaFace with a MobileNetV1 backbone, also from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference) | |
| * DSFD: High Accuracy. [Dual Shot Face Detector](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_DSFD_Dual_Shot_Face_Detector_CVPR_2019_paper.pdf) from the [face-detection library](https://github.com/hukkelas/DSFD-Pytorch-Inference) as well. | |
| """ | |
| compare_iface = gr.Interface( | |
| fn=compare_detect_faces, | |
| inputs=[input_image, confidence_slider, nms_slider, kornia_radio, retinanet_radio, retina_mobile_radio, dsfd_radio],#, size_slider, neighbour_slider, scale_slider], | |
| outputs=[image_kornia, image_retina, image_retina_mobile, image_dsfd], | |
| examples=[["data/50_Celebration_Or_Party_birthdayparty_50_25.jpg", 0.5, 0.3, "On", "On", "On", "On"], | |
| ["data/12_Group_Group_12_Group_Group_12_39.jpg", 0.5, 0.3, "On", "On", "On", "On"], | |
| ["data/31_Waiter_Waitress_Waiter_Waitress_31_55.jpg", 0.5, 0.3, "On", "On", "On", "On"], | |
| ["data/12_Group_Group_12_Group_Group_12_283.jpg", 0.5, 0.3, "On", "On", "On", "On"]], | |
| title="Face Detections", | |
| description=description | |
| ).launch() |