From a8f3c8915f3828023680fe003d4d9c3b8cee52aa Mon Sep 17 00:00:00 2001 From: Nikketryhard Date: Mon, 16 Feb 2026 00:51:56 -0600 Subject: [PATCH] fix: block ALL LS follow-up requests, deduplicate function calls - Add request_in_flight flag to MitmStore, set immediately when first LLM request is forwarded with custom tools active - Block ALL subsequent LS requests (agentic loop + internal flash-lite) with fake SSE responses instead of waiting for response_complete - Fix function call deduplication: drain() accumulator after storing to prevent 3x duplicate tool calls across SSE chunks - Clear all stale state (response, thinking, function calls, errors) at the start of each streaming request - Handle response_complete with no content (thoughtSignature-only) gracefully with timeout instead of infinite hang --- src/api/completions.rs | 328 ++++++++++++++++++++--------------------- src/main.rs | 130 +++++++++------- src/mitm/modify.rs | 64 +++++++- src/mitm/proxy.rs | 166 +++++++++------------ src/mitm/store.rs | 33 +++++ src/warmup.rs | 24 ++- 6 files changed, 419 insertions(+), 326 deletions(-) diff --git a/src/api/completions.rs b/src/api/completions.rs index a085fcb..33a7c22 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -488,9 +488,12 @@ async fn chat_completions_stream( let mut last_text = String::new(); let has_custom_tools = state.mitm_store.get_tools().await.is_some(); - // Clear any stale captured response and upstream errors from previous requests + // Clear ALL stale state from previous requests state.mitm_store.clear_response_async().await; state.mitm_store.clear_upstream_error().await; + state.mitm_store.clear_active_function_call(); + // Drain any stale function calls from previous requests + let _ = state.mitm_store.take_any_function_calls().await; // Initial role chunk yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( @@ -501,6 +504,7 @@ async fn chat_completions_stream( let mut keepalive_counter: u64 = 0; let mut last_thinking_len: usize = 0; + let mut complete_polls: u32 = 0; // Helper: build usage JSON from MITM tokens let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value { @@ -586,19 +590,13 @@ async fn chat_completions_stream( } } - // ── Check for MITM-captured response text (bypass LS) ── + // ── Primary: MITM-captured response (when custom tools are active) ── + // The MITM intercepts the real Google SSE stream and captures text, + // thinking, and function calls. This is the authoritative data source. + // The LS only gets rewritten responses (function calls → text placeholders) + // so it doesn't provide useful streaming data when MITM is active. if has_custom_tools { - let peek = state.mitm_store.peek_response_text().await; - let complete = state.mitm_store.is_response_complete(); - let has_fc = state.mitm_store.has_active_function_call(); - if keepalive_counter % 10 == 0 || peek.is_some() || complete || has_fc { - debug!( - "Completions bypass poll: peek={}, complete={}, has_fc={}, last_text_len={}", - peek.as_ref().map(|t| t.len()).unwrap_or(0), - complete, has_fc, last_text.len() - ); - } - // Stream thinking text as reasoning_content deltas (MITM bypass) + // Stream thinking text as reasoning_content deltas if let Some(tc) = state.mitm_store.peek_thinking_text().await { if tc.len() > last_thinking_len { let delta = &tc[last_thinking_len..]; @@ -612,7 +610,8 @@ async fn chat_completions_stream( } } - if let Some(text) = peek { + // Stream response text as content deltas + if let Some(text) = state.mitm_store.peek_response_text().await { if !text.is_empty() && text != last_text { let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { text[last_text.len()..].to_string() @@ -629,139 +628,18 @@ async fn chat_completions_stream( last_text = text; } } - - // Check if MITM response is complete - if state.mitm_store.is_response_complete() && !last_text.is_empty() { - debug!("Completions: MITM response complete (bypass), text length={}", last_text.len()); - // Take usage FIRST so we can read stop_reason for finish_reason - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), - None, - ))); - if include_usage { - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), - ))); - } - yield Ok(Event::default().data("[DONE]")); - return; - } - } else if complete { - // Response complete but no text — might be a tool call arriving shortly, - // stale state from a previous request, or an empty response. - tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; - // Re-check function calls one more time - let final_check = state.mitm_store.take_any_function_calls().await; - if let Some(ref calls) = final_check { - if !calls.is_empty() { - let mut tool_calls = Vec::new(); - for (i, fc) in calls.iter().enumerate() { - let call_id = format!( - "call_{}", - uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() - ); - let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); - tool_calls.push(serde_json::json!({ - "index": i, - "id": call_id, - "type": "function", - "function": { - "name": fc.name, - "arguments": arguments, - }, - })); - } - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]), - None, - ))); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]), - None, - ))); - if include_usage { - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), - ))); - } - yield Ok(Event::default().data("[DONE]")); - return; - } - } - // No text and no function calls but complete=true: stale state. - // Clear the flag so we wait for the real response from this request. - warn!("Completions: stale response_complete detected (no text, no FC) — clearing"); - state.mitm_store.clear_response_async().await; } - // When using bypass mode, skip LS step polling - keepalive_counter += 1; - if keepalive_counter % 10 == 0 { - yield Ok(Event::default().comment("keepalive")); - } - let poll_ms: u64 = rand::thread_rng().gen_range(200..350); - tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; - continue; - } - - // ── Check LS steps for text streaming ── - if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { - if status == 200 { - if let Some(steps) = data["steps"].as_array() { - // Stream thinking deltas (reasoning_content) - if let Some(tc) = extract_thinking_content(steps) { - if tc.len() > last_thinking_len { - let delta = &tc[last_thinking_len..]; - last_thinking_len = tc.len(); - - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), - None, - ))); - } - } - - let text = extract_response_text(steps); - - if !text.is_empty() && text != last_text { - let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { - &text[last_text.len()..] - } else { - &text - }; - - if !delta.is_empty() { - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), - None, - ))); - last_text = text.to_string(); - } - } - - // Done check: need DONE status AND non-empty text - if is_response_done(steps) && !last_text.is_empty() { - debug!("Completions stream done, text length={}", last_text.len()); + // Check if MITM response is complete + // Must have ACTUAL content (response text or function calls) — not just thinking. + // The LS makes multiple API calls and response_complete flips on each one, + // so we wait for it to be stable across 2+ polls with real content. + if state.mitm_store.is_response_complete() { + if !last_text.is_empty() { + // Have actual response text — done + complete_polls += 1; + if complete_polls >= 2 { + debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len); let mitm = state.mitm_store.take_usage(&cascade_id).await .or(state.mitm_store.take_usage("_latest").await); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); @@ -783,35 +661,147 @@ async fn chat_completions_stream( yield Ok(Event::default().data("[DONE]")); return; } + } else if last_thinking_len > 0 { + // Only thinking so far — wait for actual text/tools to arrive + // The LS may still be processing and will make follow-up API calls + complete_polls += 1; + if complete_polls >= 6 { + // Waited ~2s with no text/tools after complete — emit what we have + debug!("Completions: MITM thinking-only timeout, thinking_len={}", last_thinking_len); + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await); + let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), + None, + ))); + if include_usage { + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + return; + } + } else { + // response_complete but no text AND no thinking — might be + // a function-call-only response that was already consumed, + // or empty response. Wait a bit then give up. + complete_polls += 1; + if complete_polls >= 4 { + info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream"); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]), + None, + ))); + yield Ok(Event::default().data("[DONE]")); + return; + } + } + } else { + complete_polls = 0; // Reset — not complete yet + } + } else { + // ── Fallback: LS steps (no MITM capture active) ── + if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { + if status == 200 { + if let Some(steps) = data["steps"].as_array() { + // Stream thinking deltas (reasoning_content) + if let Some(tc) = extract_thinking_content(steps) { + if tc.len() > last_thinking_len { + let delta = &tc[last_thinking_len..]; + last_thinking_len = tc.len(); - // IDLE fallback - let step_count = steps.len(); - if step_count > 4 && step_count % 5 == 0 { - if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { - if ts == 200 { - let run_status = td["status"].as_str().unwrap_or(""); - if run_status.contains("IDLE") && !last_text.is_empty() { - debug!("Completions IDLE, text length={}", last_text.len()); - let mitm = state.mitm_store.take_usage(&cascade_id).await - .or(state.mitm_store.take_usage("_latest").await); - let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); - yield Ok(Event::default().data(chunk_json( - &completion_id, &model_name, - serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), - None, - ))); - if include_usage { - let (pt, ct, crt, tt) = if let Some(ref u) = mitm { - (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) - } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]), + None, + ))); + } + } + + let text = extract_response_text(steps); + + if !text.is_empty() && text != last_text { + let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { + &text[last_text.len()..] + } else { + &text + }; + + if !delta.is_empty() { + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]), + None, + ))); + last_text = text.to_string(); + } + } + + // Done check + let has_content = !last_text.is_empty() || last_thinking_len > 0; + if is_response_done(steps) && has_content { + debug!("Completions stream done, text length={}, thinking_len={}", last_text.len(), last_thinking_len); + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await); + let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), + None, + ))); + if include_usage { + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + return; + } + + // IDLE fallback + let step_count = steps.len(); + if step_count > 4 && step_count % 5 == 0 { + if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { + if ts == 200 { + let run_status = td["status"].as_str().unwrap_or(""); + let has_content_idle = !last_text.is_empty() || last_thinking_len > 0; + if run_status.contains("IDLE") && has_content_idle { + debug!("Completions IDLE, text length={}, thinking_len={}", last_text.len(), last_thinking_len); + let mitm = state.mitm_store.take_usage(&cascade_id).await + .or(state.mitm_store.take_usage("_latest").await); + let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); yield Ok(Event::default().data(chunk_json( &completion_id, &model_name, - serde_json::json!([]), - Some(build_usage(pt, ct, crt, tt)), + serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]), + None, ))); + if include_usage { + let (pt, ct, crt, tt) = if let Some(ref u) = mitm { + (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) + } else { (0, 0, 0, 0) }; + yield Ok(Event::default().data(chunk_json( + &completion_id, &model_name, + serde_json::json!([]), + Some(build_usage(pt, ct, crt, tt)), + ))); + } + yield Ok(Event::default().data("[DONE]")); + return; } - yield Ok(Event::default().data("[DONE]")); - return; } } } diff --git a/src/main.rs b/src/main.rs index bcd8234..7f0e580 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,10 +50,15 @@ struct Cli { #[arg(long)] no_standalone: bool, - /// Headless mode — no running Antigravity app required. + /// Headless mode (DEFAULT) — no running Antigravity app required. /// Generates its own CSRF, disables extension server, uses HTTPS_PROXY for MITM. - #[arg(long)] + /// Use --no-headless or --classic to attach to a running Antigravity instance instead. + #[arg(long, default_value_t = true)] headless: bool, + + /// Classic mode — requires a running Antigravity app. Alias for --no-headless. + #[arg(long, conflicts_with = "headless")] + classic: bool, } #[tokio::main] @@ -75,6 +80,7 @@ async fn main() { let _ = rustls::crypto::ring::default_provider().install_default(); let cli = Cli::parse(); + let headless = cli.headless && !cli.classic; // Flag > env var > default (warn) let log_level = if cli.debug { @@ -97,22 +103,76 @@ async fn main() { .with_env_filter(filter) .init(); - // ── Step 1: Bind main port FIRST (fail fast, before spawning anything) ──── + // ── Step 1: Bind main port (auto-kill stale process if needed) ───────────── let addr = format!("127.0.0.1:{}", cli.port); let listener = match tokio::net::TcpListener::bind(&addr).await { Ok(l) => l, - Err(e) => { - eprintln!("Fatal: cannot bind to {addr}: {e}"); - eprintln!("Hint: kill $(lsof -ti:{}) 2>/dev/null", cli.port); - std::process::exit(1); + Err(_) => { + // Port in use — try to kill whatever's holding it + eprintln!(" Port {} in use, killing stale process...", cli.port); + let _ = std::process::Command::new("sh") + .args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)]) + .status(); + // Also kill any leftover standalone LS processes + let _ = std::process::Command::new("pkill") + .args(["-f", "language_server_linux.*antigravity-standalone"]) + .status(); + // Retry once + match tokio::net::TcpListener::bind(&addr).await { + Ok(l) => l, + Err(e) => { + eprintln!("Fatal: cannot bind to {addr} even after kill: {e}"); + std::process::exit(1); + } + } } }; - // ── Step 2: Backend discovery (or standalone LS spawn) ───────────────────── + // ── Step 2: MITM proxy (must be running BEFORE LS spawn in headless) ────── + // In headless mode, the LS is configured with HTTPS_PROXY pointing at our + // MITM proxy, so it must be listening before the LS tries to connect. + let mitm_store = MitmStore::new(); + let (mitm_port_actual, mitm_handle) = if !cli.no_mitm { + let data_dir = dirs_data_dir(); + match mitm::ca::MitmCa::load_or_generate(&data_dir) { + Ok(ca) => { + let ca = Arc::new(ca); + let ca_pem = ca.ca_pem_path.display().to_string(); + let config = mitm::proxy::MitmConfig { + port: cli.mitm_port, + modify_requests: true, + }; + match mitm::proxy::run(ca, mitm_store.clone(), config).await { + Ok((port, handle)) => { + info!(port, ca = %ca_pem, "MITM proxy started"); + // Write actual port to file for wrapper script discovery + let port_file = data_dir.join("mitm-port"); + if let Err(e) = std::fs::write(&port_file, port.to_string()) { + warn!("Failed to write MITM port file: {e}"); + } + (Some((port, ca_pem)), Some(handle)) + } + Err(e) => { + warn!("MITM proxy failed to start: {e}"); + (None, None) + } + } + } + Err(e) => { + warn!("MITM CA generation failed: {e}"); + (None, None) + } + } + } else { + info!("MITM proxy disabled (--no-mitm)"); + (None, None) + }; + + // ── Step 3: Backend discovery (or standalone LS spawn) ───────────────────── // --headless implies standalone mode - let standalone_ls = if cli.headless || !cli.no_standalone { + let standalone_ls = if headless || !cli.no_standalone { // Get LS config: headless generates its own, normal steals from running LS - let main_config = if cli.headless { + let main_config = if headless { info!("Headless mode: generating self-contained config"); standalone::generate_standalone_config() } else { @@ -120,26 +180,26 @@ async fn main() { Ok(c) => c, Err(e) => { eprintln!("Fatal: {e}"); - eprintln!("Hint: start Antigravity first, or use --headless for full independence"); + eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode"); std::process::exit(1); } } }; - // Build MITM config if MITM is enabled - let mitm_cfg = if !cli.no_mitm { + // Build MITM config using the actual MITM port (not just the CLI default) + let mitm_cfg = if let Some((mitm_port, _)) = &mitm_port_actual { let ca_path = dirs_data_dir() .join("mitm-ca.pem") .to_string_lossy() .to_string(); Some(standalone::StandaloneMitmConfig { - proxy_addr: format!("http://127.0.0.1:{}", cli.mitm_port), + proxy_addr: format!("http://127.0.0.1:{}", mitm_port), ca_cert_path: ca_path, }) } else { None }; - let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), cli.headless) { + let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) { Ok(ls) => ls, Err(e) => { eprintln!("Fatal: failed to spawn standalone LS: {e}"); @@ -197,46 +257,8 @@ async fn main() { let (pid, https_port, csrf, token) = backend.info().await; - // ── Step 3: MITM proxy (after port is secured) ──────────────────────────── - let mitm_store = MitmStore::new(); - let (mitm_port_actual, mitm_handle) = if !cli.no_mitm { - let data_dir = dirs_data_dir(); - match mitm::ca::MitmCa::load_or_generate(&data_dir) { - Ok(ca) => { - let ca = Arc::new(ca); - let ca_pem = ca.ca_pem_path.display().to_string(); - let config = mitm::proxy::MitmConfig { - port: cli.mitm_port, - modify_requests: true, - }; - match mitm::proxy::run(ca, mitm_store.clone(), config).await { - Ok((port, handle)) => { - info!(port, ca = %ca_pem, "MITM proxy started"); - // Write actual port to file for wrapper script discovery - let port_file = data_dir.join("mitm-port"); - if let Err(e) = std::fs::write(&port_file, port.to_string()) { - warn!("Failed to write MITM port file: {e}"); - } - (Some((port, ca_pem)), Some(handle)) - } - Err(e) => { - warn!("MITM proxy failed to start: {e}"); - (None, None) - } - } - } - Err(e) => { - warn!("MITM CA generation failed: {e}"); - (None, None) - } - } - } else { - info!("MITM proxy disabled (--no-mitm)"); - (None, None) - }; - // ── Step 4: Warmup + heartbeat ──────────────────────────────────────────── - warmup::warmup_sequence(&backend).await; + warmup::warmup_sequence(&backend, headless).await; let heartbeat_handle = warmup::start_heartbeat(Arc::clone(&backend)); // ── Step 4b: Quota monitor ──────────────────────────────────────────────── diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index 67f525d..a159136 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -202,11 +202,35 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option) -> Option = tool_ctx + .as_ref() + .and_then(|ctx| ctx.tools.as_ref()) + .map(|tools| { + tools.iter() + .filter_map(|t| t["functionDeclarations"].as_array()) + .flatten() + .filter_map(|decl| decl["name"].as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + if let Some(contents) = json .pointer_mut("/request/contents") .and_then(|v| v.as_array_mut()) @@ -244,8 +283,21 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option = streaming_acc.function_calls.drain(..).collect(); + for fc in &calls { store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; } - store.set_last_function_calls(streaming_acc.function_calls.clone()).await; - info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len()); + store.set_last_function_calls(calls.clone()).await; + info!("MITM: stored {} function call(s) from initial body", calls.len()); } - // Capture response + thinking text + grounding directly into MitmStore - if bypass_ls { - if !streaming_acc.response_text.is_empty() { - store.set_response_text(&streaming_acc.response_text).await; - } - if !streaming_acc.thinking_text.is_empty() { - store.set_thinking_text(&streaming_acc.thinking_text).await; - } - if let Some(ref gm) = streaming_acc.grounding_metadata { - store.set_grounding(gm.clone()).await; - } - if streaming_acc.is_complete { - store.mark_response_complete(); - } + // Capture response + thinking text + grounding into MitmStore + if !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if !streaming_acc.thinking_text.is_empty() { + store.set_thinking_text(&streaming_acc.thinking_text).await; + } + if let Some(ref gm) = streaming_acc.grounding_metadata { + store.set_grounding(gm.clone()).await; + } + if streaming_acc.is_complete { + store.mark_response_complete(); } } - if bypass_ls { - if has_function_call { - info!("MITM: functionCall captured → NOT forwarding to LS (bypass mode)"); - store.mark_response_complete(); - break; + // Forward to client — rewrite function calls if custom tools are injected + let forward_buf = if modify_requests { + if let Some(modified) = super::modify::modify_response_chunk(&header_buf) { + modified + } else { + header_buf.clone() } - // Don't forward to LS — just continue reading chunks - // Send headers only so upstream doesn't close - if let Some(cl) = response_content_length { - if response_body_buf.len() >= cl { - store.mark_response_complete(); - break; - } - } - if is_chunked && has_chunked_terminator(&response_body_buf) { - store.mark_response_complete(); - break; - } - continue; - } - - // Normal path (no custom tools): forward headers+body as-is - if let Err(e) = client.write_all(&header_buf).await { + } else { + header_buf.clone() + }; + if let Err(e) = client.write_all(&forward_buf).await { warn!(error = %e, "MITM: write to client failed"); break; } @@ -838,63 +829,46 @@ async fn handle_http_over_tls( } // ── Response body interception ──────────────────────────────── - let mut chunk_has_fc = false; - let bypass_ls = modify_requests && store.get_tools().await.is_some(); - if is_streaming_response { let s = String::from_utf8_lossy(chunk); parse_streaming_chunk(&s, &mut streaming_acc); - chunk_has_fc = !streaming_acc.function_calls.is_empty(); - // Immediately store captured function calls - if chunk_has_fc { - for fc in &streaming_acc.function_calls { + // Store captured function calls (drain to avoid re-storing on next chunk) + if !streaming_acc.function_calls.is_empty() { + let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); + for fc in &calls { store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; } - store.set_last_function_calls(streaming_acc.function_calls.clone()).await; - info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len()); + store.set_last_function_calls(calls.clone()).await; + info!("MITM: stored {} function call(s) from body chunk", calls.len()); } - // Capture response + thinking text + grounding directly into MitmStore - if bypass_ls { - if !streaming_acc.response_text.is_empty() { - store.set_response_text(&streaming_acc.response_text).await; - } - if !streaming_acc.thinking_text.is_empty() { - store.set_thinking_text(&streaming_acc.thinking_text).await; - } - if let Some(ref gm) = streaming_acc.grounding_metadata { - store.set_grounding(gm.clone()).await; - } - if streaming_acc.is_complete { - store.mark_response_complete(); - } + // Capture response + thinking text + grounding into MitmStore + if !streaming_acc.response_text.is_empty() { + store.set_response_text(&streaming_acc.response_text).await; + } + if !streaming_acc.thinking_text.is_empty() { + store.set_thinking_text(&streaming_acc.thinking_text).await; + } + if let Some(ref gm) = streaming_acc.grounding_metadata { + store.set_grounding(gm.clone()).await; + } + if streaming_acc.is_complete { + store.mark_response_complete(); } } - if bypass_ls { - if chunk_has_fc || streaming_acc.is_complete { - info!("MITM: response captured → NOT forwarding to LS (bypass mode)"); - store.mark_response_complete(); - break; + // Forward chunk to client (LS) — rewrite function calls if custom tools + let forward_chunk = if modify_requests { + if let Some(modified) = super::modify::modify_response_chunk(chunk) { + modified + } else { + chunk.to_vec() } - // Keep reading chunks without forwarding to LS - response_body_buf.extend_from_slice(chunk); - if let Some(cl) = response_content_length { - if response_body_buf.len() >= cl { - store.mark_response_complete(); - break; - } - } - if is_chunked && has_chunked_terminator(&response_body_buf) { - store.mark_response_complete(); - break; - } - continue; - } - - // Normal path: forward chunk to client (LS) - if let Err(e) = client.write_all(chunk).await { + } else { + chunk.to_vec() + }; + if let Err(e) = client.write_all(&forward_chunk).await { warn!(error = %e, "MITM: write to client failed"); break; } diff --git a/src/mitm/store.rs b/src/mitm/store.rs index cce9d5a..c30aa9c 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -130,6 +130,13 @@ pub struct MitmStore { /// Simple flag: set when a functionCall is captured, cleared when consumed. /// Used to block follow-up requests regardless of cascade identification. has_active_function_call: Arc, + /// Persistent flag: set when a function call is captured, cleared ONLY when + /// a tool result is submitted. Prevents the LS from making follow-up API + /// calls during the entire tool execution cycle. + awaiting_tool_result: Arc, + /// Set when the MITM forwards the first LLM request with custom tools. + /// Blocks ALL subsequent LS requests until the API handler clears it. + request_in_flight: Arc, // ── Tool call support ──────────────────────────────────────────────── /// Active tool definitions (Gemini format) for MITM injection. @@ -205,6 +212,8 @@ impl MitmStore { stats: Arc::new(RwLock::new(MitmStats::default())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())), has_active_function_call: Arc::new(AtomicBool::new(false)), + awaiting_tool_result: Arc::new(AtomicBool::new(false)), + request_in_flight: Arc::new(AtomicBool::new(false)), active_tools: Arc::new(RwLock::new(None)), active_tool_config: Arc::new(RwLock::new(None)), pending_tool_results: Arc::new(RwLock::new(Vec::new())), @@ -343,6 +352,7 @@ impl MitmStore { let mut pending = self.pending_function_calls.write().await; pending.entry(key).or_default().push(fc); self.has_active_function_call.store(true, Ordering::SeqCst); + self.awaiting_tool_result.store(true, Ordering::SeqCst); } /// Check if there's an active (unclaimed) function call. @@ -355,6 +365,18 @@ impl MitmStore { self.has_active_function_call.store(false, Ordering::SeqCst); } + /// Check if we're awaiting a tool result (blocks LS follow-up requests). + /// This persists across function call consumption — only cleared when + /// actual tool results are submitted. + pub fn is_awaiting_tool_result(&self) -> bool { + self.awaiting_tool_result.load(Ordering::SeqCst) + } + + /// Clear the awaiting-tool-result flag (called when tool results arrive). + pub fn clear_awaiting_tool_result(&self) { + self.awaiting_tool_result.store(false, Ordering::SeqCst); + } + /// Take any pending function calls (ignoring cascade ID). pub async fn take_any_function_calls(&self) -> Option> { @@ -467,10 +489,21 @@ impl MitmStore { /// Async version of clear_response. pub async fn clear_response_async(&self) { self.response_complete.store(false, Ordering::SeqCst); + self.request_in_flight.store(false, Ordering::SeqCst); *self.captured_response_text.write().await = None; *self.captured_thinking_text.write().await = None; } + /// Mark the request as in-flight (first LLM request forwarded). + pub fn mark_request_in_flight(&self) { + self.request_in_flight.store(true, Ordering::SeqCst); + } + + /// Check if a request is currently in-flight. + pub fn is_request_in_flight(&self) -> bool { + self.request_in_flight.load(Ordering::SeqCst) + } + // ── Thinking text capture ──────────────────────────────────────────── /// Set (replace) the captured thinking text. diff --git a/src/warmup.rs b/src/warmup.rs index c01bf32..a447b95 100644 --- a/src/warmup.rs +++ b/src/warmup.rs @@ -15,9 +15,31 @@ use tracing::{debug, info, warn}; /// /// Called BEFORE accepting any API requests. Each call is fire-and-forget /// (we don't care if some fail — the LS might not support all methods). -pub async fn warmup_sequence(backend: &Backend) { +pub async fn warmup_sequence(backend: &Backend, headless: bool) { info!("Running webview warmup sequence..."); + // ── CRITICAL: Set detect_and_use_proxy BEFORE any API-triggering call ── + // The LS creates its HTTP transport lazily on the first API call. + // If we set ENABLED before that, the LS will honor HTTPS_PROXY env var, + // routing API traffic through the MITM proxy without iptables/sudo. + if headless { + let settings_body = serde_json::json!({ + "userSettings": { + "detectAndUseProxy": 1 // DETECT_AND_USE_PROXY_ENABLED + } + }); + match tokio::time::timeout( + Duration::from_secs(5), + backend.call_json("SetUserSettings", &settings_body), + ) + .await + { + Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"), + Ok(Err(e)) => warn!("SetUserSettings failed: {e}"), + Err(_) => warn!("SetUserSettings timed out"), + } + } + let calls: &[(&str, serde_json::Value)] = &[ ("GetStatus", serde_json::json!({})), ("Heartbeat", serde_json::json!({})),