fix: decouple function call detection from LS step polling

Move MitmStore function call check outside get_steps() block so tool
calls are detected immediately when captured by MITM, regardless of
LS processing state. Also reduce poll interval to 300ms.

The LS can take 20-30s for its internal multi-turn loop. Previously,
function call checks were nested inside the steps block and required
LS to have produced steps. Now the MITM capture is picked up within
300ms of detection.
This commit is contained in:
Nikketryhard
2026-02-15 00:48:14 -06:00
parent 4f08b994c7
commit ec1c0c700d

View File

@@ -288,63 +288,64 @@ async fn chat_completions_stream(
}],
})).unwrap_or_default()));
let mut keepalive_counter: u64 = 0;
while start.elapsed().as_secs() < timeout {
// ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": tool_calls},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
// ── Check LS steps for text streaming ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
// Check for MITM-captured function calls FIRST (before text)
// This prevents dummy placeholder text from leaking to client
let captured = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
// Emit tool_calls in OpenAI streaming format — NO text
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {"tool_calls": tool_calls},
"finish_reason": serde_json::Value::Null,
}],
})).unwrap_or_default()));
// Finish with tool_calls reason
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id,
"object": "chat.completion.chunk",
"created": now_unix(),
"model": model_name,
"choices": [{
"index": 0,
"delta": {},
"finish_reason": "tool_calls",
}],
})).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
// Normal text streaming (only when no function calls)
let text = extract_response_text(steps);
if !text.is_empty() && text != last_text {
@@ -388,8 +389,7 @@ async fn chat_completions_stream(
return;
}
// IDLE fallback: check trajectory status periodically
// Only check every 5th step count to reduce backend traffic
// IDLE fallback
let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
@@ -418,7 +418,14 @@ async fn chat_completions_stream(
}
}
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200);
// Keep-alive comment every ~5 iterations
keepalive_counter += 1;
if keepalive_counter % 5 == 0 {
yield Ok(Event::default().comment("keepalive"));
}
// Fast poll — 300ms so we pick up MITM captures quickly
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}