|
|
import json |
|
|
import numpy as np |
|
|
from typing import Any, Dict, List |
|
|
from rank_bm25 import BM25Okapi |
|
|
|
|
|
class BM25Retriever: |
|
|
def __init__(self, mode="instruction"): |
|
|
assert mode in ("instruction", "code") |
|
|
self.bm25: BM25Okapi = None |
|
|
self.content_input_path: str = "" |
|
|
self.mode = mode |
|
|
|
|
|
def process(self, content_input_path: str): |
|
|
self.content_input_path = content_input_path |
|
|
with open(content_input_path, "r", encoding="utf-8") as f: |
|
|
content = json.load(f) |
|
|
|
|
|
|
|
|
self.chunks = [] |
|
|
self.corpus = [] |
|
|
for c in content: |
|
|
self.chunks.append(c["code"]) |
|
|
self.corpus.append(c["description_1"]) |
|
|
|
|
|
if self.mode == "instruction" and self.corpus: |
|
|
tokenized_corpus = [co.split(" ") for co in self.corpus] |
|
|
self.bm25 = BM25Okapi(tokenized_corpus) |
|
|
elif self.mode == "code" and self.chunks: |
|
|
tokenized_corpus = [co.split(" ") for co in self.chunks] |
|
|
self.bm25 = BM25Okapi(tokenized_corpus) |
|
|
else: |
|
|
self.bm25 = None |
|
|
|
|
|
def query( |
|
|
self, |
|
|
query: str, |
|
|
top_k: int = 1 |
|
|
) -> List[Dict[str, Any]]: |
|
|
|
|
|
if top_k <= 0: |
|
|
raise ValueError("top_k must be a positive integer.") |
|
|
if self.bm25 is None or not self.chunks: |
|
|
raise ValueError( |
|
|
"BM25 model is not initialized. Call `process` first." |
|
|
) |
|
|
|
|
|
|
|
|
processed_query = query.split(" ") |
|
|
|
|
|
scores = self.bm25.get_scores(processed_query) |
|
|
|
|
|
top_k_indices = np.argpartition(scores, -top_k)[-top_k:] |
|
|
|
|
|
formatted_results = [] |
|
|
for i in top_k_indices: |
|
|
result_dict = { |
|
|
"similarity score": scores[i], |
|
|
"original instruction": self.corpus[i], |
|
|
"code": self.chunks[i] |
|
|
} |
|
|
formatted_results.append(result_dict) |
|
|
|
|
|
|
|
|
formatted_results.sort( |
|
|
key=lambda x: x['similarity score'], reverse=True |
|
|
) |
|
|
|
|
|
return formatted_results |