fix: decouple function call detection from LS step polling
Move MitmStore function call check outside get_steps() block so tool calls are detected immediately when captured by MITM, regardless of LS processing state. Also reduce poll interval to 300ms. The LS can take 20-30s for its internal multi-turn loop. Previously, function call checks were nested inside the steps block and required LS to have produced steps. Now the MITM capture is picked up within 300ms of detection.
This commit is contained in:
@@ -288,63 +288,64 @@ async fn chat_completions_stream(
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
|
||||
let mut keepalive_counter: u64 = 0;
|
||||
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
// ── Check for MITM-captured function calls FIRST ──
|
||||
// This runs independently of LS steps — the MITM captures tool calls
|
||||
// at the proxy layer, so we don't need to wait for LS processing.
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
if let Some(ref calls) = captured {
|
||||
if !calls.is_empty() {
|
||||
let mut tool_calls = Vec::new();
|
||||
for (i, fc) in calls.iter().enumerate() {
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
tool_calls.push(serde_json::json!({
|
||||
"index": i,
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc.name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}));
|
||||
}
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"tool_calls": tool_calls},
|
||||
"finish_reason": serde_json::Value::Null,
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "tool_calls",
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Check LS steps for text streaming ──
|
||||
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||
if status == 200 {
|
||||
if let Some(steps) = data["steps"].as_array() {
|
||||
// Check for MITM-captured function calls FIRST (before text)
|
||||
// This prevents dummy placeholder text from leaking to client
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
if let Some(ref calls) = captured {
|
||||
if !calls.is_empty() {
|
||||
// Emit tool_calls in OpenAI streaming format — NO text
|
||||
let mut tool_calls = Vec::new();
|
||||
for (i, fc) in calls.iter().enumerate() {
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
tool_calls.push(serde_json::json!({
|
||||
"index": i,
|
||||
"id": call_id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": fc.name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
}));
|
||||
}
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {"tool_calls": tool_calls},
|
||||
"finish_reason": serde_json::Value::Null,
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
|
||||
// Finish with tool_calls reason
|
||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"id": completion_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": now_unix(),
|
||||
"model": model_name,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "tool_calls",
|
||||
}],
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Normal text streaming (only when no function calls)
|
||||
let text = extract_response_text(steps);
|
||||
|
||||
if !text.is_empty() && text != last_text {
|
||||
@@ -388,8 +389,7 @@ async fn chat_completions_stream(
|
||||
return;
|
||||
}
|
||||
|
||||
// IDLE fallback: check trajectory status periodically
|
||||
// Only check every 5th step count to reduce backend traffic
|
||||
// IDLE fallback
|
||||
let step_count = steps.len();
|
||||
if step_count > 4 && step_count % 5 == 0 {
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
||||
@@ -418,7 +418,14 @@ async fn chat_completions_stream(
|
||||
}
|
||||
}
|
||||
|
||||
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
|
||||
// Keep-alive comment every ~5 iterations
|
||||
keepalive_counter += 1;
|
||||
if keepalive_counter % 5 == 0 {
|
||||
yield Ok(Event::default().comment("keepalive"));
|
||||
}
|
||||
|
||||
// Fast poll — 300ms so we pick up MITM captures quickly
|
||||
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user