Upload 2 files
Browse files- utils/cache_manager.py +35 -16
utils/cache_manager.py
CHANGED
|
@@ -1,8 +1,27 @@
|
|
| 1 |
from functools import lru_cache
|
| 2 |
from typing import Iterable, List, Tuple
|
| 3 |
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
BATCH_SIZE = 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
|
|
@@ -12,23 +31,19 @@ def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
|
|
| 12 |
|
| 13 |
|
| 14 |
@lru_cache(maxsize=5)
|
| 15 |
-
def
|
| 16 |
-
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
tensor_image = img.copy()
|
| 23 |
-
|
| 24 |
-
device = next(model.parameters()).device
|
| 25 |
-
dtype = next(model.parameters()).dtype
|
| 26 |
logits: List[float] = []
|
| 27 |
|
| 28 |
with torch.no_grad():
|
| 29 |
for start in range(0, len(label_tuple), BATCH_SIZE):
|
| 30 |
-
batch = label_tuple[start : start + BATCH_SIZE]
|
| 31 |
-
inputs =
|
| 32 |
|
| 33 |
prepared = {}
|
| 34 |
for key, value in inputs.items():
|
|
@@ -40,10 +55,14 @@ def cached_inference(image_path, labels, model, processor):
|
|
| 40 |
else:
|
| 41 |
prepared[key] = value
|
| 42 |
|
| 43 |
-
outputs =
|
| 44 |
-
|
| 45 |
-
logits.extend(batch_logits)
|
| 46 |
|
| 47 |
-
|
| 48 |
scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
|
| 49 |
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from functools import lru_cache
|
| 2 |
from typing import Iterable, List, Tuple
|
| 3 |
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
|
| 8 |
BATCH_SIZE = 50
|
| 9 |
+
_CACHE_MODEL = None
|
| 10 |
+
_CACHE_PROCESSOR = None
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def configure_cache(model, processor) -> None:
|
| 14 |
+
"""Bind the shared model and processor for cached inference."""
|
| 15 |
+
global _CACHE_MODEL, _CACHE_PROCESSOR
|
| 16 |
+
_CACHE_MODEL = model
|
| 17 |
+
_CACHE_PROCESSOR = processor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def preprocess_image(img_path: str) -> Image.Image:
|
| 21 |
+
img = Image.open(img_path)
|
| 22 |
+
img = img.convert("RGB")
|
| 23 |
+
img.thumbnail((448, 448))
|
| 24 |
+
return img
|
| 25 |
|
| 26 |
|
| 27 |
def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
@lru_cache(maxsize=5)
|
| 34 |
+
def _cached_logits(image_path: str, label_tuple: Tuple[str, ...]) -> List[float]:
|
| 35 |
+
if _CACHE_MODEL is None or _CACHE_PROCESSOR is None:
|
| 36 |
+
raise RuntimeError("Cache manager not configured with model and processor.")
|
| 37 |
|
| 38 |
+
device = next(_CACHE_MODEL.parameters()).device
|
| 39 |
+
dtype = next(_CACHE_MODEL.parameters()).dtype
|
| 40 |
+
image = preprocess_image(image_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
logits: List[float] = []
|
| 42 |
|
| 43 |
with torch.no_grad():
|
| 44 |
for start in range(0, len(label_tuple), BATCH_SIZE):
|
| 45 |
+
batch = list(label_tuple[start : start + BATCH_SIZE])
|
| 46 |
+
inputs = _CACHE_PROCESSOR(images=image, text=batch, return_tensors="pt", padding=True)
|
| 47 |
|
| 48 |
prepared = {}
|
| 49 |
for key, value in inputs.items():
|
|
|
|
| 55 |
else:
|
| 56 |
prepared[key] = value
|
| 57 |
|
| 58 |
+
outputs = _CACHE_MODEL(**prepared)
|
| 59 |
+
logits.extend(outputs.logits_per_image[0].detach().cpu().tolist())
|
|
|
|
| 60 |
|
| 61 |
+
image.close()
|
| 62 |
scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
|
| 63 |
return scores
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def cached_inference(image_path: str, labels: Iterable[str]) -> List[float]:
|
| 67 |
+
label_tuple = _ensure_tuple(labels)
|
| 68 |
+
return _cached_logits(image_path, label_tuple)
|