From 637fbc0e54d5703def18520208fe707ce3ed06ff Mon Sep 17 00:00:00 2001 From: Nikketryhard Date: Mon, 16 Feb 2026 21:47:00 -0600 Subject: [PATCH] refactor: endpoint parity and proxy improvements Mixed changes from recent sessions: endpoint feature parity improvements, proxy bug fixes, and store cleanup. --- src/api/completions.rs | 457 +++++++++++++++--------------- src/api/responses.rs | 610 ++++++++++++++++++++++------------------- src/api/search.rs | 4 +- src/mitm/proxy.rs | 176 +++++++----- src/mitm/store.rs | 208 +++++++------- 5 files changed, 763 insertions(+), 692 deletions(-) diff --git a/src/api/completions.rs b/src/api/completions.rs index 426a789..ff0bc69 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -223,7 +223,6 @@ pub(crate) async fn handle_completions( } else { state.mitm_store.clear_tools().await; } - state.mitm_store.clear_active_function_call(); // ── Extract tool results from messages for MITM injection ────────── // When OpenCode sends back tool results, the messages array contains: @@ -340,7 +339,6 @@ pub(crate) async fn handle_completions( state.mitm_store.set_last_function_calls(last_round.calls.clone()).await; } state.mitm_store.set_tool_rounds(rounds).await; - state.mitm_store.clear_awaiting_tool_result(); } } @@ -441,6 +439,8 @@ pub(crate) async fn handle_completions( // Send message on primary cascade state.mitm_store.set_active_cascade(&cascade_id).await; + // Store real user text for MITM injection — LS gets a dummy prompt + state.mitm_store.set_pending_user_text(user_text.clone()).await; // Store image for MITM injection (LS doesn't forward images to Google API) if let Some(ref img) = image { use base64::Engine; @@ -452,9 +452,25 @@ pub(crate) async fn handle_completions( }) .await; } + + // Pre-flight: install channel BEFORE send_message so the MITM proxy + // can grab it when the LS fires its API call. + // Only for streaming — sync paths use poll_for_response (legacy store). + let has_custom_tools = state.mitm_store.get_tools().await.is_some(); + let mitm_rx = if has_custom_tools && body.stream { + state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; + let _ = state.mitm_store.take_any_function_calls().await; + let (tx, rx) = tokio::sync::mpsc::channel(64); + state.mitm_store.set_channel(tx).await; + Some(rx) + } else { + None + }; + match state .backend - .send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref()) + .send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref()) .await { Ok((200, _)) => { @@ -465,6 +481,7 @@ pub(crate) async fn handle_completions( }); } Ok((status, _)) => { + state.mitm_store.drop_channel().await; return err_response( StatusCode::BAD_GATEWAY, format!("Backend returned {status}"), @@ -472,6 +489,7 @@ pub(crate) async fn handle_completions( ); } Err(e) => { + state.mitm_store.drop_channel().await; return err_response( StatusCode::BAD_GATEWAY, format!("Send failed: {e}"), @@ -498,6 +516,7 @@ pub(crate) async fn handle_completions( cascade_id, body.timeout, include_usage, + mitm_rx, ) .await } else if n <= 1 { @@ -518,7 +537,7 @@ pub(crate) async fn handle_completions( // Send the same message on each extra cascade match state .backend - .send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()) + .send_message_with_image(&cid, ".", model.model_enum, image.as_ref()) .await { Ok((200, _)) => { @@ -635,18 +654,18 @@ async fn chat_completions_stream( cascade_id: String, timeout: u64, include_usage: bool, + mitm_rx: Option>, ) -> axum::response::Response { let stream = async_stream::stream! { let start = std::time::Instant::now(); let mut last_text = String::new(); - let has_custom_tools = state.mitm_store.get_tools().await.is_some(); + let has_custom_tools = mitm_rx.is_some(); - // Clear ALL stale state from previous requests - state.mitm_store.clear_response_async().await; - state.mitm_store.clear_upstream_error().await; - state.mitm_store.clear_active_function_call(); - // Drain any stale function calls from previous requests - let _ = state.mitm_store.take_any_function_calls().await; + if !has_custom_tools { + state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; + let _ = state.mitm_store.take_any_function_calls().await; + } // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( @@ -659,7 +678,6 @@ async fn chat_completions_stream( let mut last_thinking_len: usize = 0; let mut complete_polls: u32 = 0; let mut did_unblock_ls = false; // Prevents infinite unblock loops - let mut my_generation = state.mitm_store.current_generation(); // Helper: build usage JSON from MITM tokens let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value { @@ -672,212 +690,219 @@ async fn chat_completions_stream( }) }; - while start.elapsed().as_secs() < timeout { - // Check for upstream errors from MITM (Google API errors) - if let Some(err) = state.mitm_store.take_upstream_error().await { - let error_msg = super::util::upstream_error_message(&err); - let error_type = super::util::upstream_error_type(&err); - yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ - "error": { - "message": error_msg, - "type": error_type, - "code": err.status, - } - })).unwrap())); - yield Ok(Event::default().data("[DONE]".to_string())); - break; - } + // Take ownership of the pre-installed channel receiver + let mut rx_opt = mitm_rx; - // Bail if another completions handler has superseded us - if state.mitm_store.current_generation() != my_generation { - debug!("Completions: generation changed (superseded), ending stream"); + while start.elapsed().as_secs() < timeout { + if let Some(ref mut rx) = rx_opt { + // ── Channel-based MITM pipeline ── + + // Track accumulated text for delta computation + let mut acc_text = String::new(); + let mut acc_thinking = String::new(); + let mut last_usage: Option = None; + + 'channel_loop: while let Some(event) = tokio::time::timeout( + std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), + rx.recv(), + ).await.ok().flatten() { + use crate::mitm::store::MitmEvent; + match event { + MitmEvent::ThinkingDelta(full_thinking) => { + if full_thinking.len() > acc_thinking.len() { + let delta = full_thinking[acc_thinking.len()..].to_string(); + acc_thinking = full_thinking; + last_thinking_len = acc_thinking.len(); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), + None, + ))); + } + } + MitmEvent::TextDelta(full_text) => { + if full_text.len() > acc_text.len() { + let delta = full_text[acc_text.len()..].to_string(); + acc_text = full_text; + last_text = acc_text.clone(); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), + None, + ))); + } + } + MitmEvent::FunctionCall(calls) => { + let mut tool_calls = Vec::new(); + for (i, fc) in calls.iter().enumerate() { + let call_id = format!( + "call_{}", + uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() + ); + let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); + tool_calls.push(serde_json::json!({ + "index": i, + "id": call_id, + "type": "function", + "function": { + "name": fc.name, + "arguments": arguments, + }, + })); + } + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]), + None, + ))); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]), + None, + ))); + if include_usage { + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await); + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else if let Some(ref u) = last_usage { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + state.mitm_store.drop_channel().await; + return; + } + MitmEvent::ResponseComplete => { + if !acc_text.is_empty() { + // Have response text — done + debug!("Completions: channel response complete, text_len={}, thinking_len={}", + acc_text.len(), acc_thinking.len()); + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await) + .or(last_usage.take()); + let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), + None, + ))); + if include_usage { + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + state.mitm_store.drop_channel().await; + return; + } else if !acc_thinking.is_empty() && !did_unblock_ls { + // Thinking-only response — LS needs follow-up API calls. + // Create a new channel and unblock the gate. + did_unblock_ls = true; + let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); + state.mitm_store.set_channel(new_tx).await; + state.mitm_store.clear_request_in_flight(); + let _ = state.mitm_store.take_any_function_calls().await; + *rx = new_rx; + debug!( + "Completions: thinking-only — new channel for follow-up, thinking_len={}", + acc_thinking.len() + ); + continue 'channel_loop; + } else if !acc_thinking.is_empty() && did_unblock_ls { + // Already unblocked once, still thinking-only. + // Wait a bit for potential follow-up events. + complete_polls += 1; + if complete_polls >= 25 { + info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len()); + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await) + .or(last_usage.take()); + let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), + None, + ))); + if include_usage { + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + state.mitm_store.drop_channel().await; + return; + } + // Don't break — wait for more channel events + continue 'channel_loop; + } else { + // Empty response (no text, no thinking, no tools) + complete_polls += 1; + if complete_polls >= 4 { + info!("Completions: channel response complete but empty, ending stream"); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), + None, + ))); + yield Ok(Event::default().data("[DONE]")); + state.mitm_store.drop_channel().await; + return; + } + continue 'channel_loop; + } + } + MitmEvent::UpstreamError(err) => { + let error_msg = super::util::upstream_error_message(&err); + let error_type = super::util::upstream_error_type(&err); + yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({ + "error": { + "message": error_msg, + "type": error_type, + "code": err.status, + } + })).unwrap())); + yield Ok(Event::default().data("[DONE]")); + state.mitm_store.drop_channel().await; + return; + } + MitmEvent::Usage(u) => { + last_usage = Some(u); + } + MitmEvent::Grounding(_) => { + // Grounding metadata handled by store directly + } + } + } + + // Channel closed or timeout — clean up + state.mitm_store.drop_channel().await; + + // If we got here from timeout with content, emit what we have + if !last_text.is_empty() || last_thinking_len > 0 { + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), + None, + ))); + } yield Ok(Event::default().data("[DONE]")); return; - } - - // ── Check for MITM-captured function calls FIRST ── - // This runs independently of LS steps — the MITM captures tool calls - // at the proxy layer, so we don't need to wait for LS processing. - let captured = state.mitm_store.take_function_calls(&cascade_id).await; - if let Some(ref calls) = captured { - if !calls.is_empty() { - let mut tool_calls = Vec::new(); - for (i, fc) in calls.iter().enumerate() { - let call_id = format!( - "call_{}", - uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() - ); - let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); - tool_calls.push(serde_json::json!({ - "index": i, - "id": call_id, - "type": "function", - "function": { - "name": fc.name, - "arguments": arguments, - }, - })); - } - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]), - None, - ))); - - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]), - None, - ))); - if include_usage { - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), - ))); - } - yield Ok(Event::default().data("[DONE]")); - // Clear in-flight flag so the next turn's requests can get through - state.mitm_store.clear_response_async().await; - return; - } - } - - // ── Primary: MITM-captured response (when custom tools are active) ── - // The MITM intercepts the real Google SSE stream and captures text, - // thinking, and function calls. This is the authoritative data source. - // The LS only gets rewritten responses (function calls → text placeholders) - // so it doesn't provide useful streaming data when MITM is active. - if has_custom_tools { - // Stream thinking text as reasoning_content deltas - if let Some(tc) = state.mitm_store.peek_thinking_text().await { - if tc.len() > last_thinking_len { - let delta = &tc[last_thinking_len..]; - last_thinking_len = tc.len(); - - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), - None, - ))); - } - } - - // Stream response text as content deltas - if let Some(text) = state.mitm_store.peek_response_text().await { - if !text.is_empty() && text != last_text { - let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { - text[last_text.len()..].to_string() - } else { - text.clone() - }; - - if !delta.is_empty() { - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), - None, - ))); - last_text = text; - } - } - } - - // Check if MITM response is complete - if state.mitm_store.is_response_complete() { - if !last_text.is_empty() { - // Have actual response text — done - complete_polls += 1; - if complete_polls >= 2 { - debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len); - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), - None, - ))); - if include_usage { - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), - ))); - } - yield Ok(Event::default().data("[DONE]")); - return; - } - } else if last_thinking_len > 0 && !did_unblock_ls { - // Thinking-only response. The LS needs follow-up API calls - // to get actual function calls or text. Unblock once. - did_unblock_ls = true; - complete_polls = 0; - // Bump generation FIRST — invalidates old MITM connection's store writes - my_generation = state.mitm_store.bump_generation(); - state.mitm_store.clear_request_in_flight(); - state.mitm_store.clear_response_complete(); - // Drain store so leaked connections can't produce stale content - state.mitm_store.set_response_text("").await; - state.mitm_store.set_thinking_text("").await; - let _ = state.mitm_store.take_any_function_calls().await; - debug!( - "Completions: thinking-only — unblocking LS for follow-up, thinking_len={}, new_gen={}", - last_thinking_len, my_generation - ); - } else if last_thinking_len > 0 && did_unblock_ls { - // Already unblocked once. Still only thinking after follow-up. - complete_polls += 1; - if complete_polls >= 25 { - info!("Completions: thinking-only timeout after ~10s, thinking_len={}", last_thinking_len); - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), - None, - ))); - if include_usage { - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), - ))); - } - yield Ok(Event::default().data("[DONE]")); - return; - } - } else { - // response_complete but no text AND no thinking — might be - // a function-call-only response that was already consumed, - // or empty response. Wait a bit then give up. - complete_polls += 1; - if complete_polls >= 4 { - info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream"); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), - None, - ))); - yield Ok(Event::default().data("[DONE]")); - return; - } - } - } else { - complete_polls = 0; // Reset — not complete yet - } } else { // ── Fallback: LS steps (no MITM capture active) ── if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { @@ -1001,7 +1026,7 @@ async fn chat_completions_stream( } })).unwrap())); // Always clear in-flight flag when stream ends - state.mitm_store.clear_response_async().await; + state.mitm_store.drop_channel().await; yield Ok(Event::default().data("[DONE]")); }; diff --git a/src/api/responses.rs b/src/api/responses.rs index 48b996c..36e7bd1 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -309,7 +309,11 @@ pub(crate) async fn handle_responses( count = tools.len(), "Stored client tools for MITM injection" ); + } else { + state.mitm_store.clear_tools().await; } + } else { + state.mitm_store.clear_tools().await; } if let Some(ref choice) = body.tool_choice { let gemini_config = openai_tool_choice_to_gemini(choice); @@ -404,6 +408,8 @@ pub(crate) async fn handle_responses( // Send message state.mitm_store.set_active_cascade(&cascade_id).await; + // Store real user text for MITM injection — LS gets a dummy prompt + state.mitm_store.set_pending_user_text(user_text.clone()).await; // Store image for MITM injection (LS doesn't forward images to Google API) if let Some(ref img) = image { use base64::Engine; @@ -415,9 +421,24 @@ pub(crate) async fn handle_responses( }) .await; } + + // Pre-flight: install channel BEFORE send_message so the MITM proxy + // can grab it when the LS fires its API call. + let has_custom_tools = state.mitm_store.get_tools().await.is_some(); + let mitm_rx = if has_custom_tools { + state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; + let _ = state.mitm_store.take_any_function_calls().await; + let (tx, rx) = tokio::sync::mpsc::channel(64); + state.mitm_store.set_channel(tx).await; + Some(rx) + } else { + None + }; + match state .backend - .send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref()) + .send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref()) .await { Ok((200, _)) => { @@ -428,6 +449,7 @@ pub(crate) async fn handle_responses( }); } Ok((status, _)) => { + state.mitm_store.drop_channel().await; return err_response( StatusCode::BAD_GATEWAY, format!("Antigravity returned {status}"), @@ -435,6 +457,7 @@ pub(crate) async fn handle_responses( ); } Err(e) => { + state.mitm_store.drop_channel().await; return err_response( StatusCode::BAD_GATEWAY, format!("Send message failed: {e}"), @@ -472,6 +495,7 @@ pub(crate) async fn handle_responses( cascade_id, body.timeout, req_params, + mitm_rx, ) .await } else { @@ -482,6 +506,7 @@ pub(crate) async fn handle_responses( cascade_id, body.timeout, req_params, + mitm_rx, ) .await } @@ -603,54 +628,54 @@ async fn handle_responses_sync( cascade_id: String, timeout: u64, params: RequestParams, + mitm_rx: Option>, ) -> axum::response::Response { let created_at = now_unix(); - let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response and upstream errors - state.mitm_store.clear_response_async().await; - state.mitm_store.clear_upstream_error().await; + // Clear stale captured response and upstream errors (only if no pre-installed channel) + if mitm_rx.is_none() { + state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; + } - // ── MITM bypass: poll MitmStore directly when custom tools active ── - if has_custom_tools { + // ── MITM bypass: channel-based pipeline when custom tools active ── + if let Some(mut rx) = mitm_rx { let start = std::time::Instant::now(); - while start.elapsed().as_secs() < timeout { - // Check for upstream errors from MITM (Google API errors) - if let Some(err) = state.mitm_store.take_upstream_error().await { - return upstream_err_response(&err); - } - // Check for function calls - let captured = state.mitm_store.take_function_calls(&cascade_id).await; - if let Some(ref raw_calls) = captured { - let calls: Vec<_> = if let Some(max) = params.max_tool_calls { - raw_calls.iter().take(max as usize).collect() - } else { - raw_calls.iter().collect() - }; - if !calls.is_empty() { + let mut acc_text = String::new(); + let mut acc_thinking: Option = None; + let mut last_usage: Option = None; + + while let Some(event) = tokio::time::timeout( + std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), + rx.recv(), + ).await.ok().flatten() { + use crate::mitm::store::MitmEvent; + match event { + MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); } + MitmEvent::TextDelta(t) => { acc_text = t; } + MitmEvent::Usage(u) => { last_usage = Some(u); } + MitmEvent::Grounding(_) => {} // stored by proxy directly + MitmEvent::FunctionCall(raw_calls) => { + let calls: Vec<_> = if let Some(max) = params.max_tool_calls { + raw_calls.iter().take(max as usize).collect() + } else { + raw_calls.iter().collect() + }; let mut output_items: Vec = Vec::new(); for fc in &calls { let call_id = format!( "call_{}", uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() ); - state - .mitm_store - .register_call_id(call_id.clone(), fc.name.clone()) - .await; + state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); - output_items - .push(build_function_call_output(&call_id, &fc.name, &arguments)); + output_items.push(build_function_call_output(&call_id, &fc.name, &arguments)); } let (usage, _) = usage_from_poll( - &state.mitm_store, - &cascade_id, - &None, - ¶ms.user_text, - "", - ) - .await; + &state.mitm_store, &cascade_id, &None, ¶ms.user_text, "", + ).await; + state.mitm_store.drop_channel().await; let resp = build_response_object( ResponseData { id: response_id, @@ -666,52 +691,61 @@ async fn handle_responses_sync( ); return Json(resp).into_response(); } - } + MitmEvent::ResponseComplete => { + if acc_text.is_empty() && acc_thinking.is_none() { + // Empty response — continue waiting + continue; + } + if acc_text.is_empty() && acc_thinking.is_some() { + // Thinking-only — LS needs to make a follow-up request. + // Reinstall channel and unblock gate. + let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); + state.mitm_store.set_channel(new_tx).await; + state.mitm_store.clear_request_in_flight(); + let _ = state.mitm_store.take_any_function_calls().await; + rx = new_rx; + debug!( + "Responses sync: thinking-only — new channel for follow-up, thinking_len={}", + acc_thinking.as_ref().map(|t| t.len()).unwrap_or(0) + ); + continue; + } + let (usage, _) = usage_from_poll( + &state.mitm_store, &cascade_id, &None, ¶ms.user_text, &acc_text, + ).await; + state.mitm_store.drop_channel().await; - // Check for completed text response - if state.mitm_store.is_response_complete() { - let text = state - .mitm_store - .take_response_text() - .await - .unwrap_or_default(); - let thinking = state.mitm_store.take_thinking_text().await; - let (usage, _) = usage_from_poll( - &state.mitm_store, - &cascade_id, - &None, - ¶ms.user_text, - &text, - ) - .await; + let mut output_items: Vec = Vec::new(); + if let Some(ref t) = acc_thinking { + output_items.push(build_reasoning_output(t)); + } + let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); + output_items.push(build_message_output(&msg_id, &acc_text)); - let mut output_items: Vec = Vec::new(); - if let Some(ref t) = thinking { - output_items.push(build_reasoning_output(t)); + let resp = build_response_object( + ResponseData { + id: response_id, + model: model_name, + status: "completed", + created_at, + completed_at: Some(now_unix()), + output: output_items, + usage: Some(usage), + thinking_signature: None, + }, + ¶ms, + ); + return Json(resp).into_response(); + } + MitmEvent::UpstreamError(err) => { + state.mitm_store.drop_channel().await; + return upstream_err_response(&err); } - let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); - output_items.push(build_message_output(&msg_id, &text)); - - let resp = build_response_object( - ResponseData { - id: response_id, - model: model_name, - status: "completed", - created_at, - completed_at: Some(now_unix()), - output: output_items, - usage: Some(usage), - thinking_signature: None, - }, - ¶ms, - ); - return Json(resp).into_response(); } - - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; } - // Timeout — return proper error, not fake incomplete response + // Timeout + state.mitm_store.drop_channel().await; return err_response( StatusCode::GATEWAY_TIMEOUT, format!("Timeout: no response from Google API after {timeout}s"), @@ -835,6 +869,7 @@ async fn handle_responses_stream( cascade_id: String, timeout: u64, params: RequestParams, + mitm_rx: Option>, ) -> axum::response::Response { let stream = async_stream::stream! { let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); @@ -886,50 +921,170 @@ async fn handle_responses_stream( let mut thinking_text: Option = None; let mut message_started = false; let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); - let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear stale captured response and upstream errors - state.mitm_store.clear_response_async().await; - state.mitm_store.clear_upstream_error().await; + // Clear stale response (only if no pre-installed channel) + if mitm_rx.is_none() { + state.mitm_store.clear_response_async().await; + state.mitm_store.clear_upstream_error().await; + } // ── MITM bypass mode (when custom tools are active) ── - // Skip LS entirely — read text, thinking, and tool calls directly from MitmStore. - if has_custom_tools { + // Channel-based pipeline: read events directly from MITM proxy. + // Channel is pre-installed before send_message to avoid race conditions. + if let Some(mut rx) = mitm_rx { let mut last_thinking = String::new(); - while start.elapsed().as_secs() < timeout { - // Check for upstream errors from MITM (Google API errors) - if let Some(err) = state.mitm_store.take_upstream_error().await { - let error_msg = super::util::upstream_error_message(&err); - let error_type = super::util::upstream_error_type(&err); - yield Ok(responses_sse_event( - "response.failed", - serde_json::json!({ - "type": "response.failed", - "sequence_number": next_seq(), - "response": { - "id": &response_id, - "status": "failed", - "error": { - "type": error_type, - "message": error_msg, - "code": err.status, - }, - }, - }), - )); - break; - } + while let Some(event) = tokio::time::timeout( + std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), + rx.recv(), + ).await.ok().flatten() { + use crate::mitm::store::MitmEvent; + match event { + MitmEvent::ThinkingDelta(full_thinking) => { + if !thinking_emitted && full_thinking.len() > last_thinking.len() { + // First thinking text — emit reasoning output_item.added + if last_thinking.is_empty() { + yield Ok(responses_sse_event( + "response.output_item.added", + serde_json::json!({ + "type": "response.output_item.added", + "sequence_number": next_seq(), + "output_index": 0, + "item": { + "id": &reasoning_id, + "type": "reasoning", + "summary": [], + }, + }), + )); + yield Ok(responses_sse_event( + "response.reasoning_summary_part.added", + serde_json::json!({ + "type": "response.reasoning_summary_part.added", + "sequence_number": next_seq(), + "item_id": &reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": { "type": "summary_text", "text": "" }, + }), + )); + } + let delta = &full_thinking[last_thinking.len()..]; + if !delta.is_empty() { + yield Ok(responses_sse_event( + "response.reasoning_summary_text.delta", + serde_json::json!({ + "type": "response.reasoning_summary_text.delta", + "sequence_number": next_seq(), + "item_id": &reasoning_id, + "output_index": 0, + "summary_index": 0, + "delta": delta, + }), + )); + } + last_thinking = full_thinking; + } + } + MitmEvent::TextDelta(full_text) => { + if full_text.len() > last_text.len() { + // Finalize thinking if started but not done + if !thinking_emitted && !last_thinking.is_empty() { + thinking_emitted = true; + thinking_text = Some(last_thinking.clone()); + yield Ok(responses_sse_event( + "response.reasoning_summary_text.done", + serde_json::json!({ + "type": "response.reasoning_summary_text.done", + "sequence_number": next_seq(), + "item_id": &reasoning_id, + "output_index": 0, + "summary_index": 0, + "text": &last_thinking, + }), + )); + yield Ok(responses_sse_event( + "response.reasoning_summary_part.done", + serde_json::json!({ + "type": "response.reasoning_summary_part.done", + "sequence_number": next_seq(), + "item_id": &reasoning_id, + "output_index": 0, + "summary_index": 0, + "part": { "type": "summary_text", "text": &last_thinking }, + }), + )); + yield Ok(responses_sse_event( + "response.output_item.done", + serde_json::json!({ + "type": "response.output_item.done", + "sequence_number": next_seq(), + "output_index": 0, + "item": { + "id": &reasoning_id, + "type": "reasoning", + "summary": [{ + "type": "summary_text", + "text": &last_thinking, + }], + }, + }), + )); + } - // Check for function calls first - let captured = state.mitm_store.take_function_calls(&cascade_id).await; - if let Some(ref raw_calls) = captured { - let calls: Vec<_> = if let Some(max) = params.max_tool_calls { - raw_calls.iter().take(max as usize).collect() - } else { - raw_calls.iter().collect() - }; - if !calls.is_empty() { + let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; + + if !message_started { + message_started = true; + yield Ok(responses_sse_event( + "response.output_item.added", + serde_json::json!({ + "type": "response.output_item.added", + "sequence_number": next_seq(), + "output_index": msg_output_index, + "item": build_message_output_in_progress(&msg_id), + }), + )); + yield Ok(responses_sse_event( + "response.content_part.added", + serde_json::json!({ + "type": "response.content_part.added", + "sequence_number": next_seq(), + "output_index": msg_output_index, + "content_index": CONTENT_IDX, + "part": { + "type": "output_text", + "text": "", + "annotations": [], + } + }), + )); + } + + let delta = &full_text[last_text.len()..]; + if !delta.is_empty() { + let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; + yield Ok(responses_sse_event( + "response.output_text.delta", + serde_json::json!({ + "type": "response.output_text.delta", + "sequence_number": next_seq(), + "item_id": &msg_id, + "output_index": msg_output_index, + "content_index": CONTENT_IDX, + "delta": delta, + }), + )); + last_text = full_text; + } + } + } + MitmEvent::FunctionCall(raw_calls) => { + let calls: Vec<_> = if let Some(max) = params.max_tool_calls { + raw_calls.iter().take(max as usize).collect() + } else { + raw_calls.iter().collect() + }; let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; for (i, fc) in calls.iter().enumerate() { let call_id = format!( @@ -1011,194 +1166,71 @@ async fn handle_responses_stream( "response": response_to_json(&final_resp), }), )); + state.mitm_store.drop_channel().await; return; } - } - - // Stream thinking text in real-time - if !thinking_emitted { - if let Some(thinking) = state.mitm_store.peek_thinking_text().await { - if !thinking.is_empty() && thinking != last_thinking { - // First thinking text — emit reasoning output_item.added - if last_thinking.is_empty() { - yield Ok(responses_sse_event( - "response.output_item.added", - serde_json::json!({ - "type": "response.output_item.added", - "sequence_number": next_seq(), - "output_index": 0, - "item": { - "id": &reasoning_id, - "type": "reasoning", - "summary": [], - }, - }), - )); - yield Ok(responses_sse_event( - "response.reasoning_summary_part.added", - serde_json::json!({ - "type": "response.reasoning_summary_part.added", - "sequence_number": next_seq(), - "item_id": &reasoning_id, - "output_index": 0, - "summary_index": 0, - "part": { "type": "summary_text", "text": "" }, - }), - )); + MitmEvent::ResponseComplete => { + if !last_text.is_empty() { + let msg_idx: u32 = if thinking_emitted { 1 } else { 0 }; + let (usage, _) = usage_from_poll( + &state.mitm_store, &cascade_id, &None, + ¶ms.user_text, &last_text, + ).await; + let tc = thinking_text.clone(); + for evt in completion_events( + &response_id, &model_name, &msg_id, &reasoning_id, + msg_idx, CONTENT_IDX, &last_text, usage, + created_at, &seq, ¶ms, None, tc, + ) { + yield Ok(evt); } - - // Delta of new thinking text - let delta = if thinking.len() > last_thinking.len() - && thinking.starts_with(&*last_thinking) - { - thinking[last_thinking.len()..].to_string() - } else { - thinking.clone() - }; - - if !delta.is_empty() { - yield Ok(responses_sse_event( - "response.reasoning_summary_text.delta", - serde_json::json!({ - "type": "response.reasoning_summary_text.delta", - "sequence_number": next_seq(), - "item_id": &reasoning_id, - "output_index": 0, - "summary_index": 0, - "delta": &delta, - }), - )); - } - last_thinking = thinking; + state.mitm_store.drop_channel().await; + return; + } else if !last_thinking.is_empty() { + // Thinking-only response — LS needs follow-up API calls. + // Create a new channel and unblock the gate. + let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); + state.mitm_store.set_channel(new_tx).await; + state.mitm_store.clear_request_in_flight(); + let _ = state.mitm_store.take_any_function_calls().await; + rx = new_rx; + debug!( + "Responses stream: thinking-only — new channel for follow-up, thinking_len={}", + last_thinking.len() + ); } + // ResponseComplete with no text and no thinking — continue waiting } - } - - // Stream response text - if let Some(text) = state.mitm_store.peek_response_text().await { - if !text.is_empty() && text != last_text { - // Finalize thinking if started but not done - if !thinking_emitted && !last_thinking.is_empty() { - thinking_emitted = true; - thinking_text = Some(last_thinking.clone()); - yield Ok(responses_sse_event( - "response.reasoning_summary_text.done", - serde_json::json!({ - "type": "response.reasoning_summary_text.done", - "sequence_number": next_seq(), - "item_id": &reasoning_id, - "output_index": 0, - "summary_index": 0, - "text": &last_thinking, - }), - )); - yield Ok(responses_sse_event( - "response.reasoning_summary_part.done", - serde_json::json!({ - "type": "response.reasoning_summary_part.done", - "sequence_number": next_seq(), - "item_id": &reasoning_id, - "output_index": 0, - "summary_index": 0, - "part": { "type": "summary_text", "text": &last_thinking }, - }), - )); - yield Ok(responses_sse_event( - "response.output_item.done", - serde_json::json!({ - "type": "response.output_item.done", - "sequence_number": next_seq(), - "output_index": 0, - "item": { - "id": &reasoning_id, - "type": "reasoning", - "summary": [{ - "type": "summary_text", - "text": &last_thinking, - }], + MitmEvent::UpstreamError(err) => { + let error_msg = super::util::upstream_error_message(&err); + let error_type = super::util::upstream_error_type(&err); + yield Ok(responses_sse_event( + "response.failed", + serde_json::json!({ + "type": "response.failed", + "sequence_number": next_seq(), + "response": { + "id": &response_id, + "status": "failed", + "error": { + "type": error_type, + "message": error_msg, + "code": err.status, }, - }), - )); - } - - let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; - - if !message_started { - message_started = true; - yield Ok(responses_sse_event( - "response.output_item.added", - serde_json::json!({ - "type": "response.output_item.added", - "sequence_number": next_seq(), - "output_index": msg_output_index, - "item": build_message_output_in_progress(&msg_id), - }), - )); - yield Ok(responses_sse_event( - "response.content_part.added", - serde_json::json!({ - "type": "response.content_part.added", - "sequence_number": next_seq(), - "output_index": msg_output_index, - "content_index": CONTENT_IDX, - "part": { - "type": "output_text", - "text": "", - "annotations": [], - } - }), - )); - } - - let new_content = if text.len() > last_text.len() - && text.starts_with(&*last_text) - { - text[last_text.len()..].to_string() - } else { - text.clone() - }; - - if !new_content.is_empty() { - yield Ok(responses_sse_event( - "response.output_text.delta", - serde_json::json!({ - "type": "response.output_text.delta", - "sequence_number": next_seq(), - "item_id": &msg_id, - "output_index": msg_output_index, - "content_index": CONTENT_IDX, - "delta": &new_content, - }), - )); - last_text = text; - } - } - - // Check if response is complete - if state.mitm_store.is_response_complete() && !last_text.is_empty() { - let msg_idx: u32 = if thinking_emitted { 1 } else { 0 }; - let (usage, _) = usage_from_poll( - &state.mitm_store, &cascade_id, &None, - ¶ms.user_text, &last_text, - ).await; - let tc = thinking_text.clone(); - for evt in completion_events( - &response_id, &model_name, &msg_id, &reasoning_id, - msg_idx, CONTENT_IDX, &last_text, usage, - created_at, &seq, ¶ms, None, tc, - ) { - yield Ok(evt); - } + }, + }), + )); + state.mitm_store.drop_channel().await; return; } + MitmEvent::Usage(_) | MitmEvent::Grounding(_) => { + // Usage/grounding stored by proxy, consumed via usage_from_poll + } } - - // Poll interval - let poll_ms: u64 = rand::thread_rng().gen_range(150..300); - tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; } - // Timeout in bypass mode — emit error, not fake incomplete + // Timeout in channel mode + state.mitm_store.drop_channel().await; yield Ok(responses_sse_event( "response.failed", serde_json::json!({ diff --git a/src/api/search.rs b/src/api/search.rs index 8c40110..6b525d0 100644 --- a/src/api/search.rs +++ b/src/api/search.rs @@ -163,11 +163,13 @@ async fn do_search(state: Arc, body: SearchRequest) -> axum::response: // Set active cascade for MITM correlation state.mitm_store.set_active_cascade(&cascade_id).await; + // Store real search prompt for MITM injection — LS gets a dummy prompt + state.mitm_store.set_pending_user_text(search_prompt.clone()).await; // Send the search message if let Err(e) = state .backend - .send_message(&cascade_id, &search_prompt, model.model_enum) + .send_message(&cascade_id, ".", model.model_enum) .await { state.mitm_store.clear_active_cascade().await; diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index 632cc55..066be40 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -538,9 +538,10 @@ async fn handle_http_over_tls( } }; - // Generation tracking for store write guards - let mut won_gate = false; - let mut conn_generation = store.current_generation(); + // Channel-based event pipeline: grab the channel sender if one exists. + // If the API handler set up a channel, we send events through it. + // Otherwise, we fall back to legacy store writes (search endpoint, etc.). + let mut event_tx: Option> = None; // Log LLM calls at info, everything else at debug if req_path.contains("streamGenerateContent") { @@ -558,7 +559,7 @@ async fn handle_http_over_tls( // When custom tools are active, only the FIRST request wins the // atomic compare_exchange. All others get fake STOP responses. let has_tools = store.get_tools().await.is_some(); - won_gate = if has_tools { + if has_tools { if !store.try_mark_request_in_flight() { info!("MITM: blocking LS request — another request already in-flight"); let fake_response = "HTTP/1.1 200 OK\r\n\ @@ -575,13 +576,11 @@ async fn handle_http_over_tls( let _ = client.flush().await; continue; } - true - } else { - false - }; - // Snapshot the generation at gate-win time. If it changes later, - // another completions turn started and our data is stale. - conn_generation = store.current_generation(); + // Grab the channel sender — the API handler installed it before + // sending the LS message. If it's gone, we still proceed but + // fall back to legacy store writes. + event_tx = store.take_channel().await; + } // ── Request modification ───────────────────────────────────── // Dechunk body → check if agent request → modify → rechunk @@ -603,12 +602,14 @@ async fn handle_http_over_tls( let generation_params = store.get_generation_params().await; let pending_image = store.take_pending_image().await; let tool_rounds = store.get_tool_rounds().await; + let pending_user_text = store.take_pending_user_text().await; let tool_ctx = if tools.is_some() || !pending_results.is_empty() || !tool_rounds.is_empty() || generation_params.is_some() || pending_image.is_some() + || pending_user_text.is_some() { Some(super::modify::ToolContext { tools, @@ -618,6 +619,7 @@ async fn handle_http_over_tls( generation_params, pending_image, tool_rounds, + pending_user_text, }) } else { None @@ -794,14 +796,18 @@ async fn handle_http_over_tls( }) .unwrap_or((None, None)); - store - .set_upstream_error(super::store::UpstreamError { - status: http_status, - body: body_str, - message, - error_status, - }) - .await; + let upstream_err = super::store::UpstreamError { + status: http_status, + body: body_str, + message, + error_status, + }; + // Send through channel if available, otherwise store for legacy consumers + if let Some(ref tx) = event_tx { + let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await; + } else { + store.set_upstream_error(upstream_err).await; + } } // Save body for usage parsing @@ -812,46 +818,59 @@ async fn handle_http_over_tls( let body = String::from_utf8_lossy(&header_buf[hdr_end..]); parse_streaming_chunk(&body, &mut streaming_acc); - // Only write to store if our generation is still current. - // If another completions turn started, our data is stale. - let gen_valid = !won_gate || store.current_generation() == conn_generation; - if gen_valid { - // Store captured function calls (drain to avoid re-storing on next chunk) + // Send events through channel if available, otherwise use legacy store + if let Some(ref tx) = event_tx { + // Function calls → channel event if !streaming_acc.function_calls.is_empty() { - let calls: Vec<_> = - streaming_acc.function_calls.drain(..).collect(); - for fc in &calls { - store - .record_function_call(cascade_hint.as_deref(), fc.clone()) - .await; - } + let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); store.set_last_function_calls(calls.clone()).await; - info!( - "MITM: stored {} function call(s) from initial body", - calls.len() - ); - } - - // Capture response + thinking text + grounding into MitmStore - if !streaming_acc.response_text.is_empty() { - store.set_response_text(&streaming_acc.response_text).await; + store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await; + info!("MITM: sending {} function call(s) via channel (initial body)", calls.len()); + let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await; } + // Thinking delta → channel event if !streaming_acc.thinking_text.is_empty() { - store.set_thinking_text(&streaming_acc.thinking_text).await; + let _ = tx.send(super::store::MitmEvent::ThinkingDelta( + streaming_acc.thinking_text.clone(), + )).await; } + // Text delta → channel event + if !streaming_acc.response_text.is_empty() { + let _ = tx.send(super::store::MitmEvent::TextDelta( + streaming_acc.response_text.clone(), + )).await; + } + // Grounding → channel event if let Some(ref gm) = streaming_acc.grounding_metadata { store.set_grounding(gm.clone()).await; + let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await; } + // Response complete → channel event if streaming_acc.is_complete { info!( response_text_len = streaming_acc.response_text.len(), thinking_text_len = streaming_acc.thinking_text.len(), - "MITM: response complete (initial body) — marking store" + "MITM: response complete (initial body) — sending via channel" ); - store.mark_response_complete(); + let _ = tx.send(super::store::MitmEvent::ResponseComplete).await; + streaming_acc.is_complete = false; // prevent duplicate sends + } + } else { + // Legacy path: store writes for non-channel consumers (search, etc.) + if !streaming_acc.function_calls.is_empty() { + let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); + for fc in &calls { + store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; + } + store.set_last_function_calls(calls.clone()).await; + info!("MITM: stored {} function call(s) from initial body", calls.len()); + } + if !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if let Some(ref gm) = streaming_acc.grounding_metadata { + store.set_grounding(gm.clone()).await; } - } else if streaming_acc.is_complete { - debug!("MITM: skipping store write — generation stale (initial body)"); } } @@ -890,45 +909,60 @@ async fn handle_http_over_tls( let s = String::from_utf8_lossy(chunk); parse_streaming_chunk(&s, &mut streaming_acc); - // Only write to store if our generation is still current. - let gen_valid = !won_gate || store.current_generation() == conn_generation; - if gen_valid { - // Store captured function calls (drain to avoid re-storing on next chunk) + // Send events through channel if available, otherwise use legacy store + if let Some(ref tx) = event_tx { + // Function calls → channel event if !streaming_acc.function_calls.is_empty() { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); - for fc in &calls { - store - .record_function_call(cascade_hint.as_deref(), fc.clone()) - .await; - } store.set_last_function_calls(calls.clone()).await; - info!( - "MITM: stored {} function call(s) from body chunk", - calls.len() - ); - } - - // Capture response + thinking text + grounding into MitmStore - if !streaming_acc.response_text.is_empty() { - store.set_response_text(&streaming_acc.response_text).await; + store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await; + info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len()); + let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await; } + // Thinking delta → channel event (send accumulated, handler tracks last len) if !streaming_acc.thinking_text.is_empty() { - store.set_thinking_text(&streaming_acc.thinking_text).await; + let _ = tx.send(super::store::MitmEvent::ThinkingDelta( + streaming_acc.thinking_text.clone(), + )).await; } + // Text delta → channel event (send accumulated, handler tracks last len) + if !streaming_acc.response_text.is_empty() { + let _ = tx.send(super::store::MitmEvent::TextDelta( + streaming_acc.response_text.clone(), + )).await; + } + // Grounding → channel event if let Some(ref gm) = streaming_acc.grounding_metadata { store.set_grounding(gm.clone()).await; + let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await; } + // Response complete → channel event if streaming_acc.is_complete { info!( response_text_len = streaming_acc.response_text.len(), thinking_text_len = streaming_acc.thinking_text.len(), function_calls = streaming_acc.function_calls.len(), - "MITM: response complete — marking store" + "MITM: response complete — sending via channel" ); - store.mark_response_complete(); + let _ = tx.send(super::store::MitmEvent::ResponseComplete).await; + streaming_acc.is_complete = false; // prevent duplicate sends + } + } else { + // Legacy path: store writes for non-channel consumers + if !streaming_acc.function_calls.is_empty() { + let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); + for fc in &calls { + store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; + } + store.set_last_function_calls(calls.clone()).await; + info!("MITM: stored {} function call(s) from body chunk", calls.len()); + } + if !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if let Some(ref gm) = streaming_acc.grounding_metadata { + store.set_grounding(gm.clone()).await; } - } else if streaming_acc.is_complete { - debug!("MITM: skipping store write — generation stale (body chunk)"); } } @@ -969,9 +1003,11 @@ async fn handle_http_over_tls( store.set_grounding(gm.clone()).await; } if streaming_acc.is_complete || streaming_acc.output_tokens > 0 { - // Function calls are stored immediately when detected (above), - // so no need to store them again here. let usage = streaming_acc.into_usage(); + // Send usage through channel if available + if let Some(ref tx) = event_tx { + let _ = tx.send(super::store::MitmEvent::Usage(usage.clone())).await; + } store.record_usage(cascade_hint.as_deref(), usage).await; } } else if !response_body_buf.is_empty() { diff --git a/src/mitm/store.rs b/src/mitm/store.rs index 5b23114..66e5cdb 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -1,12 +1,14 @@ //! Shared store for intercepted API usage data. //! //! The MITM proxy writes usage data here; the API handlers read from it. +//! When custom tools are active, the MITM proxy sends real-time events +//! through a channel instead of writing to shared state. use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tokio::sync::RwLock; +use tokio::sync::{mpsc, RwLock}; use tracing::{debug, info}; /// Token usage from an intercepted API response. @@ -126,6 +128,29 @@ pub struct GenerationParams { pub google_search: bool, } +// ─── Channel-based event pipeline ──────────────────────────────────────────── + +/// Events sent from the MITM proxy to API handlers through a per-request channel. +/// Replaces the old polling-based approach (shared atomics + RwLocks) with +/// instant, race-free delivery. +#[derive(Debug, Clone)] +pub enum MitmEvent { + /// Incremental thinking/reasoning text from the model. + ThinkingDelta(String), + /// Incremental response text from the model. + TextDelta(String), + /// Model requested function call(s). + FunctionCall(Vec), + /// Response streaming is complete (finishReason received). + ResponseComplete, + /// Google API returned an error. + UpstreamError(UpstreamError), + /// Grounding metadata (search results) from the response. + Grounding(serde_json::Value), + /// Token usage data from the response. + Usage(ApiUsage), +} + /// Thread-safe store for intercepted data. /// /// Keyed by a unique request ID that we can correlate with cascade operations. @@ -138,20 +163,17 @@ pub struct MitmStore { stats: Arc>, /// Pending function calls captured from Google responses. /// Key: cascade hint or "_latest". Value: list of function calls. + /// Used by the non-tool LS path (normal sync responses). pending_function_calls: Arc>>>, - /// Simple flag: set when a functionCall is captured, cleared when consumed. - /// Used to block follow-up requests regardless of cascade identification. - has_active_function_call: Arc, - /// Persistent flag: set when a function call is captured, cleared ONLY when - /// a tool result is submitted. Prevents the LS from making follow-up API - /// calls during the entire tool execution cycle. - awaiting_tool_result: Arc, /// Set when the MITM forwards the first LLM request with custom tools. /// Blocks ALL subsequent LS requests until the API handler clears it. request_in_flight: Arc, - /// Generation counter — incremented each time a new completions turn starts. - /// Used to discard stale data from leaked LS connections. - request_generation: Arc, + + // ── Channel-based event pipeline (replaces old polling) ────────────── + /// Active channel sender for the current tool-path request. + /// When present, the MITM proxy sends events through this instead of + /// writing to shared state. The channel's existence = request in-flight. + active_channel: Arc>>>, // ── Tool call support ──────────────────────────────────────────────── /// Active tool definitions (Gemini format) for MITM injection. @@ -173,14 +195,9 @@ pub struct MitmStore { /// Used by the MITM proxy to correlate intercepted traffic to cascades. active_cascade_id: Arc>>, - // ── Direct response capture (bypasses LS) ──────────────────────────── - /// Captured response text from MITM when custom tools are active. - /// The completions/responses handler reads this instead of polling LS steps. + // ── Legacy direct response capture (used by search.rs) ─────────────── + /// Captured response text from MITM. Used as fallback by search endpoint. captured_response_text: Arc>>, - /// Captured thinking/reasoning text from MITM (for real-time streaming). - captured_thinking_text: Arc>>, - /// Whether the captured response is complete (finishReason received). - response_complete: Arc, // ── Generation parameters for MITM injection ───────────────────────── /// Client-specified sampling parameters to inject into Google API requests. @@ -194,9 +211,14 @@ pub struct MitmStore { /// Image to inject into the next Google API request via MITM. pending_image: Arc>>, - // ── Upstream error capture ─────────────────────────────────────────── + // ── Upstream error capture (legacy, used when no channel) ──────────── /// Error from Google's API, captured by MITM for forwarding to client. upstream_error: Arc>>, + + // ── Standard LS input: real user text for MITM injection ───────────── + /// The real user text to inject into the Google API request. + /// API handlers store this before sending a dummy prompt to the LS. + pending_user_text: Arc>>, } /// Aggregate statistics across all intercepted traffic. @@ -229,10 +251,8 @@ impl MitmStore { latest_usage: Arc::new(RwLock::new(HashMap::new())), stats: Arc::new(RwLock::new(MitmStats::default())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())), - has_active_function_call: Arc::new(AtomicBool::new(false)), - awaiting_tool_result: Arc::new(AtomicBool::new(false)), request_in_flight: Arc::new(AtomicBool::new(false)), - request_generation: Arc::new(AtomicU64::new(0)), + active_channel: Arc::new(RwLock::new(None)), active_tools: Arc::new(RwLock::new(None)), active_tool_config: Arc::new(RwLock::new(None)), pending_tool_results: Arc::new(RwLock::new(Vec::new())), @@ -241,12 +261,11 @@ impl MitmStore { tool_rounds: Arc::new(RwLock::new(Vec::new())), active_cascade_id: Arc::new(RwLock::new(None)), captured_response_text: Arc::new(RwLock::new(None)), - captured_thinking_text: Arc::new(RwLock::new(None)), - response_complete: Arc::new(AtomicBool::new(false)), generation_params: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)), pending_image: Arc::new(RwLock::new(None)), upstream_error: Arc::new(RwLock::new(None)), + pending_user_text: Arc::new(RwLock::new(None)), } } @@ -381,30 +400,6 @@ impl MitmStore { ); let mut pending = self.pending_function_calls.write().await; pending.entry(key).or_default().push(fc); - self.has_active_function_call.store(true, Ordering::SeqCst); - self.awaiting_tool_result.store(true, Ordering::SeqCst); - } - - /// Check if there's an active (unclaimed) function call. - pub fn has_active_function_call(&self) -> bool { - self.has_active_function_call.load(Ordering::SeqCst) - } - - /// Force-clear the active function call flag (used to reset stale state). - pub fn clear_active_function_call(&self) { - self.has_active_function_call.store(false, Ordering::SeqCst); - } - - /// Check if we're awaiting a tool result (blocks LS follow-up requests). - /// This persists across function call consumption — only cleared when - /// actual tool results are submitted. - pub fn is_awaiting_tool_result(&self) -> bool { - self.awaiting_tool_result.load(Ordering::SeqCst) - } - - /// Clear the awaiting-tool-result flag (called when tool results arrive). - pub fn clear_awaiting_tool_result(&self) { - self.awaiting_tool_result.store(false, Ordering::SeqCst); } /// Take pending function calls for a specific cascade. @@ -417,7 +412,6 @@ impl MitmStore { // 1. Exact cascade match if let Some(result) = pending.remove(cascade_id) { - self.has_active_function_call.store(false, Ordering::SeqCst); return Some(result); } @@ -425,7 +419,6 @@ impl MitmStore { if let Some(active) = self.active_cascade_id.read().await.as_ref() { if active != cascade_id { if let Some(result) = pending.remove(active.as_str()) { - self.has_active_function_call.store(false, Ordering::SeqCst); return Some(result); } } @@ -433,17 +426,12 @@ impl MitmStore { // 3. Fallback to _latest if let Some(result) = pending.remove("_latest") { - self.has_active_function_call.store(false, Ordering::SeqCst); return Some(result); } // 4. Last resort: any key if let Some(key) = pending.keys().next().cloned() { - let result = pending.remove(&key); - if result.is_some() { - self.has_active_function_call.store(false, Ordering::SeqCst); - } - return result; + return pending.remove(&key); } None @@ -455,19 +443,40 @@ impl MitmStore { let mut pending = self.pending_function_calls.write().await; let result = pending.remove("_latest"); if result.is_some() { - self.has_active_function_call.store(false, Ordering::SeqCst); return result; } if let Some(key) = pending.keys().next().cloned() { - let result = pending.remove(&key); - if result.is_some() { - self.has_active_function_call.store(false, Ordering::SeqCst); - } - return result; + return pending.remove(&key); } None } + // ── Channel-based event pipeline ───────────────────────────────────── + + /// Install a channel sender for the current tool-path request. + /// The MITM proxy will send events through this channel. + pub async fn set_channel(&self, tx: mpsc::Sender) { + *self.active_channel.write().await = Some(tx); + // NOTE: Do NOT set request_in_flight here. The MITM proxy's + // try_mark_request_in_flight() is the sole setter — setting it + // here causes compare_exchange(false,true) to always fail, + // blocking every real LS request. + } + + /// Take the active channel sender (used by MITM proxy to grab it). + /// Returns None if no channel is active. + pub async fn take_channel(&self) -> Option> { + self.active_channel.write().await.take() + } + + + /// Drop the active channel and clear in-flight state. + /// Called when the API handler is done with the current request. + pub async fn drop_channel(&self) { + *self.active_channel.write().await = None; + self.request_in_flight.store(false, Ordering::SeqCst); + } + // ── Tool context methods ───────────────────────────────────────────── /// Set active tool definitions (already in Gemini format). @@ -546,10 +555,10 @@ impl MitmStore { self.tool_rounds.read().await.clone() } - - // ── Direct response capture (bypass LS) ────────────────────────────── + // ── Legacy direct response capture (search.rs fallback) ────────────── /// Set (replace) the captured response text. + /// Used by MITM proxy for non-channel path (search endpoint fallback). pub async fn set_response_text(&self, text: &str) { *self.captured_response_text.write().await = Some(text.to_string()); } @@ -559,28 +568,12 @@ impl MitmStore { self.captured_response_text.write().await.take() } - /// Peek at the captured response text without consuming it. - pub async fn peek_response_text(&self) -> Option { - self.captured_response_text.read().await.clone() - } - - /// Mark the response as complete. - pub fn mark_response_complete(&self) { - self.response_complete.store(true, Ordering::SeqCst); - } - - /// Check if the response is complete. - pub fn is_response_complete(&self) -> bool { - self.response_complete.load(Ordering::SeqCst) - } - - /// Async version of clear_response. Bumps generation counter. + /// Clear stale state between requests. + /// Drops any active channel and clears in-flight flags. pub async fn clear_response_async(&self) { - self.response_complete.store(false, Ordering::SeqCst); self.request_in_flight.store(false, Ordering::SeqCst); - self.request_generation.fetch_add(1, Ordering::SeqCst); + *self.active_channel.write().await = None; *self.captured_response_text.write().await = None; - *self.captured_thinking_text.write().await = None; } /// Atomically try to mark request as in-flight. @@ -593,6 +586,7 @@ impl MitmStore { } /// Check if a request is currently in-flight. + #[allow(dead_code)] pub fn is_request_in_flight(&self) -> bool { self.request_in_flight.load(Ordering::SeqCst) } @@ -602,38 +596,6 @@ impl MitmStore { self.request_in_flight.store(false, Ordering::SeqCst); } - /// Reset response_complete so we can wait for the next response. - pub fn clear_response_complete(&self) { - self.response_complete.store(false, Ordering::SeqCst); - } - - /// Get current generation number. - pub fn current_generation(&self) -> u64 { - self.request_generation.load(Ordering::SeqCst) - } - - /// Bump generation counter (invalidates all pending data from old generation). - pub fn bump_generation(&self) -> u64 { - self.request_generation.fetch_add(1, Ordering::SeqCst) + 1 - } - - // ── Thinking text capture ──────────────────────────────────────────── - - /// Set (replace) the captured thinking text. - pub async fn set_thinking_text(&self, text: &str) { - *self.captured_thinking_text.write().await = Some(text.to_string()); - } - - /// Peek at the captured thinking text without consuming it. - pub async fn peek_thinking_text(&self) -> Option { - self.captured_thinking_text.read().await.clone() - } - - /// Take the captured thinking text (consumes it). - pub async fn take_thinking_text(&self) -> Option { - self.captured_thinking_text.write().await.take() - } - // ── Cascade correlation ────────────────────────────────────────────── /// Set the active cascade ID (called by API handlers before sending a message). @@ -718,4 +680,18 @@ impl MitmStore { pub async fn clear_upstream_error(&self) { *self.upstream_error.write().await = None; } + + // ── Pending user text for MITM injection ───────────────────────────── + + /// Store the real user text for MITM injection. + /// Called by API handlers before sending a dummy prompt to the LS. + pub async fn set_pending_user_text(&self, text: String) { + *self.pending_user_text.write().await = Some(text); + } + + /// Take (consume) the pending user text. + /// Called by the MITM proxy when building ToolContext. + pub async fn take_pending_user_text(&self) -> Option { + self.pending_user_text.write().await.take() + } }