fokan commited on
Commit
89f2683
·
verified ·
1 Parent(s): f91a057

Upload 2 files

Browse files
Files changed (1) hide show
  1. utils/cache_manager.py +35 -16
utils/cache_manager.py CHANGED
@@ -1,8 +1,27 @@
1
  from functools import lru_cache
2
  from typing import Iterable, List, Tuple
3
 
 
 
 
4
 
5
  BATCH_SIZE = 50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
@@ -12,23 +31,19 @@ def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
12
 
13
 
14
  @lru_cache(maxsize=5)
15
- def cached_inference(image_path, labels, model, processor):
16
- import torch
17
- from PIL import Image
18
 
19
- label_tuple: Tuple[str, ...] = _ensure_tuple(labels)
20
-
21
- with Image.open(image_path).convert("RGB") as img:
22
- tensor_image = img.copy()
23
-
24
- device = next(model.parameters()).device
25
- dtype = next(model.parameters()).dtype
26
  logits: List[float] = []
27
 
28
  with torch.no_grad():
29
  for start in range(0, len(label_tuple), BATCH_SIZE):
30
- batch = label_tuple[start : start + BATCH_SIZE]
31
- inputs = processor(images=tensor_image, text=list(batch), return_tensors="pt", padding=True)
32
 
33
  prepared = {}
34
  for key, value in inputs.items():
@@ -40,10 +55,14 @@ def cached_inference(image_path, labels, model, processor):
40
  else:
41
  prepared[key] = value
42
 
43
- outputs = model(**prepared)
44
- batch_logits = outputs.logits_per_image[0].detach().cpu().tolist()
45
- logits.extend(batch_logits)
46
 
47
- tensor_image.close()
48
  scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
49
  return scores
 
 
 
 
 
 
1
  from functools import lru_cache
2
  from typing import Iterable, List, Tuple
3
 
4
+ import torch
5
+ from PIL import Image
6
+
7
 
8
  BATCH_SIZE = 50
9
+ _CACHE_MODEL = None
10
+ _CACHE_PROCESSOR = None
11
+
12
+
13
+ def configure_cache(model, processor) -> None:
14
+ """Bind the shared model and processor for cached inference."""
15
+ global _CACHE_MODEL, _CACHE_PROCESSOR
16
+ _CACHE_MODEL = model
17
+ _CACHE_PROCESSOR = processor
18
+
19
+
20
+ def preprocess_image(img_path: str) -> Image.Image:
21
+ img = Image.open(img_path)
22
+ img = img.convert("RGB")
23
+ img.thumbnail((448, 448))
24
+ return img
25
 
26
 
27
  def _ensure_tuple(labels: Iterable[str]) -> Tuple[str, ...]:
 
31
 
32
 
33
  @lru_cache(maxsize=5)
34
+ def _cached_logits(image_path: str, label_tuple: Tuple[str, ...]) -> List[float]:
35
+ if _CACHE_MODEL is None or _CACHE_PROCESSOR is None:
36
+ raise RuntimeError("Cache manager not configured with model and processor.")
37
 
38
+ device = next(_CACHE_MODEL.parameters()).device
39
+ dtype = next(_CACHE_MODEL.parameters()).dtype
40
+ image = preprocess_image(image_path)
 
 
 
 
41
  logits: List[float] = []
42
 
43
  with torch.no_grad():
44
  for start in range(0, len(label_tuple), BATCH_SIZE):
45
+ batch = list(label_tuple[start : start + BATCH_SIZE])
46
+ inputs = _CACHE_PROCESSOR(images=image, text=batch, return_tensors="pt", padding=True)
47
 
48
  prepared = {}
49
  for key, value in inputs.items():
 
55
  else:
56
  prepared[key] = value
57
 
58
+ outputs = _CACHE_MODEL(**prepared)
59
+ logits.extend(outputs.logits_per_image[0].detach().cpu().tolist())
 
60
 
61
+ image.close()
62
  scores = torch.softmax(torch.tensor(logits), dim=0).tolist()
63
  return scores
64
+
65
+
66
+ def cached_inference(image_path: str, labels: Iterable[str]) -> List[float]:
67
+ label_tuple = _ensure_tuple(labels)
68
+ return _cached_logits(image_path, label_tuple)