forklift-video / app.py
Stepp1
[app:video] adding video demo
68f21ca
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18
from transferwee import download
model = resnet18(pretrained=True)
model.fc = nn.Sequential(
nn.Linear(512, 16),
nn.ReLU(),
nn.Linear(16,1)
)
# download latest model
# download("https://we.tl/t-bbgc3gXROZ", "best.pt") # 1
# download("https://we.tl/t-25s74dahjU", "best.pt") # 4 --> 0.92
# checkpoint = torch.load("best.pt", map_location=torch.device('cpu'))
# model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
labels_to_class = {
0: "normal",
1: "risk"
}
def predict(inp):
inp = transforms.ToTensor()(inp).unsqueeze(0) # [1, C, H, W]
with torch.no_grad():
prediction = torch.sigmoid(model(inp)[0])
if prediction > 0.7:
confidences = {
"Normal": float(prediction[0])
}
else:
confidences = {
"Riesgo": float(prediction[0])
}
print(confidences)
return confidences
description = """
<center>
Este nuestro clasificador de video de uso de grúas horquillas.\n\n
A partir de un múltiples frames, nuestro modelo compone un <i>vídeo<
/i> con objetivo de este es poder determinar si es que operación de riesgo o no.\n\n
Nuestro modelo utiliza una red convolucional pre-entrenada y, posteriormente, finetuneada en nuestro conjunto de datos.
</center>
"""
examples = [
"videos/normal-1_frame_2.jpg",
"videos/risk-0_frame_177.jpg",
]
gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
# title="Forklikt Risk Detection Demo",
description=description,
examples=examples,
).launch(share=False)