File size: 2,261 Bytes
060ca74
ac43fd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
060ca74
 
 
 
 
ac43fd0
 
 
 
 
060ca74
ac43fd0
 
060ca74
 
ac43fd0
 
060ca74
 
a8d3569
060ca74
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import cv2
import numpy as np
from PIL import Image
from huggingface_hub import hf_hub_download
from lang_sam import LangSAM


class AutoInpaintMaskGenerator:
    def __init__(
        self,
        langsam_model: LangSAM | None = None,
    ):
        """
        langsam_model: an instance of LangSAM already loaded
        threshold: mask score threshold for filtering masks
        mask_selection:
            - "best": use the highest-scoring mask only
            - "union": combine all masks passing threshold
        """
        if langsam_model is None:
            sam_path = hf_hub_download(
                repo_id="facebook/sam2.1-hiera-large",
                filename="sam2.1_hiera_large.pt",
            )
            langsam_model = LangSAM(
                "sam2.1_hiera_large",
                sam_path,
            )

        self.model = langsam_model

    def generate_mask(
        self,
        image: Image.Image,
        prompt: str,
        threshold: float = 0.3,
    ) -> np.ndarray:
        """
        Generate a binary mask for inpainting.

        Returns:
            A 2D P (dtype=uint8), with 255 for masked regions and 0 elsewhere.
        """
        result = self.model.predict(
            texts_prompt=[prompt],
            images_pil=[image],
        )[0]

        masks = result["masks"]  # (N, H, W)
        scores = np.atleast_1d(result["mask_scores"]) # Ensure it's always at least 1D

        # If only one mask returned, expand dims
        if masks.ndim == 2:
            masks = masks[np.newaxis, :, :]  # Make it (1, H, W)

        if len(masks) == 0:
            raise ValueError("No masks found.")

        # Filter masks by score threshold
        valid_indices = scores >= threshold
        if len(valid_indices) == 0:
            raise ValueError("No masks scored the required threshold.")
        
        combined_mask = np.any(masks[valid_indices], axis=0)

        # Convert to uint8 binary mask for inpainting
        binary_mask = (combined_mask.astype(np.uint8)) * 255  # 0 or 255

        # Apply dilation, to give more flexibility to the inpainting model
        kernel = np.ones((10, 10), np.uint8)
        dilated_mask = cv2.dilate(binary_mask, kernel, iterations=1)
        return dilated_mask