Spaces:
Runtime error
Runtime error
| import cohere | |
| from annoy import AnnoyIndex | |
| import numpy as np | |
| import dotenv | |
| import os | |
| import pandas as pd | |
| dotenv.load_dotenv() | |
| model_name = "embed-english-v3.0" | |
| api_key = os.environ['COHERE_API_KEY'] | |
| input_type_embed = "search_document" | |
| # Set up the cohere client | |
| co = cohere.Client(api_key) | |
| # Get the dataset of topics | |
| topics = pd.read_csv("aicovers_topics.csv") | |
| # Get the embeddings | |
| list_embeds = co.embed(texts=list(topics['topic_cleaned']), model=model_name, input_type=input_type_embed).embeddings | |
| # Create the search index, pass the size of embedding | |
| search_index = AnnoyIndex(np.array(list_embeds).shape[1], metric='angular') | |
| # Add vectors to the search index | |
| for i in range(len(list_embeds)): | |
| search_index.add_item(i, list_embeds[i]) | |
| search_index.build(10) # 10 trees | |
| search_index.save('test.ann') | |
| def topic_from_caption(caption): | |
| """ | |
| Returns a topic from an uploaded list that is semantically similar to the input caption. | |
| Args: | |
| - caption (str): The image caption generated by MS Azure. | |
| Returns: | |
| - str: The extracted topic based on the provided caption. | |
| """ | |
| input_type_query = "search_query" | |
| caption_embed = co.embed(texts=[caption], model=model_name, input_type=input_type_query).embeddings # embeds a caption | |
| topic_ids = search_index.get_nns_by_vector(caption_embed[0], n=1, include_distances=True) # retrieves the nearest category | |
| topic = topics.iloc[topic_ids[0]]['topic_cleaned'].to_string(index=False, header=False) | |
| return topic | |