fix: decouple function call detection from LS step polling

Move MitmStore function call check outside get_steps() block so tool
calls are detected immediately when captured by MITM, regardless of
LS processing state. Also reduce poll interval to 300ms.

The LS can take 20-30s for its internal multi-turn loop. Previously,
function call checks were nested inside the steps block and required
LS to have produced steps. Now the MITM capture is picked up within
300ms of detection.
This commit is contained in:
Nikketryhard
2026-02-15 00:48:14 -06:00
parent 4f08b994c7
commit ec1c0c700d

View File

@@ -288,16 +288,15 @@ async fn chat_completions_stream(
}], }],
})).unwrap_or_default())); })).unwrap_or_default()));
let mut keepalive_counter: u64 = 0;
while start.elapsed().as_secs() < timeout { while start.elapsed().as_secs() < timeout {
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { // ── Check for MITM-captured function calls FIRST ──
if status == 200 { // This runs independently of LS steps — the MITM captures tool calls
if let Some(steps) = data["steps"].as_array() { // at the proxy layer, so we don't need to wait for LS processing.
// Check for MITM-captured function calls FIRST (before text)
// This prevents dummy placeholder text from leaking to client
let captured = state.mitm_store.take_any_function_calls().await; let captured = state.mitm_store.take_any_function_calls().await;
if let Some(ref calls) = captured { if let Some(ref calls) = captured {
if !calls.is_empty() { if !calls.is_empty() {
// Emit tool_calls in OpenAI streaming format — NO text
let mut tool_calls = Vec::new(); let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() { for (i, fc) in calls.iter().enumerate() {
let call_id = format!( let call_id = format!(
@@ -327,7 +326,6 @@ async fn chat_completions_stream(
}], }],
})).unwrap_or_default())); })).unwrap_or_default()));
// Finish with tool_calls reason
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"id": completion_id, "id": completion_id,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
@@ -344,7 +342,10 @@ async fn chat_completions_stream(
} }
} }
// Normal text streaming (only when no function calls) // ── Check LS steps for text streaming ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 {
if let Some(steps) = data["steps"].as_array() {
let text = extract_response_text(steps); let text = extract_response_text(steps);
if !text.is_empty() && text != last_text { if !text.is_empty() && text != last_text {
@@ -388,8 +389,7 @@ async fn chat_completions_stream(
return; return;
} }
// IDLE fallback: check trajectory status periodically // IDLE fallback
// Only check every 5th step count to reduce backend traffic
let step_count = steps.len(); let step_count = steps.len();
if step_count > 4 && step_count % 5 == 0 { if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
@@ -418,7 +418,14 @@ async fn chat_completions_stream(
} }
} }
let poll_ms: u64 = rand::thread_rng().gen_range(800..1200); // Keep-alive comment every ~5 iterations
keepalive_counter += 1;
if keepalive_counter % 5 == 0 {
yield Ok(Event::default().comment("keepalive"));
}
// Fast poll — 300ms so we pick up MITM captures quickly
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
} }