Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import ( | |
| BlipForQuestionAnswering, | |
| BlipProcessor, | |
| ViltForQuestionAnswering, | |
| ViltProcessor, | |
| ) | |
| class ModelManager: | |
| """ | |
| Class to manage loading and caching of various VQA models from Hugging Face | |
| """ | |
| def __init__(self, cache_dir=None): | |
| """ | |
| Initialize the model manager | |
| Args: | |
| cache_dir (str, optional): Directory to cache models. Defaults to None. | |
| """ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.cache_dir = cache_dir | |
| self.models = {} | |
| self.processors = {} | |
| # Print device being used | |
| print(f"Using device: {self.device}") | |
| def load_blip(self): | |
| """ | |
| Load BLIP model for VQA | |
| Returns: | |
| tuple: (processor, model) | |
| """ | |
| if "blip" not in self.models: | |
| print("Loading BLIP model for visual question answering...") | |
| # Load processor and model | |
| processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| # Move model to appropriate device | |
| model.to(self.device) | |
| # Store model and processor | |
| self.models["blip"] = model | |
| self.processors["blip"] = processor | |
| print("BLIP model loaded successfully!") | |
| return self.processors["blip"], self.models["blip"] | |
| def load_vilt(self): | |
| """ | |
| Load ViLT model for VQA | |
| Returns: | |
| tuple: (processor, model) | |
| """ | |
| if "vilt" not in self.models: | |
| print("Loading ViLT model for visual question answering...") | |
| # Load processor and model | |
| processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-vqa") | |
| model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-vqa") | |
| # Move model to appropriate device | |
| model.to(self.device) | |
| # Store model and processor | |
| self.models["vilt"] = model | |
| self.processors["vilt"] = processor | |
| print("ViLT model loaded successfully!") | |
| return self.processors["vilt"], self.models["vilt"] | |
| def get_model(self, model_name="blip"): | |
| """ | |
| Get a model by name | |
| Args: | |
| model_name (str, optional): Name of model to load. Defaults to "blip". | |
| Options: "blip", "vilt" | |
| Returns: | |
| tuple: (processor, model) | |
| """ | |
| if model_name.lower() == "blip": | |
| return self.load_blip() | |
| elif model_name.lower() == "vilt": | |
| return self.load_vilt() | |
| else: | |
| raise ValueError( | |
| f"Unknown model: {model_name}. Available models: blip, vilt" | |
| ) | |