fix: block ALL LS follow-up requests, deduplicate function calls

- Add request_in_flight flag to MitmStore, set immediately when first
  LLM request is forwarded with custom tools active
- Block ALL subsequent LS requests (agentic loop + internal flash-lite)
  with fake SSE responses instead of waiting for response_complete
- Fix function call deduplication: drain() accumulator after storing
  to prevent 3x duplicate tool calls across SSE chunks
- Clear all stale state (response, thinking, function calls, errors)
  at the start of each streaming request
- Handle response_complete with no content (thoughtSignature-only)
  gracefully with timeout instead of infinite hang
This commit is contained in:
Nikketryhard
2026-02-16 00:51:56 -06:00
parent 5f40385c8d
commit a8f3c8915f
6 changed files with 419 additions and 326 deletions

View File

@@ -488,9 +488,12 @@ async fn chat_completions_stream(
let mut last_text = String::new(); let mut last_text = String::new();
let has_custom_tools = state.mitm_store.get_tools().await.is_some(); let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear any stale captured response and upstream errors from previous requests // Clear ALL stale state from previous requests
state.mitm_store.clear_response_async().await; state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await; state.mitm_store.clear_upstream_error().await;
state.mitm_store.clear_active_function_call();
// Drain any stale function calls from previous requests
let _ = state.mitm_store.take_any_function_calls().await;
// Initial role chunk // Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
@@ -501,6 +504,7 @@ async fn chat_completions_stream(
let mut keepalive_counter: u64 = 0; let mut keepalive_counter: u64 = 0;
let mut last_thinking_len: usize = 0; let mut last_thinking_len: usize = 0;
let mut complete_polls: u32 = 0;
// Helper: build usage JSON from MITM tokens // Helper: build usage JSON from MITM tokens
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value { let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
@@ -586,19 +590,13 @@ async fn chat_completions_stream(
} }
} }
// ── Check for MITM-captured response text (bypass LS) ── // ── Primary: MITM-captured response (when custom tools are active) ──
// The MITM intercepts the real Google SSE stream and captures text,
// thinking, and function calls. This is the authoritative data source.
// The LS only gets rewritten responses (function calls → text placeholders)
// so it doesn't provide useful streaming data when MITM is active.
if has_custom_tools { if has_custom_tools {
let peek = state.mitm_store.peek_response_text().await; // Stream thinking text as reasoning_content deltas
let complete = state.mitm_store.is_response_complete();
let has_fc = state.mitm_store.has_active_function_call();
if keepalive_counter % 10 == 0 || peek.is_some() || complete || has_fc {
debug!(
"Completions bypass poll: peek={}, complete={}, has_fc={}, last_text_len={}",
peek.as_ref().map(|t| t.len()).unwrap_or(0),
complete, has_fc, last_text.len()
);
}
// Stream thinking text as reasoning_content deltas (MITM bypass)
if let Some(tc) = state.mitm_store.peek_thinking_text().await { if let Some(tc) = state.mitm_store.peek_thinking_text().await {
if tc.len() > last_thinking_len { if tc.len() > last_thinking_len {
let delta = &tc[last_thinking_len..]; let delta = &tc[last_thinking_len..];
@@ -612,7 +610,8 @@ async fn chat_completions_stream(
} }
} }
if let Some(text) = peek { // Stream response text as content deltas
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text { if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) { let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
text[last_text.len()..].to_string() text[last_text.len()..].to_string()
@@ -629,11 +628,18 @@ async fn chat_completions_stream(
last_text = text; last_text = text;
} }
} }
}
// Check if MITM response is complete // Check if MITM response is complete
if state.mitm_store.is_response_complete() && !last_text.is_empty() { // Must have ACTUAL content (response text or function calls) — not just thinking.
debug!("Completions: MITM response complete (bypass), text length={}", last_text.len()); // The LS makes multiple API calls and response_complete flips on each one,
// Take usage FIRST so we can read stop_reason for finish_reason // so we wait for it to be stable across 2+ polls with real content.
if state.mitm_store.is_response_complete() {
if !last_text.is_empty() {
// Have actual response text — done
complete_polls += 1;
if complete_polls >= 2 {
debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await); .or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
@@ -655,44 +661,22 @@ async fn chat_completions_stream(
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
return; return;
} }
} else if complete { } else if last_thinking_len > 0 {
// Response complete but no text — might be a tool call arriving shortly, // Only thinking so far — wait for actual text/tools to arrive
// stale state from a previous request, or an empty response. // The LS may still be processing and will make follow-up API calls
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; complete_polls += 1;
// Re-check function calls one more time if complete_polls >= 6 {
let final_check = state.mitm_store.take_any_function_calls().await; // Waited ~2s with no text/tools after complete — emit what we have
if let Some(ref calls) = final_check { debug!("Completions: MITM thinking-only timeout, thinking_len={}", last_thinking_len);
if !calls.is_empty() { let mitm = state.mitm_store.take_usage(&cascade_id).await
let mut tool_calls = Vec::new(); .or(state.mitm_store.take_usage("_latest").await);
for (i, fc) in calls.iter().enumerate() { let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json( yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name, &completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]), serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None, None,
))); )));
if include_usage { if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm { let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens) (u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) }; } else { (0, 0, 0, 0) };
@@ -705,24 +689,27 @@ async fn chat_completions_stream(
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
return; return;
} }
} else {
// response_complete but no text AND no thinking — might be
// a function-call-only response that was already consumed,
// or empty response. Wait a bit then give up.
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
return;
} }
// No text and no function calls but complete=true: stale state.
// Clear the flag so we wait for the real response from this request.
warn!("Completions: stale response_complete detected (no text, no FC) — clearing");
state.mitm_store.clear_response_async().await;
} }
} else {
// When using bypass mode, skip LS step polling complete_polls = 0; // Reset — not complete yet
keepalive_counter += 1;
if keepalive_counter % 10 == 0 {
yield Ok(Event::default().comment("keepalive"));
} }
let poll_ms: u64 = rand::thread_rng().gen_range(200..350); } else {
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await; // ── Fallback: LS steps (no MITM capture active) ──
continue;
}
// ── Check LS steps for text streaming ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
if status == 200 { if status == 200 {
if let Some(steps) = data["steps"].as_array() { if let Some(steps) = data["steps"].as_array() {
@@ -759,9 +746,10 @@ async fn chat_completions_stream(
} }
} }
// Done check: need DONE status AND non-empty text // Done check
if is_response_done(steps) && !last_text.is_empty() { let has_content = !last_text.is_empty() || last_thinking_len > 0;
debug!("Completions stream done, text length={}", last_text.len()); if is_response_done(steps) && has_content {
debug!("Completions stream done, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await); .or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
@@ -790,8 +778,9 @@ async fn chat_completions_stream(
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await { if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
if ts == 200 { if ts == 200 {
let run_status = td["status"].as_str().unwrap_or(""); let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") && !last_text.is_empty() { let has_content_idle = !last_text.is_empty() || last_thinking_len > 0;
debug!("Completions IDLE, text length={}", last_text.len()); if run_status.contains("IDLE") && has_content_idle {
debug!("Completions IDLE, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await); .or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
@@ -819,6 +808,7 @@ async fn chat_completions_stream(
} }
} }
} }
}
// Keep-alive comment every ~5 iterations // Keep-alive comment every ~5 iterations
keepalive_counter += 1; keepalive_counter += 1;

View File

@@ -50,10 +50,15 @@ struct Cli {
#[arg(long)] #[arg(long)]
no_standalone: bool, no_standalone: bool,
/// Headless mode — no running Antigravity app required. /// Headless mode (DEFAULT) — no running Antigravity app required.
/// Generates its own CSRF, disables extension server, uses HTTPS_PROXY for MITM. /// Generates its own CSRF, disables extension server, uses HTTPS_PROXY for MITM.
#[arg(long)] /// Use --no-headless or --classic to attach to a running Antigravity instance instead.
#[arg(long, default_value_t = true)]
headless: bool, headless: bool,
/// Classic mode — requires a running Antigravity app. Alias for --no-headless.
#[arg(long, conflicts_with = "headless")]
classic: bool,
} }
#[tokio::main] #[tokio::main]
@@ -75,6 +80,7 @@ async fn main() {
let _ = rustls::crypto::ring::default_provider().install_default(); let _ = rustls::crypto::ring::default_provider().install_default();
let cli = Cli::parse(); let cli = Cli::parse();
let headless = cli.headless && !cli.classic;
// Flag > env var > default (warn) // Flag > env var > default (warn)
let log_level = if cli.debug { let log_level = if cli.debug {
@@ -97,22 +103,76 @@ async fn main() {
.with_env_filter(filter) .with_env_filter(filter)
.init(); .init();
// ── Step 1: Bind main port FIRST (fail fast, before spawning anything) ──── // ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
let addr = format!("127.0.0.1:{}", cli.port); let addr = format!("127.0.0.1:{}", cli.port);
let listener = match tokio::net::TcpListener::bind(&addr).await { let listener = match tokio::net::TcpListener::bind(&addr).await {
Ok(l) => l,
Err(_) => {
// Port in use — try to kill whatever's holding it
eprintln!(" Port {} in use, killing stale process...", cli.port);
let _ = std::process::Command::new("sh")
.args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)])
.status();
// Also kill any leftover standalone LS processes
let _ = std::process::Command::new("pkill")
.args(["-f", "language_server_linux.*antigravity-standalone"])
.status();
// Retry once
match tokio::net::TcpListener::bind(&addr).await {
Ok(l) => l, Ok(l) => l,
Err(e) => { Err(e) => {
eprintln!("Fatal: cannot bind to {addr}: {e}"); eprintln!("Fatal: cannot bind to {addr} even after kill: {e}");
eprintln!("Hint: kill $(lsof -ti:{}) 2>/dev/null", cli.port);
std::process::exit(1); std::process::exit(1);
} }
}
}
}; };
// ── Step 2: Backend discovery (or standalone LS spawn) ───────────────────── // ── Step 2: MITM proxy (must be running BEFORE LS spawn in headless) ──────
// In headless mode, the LS is configured with HTTPS_PROXY pointing at our
// MITM proxy, so it must be listening before the LS tries to connect.
let mitm_store = MitmStore::new();
let (mitm_port_actual, mitm_handle) = if !cli.no_mitm {
let data_dir = dirs_data_dir();
match mitm::ca::MitmCa::load_or_generate(&data_dir) {
Ok(ca) => {
let ca = Arc::new(ca);
let ca_pem = ca.ca_pem_path.display().to_string();
let config = mitm::proxy::MitmConfig {
port: cli.mitm_port,
modify_requests: true,
};
match mitm::proxy::run(ca, mitm_store.clone(), config).await {
Ok((port, handle)) => {
info!(port, ca = %ca_pem, "MITM proxy started");
// Write actual port to file for wrapper script discovery
let port_file = data_dir.join("mitm-port");
if let Err(e) = std::fs::write(&port_file, port.to_string()) {
warn!("Failed to write MITM port file: {e}");
}
(Some((port, ca_pem)), Some(handle))
}
Err(e) => {
warn!("MITM proxy failed to start: {e}");
(None, None)
}
}
}
Err(e) => {
warn!("MITM CA generation failed: {e}");
(None, None)
}
}
} else {
info!("MITM proxy disabled (--no-mitm)");
(None, None)
};
// ── Step 3: Backend discovery (or standalone LS spawn) ─────────────────────
// --headless implies standalone mode // --headless implies standalone mode
let standalone_ls = if cli.headless || !cli.no_standalone { let standalone_ls = if headless || !cli.no_standalone {
// Get LS config: headless generates its own, normal steals from running LS // Get LS config: headless generates its own, normal steals from running LS
let main_config = if cli.headless { let main_config = if headless {
info!("Headless mode: generating self-contained config"); info!("Headless mode: generating self-contained config");
standalone::generate_standalone_config() standalone::generate_standalone_config()
} else { } else {
@@ -120,26 +180,26 @@ async fn main() {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
eprintln!("Fatal: {e}"); eprintln!("Fatal: {e}");
eprintln!("Hint: start Antigravity first, or use --headless for full independence"); eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode");
std::process::exit(1); std::process::exit(1);
} }
} }
}; };
// Build MITM config if MITM is enabled // Build MITM config using the actual MITM port (not just the CLI default)
let mitm_cfg = if !cli.no_mitm { let mitm_cfg = if let Some((mitm_port, _)) = &mitm_port_actual {
let ca_path = dirs_data_dir() let ca_path = dirs_data_dir()
.join("mitm-ca.pem") .join("mitm-ca.pem")
.to_string_lossy() .to_string_lossy()
.to_string(); .to_string();
Some(standalone::StandaloneMitmConfig { Some(standalone::StandaloneMitmConfig {
proxy_addr: format!("http://127.0.0.1:{}", cli.mitm_port), proxy_addr: format!("http://127.0.0.1:{}", mitm_port),
ca_cert_path: ca_path, ca_cert_path: ca_path,
}) })
} else { } else {
None None
}; };
let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), cli.headless) { let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
Ok(ls) => ls, Ok(ls) => ls,
Err(e) => { Err(e) => {
eprintln!("Fatal: failed to spawn standalone LS: {e}"); eprintln!("Fatal: failed to spawn standalone LS: {e}");
@@ -197,46 +257,8 @@ async fn main() {
let (pid, https_port, csrf, token) = backend.info().await; let (pid, https_port, csrf, token) = backend.info().await;
// ── Step 3: MITM proxy (after port is secured) ────────────────────────────
let mitm_store = MitmStore::new();
let (mitm_port_actual, mitm_handle) = if !cli.no_mitm {
let data_dir = dirs_data_dir();
match mitm::ca::MitmCa::load_or_generate(&data_dir) {
Ok(ca) => {
let ca = Arc::new(ca);
let ca_pem = ca.ca_pem_path.display().to_string();
let config = mitm::proxy::MitmConfig {
port: cli.mitm_port,
modify_requests: true,
};
match mitm::proxy::run(ca, mitm_store.clone(), config).await {
Ok((port, handle)) => {
info!(port, ca = %ca_pem, "MITM proxy started");
// Write actual port to file for wrapper script discovery
let port_file = data_dir.join("mitm-port");
if let Err(e) = std::fs::write(&port_file, port.to_string()) {
warn!("Failed to write MITM port file: {e}");
}
(Some((port, ca_pem)), Some(handle))
}
Err(e) => {
warn!("MITM proxy failed to start: {e}");
(None, None)
}
}
}
Err(e) => {
warn!("MITM CA generation failed: {e}");
(None, None)
}
}
} else {
info!("MITM proxy disabled (--no-mitm)");
(None, None)
};
// ── Step 4: Warmup + heartbeat ──────────────────────────────────────────── // ── Step 4: Warmup + heartbeat ────────────────────────────────────────────
warmup::warmup_sequence(&backend).await; warmup::warmup_sequence(&backend, headless).await;
let heartbeat_handle = warmup::start_heartbeat(Arc::clone(&backend)); let heartbeat_handle = warmup::start_heartbeat(Arc::clone(&backend));
// ── Step 4b: Quota monitor ──────────────────────────────────────────────── // ── Step 4b: Quota monitor ────────────────────────────────────────────────

View File

@@ -202,11 +202,35 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// Inject client-provided tools from ToolContext // Inject client-provided tools from ToolContext
if let Some(ref ctx) = tool_ctx { if let Some(ref ctx) = tool_ctx {
if let Some(ref custom_tools) = ctx.tools { if let Some(ref custom_tools) = ctx.tools {
let total_decls: usize = custom_tools.iter()
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
.map(|a| a.len())
.sum();
for tool in custom_tools { for tool in custom_tools {
tools.push(tool.clone()); tools.push(tool.clone());
} }
has_custom_tools = true; has_custom_tools = true;
changes.push(format!("inject {} custom tool group(s)", custom_tools.len())); changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
// Override LS's VALIDATED toolConfig → AUTO for custom tools.
// VALIDATED mode forces Google to validate function calls against a
// specific tool list that the LS controls. Our custom tools aren't in
// that list, so they'd be rejected. AUTO lets the model freely choose
// between text and function calls.
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
let has_validated = req.get("toolConfig")
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
.and_then(|m| m.as_str())
.map_or(false, |m| m == "VALIDATED");
if has_validated {
req.insert("toolConfig".to_string(), serde_json::json!({
"functionCallingConfig": {
"mode": "AUTO"
}
}));
changes.push("override toolConfig VALIDATED → AUTO".to_string());
}
}
} }
} }
} }
@@ -230,11 +254,26 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
} }
} }
// ── 3b. ALWAYS strip old functionCall/functionResponse from history ─── // ── 3b. Strip old functionCall/functionResponse from history ────────
// Even when custom tools are injected, the LS history contains function // The LS history contains function call parts for LS-internal tools
// call parts for LS-internal tools we stripped. Google rejects these as // we stripped. Google rejects these as MALFORMED_FUNCTION_CALL because
// MALFORMED_FUNCTION_CALL because the referenced tools don't exist. // the referenced tools don't exist. However, when custom tools are
// injected, we must PRESERVE function calls for those tools so the
// model retains its tool call history and doesn't re-execute them.
if STRIP_ALL_TOOLS { if STRIP_ALL_TOOLS {
// Build set of custom tool names to preserve
let custom_tool_names: std::collections::HashSet<String> = tool_ctx
.as_ref()
.and_then(|ctx| ctx.tools.as_ref())
.map(|tools| {
tools.iter()
.filter_map(|t| t["functionDeclarations"].as_array())
.flatten()
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
if let Some(contents) = json if let Some(contents) = json
.pointer_mut("/request/contents") .pointer_mut("/request/contents")
.and_then(|v| v.as_array_mut()) .and_then(|v| v.as_array_mut())
@@ -244,8 +283,21 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) { if let Some(parts) = msg.get_mut("parts").and_then(|v| v.as_array_mut()) {
let before = parts.len(); let before = parts.len();
parts.retain(|part| { parts.retain(|part| {
!part.get("functionCall").is_some() // Check functionCall — keep if it's for a custom tool
&& !part.get("functionResponse").is_some() if let Some(fc) = part.get("functionCall") {
if let Some(name) = fc.get("name").and_then(|v| v.as_str()) {
return custom_tool_names.contains(name);
}
return false; // No name → strip
}
// Check functionResponse — keep if it's for a custom tool
if let Some(fr) = part.get("functionResponse") {
if let Some(name) = fr.get("name").and_then(|v| v.as_str()) {
return custom_tool_names.contains(name);
}
return false; // No name → strip
}
true // Not a function part → keep
}); });
stripped_fc += before - parts.len(); stripped_fc += before - parts.len();
} }

View File

@@ -458,6 +458,7 @@ async fn handle_http_over_tls(
loop { loop {
// ── Read the HTTP request from the client ───────────────────────── // ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64); let mut request_buf = Vec::with_capacity(1024 * 64);
let mut is_our_request = false;
// 60s timeout on initial read (LS may open connection without sending immediately) // 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
@@ -587,23 +588,30 @@ async fn handle_http_over_tls(
let mut new_buf = updated_headers.into_bytes(); let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked); new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf; request_buf = new_buf;
// Mark this as our modified request and set in-flight flag
is_our_request = true;
store.mark_request_in_flight();
} }
} }
} }
// ── Block follow-up requests when we already have a captured functionCall ── // ── Block ALL LS follow-up requests once first is in-flight ──
// The LS doesn't know what to do with the functionCall, so it tries more // When custom tools are active, we only need ONE request to Google.
// Google API calls. Block those to save quota. // The LS tries to send multiple requests (its own agentic loop +
if store.has_active_function_call() { // internal requests on gemini-2.5-flash-lite). Block them ALL
// immediately — don't wait for response_complete.
let has_tools = store.get_tools().await.is_some();
if has_tools && store.is_request_in_flight() && !is_our_request {
info!( info!(
"MITM: blocking follow-up request — functionCall already captured" "MITM: blocking LS follow-up request already in-flight"
); );
// Return a fake SSE response that makes the LS stop // Return a fake SSE response that makes the LS stop
let fake_response = "HTTP/1.1 200 OK\r\n\ let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\ Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\ Transfer-Encoding: chunked\r\n\
\r\n"; \r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Tool call completed. Awaiting external tool result.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n"; let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes()); let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec(); let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body); response.extend_from_slice(&chunked_body);
@@ -763,25 +771,21 @@ async fn handle_http_over_tls(
response_body_buf.extend_from_slice(&header_buf[hdr_end..]); response_body_buf.extend_from_slice(&header_buf[hdr_end..]);
// Parse ORIGINAL initial body for MITM interception // Parse ORIGINAL initial body for MITM interception
let mut has_function_call = false;
let bypass_ls = modify_requests && store.get_tools().await.is_some();
if is_streaming_response && hdr_end < header_buf.len() { if is_streaming_response && hdr_end < header_buf.len() {
let body = String::from_utf8_lossy(&header_buf[hdr_end..]); let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc); parse_streaming_chunk(&body, &mut streaming_acc);
has_function_call = !streaming_acc.function_calls.is_empty();
// Immediately store captured function calls // Store captured function calls (drain to avoid re-storing on next chunk)
if has_function_call { if !streaming_acc.function_calls.is_empty() {
for fc in &streaming_acc.function_calls { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
} }
store.set_last_function_calls(streaming_acc.function_calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", streaming_acc.function_calls.len()); info!("MITM: stored {} function call(s) from initial body", calls.len());
} }
// Capture response + thinking text + grounding directly into MitmStore // Capture response + thinking text + grounding into MitmStore
if bypass_ls {
if !streaming_acc.response_text.is_empty() { if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await; store.set_response_text(&streaming_acc.response_text).await;
} }
@@ -795,31 +799,18 @@ async fn handle_http_over_tls(
store.mark_response_complete(); store.mark_response_complete();
} }
} }
}
if bypass_ls { // Forward to client — rewrite function calls if custom tools are injected
if has_function_call { let forward_buf = if modify_requests {
info!("MITM: functionCall captured → NOT forwarding to LS (bypass mode)"); if let Some(modified) = super::modify::modify_response_chunk(&header_buf) {
store.mark_response_complete(); modified
break; } else {
header_buf.clone()
} }
// Don't forward to LS — just continue reading chunks } else {
// Send headers only so upstream doesn't close header_buf.clone()
if let Some(cl) = response_content_length { };
if response_body_buf.len() >= cl { if let Err(e) = client.write_all(&forward_buf).await {
store.mark_response_complete();
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
store.mark_response_complete();
break;
}
continue;
}
// Normal path (no custom tools): forward headers+body as-is
if let Err(e) = client.write_all(&header_buf).await {
warn!(error = %e, "MITM: write to client failed"); warn!(error = %e, "MITM: write to client failed");
break; break;
} }
@@ -838,25 +829,21 @@ async fn handle_http_over_tls(
} }
// ── Response body interception ──────────────────────────────── // ── Response body interception ────────────────────────────────
let mut chunk_has_fc = false;
let bypass_ls = modify_requests && store.get_tools().await.is_some();
if is_streaming_response { if is_streaming_response {
let s = String::from_utf8_lossy(chunk); let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc); parse_streaming_chunk(&s, &mut streaming_acc);
chunk_has_fc = !streaming_acc.function_calls.is_empty();
// Immediately store captured function calls // Store captured function calls (drain to avoid re-storing on next chunk)
if chunk_has_fc { if !streaming_acc.function_calls.is_empty() {
for fc in &streaming_acc.function_calls { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
} }
store.set_last_function_calls(streaming_acc.function_calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", streaming_acc.function_calls.len()); info!("MITM: stored {} function call(s) from body chunk", calls.len());
} }
// Capture response + thinking text + grounding directly into MitmStore // Capture response + thinking text + grounding into MitmStore
if bypass_ls {
if !streaming_acc.response_text.is_empty() { if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await; store.set_response_text(&streaming_acc.response_text).await;
} }
@@ -870,31 +857,18 @@ async fn handle_http_over_tls(
store.mark_response_complete(); store.mark_response_complete();
} }
} }
}
if bypass_ls { // Forward chunk to client (LS) — rewrite function calls if custom tools
if chunk_has_fc || streaming_acc.is_complete { let forward_chunk = if modify_requests {
info!("MITM: response captured → NOT forwarding to LS (bypass mode)"); if let Some(modified) = super::modify::modify_response_chunk(chunk) {
store.mark_response_complete(); modified
break; } else {
chunk.to_vec()
} }
// Keep reading chunks without forwarding to LS } else {
response_body_buf.extend_from_slice(chunk); chunk.to_vec()
if let Some(cl) = response_content_length { };
if response_body_buf.len() >= cl { if let Err(e) = client.write_all(&forward_chunk).await {
store.mark_response_complete();
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
store.mark_response_complete();
break;
}
continue;
}
// Normal path: forward chunk to client (LS)
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed"); warn!(error = %e, "MITM: write to client failed");
break; break;
} }

View File

@@ -130,6 +130,13 @@ pub struct MitmStore {
/// Simple flag: set when a functionCall is captured, cleared when consumed. /// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification. /// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>, has_active_function_call: Arc<AtomicBool>,
/// Persistent flag: set when a function call is captured, cleared ONLY when
/// a tool result is submitted. Prevents the LS from making follow-up API
/// calls during the entire tool execution cycle.
awaiting_tool_result: Arc<AtomicBool>,
/// Set when the MITM forwards the first LLM request with custom tools.
/// Blocks ALL subsequent LS requests until the API handler clears it.
request_in_flight: Arc<AtomicBool>,
// ── Tool call support ──────────────────────────────────────────────── // ── Tool call support ────────────────────────────────────────────────
/// Active tool definitions (Gemini format) for MITM injection. /// Active tool definitions (Gemini format) for MITM injection.
@@ -205,6 +212,8 @@ impl MitmStore {
stats: Arc::new(RwLock::new(MitmStats::default())), stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)), has_active_function_call: Arc::new(AtomicBool::new(false)),
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
request_in_flight: Arc::new(AtomicBool::new(false)),
active_tools: Arc::new(RwLock::new(None)), active_tools: Arc::new(RwLock::new(None)),
active_tool_config: Arc::new(RwLock::new(None)), active_tool_config: Arc::new(RwLock::new(None)),
pending_tool_results: Arc::new(RwLock::new(Vec::new())), pending_tool_results: Arc::new(RwLock::new(Vec::new())),
@@ -343,6 +352,7 @@ impl MitmStore {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc); pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst); self.has_active_function_call.store(true, Ordering::SeqCst);
self.awaiting_tool_result.store(true, Ordering::SeqCst);
} }
/// Check if there's an active (unclaimed) function call. /// Check if there's an active (unclaimed) function call.
@@ -355,6 +365,18 @@ impl MitmStore {
self.has_active_function_call.store(false, Ordering::SeqCst); self.has_active_function_call.store(false, Ordering::SeqCst);
} }
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
/// This persists across function call consumption — only cleared when
/// actual tool results are submitted.
pub fn is_awaiting_tool_result(&self) -> bool {
self.awaiting_tool_result.load(Ordering::SeqCst)
}
/// Clear the awaiting-tool-result flag (called when tool results arrive).
pub fn clear_awaiting_tool_result(&self) {
self.awaiting_tool_result.store(false, Ordering::SeqCst);
}
/// Take any pending function calls (ignoring cascade ID). /// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> { pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
@@ -467,10 +489,21 @@ impl MitmStore {
/// Async version of clear_response. /// Async version of clear_response.
pub async fn clear_response_async(&self) { pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst); self.response_complete.store(false, Ordering::SeqCst);
self.request_in_flight.store(false, Ordering::SeqCst);
*self.captured_response_text.write().await = None; *self.captured_response_text.write().await = None;
*self.captured_thinking_text.write().await = None; *self.captured_thinking_text.write().await = None;
} }
/// Mark the request as in-flight (first LLM request forwarded).
pub fn mark_request_in_flight(&self) {
self.request_in_flight.store(true, Ordering::SeqCst);
}
/// Check if a request is currently in-flight.
pub fn is_request_in_flight(&self) -> bool {
self.request_in_flight.load(Ordering::SeqCst)
}
// ── Thinking text capture ──────────────────────────────────────────── // ── Thinking text capture ────────────────────────────────────────────
/// Set (replace) the captured thinking text. /// Set (replace) the captured thinking text.

View File

@@ -15,9 +15,31 @@ use tracing::{debug, info, warn};
/// ///
/// Called BEFORE accepting any API requests. Each call is fire-and-forget /// Called BEFORE accepting any API requests. Each call is fire-and-forget
/// (we don't care if some fail — the LS might not support all methods). /// (we don't care if some fail — the LS might not support all methods).
pub async fn warmup_sequence(backend: &Backend) { pub async fn warmup_sequence(backend: &Backend, headless: bool) {
info!("Running webview warmup sequence..."); info!("Running webview warmup sequence...");
// ── CRITICAL: Set detect_and_use_proxy BEFORE any API-triggering call ──
// The LS creates its HTTP transport lazily on the first API call.
// If we set ENABLED before that, the LS will honor HTTPS_PROXY env var,
// routing API traffic through the MITM proxy without iptables/sudo.
if headless {
let settings_body = serde_json::json!({
"userSettings": {
"detectAndUseProxy": 1 // DETECT_AND_USE_PROXY_ENABLED
}
});
match tokio::time::timeout(
Duration::from_secs(5),
backend.call_json("SetUserSettings", &settings_body),
)
.await
{
Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"),
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
Err(_) => warn!("SetUserSettings timed out"),
}
}
let calls: &[(&str, serde_json::Value)] = &[ let calls: &[(&str, serde_json::Value)] = &[
("GetStatus", serde_json::json!({})), ("GetStatus", serde_json::json!({})),
("Heartbeat", serde_json::json!({})), ("Heartbeat", serde_json::json!({})),