yrshi commited on
Commit
8b4913f
·
1 Parent(s): c55da95

first commit

Browse files
Files changed (6) hide show
  1. app.py +212 -53
  2. infer.py +150 -0
  3. install_cuda.sh +7 -0
  4. install_env.sh +9 -0
  5. retreival_launch.sh +11 -0
  6. retrieval_server.py +390 -0
app.py CHANGED
@@ -1,70 +1,229 @@
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def respond(
6
  message,
7
  history: list[dict[str, str]],
8
- system_message,
9
  max_tokens,
10
  temperature,
11
  top_p,
12
- hf_token: gr.OAuthToken,
13
  ):
14
  """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
16
  """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
+ import transformers
2
+ import torch
3
+ import requests
4
+ import re
5
  import gradio as gr
6
+ from threading import Thread
7
 
8
+ # --- Configuration --------------------------------------------------
9
+
10
+ # 1. DEFINE YOUR MODEL
11
+ model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"
12
+
13
+ # 2. !!! CRITICAL: UPDATE THIS URL !!!
14
+ # Your local 'http://127.0.0.1:8000/retrieve' will NOT work on Hugging Face.
15
+ # You must deploy your retrieval service and provide its public URL here.
16
+ RETRIEVER_URL = "http://127.0.0.1:8000/retrieve" # <-- UPDATE ME
17
+
18
+ # 3. MODEL & SEARCH CONSTANTS
19
+ curr_eos = [151645, 151643] # for Qwen2.5 series models
20
+ curr_search_template = '\n\n{output_text}<documents>{search_results}</documents>\n\n'
21
+ target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
22
+
23
+ # --- Global Model & Tokenizer Loading -------------------------------
24
+ # This happens once when the Space starts.
25
+ # Ensure your Space has a GPU assigned (e.g., T4, A10G).
26
+
27
+ print("Loading model and tokenizer...")
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+
30
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
31
+ model = transformers.AutoModelForCausalLM.from_pretrained(
32
+ model_id,
33
+ torch_dtype=torch.bfloat16,
34
+ device_map="auto"
35
+ )
36
+ print("Model and tokenizer loaded successfully.")
37
+
38
+ # --- Custom Stopping Criteria Class ---------------------------------
39
+
40
+ class StopOnSequence(transformers.StoppingCriteria):
41
+ def __init__(self, target_sequences, tokenizer):
42
+ self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
43
+ self.target_lengths = [len(target_id) for target_id in self.target_ids]
44
+ self._tokenizer = tokenizer
45
+
46
+ def __call__(self, input_ids, scores, **kwargs):
47
+ targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
48
+ if input_ids.shape[1] < min(self.target_lengths):
49
+ return False
50
+ for i, target in enumerate(targets):
51
+ if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
52
+ return True
53
+ return False
54
+
55
+ # Initialize stopping criteria globally
56
+ stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
57
+
58
+ # --- Helper Functions (Search & Parse) ------------------------------
59
+
60
+ def get_query(text):
61
+ pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
62
+ matches = pattern.findall(text)
63
+ return matches[-1] if matches else None
64
+
65
+ def search(query: str):
66
+ """
67
+ Calls your deployed retriever service.
68
+ """
69
+ payload = {"queries": [query], "topk": 3, "return_scores": True}
70
+
71
+ if RETRIEVER_URL == "http://127.0.0.1:8000/retrieve":
72
+ print("WARNING: Using default local retriever URL. This will likely fail.")
73
+ print("Please update RETRIEVER_URL in app.py to your deployed service.")
74
+
75
+ try:
76
+ response = requests.post(RETRIEVER_URL, json=payload, timeout=10)
77
+ response.raise_for_status() # Raise an error for bad responses
78
+ results = response.json()['result']
79
+
80
+ format_reference = ''
81
+ for idx, doc_item in enumerate(results[0]):
82
+ content = doc_item['document']['contents']
83
+ title = content.split("\n")[0]
84
+ text = "\n".join(content.split("\n")[1:])
85
+ format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
86
+ return format_reference
87
+
88
+ except requests.exceptions.RequestException as e:
89
+ print(f"Error calling retriever: {e}")
90
+ return f"Error: Could not retrieve search results for query: {query}"
91
+ except (KeyError, IndexError):
92
+ print("Error parsing retriever response")
93
+ return "Error: Malformed response from retriever."
94
+
95
+ # --- Main Gradio 'respond' Function ---------------------------------
96
 
97
  def respond(
98
  message,
99
  history: list[dict[str, str]],
100
+ system_message, # This is now our base prompt
101
  max_tokens,
102
  temperature,
103
  top_p,
104
+ hf_token: gr.OAuthToken = None, # Not used here, but in template
105
  ):
106
  """
107
+ This function implements your local multi-turn search logic as a
108
+ streaming generator for the Gradio interface.
109
  """
110
+
111
+ question = message.strip()
112
+
113
+ # Use the system_message from the UI as the base prompt
114
+ # Or, if empty, use your default.
115
+ if not system_message:
116
+ system_message = """You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
117
+ To answer questions, you must first reason through the available information using <think> and </think>. \
118
+ If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
119
+ After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
120
+ You may send multiple search requests if needed. \
121
+ Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>."""
122
+
123
+ prompt = f"{system_message} Question: {question}\n"
124
+
125
+ if tokenizer.chat_template:
126
+ # Apply chat template if it exists
127
+ # Note: Your logic builds the prompt manually, but this ensures
128
+ # correct special tokens if the model needs them.
129
+ chat_prompt = [{"role": "user", "content": prompt}]
130
+ prompt = tokenizer.apply_chat_template(chat_prompt, add_generation_prompt=True, tokenize=False)
131
+
132
+ # This string will accumulate the full agent trajectory
133
+ full_response_trajectory = ""
134
+
135
+ while True:
136
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
137
+ attention_mask = torch.ones_like(input_ids)
138
+
139
+ # Check for context overflow
140
+ if input_ids.shape[1] > model.config.max_position_embeddings - max_tokens:
141
+ print("Context limit reached.")
142
+ full_response_trajectory += "\n\n[Error: Context limit reached. Aborting.]"
143
+ yield full_response_trajectory
144
+ break
145
+
146
+ # Generate text with the stopping criteria
147
+ outputs = model.generate(
148
+ input_ids,
149
+ attention_mask=attention_mask,
150
+ max_new_tokens=max_tokens,
151
+ stopping_criteria=stopping_criteria,
152
+ pad_token_id=tokenizer.eos_token_id,
153
+ do_sample=True,
154
+ temperature=temperature,
155
+ top_p=top_p
156
+ )
157
+
158
+ # Decode the *newly* generated tokens
159
+ generated_token_ids = outputs[0][input_ids.shape[1]:]
160
+ output_text = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
161
+
162
+ # Check if generation ended with an EOS token
163
+ if outputs[0][-1].item() in curr_eos:
164
+ full_response_trajectory += output_text
165
+ yield full_response_trajectory # Yield the final text
166
+ break # Exit the loop
167
+
168
+ # --- Generation stopped at </search> ---
169
+
170
+ # Get the full text (prompt + new generation) to parse the *last* query
171
+ full_generation_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
172
+ query_text = get_query(full_generation_text)
173
+
174
+ if query_text:
175
+ search_results = search(query_text)
176
+ else:
177
+ search_results = 'Error: Stop token found but no <search> query was parsed.'
178
+
179
+ # Construct the text to append to the prompt
180
+ search_text = curr_search_template.format(
181
+ output_text=output_text,
182
+ search_results=search_results
183
+ )
184
+
185
+ # Append to the prompt for the next loop
186
+ prompt += search_text
187
+
188
+ # Append to the trajectory string and yield to the UI
189
+ full_response_trajectory += search_text
190
+ yield full_response_trajectory
191
+
192
+
193
+ # --- Gradio UI (Example) -------------------------------------------
194
+ # This part is just to make the file runnable.
195
+ # You can customize your Gradio UI as needed.
196
 
197
  with gr.Blocks() as demo:
198
+ gr.Markdown("# Multi-Turn Search Agent")
199
+ gr.Markdown(f"Running model: `{model_id}`")
200
+
201
+ with gr.Accordion("Prompt & Parameters"):
202
+ system_message = gr.Textbox(
203
+ label="System Message",
204
+ value="""You are a helpful assistant... (full prompt from code)""",
205
+ lines=10
206
+ )
207
+ max_tokens = gr.Slider(50, 2048, value=1024, label="Max New Tokens")
208
+ temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
209
+ top_p = gr.Slider(0.1, 1.0, value=1.0, label="Top-p")
210
+
211
+ chatbot = gr.Chatbot(label="Agent Trajectory")
212
+ msg = gr.Textbox(label="Your Question")
213
+
214
+ def user_turn(user_message, history):
215
+ return "", history + [[user_message, None]]
216
 
217
+ msg.submit(
218
+ user_turn,
219
+ [msg, chatbot],
220
+ [msg, chatbot],
221
+ queue=False
222
+ ).then(
223
+ respond,
224
+ [msg, chatbot, system_message, max_tokens, temperature, top_p],
225
+ chatbot
226
+ )
227
 
228
  if __name__ == "__main__":
229
+ demo.queue().launch(debug=True)
infer.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import torch
3
+ import requests
4
+ import re
5
+
6
+ question_list = [
7
+ "Who was born first out of Cameron Mitchell (Singer) and Léopold De Saussure?", # Ground Truth: "Léopold De Saussure"
8
+ "The Clavivox was invented by an American composer who was born Harry Warnow in what year?", # Ground Truth: "1908"
9
+ "Which movie did Disney produce first, The Many Adventures of Winnie the Pooh or Ride a Wild Pony?", # Ground Truth: "Ride a Wild Pony"
10
+ "Who is the sibling of the author of Kapalkundala?", # Ground Truth: "Sanjib Chandra" or "Sanjib Chandra Chattopadhyay"
11
+ ]
12
+
13
+ # Model ID and device setup
14
+ model_id = "yrshi/AutoRefine-Qwen2.5-3B-Base"
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ curr_eos = [151645, 151643] # for Qwen2.5 series models
18
+ curr_search_template = '{output_text}\n\n<documents>{search_results}</documents>\n\n'
19
+
20
+ # Initialize the tokenizer and model
21
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)
22
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
23
+
24
+ # Define the custom stopping criterion
25
+ class StopOnSequence(transformers.StoppingCriteria):
26
+ def __init__(self, target_sequences, tokenizer):
27
+ # Encode the string so we have the exact token-IDs pattern
28
+ self.target_ids = [tokenizer.encode(target_sequence, add_special_tokens=False) for target_sequence in target_sequences]
29
+ self.target_lengths = [len(target_id) for target_id in self.target_ids]
30
+ self._tokenizer = tokenizer
31
+
32
+ def __call__(self, input_ids, scores, **kwargs):
33
+ # Make sure the target IDs are on the same device
34
+ targets = [torch.as_tensor(target_id, device=input_ids.device) for target_id in self.target_ids]
35
+
36
+ if input_ids.shape[1] < min(self.target_lengths):
37
+ return False
38
+
39
+ # Compare the tail of input_ids with our target_ids
40
+ for i, target in enumerate(targets):
41
+ if torch.equal(input_ids[0, -self.target_lengths[i]:], target):
42
+ return True
43
+
44
+ return False
45
+
46
+ def get_query(text):
47
+ import re
48
+ pattern = re.compile(r"<search>(.*?)</search>", re.DOTALL)
49
+ matches = pattern.findall(text)
50
+ if matches:
51
+ return matches[-1]
52
+ else:
53
+ return None
54
+
55
+ def search(query: str):
56
+ payload = {
57
+ "queries": [query],
58
+ "topk": 3,
59
+ "return_scores": True
60
+ }
61
+ results = requests.post("http://127.0.0.1:8000/retrieve", json=payload).json()['result']
62
+
63
+ def _passages2string(retrieval_result):
64
+ format_reference = ''
65
+ for idx, doc_item in enumerate(retrieval_result):
66
+
67
+ content = doc_item['document']['contents']
68
+ title = content.split("\n")[0]
69
+ text = "\n".join(content.split("\n")[1:])
70
+ format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
71
+ return format_reference
72
+
73
+ return _passages2string(results[0])
74
+
75
+
76
+ # Initialize the stopping criteria
77
+ target_sequences = ["</search>", " </search>", "</search>\n", " </search>\n", "</search>\n\n", " </search>\n\n"]
78
+ stopping_criteria = transformers.StoppingCriteriaList([StopOnSequence(target_sequences, tokenizer)])
79
+
80
+
81
+ def run_search(question):
82
+ question = question.strip()
83
+ cnt = 0
84
+ trajectory = []
85
+
86
+ # Prepare the message
87
+ prompt = f"""You are a helpful assistant excel at answering questions with multi-turn search engine calling. \
88
+ To answer questions, you must first reason through the available information using <think> and </think>. \
89
+ If you identify missing knowledge, you may issue a search request using <search> query </search> at any time. The retrieval system will provide you with the three most relevant documents enclosed in <documents> and </documents>. \
90
+ After each search, you need to summarize and refine the existing documents in <refine> and </refine>. \
91
+ You may send multiple search requests if needed. \
92
+ Once you have sufficient information, provide a concise final answer using <answer> and </answer>. For example, <answer> Donald Trump </answer>. Question: {question}\n"""
93
+
94
+
95
+ if tokenizer.chat_template:
96
+ prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True, tokenize=False)
97
+
98
+ print(prompt)
99
+ # Encode the chat-formatted prompt and move it to the correct device
100
+ while True:
101
+ input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
102
+ attention_mask = torch.ones_like(input_ids)
103
+
104
+ # Generate text with the stopping criteria
105
+ outputs = model.generate(
106
+ input_ids,
107
+ attention_mask=attention_mask,
108
+ max_new_tokens=1024,
109
+ stopping_criteria=stopping_criteria,
110
+ pad_token_id=tokenizer.eos_token_id,
111
+ do_sample=True,
112
+ temperature=0.7
113
+ )
114
+
115
+ if outputs[0][-1].item() in curr_eos:
116
+ generated_tokens = outputs[0][input_ids.shape[1]:]
117
+ output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
118
+ trajectory.append(output_text)
119
+ print(output_text)
120
+ break
121
+
122
+ generated_tokens = outputs[0][input_ids.shape[1]:]
123
+ output_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
124
+
125
+ query_text = get_query(tokenizer.decode(outputs[0], skip_special_tokens=True))
126
+ if query_text:
127
+ search_results = search(query_text)
128
+ else:
129
+ search_results = ''
130
+
131
+ search_text = curr_search_template.format(output_text=output_text.strip(), search_results=search_results.strip())
132
+ prompt += search_text
133
+ cnt += 1
134
+ print(search_text)
135
+ trajectory.append(search_text)
136
+ print(f"Total iterations: {cnt}")
137
+ answer_pattern = re.compile(r"<answer>(.*?)</answer>", re.DOTALL)
138
+ answer_match = answer_pattern.search(trajectory[-1])
139
+ if answer_match:
140
+ final_answer = answer_match.group(1).strip()
141
+ print(f"Final answer found: {final_answer}")
142
+ else:
143
+ print("No final answer found in the output.")
144
+ final_answer = "No final answer found."
145
+ return ''.join([text for text in trajectory]), final_answer
146
+
147
+ if __name__ == "__main__":
148
+ output_text, final_answer = run_search(question_list[0])
149
+ print(f"Output trajectory: {output_text}")
150
+ print(f"Final answer: {final_answer}")
install_cuda.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ mkdir -p ~/miniconda3
2
+ wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
3
+ bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
4
+ rm ~/miniconda3/miniconda.sh
5
+
6
+ source ~/miniconda3/bin/activate
7
+ conda init --all
install_env.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ conda create -n faiss_env python=3.10
2
+ conda activate faiss_env
3
+
4
+ conda install pytorch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 pytorch-cuda=12.1 -c pytorch -c nvidia
5
+ pip install transformers datasets pyserini
6
+
7
+ conda install -c pytorch -c nvidia faiss-gpu=1.8.0
8
+
9
+ pip install uvicorn fastapi
retreival_launch.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ file_path=./data
3
+ index_file=$file_path/e5_Flat.index
4
+ corpus_file=$file_path/wiki-18.jsonl
5
+ retriever=intfloat/e5-base-v2
6
+
7
+ export CUDA_VISIBLE_DEVICES="1,3"
8
+ python search_r1/search/retrieval_server.py --index_path $index_file \
9
+ --corpus_path $corpus_file \
10
+ --topk 3 \
11
+ --retriever_model $retriever
retrieval_server.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import warnings
4
+ from typing import List, Dict, Optional
5
+ import argparse
6
+
7
+ import faiss
8
+ import torch
9
+ import numpy as np
10
+ from transformers import AutoConfig, AutoTokenizer, AutoModel
11
+ from tqdm import tqdm
12
+ import datasets
13
+
14
+ import uvicorn
15
+ from fastapi import FastAPI
16
+ from pydantic import BaseModel
17
+
18
+
19
+ parser = argparse.ArgumentParser(description="Launch the local faiss retriever.")
20
+ parser.add_argument("--index_path", type=str, help="Corpus indexing file.")
21
+ parser.add_argument("--corpus_path", type=str, help="Local corpus file.")
22
+ parser.add_argument("--topk", type=int, default=3, help="Number of retrieved passages for one query.")
23
+ parser.add_argument("--retriever_model", type=str, default="intfloat/e5-base-v2", help="Name of the retriever model.")
24
+
25
+ args = parser.parse_args()
26
+
27
+ def load_corpus(corpus_path: str):
28
+ corpus = datasets.load_dataset(
29
+ 'json',
30
+ data_files=corpus_path,
31
+ split="train",
32
+ num_proc=4
33
+ )
34
+ return corpus
35
+
36
+ def read_jsonl(file_path):
37
+ data = []
38
+ with open(file_path, "r") as f:
39
+ for line in f:
40
+ data.append(json.loads(line))
41
+ return data
42
+
43
+ def load_docs(corpus, doc_idxs):
44
+ results = [corpus[int(idx)] for idx in doc_idxs]
45
+ return results
46
+
47
+ def load_model(model_path: str, use_fp16: bool = False):
48
+ model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
49
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
50
+ model.eval()
51
+ model.cuda()
52
+ if use_fp16:
53
+ model = model.half()
54
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True)
55
+ return model, tokenizer
56
+
57
+ def pooling(
58
+ pooler_output,
59
+ last_hidden_state,
60
+ attention_mask = None,
61
+ pooling_method = "mean"
62
+ ):
63
+ if pooling_method == "mean":
64
+ last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
65
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
66
+ elif pooling_method == "cls":
67
+ return last_hidden_state[:, 0]
68
+ elif pooling_method == "pooler":
69
+ return pooler_output
70
+ else:
71
+ raise NotImplementedError("Pooling method not implemented!")
72
+
73
+ class Encoder:
74
+ def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16):
75
+ self.model_name = model_name
76
+ self.model_path = model_path
77
+ self.pooling_method = pooling_method
78
+ self.max_length = max_length
79
+ self.use_fp16 = use_fp16
80
+
81
+ self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16)
82
+ self.model.eval()
83
+
84
+ @torch.no_grad()
85
+ def encode(self, query_list: List[str], is_query=True) -> np.ndarray:
86
+ # processing query for different encoders
87
+ if isinstance(query_list, str):
88
+ query_list = [query_list]
89
+
90
+ if "e5" in self.model_name.lower():
91
+ if is_query:
92
+ query_list = [f"query: {query}" for query in query_list]
93
+ else:
94
+ query_list = [f"passage: {query}" for query in query_list]
95
+
96
+ if "bge" in self.model_name.lower():
97
+ if is_query:
98
+ query_list = [f"Represent this sentence for searching relevant passages: {query}" for query in query_list]
99
+
100
+ inputs = self.tokenizer(query_list,
101
+ max_length=self.max_length,
102
+ padding=True,
103
+ truncation=True,
104
+ return_tensors="pt"
105
+ )
106
+ inputs = {k: v.cuda() for k, v in inputs.items()}
107
+
108
+ if "T5" in type(self.model).__name__:
109
+ # T5-based retrieval model
110
+ decoder_input_ids = torch.zeros(
111
+ (inputs['input_ids'].shape[0], 1), dtype=torch.long
112
+ ).to(inputs['input_ids'].device)
113
+ output = self.model(
114
+ **inputs, decoder_input_ids=decoder_input_ids, return_dict=True
115
+ )
116
+ query_emb = output.last_hidden_state[:, 0, :]
117
+ else:
118
+ output = self.model(**inputs, return_dict=True)
119
+ query_emb = pooling(output.pooler_output,
120
+ output.last_hidden_state,
121
+ inputs['attention_mask'],
122
+ self.pooling_method)
123
+ if "dpr" not in self.model_name.lower():
124
+ query_emb = torch.nn.functional.normalize(query_emb, dim=-1)
125
+
126
+ query_emb = query_emb.detach().cpu().numpy()
127
+ query_emb = query_emb.astype(np.float32, order="C")
128
+
129
+ del inputs, output
130
+ torch.cuda.empty_cache()
131
+
132
+ return query_emb
133
+
134
+ class BaseRetriever:
135
+ def __init__(self, config):
136
+ self.config = config
137
+ self.retrieval_method = config.retrieval_method
138
+ self.topk = config.retrieval_topk
139
+
140
+ self.index_path = config.index_path
141
+ self.corpus_path = config.corpus_path
142
+
143
+ def _search(self, query: str, num: int, return_score: bool):
144
+ raise NotImplementedError
145
+
146
+ def _batch_search(self, query_list: List[str], num: int, return_score: bool):
147
+ raise NotImplementedError
148
+
149
+ def search(self, query: str, num: int = None, return_score: bool = False):
150
+ return self._search(query, num, return_score)
151
+
152
+ def batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
153
+ return self._batch_search(query_list, num, return_score)
154
+
155
+ class BM25Retriever(BaseRetriever):
156
+ def __init__(self, config):
157
+ super().__init__(config)
158
+ from pyserini.search.lucene import LuceneSearcher
159
+ self.searcher = LuceneSearcher(self.index_path)
160
+ self.contain_doc = self._check_contain_doc()
161
+ if not self.contain_doc:
162
+ self.corpus = load_corpus(self.corpus_path)
163
+ self.max_process_num = 8
164
+
165
+ def _check_contain_doc(self):
166
+ return self.searcher.doc(0).raw() is not None
167
+
168
+ def _search(self, query: str, num: int = None, return_score: bool = False):
169
+ if num is None:
170
+ num = self.topk
171
+ hits = self.searcher.search(query, num)
172
+ if len(hits) < 1:
173
+ if return_score:
174
+ return [], []
175
+ else:
176
+ return []
177
+ scores = [hit.score for hit in hits]
178
+ if len(hits) < num:
179
+ warnings.warn('Not enough documents retrieved!')
180
+ else:
181
+ hits = hits[:num]
182
+
183
+ if self.contain_doc:
184
+ all_contents = [
185
+ json.loads(self.searcher.doc(hit.docid).raw())['contents']
186
+ for hit in hits
187
+ ]
188
+ results = [
189
+ {
190
+ 'title': content.split("\n")[0].strip("\""),
191
+ 'text': "\n".join(content.split("\n")[1:]),
192
+ 'contents': content
193
+ }
194
+ for content in all_contents
195
+ ]
196
+ else:
197
+ results = load_docs(self.corpus, [hit.docid for hit in hits])
198
+
199
+ if return_score:
200
+ return results, scores
201
+ else:
202
+ return results
203
+
204
+ def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
205
+ results = []
206
+ scores = []
207
+ for query in query_list:
208
+ item_result, item_score = self._search(query, num, True)
209
+ results.append(item_result)
210
+ scores.append(item_score)
211
+ if return_score:
212
+ return results, scores
213
+ else:
214
+ return results
215
+
216
+ class DenseRetriever(BaseRetriever):
217
+ def __init__(self, config):
218
+ super().__init__(config)
219
+ self.index = faiss.read_index(self.index_path)
220
+ if config.faiss_gpu:
221
+ co = faiss.GpuMultipleClonerOptions()
222
+ co.useFloat16 = True
223
+ co.shard = True
224
+ self.index = faiss.index_cpu_to_all_gpus(self.index, co=co)
225
+
226
+ self.corpus = load_corpus(self.corpus_path)
227
+ self.encoder = Encoder(
228
+ model_name = self.retrieval_method,
229
+ model_path = config.retrieval_model_path,
230
+ pooling_method = config.retrieval_pooling_method,
231
+ max_length = config.retrieval_query_max_length,
232
+ use_fp16 = config.retrieval_use_fp16
233
+ )
234
+ self.topk = config.retrieval_topk
235
+ self.batch_size = config.retrieval_batch_size
236
+
237
+ def _search(self, query: str, num: int = None, return_score: bool = False):
238
+ if num is None:
239
+ num = self.topk
240
+ query_emb = self.encoder.encode(query)
241
+ scores, idxs = self.index.search(query_emb, k=num)
242
+ idxs = idxs[0]
243
+ scores = scores[0]
244
+ results = load_docs(self.corpus, idxs)
245
+ if return_score:
246
+ return results, scores.tolist()
247
+ else:
248
+ return results
249
+
250
+ def _batch_search(self, query_list: List[str], num: int = None, return_score: bool = False):
251
+ if isinstance(query_list, str):
252
+ query_list = [query_list]
253
+ if num is None:
254
+ num = self.topk
255
+
256
+ results = []
257
+ scores = []
258
+ for start_idx in tqdm(range(0, len(query_list), self.batch_size), desc='Retrieval process: '):
259
+ query_batch = query_list[start_idx:start_idx + self.batch_size]
260
+ batch_emb = self.encoder.encode(query_batch)
261
+ batch_scores, batch_idxs = self.index.search(batch_emb, k=num)
262
+ batch_scores = batch_scores.tolist()
263
+ batch_idxs = batch_idxs.tolist()
264
+
265
+ # load_docs is not vectorized, but is a python list approach
266
+ flat_idxs = sum(batch_idxs, [])
267
+ batch_results = load_docs(self.corpus, flat_idxs)
268
+ # chunk them back
269
+ batch_results = [batch_results[i*num : (i+1)*num] for i in range(len(batch_idxs))]
270
+
271
+ results.extend(batch_results)
272
+ scores.extend(batch_scores)
273
+
274
+ del batch_emb, batch_scores, batch_idxs, query_batch, flat_idxs, batch_results
275
+ torch.cuda.empty_cache()
276
+
277
+ if return_score:
278
+ return results, scores
279
+ else:
280
+ return results
281
+
282
+ def get_retriever(config):
283
+ if config.retrieval_method == "bm25":
284
+ return BM25Retriever(config)
285
+ else:
286
+ return DenseRetriever(config)
287
+
288
+
289
+ #####################################
290
+ # FastAPI server below
291
+ #####################################
292
+
293
+ class Config:
294
+ """
295
+ Minimal config class (simulating your argparse)
296
+ Replace this with your real arguments or load them dynamically.
297
+ """
298
+ def __init__(
299
+ self,
300
+ retrieval_method: str = "bm25",
301
+ retrieval_topk: int = 10,
302
+ index_path: str = "./index/bm25",
303
+ corpus_path: str = "./data/corpus.jsonl",
304
+ dataset_path: str = "./data",
305
+ data_split: str = "train",
306
+ faiss_gpu: bool = True,
307
+ retrieval_model_path: str = "./model",
308
+ retrieval_pooling_method: str = "mean",
309
+ retrieval_query_max_length: int = 256,
310
+ retrieval_use_fp16: bool = False,
311
+ retrieval_batch_size: int = 128
312
+ ):
313
+ self.retrieval_method = retrieval_method
314
+ self.retrieval_topk = retrieval_topk
315
+ self.index_path = index_path
316
+ self.corpus_path = corpus_path
317
+ self.dataset_path = dataset_path
318
+ self.data_split = data_split
319
+ self.faiss_gpu = faiss_gpu
320
+ self.retrieval_model_path = retrieval_model_path
321
+ self.retrieval_pooling_method = retrieval_pooling_method
322
+ self.retrieval_query_max_length = retrieval_query_max_length
323
+ self.retrieval_use_fp16 = retrieval_use_fp16
324
+ self.retrieval_batch_size = retrieval_batch_size
325
+
326
+
327
+ class QueryRequest(BaseModel):
328
+ queries: List[str]
329
+ topk: Optional[int] = None
330
+ return_scores: bool = False
331
+
332
+
333
+ app = FastAPI()
334
+
335
+ # 1) Build a config (could also parse from arguments).
336
+ # In real usage, you'd parse your CLI arguments or environment variables.
337
+ config = Config(
338
+ retrieval_method = "e5", # or "dense"
339
+ index_path=args.index_path,
340
+ corpus_path=args.corpus_path,
341
+ retrieval_topk=args.topk,
342
+ faiss_gpu=True,
343
+ retrieval_model_path=args.retriever_model,
344
+ retrieval_pooling_method="mean",
345
+ retrieval_query_max_length=256,
346
+ retrieval_use_fp16=True,
347
+ retrieval_batch_size=512,
348
+ )
349
+
350
+ # 2) Instantiate a global retriever so it is loaded once and reused.
351
+ retriever = get_retriever(config)
352
+
353
+ @app.post("/retrieve")
354
+ def retrieve_endpoint(request: QueryRequest):
355
+ """
356
+ Endpoint that accepts queries and performs retrieval.
357
+ Input format:
358
+ {
359
+ "queries": ["What is Python?", "Tell me about neural networks."],
360
+ "topk": 3,
361
+ "return_scores": true
362
+ }
363
+ """
364
+ if not request.topk:
365
+ request.topk = config.retrieval_topk # fallback to default
366
+
367
+ # Perform batch retrieval
368
+ results, scores = retriever.batch_search(
369
+ query_list=request.queries,
370
+ num=request.topk,
371
+ return_score=request.return_scores
372
+ )
373
+
374
+ # Format response
375
+ resp = []
376
+ for i, single_result in enumerate(results):
377
+ if request.return_scores:
378
+ # If scores are returned, combine them with results
379
+ combined = []
380
+ for doc, score in zip(single_result, scores[i]):
381
+ combined.append({"document": doc, "score": score})
382
+ resp.append(combined)
383
+ else:
384
+ resp.append(single_result)
385
+ return {"result": resp}
386
+
387
+
388
+ if __name__ == "__main__":
389
+ # 3) Launch the server. By default, it listens on http://127.0.0.1:8000
390
+ uvicorn.run(app, host="0.0.0.0", port=8000)