ryomo commited on
Commit
056e98d
·
1 Parent(s): e20b84b

fix: update generate_stream to support async streaming for Modal and ZeroGPU

Browse files
Files changed (1) hide show
  1. src/unpredictable_lord/chat/chat.py +66 -99
src/unpredictable_lord/chat/chat.py CHANGED
@@ -30,15 +30,19 @@ if USE_MODAL:
30
  APP_NAME = "unpredictable-lord"
31
  _generate_stream = modal.Function.from_name(APP_NAME, "generate_stream")
32
 
33
- def generate_stream(input_tokens):
34
- logger.info("Calling Modal LLM generate_stream")
35
- return _generate_stream.remote_gen(input_tokens)
 
36
  else:
37
  from unpredictable_lord.chat.llm_zerogpu import generate_stream as _generate_stream
38
 
39
- def generate_stream(input_tokens):
40
- logger.info("Calling ZeroGPU LLM generate_stream")
41
- return _generate_stream(input_tokens)
 
 
 
42
 
43
 
44
  def _get_encoding():
@@ -46,16 +50,8 @@ def _get_encoding():
46
  return oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
47
 
48
 
49
- def _build_developer_message(session_id: str, personality: str) -> oh.Message:
50
- """Build developer message with system prompt and tool definitions.
51
-
52
- Args:
53
- session_id: The game session ID.
54
- personality: The lord's personality type.
55
-
56
- Returns:
57
- Developer message with instructions and tool definitions.
58
- """
59
  personality_desc = PERSONALITY_DESCRIPTIONS.get(personality, "")
60
 
61
  system_prompt = f"""You are a {personality} lord of a medieval fantasy kingdom.
@@ -96,14 +92,7 @@ What counsel do you offer, advisor? Shall we address their grievances or press o
96
 
97
 
98
  def _convert_history_to_messages(chat_history: list[dict]) -> list[oh.Message]:
99
- """Convert Gradio chat history to openai-harmony messages.
100
-
101
- Args:
102
- chat_history: Chat history in Gradio format.
103
-
104
- Returns:
105
- List of openai-harmony messages.
106
- """
107
  messages = []
108
  for msg in chat_history:
109
  if msg["role"] == "user":
@@ -117,67 +106,51 @@ def _convert_history_to_messages(chat_history: list[dict]) -> list[oh.Message]:
117
  return messages
118
 
119
 
120
- def _stream_llm_response(messages: list[oh.Message], encoding):
121
- """Stream LLM response and return full text with parsed messages.
122
 
123
- Args:
124
- messages: List of messages to send to LLM.
125
- encoding: Harmony encoding instance.
126
 
127
- Yields:
128
- Tuple of (response_text, parsed_messages or None).
129
- """
 
 
 
 
 
 
130
  convo = oh.Conversation.from_messages(messages)
131
  input_tokens = encoding.render_conversation_for_completion(convo, oh.Role.ASSISTANT)
132
 
133
  parser = oh.StreamableParser(encoding, role=oh.Role.ASSISTANT)
 
134
 
135
- response_text = ""
136
- token_count = 0
137
- parser_error = False
138
- all_content = "" # Capture all content regardless of channel
139
-
140
- for token in generate_stream(input_tokens):
141
  if token is None:
142
  continue
143
- token_count += 1
144
  try:
145
  parser.process(token)
146
- except oh.HarmonyError as e:
147
- # Parser error - LLM generated invalid format (e.g., after tool result)
148
- # This can happen when LLM copies tool call patterns in its response
149
- logger.warning(f"Parser error at token {token_count}: {e}")
150
- logger.warning("Treating this as end of valid response")
151
- parser_error = True
152
- break
153
  except Exception as e:
154
- logger.error(
155
- f"Unexpected parser error at token {token_count} (token={token}): {e}"
156
- )
157
- raise
158
-
159
- # Get content from any channel for fallback
160
- delta = parser.last_content_delta
161
- if delta:
162
- all_content += delta
163
-
164
- # Get content only from final channel for display
165
- if parser.current_channel == "final" and delta:
166
- response_text += delta
167
- yield response_text, None
168
-
169
- # Finish parsing and return parsed messages
170
- if not parser_error:
171
- parser.process_eos()
172
- parsed_messages = parser.messages
173
- else:
174
- # On parser error, return empty list to stop tool calling loop
175
- # Use all_content as fallback if response_text is empty
176
- if not response_text and all_content:
177
- response_text = all_content
178
- logger.info(f"Using fallback content (length: {len(all_content)})")
179
- parsed_messages = []
180
- yield response_text, parsed_messages
181
 
182
 
183
  async def chat_with_mcp_tools(
@@ -189,8 +162,6 @@ async def chat_with_mcp_tools(
189
  """
190
  Chat with LLM with MCP tool support (async streaming version).
191
 
192
- This version includes tool calling capabilities for game interactions.
193
-
194
  Args:
195
  user_message: User's message
196
  chat_history: Past chat history (list of dictionaries in Gradio format)
@@ -203,8 +174,8 @@ async def chat_with_mcp_tools(
203
  try:
204
  encoding = _get_encoding()
205
 
206
- # Build messages with tool definitions
207
- messages = [_build_developer_message(session_id, personality)]
208
  messages.extend(_convert_history_to_messages(chat_history))
209
  messages.append(oh.Message.from_role_and_content(oh.Role.USER, user_message))
210
 
@@ -220,29 +191,17 @@ async def chat_with_mcp_tools(
220
  for iteration in range(MAX_TOOL_CALL_ITERATIONS):
221
  logger.info(f"Tool calling iteration {iteration + 1}")
222
 
223
- parsed_messages = None
224
- current_iteration_response = ""
 
 
 
225
 
226
- # Stream LLM response
227
- for response_text, parsed in _stream_llm_response(messages, encoding):
228
- current_iteration_response = response_text
229
- if parsed is not None:
230
- parsed_messages = parsed
231
-
232
- partial_history[-1]["content"] = (
233
- accumulated_response + current_iteration_response
234
- )
235
- yield partial_history
236
-
237
- if parsed_messages is None:
238
- logger.warning("No parsed messages returned from LLM")
239
- break
240
-
241
- # Update accumulated response with the final text from this iteration
242
- accumulated_response += current_iteration_response
243
 
244
  # Check for tool calls
245
- tool_calls = extract_tool_calls(parsed_messages)
246
 
247
  if not tool_calls:
248
  logger.info("No tool calls found, ending loop")
@@ -251,14 +210,22 @@ async def chat_with_mcp_tools(
251
  logger.info(f"Found {len(tool_calls)} tool call(s)")
252
 
253
  # Add parsed messages to conversation
254
- messages.extend(parsed_messages)
 
 
 
 
255
 
256
  # Execute tools via MCP
257
  tool_result_messages = await execute_tool_calls(tool_calls)
258
 
259
  messages.extend(tool_result_messages)
260
 
261
- # Ensure final response is yielded (even if empty after tool calls)
 
 
 
 
262
  yield partial_history
263
 
264
  except Exception:
 
30
  APP_NAME = "unpredictable-lord"
31
  _generate_stream = modal.Function.from_name(APP_NAME, "generate_stream")
32
 
33
+ async def generate_stream(input_tokens):
34
+ logger.info("Calling Modal LLM generate_stream (async)")
35
+ async for token in _generate_stream.remote_gen.aio(input_tokens):
36
+ yield token
37
  else:
38
  from unpredictable_lord.chat.llm_zerogpu import generate_stream as _generate_stream
39
 
40
+ async def generate_stream(input_tokens):
41
+ logger.info("Calling ZeroGPU LLM generate_stream (sync wrapper)")
42
+ # Note: This blocks the event loop, but is acceptable for ZeroGPU/Spaces
43
+ # where concurrency is limited anyway.
44
+ for token in _generate_stream(input_tokens):
45
+ yield token
46
 
47
 
48
  def _get_encoding():
 
50
  return oh.load_harmony_encoding(oh.HarmonyEncodingName.HARMONY_GPT_OSS)
51
 
52
 
53
+ def _build_system_message(session_id: str, personality: str) -> oh.Message:
54
+ """Build developer message with system prompt and tool definitions."""
 
 
 
 
 
 
 
 
55
  personality_desc = PERSONALITY_DESCRIPTIONS.get(personality, "")
56
 
57
  system_prompt = f"""You are a {personality} lord of a medieval fantasy kingdom.
 
92
 
93
 
94
  def _convert_history_to_messages(chat_history: list[dict]) -> list[oh.Message]:
95
+ """Convert Gradio chat history to openai-harmony messages."""
 
 
 
 
 
 
 
96
  messages = []
97
  for msg in chat_history:
98
  if msg["role"] == "user":
 
106
  return messages
107
 
108
 
109
+ class StreamResult:
110
+ """Holder for streaming result."""
111
 
112
+ def __init__(self):
113
+ self.response_text = ""
114
+ self.parsed_messages = []
115
 
116
+
117
+ async def _stream_response(
118
+ messages: list[oh.Message],
119
+ encoding: oh.HarmonyEncoding,
120
+ partial_history: list[dict],
121
+ accumulated_response: str,
122
+ result: StreamResult,
123
+ ) -> AsyncGenerator[list[dict], None]:
124
+ """Stream LLM response and yield history updates."""
125
  convo = oh.Conversation.from_messages(messages)
126
  input_tokens = encoding.render_conversation_for_completion(convo, oh.Role.ASSISTANT)
127
 
128
  parser = oh.StreamableParser(encoding, role=oh.Role.ASSISTANT)
129
+ current_iteration_response = ""
130
 
131
+ # Stream LLM response
132
+ async for token in generate_stream(input_tokens):
 
 
 
 
133
  if token is None:
134
  continue
135
+
136
  try:
137
  parser.process(token)
 
 
 
 
 
 
 
138
  except Exception as e:
139
+ logger.error(f"Parser error: {e}")
140
+ break
141
+
142
+ if parser.current_channel == "final":
143
+ delta = parser.last_content_delta
144
+ if delta:
145
+ current_iteration_response += delta
146
+ partial_history[-1]["content"] = (
147
+ accumulated_response + current_iteration_response
148
+ )
149
+ yield partial_history
150
+
151
+ # Store results
152
+ result.response_text = current_iteration_response
153
+ result.parsed_messages = parser.messages
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
 
156
  async def chat_with_mcp_tools(
 
162
  """
163
  Chat with LLM with MCP tool support (async streaming version).
164
 
 
 
165
  Args:
166
  user_message: User's message
167
  chat_history: Past chat history (list of dictionaries in Gradio format)
 
174
  try:
175
  encoding = _get_encoding()
176
 
177
+ # Build messages
178
+ messages = [_build_system_message(session_id, personality)]
179
  messages.extend(_convert_history_to_messages(chat_history))
180
  messages.append(oh.Message.from_role_and_content(oh.Role.USER, user_message))
181
 
 
191
  for iteration in range(MAX_TOOL_CALL_ITERATIONS):
192
  logger.info(f"Tool calling iteration {iteration + 1}")
193
 
194
+ result = StreamResult()
195
+ async for history in _stream_response(
196
+ messages, encoding, partial_history, accumulated_response, result
197
+ ):
198
+ yield history
199
 
200
+ # Update accumulated response
201
+ accumulated_response += result.response_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  # Check for tool calls
204
+ tool_calls = extract_tool_calls(result.parsed_messages)
205
 
206
  if not tool_calls:
207
  logger.info("No tool calls found, ending loop")
 
210
  logger.info(f"Found {len(tool_calls)} tool call(s)")
211
 
212
  # Add parsed messages to conversation
213
+ messages.extend(result.parsed_messages)
214
+
215
+ # Indicate tool execution in UI
216
+ partial_history[-1]["content"] += "\n\n*(Executing orders...)*"
217
+ yield partial_history
218
 
219
  # Execute tools via MCP
220
  tool_result_messages = await execute_tool_calls(tool_calls)
221
 
222
  messages.extend(tool_result_messages)
223
 
224
+ # Remove status message
225
+ partial_history[-1]["content"] = accumulated_response
226
+ yield partial_history
227
+
228
+ # Ensure final response is yielded
229
  yield partial_history
230
 
231
  except Exception: