//! MITM proxy server: handles CONNECT tunnels and transparent TLS interception. //! //! Supports two modes: //! 1. **HTTP CONNECT** — standard proxy mode, LS sends `CONNECT host:port` //! 2. **Transparent (iptables)** — raw TLS arrives via REDIRECT, SNI extracted from ClientHello //! //! For intercepted domains, terminates TLS with our CA-signed cert, //! reads/modifies the request, forwards to the real upstream, and captures usage. //! //! For non-intercepted domains, acts as a transparent TCP tunnel. use super::ca::MitmCa; use super::intercept::{ extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator, }; use super::store::MitmStore; use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::TlsAcceptor; use tracing::{debug, error, info, trace, warn}; /// Domains we intercept (terminate TLS and inspect traffic). /// This includes exact matches and suffix matches for regional endpoints /// (e.g., us-central1-aiplatform.googleapis.com). const INTERCEPT_DOMAINS: &[&str] = &[ "cloudcode-pa.googleapis.com", "aiplatform.googleapis.com", "api.anthropic.com", "speech.googleapis.com", "modelarmor.googleapis.com", ]; /// Domains we NEVER intercept (transparent tunnel). const PASSTHROUGH_DOMAINS: &[&str] = &[ "oauth2.googleapis.com", "accounts.google.com", "storage.googleapis.com", "www.googleapis.com", "firebaseinstallations.googleapis.com", "crashlyticsreports-pa.googleapis.com", "play.googleapis.com", "update.googleapis.com", "dl.google.com", ]; /// Configuration for the MITM proxy. #[derive(Default)] pub struct MitmConfig { /// Port to listen on (0 = auto-assign). pub port: u16, /// Whether to enable request modification. pub modify_requests: bool, } /// Run the MITM proxy server. /// /// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown. pub async fn run( ca: Arc, store: MitmStore, config: MitmConfig, ) -> Result<(u16, tokio::task::JoinHandle<()>), String> { let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port)) .await .map_err(|e| format!("MITM bind failed: {e}"))?; let port = listener .local_addr() .map_err(|e| format!("MITM local_addr failed: {e}"))? .port(); info!(port, "MITM proxy listening"); let modify_requests = config.modify_requests; let handle = tokio::spawn(async move { loop { match listener.accept().await { Ok((stream, addr)) => { trace!(?addr, "MITM: new connection"); let ca = ca.clone(); let store = store.clone(); tokio::spawn(async move { if let Err(e) = handle_connection(stream, ca, store, modify_requests).await { warn!(error = %e, "MITM connection error"); } }); } Err(e) => { error!(error = %e, "MITM accept error"); } } } }); Ok((port, handle)) } /// Handle a single incoming connection. /// /// Supports two modes: /// 1. **HTTP CONNECT** (standard proxy) — LS sends `CONNECT host:port HTTP/1.1` /// 2. **Transparent/iptables redirect** — raw TLS ClientHello arrives directly /// (first byte is 0x16). We extract the domain from SNI and intercept. async fn handle_connection( mut stream: TcpStream, ca: Arc, store: MitmStore, modify_requests: bool, ) -> Result<(), String> { // Peek at the first byte to distinguish CONNECT vs raw TLS let mut peek = [0u8; 1]; let n = stream .peek(&mut peek) .await .map_err(|e| format!("Peek failed: {e}"))?; if n == 0 { return Ok(()); } if peek[0] == 0x16 { // TLS ClientHello — transparent/iptables redirect mode. // Peek enough bytes to extract SNI from the ClientHello. let mut hello_buf = vec![0u8; 16384]; let n = stream .peek(&mut hello_buf) .await .map_err(|e| format!("Peek ClientHello: {e}"))?; let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string()); info!(domain, "MITM: transparent redirect (iptables)"); let should_intercept = should_intercept_domain(&domain); if should_intercept { handle_intercepted(stream, &domain, ca, store, modify_requests).await } else { // For non-intercepted domains via iptables, we need the original dest. // Since we only have SNI, resolve and passthrough. handle_passthrough(stream, &domain, 443).await } } else { // Standard HTTP CONNECT proxy mode let mut buf = vec![0u8; 8192]; let n = stream .read(&mut buf) .await .map_err(|e| format!("Read CONNECT: {e}"))?; if n == 0 { return Ok(()); } let request = String::from_utf8_lossy(&buf[..n]); let first_line = request.lines().next().unwrap_or(""); // Parse "CONNECT host:port HTTP/1.1" let parts: Vec<&str> = first_line.split_whitespace().collect(); if parts.len() < 3 || parts[0] != "CONNECT" { let resp = "HTTP/1.1 400 Bad Request\r\n\r\n"; let _ = stream.write_all(resp.as_bytes()).await; return Ok(()); } let host_port = parts[1]; let (domain, _port) = match host_port.rsplit_once(':') { Some((h, p)) => (h, p.parse::().unwrap_or(443)), None => (host_port, 443), }; debug!(domain, "MITM: CONNECT request"); let should_intercept = should_intercept_domain(domain); // Send 200 Connection Established let response = "HTTP/1.1 200 Connection Established\r\n\r\n"; stream .write_all(response.as_bytes()) .await .map_err(|e| format!("Write 200: {e}"))?; if should_intercept { handle_intercepted(stream, domain, ca, store, modify_requests).await } else { handle_passthrough(stream, domain, _port).await } } } /// Extract SNI (Server Name Indication) from a TLS ClientHello message. /// /// Parses the raw TLS record to find the `server_name` extension (type 0x0000). /// Returns `None` if the SNI can't be found (not TLS, no SNI extension, etc.). fn extract_sni(buf: &[u8]) -> Option { // TLS record: type(1) + version(2) + length(2) + handshake if buf.len() < 5 || buf[0] != 0x16 { return None; } let record_len = u16::from_be_bytes([buf[3], buf[4]]) as usize; let handshake = &buf[5..5 + record_len.min(buf.len() - 5)]; // Handshake: type(1) + length(3) + body if handshake.is_empty() || handshake[0] != 0x01 { return None; // Not ClientHello } if handshake.len() < 4 { return None; } let hs_len = u32::from_be_bytes([0, handshake[1], handshake[2], handshake[3]]) as usize; let body = &handshake[4..4 + hs_len.min(handshake.len() - 4)]; // ClientHello: version(2) + random(32) + session_id_len(1) + session_id(var) // + cipher_suites_len(2) + cipher_suites(var) // + compression_len(1) + compression(var) // + extensions_len(2) + extensions(var) if body.len() < 34 { return None; } let mut pos = 34; // skip version + random // Session ID if pos >= body.len() { return None; } let sid_len = body[pos] as usize; pos += 1 + sid_len; // Cipher suites if pos + 2 > body.len() { return None; } let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; pos += 2 + cs_len; // Compression methods if pos >= body.len() { return None; } let cm_len = body[pos] as usize; pos += 1 + cm_len; // Extensions if pos + 2 > body.len() { return None; } let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; pos += 2; let ext_end = pos + ext_len.min(body.len() - pos); while pos + 4 <= ext_end { let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]); let ext_data_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize; pos += 4; if ext_type == 0x0000 { // SNI extension — server_name_list_len(2) + type(1) + name_len(2) + name if ext_data_len >= 5 && pos + ext_data_len <= ext_end { let name_len = u16::from_be_bytes([body[pos + 3], body[pos + 4]]) as usize; if pos + 5 + name_len <= ext_end { return String::from_utf8(body[pos + 5..pos + 5 + name_len].to_vec()).ok(); } } return None; } pos += ext_data_len; } None } /// Check if a domain should be intercepted. fn should_intercept_domain(domain: &str) -> bool { // Never intercept passthrough domains for &pt in PASSTHROUGH_DOMAINS { if domain == pt { return false; } } // Intercept known API domains (exact match, subdomain, or regional prefix) for &intercept in INTERCEPT_DOMAINS { if domain == intercept || domain.ends_with(&format!(".{intercept}")) || domain.ends_with(&format!("-{intercept}")) { return true; } } // Default: passthrough false } /// Handle an intercepted connection: terminate TLS, inspect traffic. /// /// After TLS termination, checks the negotiated ALPN protocol: /// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs) /// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic) async fn handle_intercepted( stream: TcpStream, domain: &str, ca: Arc, store: MitmStore, modify_requests: bool, ) -> Result<(), String> { info!(domain, "MITM: intercepting TLS"); // Get or create server TLS config for this domain let server_config = ca.server_config_for_domain(domain).await?; let acceptor = TlsAcceptor::from(server_config); // Perform TLS handshake with the client (LS) — 10s timeout let tls_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream)) .await { Ok(Ok(s)) => s, Ok(Err(e)) => { warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)"); return Err(format!( "TLS handshake with client failed for {domain}: {e}" )); } Err(_) => { warn!(domain, "MITM: TLS handshake TIMED OUT after 10s"); return Err(format!("TLS handshake timed out for {domain}")); } }; // Check negotiated ALPN protocol let alpn = tls_stream .get_ref() .1 .alpn_protocol() .map(|p| String::from_utf8_lossy(p).to_string()); info!(domain, alpn = ?alpn, "MITM: TLS handshake successful ✓"); match alpn.as_deref() { Some("h2") => { // HTTP/2 — use the hyper-based gRPC handler info!(domain, "MITM: routing to HTTP/2 handler (gRPC)"); super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await } _ => { // HTTP/1.1 or no ALPN — use the existing handler info!(domain, "MITM: routing to HTTP/1.1 handler"); handle_http_over_tls(tls_stream, domain, store, modify_requests).await } } } /// Handle HTTP traffic over the decrypted TLS connection. /// /// Loops to handle multiple requests on the same connection (HTTP keep-alive). /// Reads full request, connects to upstream, forwards request, streams response /// back to client while capturing usage data. async fn handle_http_over_tls( mut client: tokio_rustls::server::TlsStream, domain: &str, store: MitmStore, modify_requests: bool, ) -> Result<(), String> { let mut tmp = vec![0u8; 32768]; // Build upstream TLS connector once for this connection let mut root_store = rustls::RootCertStore::empty(); let native_certs = rustls_native_certs::load_native_certs(); for cert in native_certs.certs { let _ = root_store.add(cert); } let upstream_config = Arc::new( rustls::ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(), ); // Reusable upstream connection — created lazily, reconnected if stale let mut upstream: Option> = None; /// Connect (or reconnect) to the real upstream via TLS. /// /// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8), /// then falls back to cached IPs file, then to normal system resolution. async fn connect_upstream( domain: &str, config: &Arc, ) -> Result, String> { let connector = tokio_rustls::TlsConnector::from(config.clone()); // Try to resolve the real IP, bypassing /etc/hosts let addr = resolve_upstream(domain).await; info!(domain, addr = %addr, "MITM: connecting upstream"); let tcp = match tokio::time::timeout( std::time::Duration::from_secs(15), TcpStream::connect(&addr), ) .await { Ok(Ok(s)) => s, Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")), Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")), }; let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()) .map_err(|e| format!("Invalid server name: {e}"))?; match tokio::time::timeout( std::time::Duration::from_secs(15), connector.connect(server_name, tcp), ) .await { Ok(Ok(s)) => { info!(domain, "MITM: upstream TLS connected ✓"); Ok(s) } Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")), Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")), } } /// Resolve upstream IP bypassing /etc/hosts. async fn resolve_upstream(domain: &str) -> String { // 1. Try dig @8.8.8.8 (bypasses /etc/hosts) if let Ok(output) = tokio::process::Command::new("dig") .args(["+short", "@8.8.8.8", domain]) .output() .await { let out = String::from_utf8_lossy(&output.stdout); if let Some(ip) = out .lines() .find(|l| l.parse::().is_ok()) { return format!("{ip}:443"); } } // 2. Try cached IPs file (written by dns-redirect.sh install) if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await { for line in contents.lines() { if let Some((d, ip)) = line.split_once('=') { if d == domain { return format!("{ip}:443"); } } } } // 3. Fallback to normal resolution (may hit /etc/hosts) format!("{domain}:443") } // Keep-alive loop: handle multiple requests on this connection loop { // ── Read the HTTP request from the client ───────────────────────── let mut request_buf = Vec::with_capacity(1024 * 64); // 60s timeout on initial read (LS may open connection without sending immediately) const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); loop { let read_result = if request_buf.is_empty() { // First read — apply idle timeout match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await { Ok(r) => r, Err(_) => { // Idle timeout — connection pool warmup, no data sent debug!(domain, "MITM: client idle timeout (60s), closing"); return Ok(()); } } } else { // Subsequent reads — wait up to 30s for rest of request match tokio::time::timeout( std::time::Duration::from_secs(30), client.read(&mut tmp), ) .await { Ok(r) => r, Err(_) => { warn!(domain, "MITM: partial request read timed out"); return Err("Partial request read timed out".into()); } } }; let n = match read_result { Ok(0) => return Ok(()), // Client closed connection cleanly Ok(n) => n, Err(e) => { // Connection reset / broken pipe is normal for keep-alive end debug!(domain, error = %e, "MITM: client read finished"); return Ok(()); } }; request_buf.extend_from_slice(&tmp[..n]); // Check if we have the full request (headers + body) if has_complete_http_request(&request_buf) { break; } } if request_buf.is_empty() { return Ok(()); } // Parse the HTTP request to find headers and body let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf); // Try to extract cascade hint from request body let cascade_hint = if headers_end < request_buf.len() { extract_cascade_hint(&request_buf[headers_end..]) } else { None }; // Extract request method and path for logging let req_path = { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = httparse::Request::new(&mut headers); match req.parse(&request_buf) { Ok(httparse::Status::Complete(_)) => { format!("{} {}", req.method.unwrap_or("?"), req.path.unwrap_or("?")) } _ => "?".to_string(), } }; // Generation tracking for store write guards let mut won_gate = false; let mut conn_generation = store.current_generation(); // Log LLM calls at info, everything else at debug if req_path.contains("streamGenerateContent") { let body_len = request_buf.len() - headers_end; info!( domain, req_path = %req_path, body_len, cascade = ?cascade_hint, "MITM: forwarding LLM request" ); // ── Atomic in-flight gate ───────────────────────────────── // The LS opens multiple connections and sends parallel requests. // When custom tools are active, only the FIRST request wins the // atomic compare_exchange. All others get fake STOP responses. let has_tools = store.get_tools().await.is_some(); won_gate = if has_tools { if !store.try_mark_request_in_flight() { info!("MITM: blocking LS request — another request already in-flight"); let fake_response = "HTTP/1.1 200 OK\r\n\ Content-Type: text/event-stream\r\n\ Transfer-Encoding: chunked\r\n\ \r\n"; let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n"; let chunked_body = super::modify::rechunk(fake_sse.as_bytes()); let mut response = fake_response.as_bytes().to_vec(); response.extend_from_slice(&chunked_body); if let Err(e) = client.write_all(&response).await { warn!(error = %e, "MITM: failed to write fake response"); } let _ = client.flush().await; continue; } true } else { false }; // Snapshot the generation at gate-win time. If it changes later, // another completions turn started and our data is stale. conn_generation = store.current_generation(); // ── Request modification ───────────────────────────────────── // Dechunk body → check if agent request → modify → rechunk if modify_requests && body_len > 0 { let body_slice = &request_buf[headers_end..]; let raw_body = super::modify::dechunk(body_slice); // Only modify "agent" requests, not "checkpoint" (LS internal) let body_str = String::from_utf8_lossy(&raw_body); let is_agent = body_str.contains("\"requestType\":\"agent\"") || body_str.contains("\"requestType\": \"agent\""); if is_agent { // Build ToolContext from store let tools = store.get_tools().await; let tool_config = store.get_tool_config().await; let pending_results = store.take_tool_results().await; let last_calls = store.get_last_function_calls().await; let generation_params = store.get_generation_params().await; let pending_image = store.take_pending_image().await; let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() { Some(super::modify::ToolContext { tools, tool_config, pending_results, last_calls, generation_params, pending_image, }) } else { None }; if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) { // Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body let new_chunked = super::modify::rechunk(&modified_body); // Fix Content-Length header to match new body size let header_str = String::from_utf8_lossy(&request_buf[..headers_end]); let updated_headers = update_content_length(&header_str, new_chunked.len()); let mut new_buf = updated_headers.into_bytes(); new_buf.extend_from_slice(&new_chunked); request_buf = new_buf; // In-flight already marked atomically above } } } } else { debug!( domain, req_path = %req_path, content_length, "MITM: forwarding request" ); } // ── Ensure upstream connection is alive ────────────────────────────── // Lazily connect on first request, or reconnect if the previous connection died let conn = match upstream.as_mut() { Some(c) => c, None => { let c = connect_upstream(domain, &upstream_config).await?; upstream.insert(c) } }; // Forward the request — if write fails, reconnect and retry once if let Err(e) = conn.write_all(&request_buf).await { debug!(domain, error = %e, "MITM: upstream write failed, reconnecting"); let c = connect_upstream(domain, &upstream_config).await?; let conn = upstream.insert(c); conn.write_all(&request_buf) .await .map_err(|e| format!("Write to upstream (retry): {e}"))?; } let conn = upstream.as_mut().unwrap(); // ── Stream response back to client ────────────────────────────────── // ALWAYS forward data to client immediately (no buffering). // Buffer body on the side for usage parsing. let mut streaming_acc = StreamingAccumulator::new(); let mut is_streaming_response = false; let mut headers_parsed = false; let mut upstream_ok = true; let mut response_body_buf = Vec::new(); let mut response_content_length: Option = None; let mut is_chunked = false; let mut got_first_byte = false; let mut header_buf = Vec::with_capacity(8192); loop { // 15s idle timeout after first byte, 60s for initial response let timeout = if got_first_byte { std::time::Duration::from_secs(15) } else { std::time::Duration::from_secs(60) }; let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await { Ok(Ok(0)) => { upstream_ok = false; break; } Ok(Ok(n)) => n, Ok(Err(e)) => { debug!(domain, error = %e, "MITM: upstream read ended"); upstream_ok = false; break; } Err(_) => { if got_first_byte { debug!(domain, "MITM: response idle timeout (complete)"); } else { warn!(domain, "MITM: no upstream response in 60s"); } upstream_ok = false; break; } }; got_first_byte = true; let chunk = &tmp[..n]; if !headers_parsed { header_buf.extend_from_slice(chunk); if let Some(_hdr_end) = find_headers_end(&header_buf) { let mut resp_headers = [httparse::EMPTY_HEADER; 64]; let mut resp = httparse::Response::new(&mut resp_headers); let hdr_end = match resp.parse(&header_buf) { Ok(httparse::Status::Complete(n)) => n, _ => _hdr_end, }; let mut content_type = String::new(); for header in resp.headers.iter() { if header.name.eq_ignore_ascii_case("content-type") { if let Ok(v) = std::str::from_utf8(header.value) { content_type = v.to_string(); if v.contains("text/event-stream") { is_streaming_response = true; } } } if header.name.eq_ignore_ascii_case("content-length") { if let Ok(v) = std::str::from_utf8(header.value) { response_content_length = v.trim().parse().ok(); } } if header.name.eq_ignore_ascii_case("connection") { if let Ok(v) = std::str::from_utf8(header.value) { if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; } } } if header.name.eq_ignore_ascii_case("transfer-encoding") { if let Ok(v) = std::str::from_utf8(header.value) { if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; } } } } if is_streaming_response { info!(domain, content_type = %content_type, status = resp.code, "MITM: streaming response"); } else { debug!(domain, content_type = %content_type, status = resp.code, "MITM: response headers"); } headers_parsed = true; // Capture upstream errors for forwarding to client let http_status = resp.code.unwrap_or(0) as u16; if http_status >= 400 { let body_str = String::from_utf8_lossy(&header_buf[hdr_end..]).to_string(); warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response"); // Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}} let (message, error_status) = serde_json::from_str::(&body_str) .ok() .and_then(|v| { let err = v.get("error")?; let msg = err .get("message") .and_then(|m| m.as_str()) .map(|s| s.to_string()); let status = err .get("status") .and_then(|s| s.as_str()) .map(|s| s.to_string()); Some((msg, status)) }) .unwrap_or((None, None)); store .set_upstream_error(super::store::UpstreamError { status: http_status, body: body_str, message, error_status, }) .await; } // Save body for usage parsing response_body_buf.extend_from_slice(&header_buf[hdr_end..]); // Parse ORIGINAL initial body for MITM interception if is_streaming_response && hdr_end < header_buf.len() { let body = String::from_utf8_lossy(&header_buf[hdr_end..]); parse_streaming_chunk(&body, &mut streaming_acc); // Only write to store if our generation is still current. // If another completions turn started, our data is stale. let gen_valid = !won_gate || store.current_generation() == conn_generation; if gen_valid { // Store captured function calls (drain to avoid re-storing on next chunk) if !streaming_acc.function_calls.is_empty() { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); for fc in &calls { store .record_function_call(cascade_hint.as_deref(), fc.clone()) .await; } store.set_last_function_calls(calls.clone()).await; info!( "MITM: stored {} function call(s) from initial body", calls.len() ); } // Capture response + thinking text + grounding into MitmStore if !streaming_acc.response_text.is_empty() { store.set_response_text(&streaming_acc.response_text).await; } if !streaming_acc.thinking_text.is_empty() { store.set_thinking_text(&streaming_acc.thinking_text).await; } if let Some(ref gm) = streaming_acc.grounding_metadata { store.set_grounding(gm.clone()).await; } if streaming_acc.is_complete { info!( response_text_len = streaming_acc.response_text.len(), thinking_text_len = streaming_acc.thinking_text.len(), "MITM: response complete (initial body) — marking store" ); store.mark_response_complete(); } } else if streaming_acc.is_complete { debug!("MITM: skipping store write — generation stale (initial body)"); } } // Forward to client — rewrite function calls if custom tools are injected let forward_buf = if modify_requests { if let Some(modified) = super::modify::modify_response_chunk(&header_buf) { modified } else { header_buf.clone() } } else { header_buf.clone() }; if let Err(e) = client.write_all(&forward_buf).await { warn!(error = %e, "MITM: write to client failed"); break; } if let Some(cl) = response_content_length { if response_body_buf.len() >= cl { break; } } // Check chunked terminator in initial body if is_chunked && has_chunked_terminator(&response_body_buf) { debug!(domain, "MITM: chunked response complete (initial)"); break; } continue; } continue; } // ── Response body interception ──────────────────────────────── if is_streaming_response { let s = String::from_utf8_lossy(chunk); parse_streaming_chunk(&s, &mut streaming_acc); // Only write to store if our generation is still current. let gen_valid = !won_gate || store.current_generation() == conn_generation; if gen_valid { // Store captured function calls (drain to avoid re-storing on next chunk) if !streaming_acc.function_calls.is_empty() { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); for fc in &calls { store .record_function_call(cascade_hint.as_deref(), fc.clone()) .await; } store.set_last_function_calls(calls.clone()).await; info!( "MITM: stored {} function call(s) from body chunk", calls.len() ); } // Capture response + thinking text + grounding into MitmStore if !streaming_acc.response_text.is_empty() { store.set_response_text(&streaming_acc.response_text).await; } if !streaming_acc.thinking_text.is_empty() { store.set_thinking_text(&streaming_acc.thinking_text).await; } if let Some(ref gm) = streaming_acc.grounding_metadata { store.set_grounding(gm.clone()).await; } if streaming_acc.is_complete { info!( response_text_len = streaming_acc.response_text.len(), thinking_text_len = streaming_acc.thinking_text.len(), function_calls = streaming_acc.function_calls.len(), "MITM: response complete — marking store" ); store.mark_response_complete(); } } else if streaming_acc.is_complete { debug!("MITM: skipping store write — generation stale (body chunk)"); } } // Forward chunk to client (LS) — rewrite function calls if custom tools let forward_chunk = if modify_requests { if let Some(modified) = super::modify::modify_response_chunk(chunk) { modified } else { chunk.to_vec() } } else { chunk.to_vec() }; if let Err(e) = client.write_all(&forward_chunk).await { warn!(error = %e, "MITM: write to client failed"); break; } response_body_buf.extend_from_slice(chunk); if let Some(cl) = response_content_length { if response_body_buf.len() >= cl { break; } } if is_chunked && has_chunked_terminator(&response_body_buf) { debug!(domain, "MITM: chunked response complete"); break; } } // Flush client let _ = client.flush().await; // Capture usage data if is_streaming_response { // Store grounding metadata before consuming the accumulator if let Some(ref gm) = streaming_acc.grounding_metadata { store.set_grounding(gm.clone()).await; } if streaming_acc.is_complete || streaming_acc.output_tokens > 0 { // Function calls are stored immediately when detected (above), // so no need to store them again here. let usage = streaming_acc.into_usage(); store.record_usage(cascade_hint.as_deref(), usage).await; } } else if !response_body_buf.is_empty() { if let Some(usage) = parse_non_streaming_response(&response_body_buf) { store.record_usage(cascade_hint.as_deref(), usage).await; } } // If upstream closed, drop the connection so next iteration reconnects if !upstream_ok { upstream = None; } } // end keep-alive loop } /// Handle a passthrough connection: transparent TCP tunnel to upstream. async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> { trace!(domain, port, "MITM: transparent tunnel"); let mut upstream = TcpStream::connect(format!("{domain}:{port}")) .await .map_err(|e| format!("Connect to {domain}:{port}: {e}"))?; // Bidirectional copy match tokio::io::copy_bidirectional(&mut client, &mut upstream).await { Ok((client_to_server, server_to_client)) => { trace!( domain, client_to_server, server_to_client, "MITM: tunnel closed" ); } Err(e) => { trace!(domain, error = %e, "MITM: tunnel error (likely clean close)"); } } Ok(()) } /// Detect end of HTTP chunked transfer encoding. /// A chunked response ends with "0\r\n\r\n" (zero-length chunk + empty trailer). /// We check the tail of the buffer for this pattern. fn has_chunked_terminator(body: &[u8]) -> bool { // The minimal terminator is "0\r\n\r\n" (5 bytes) if body.len() < 5 { return false; } // Check last 7 bytes to account for possible trailing whitespace let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body }; // Look for \r\n0\r\n\r\n anywhere in the tail tail.windows(5).any(|w| w == b"0\r\n\r\n") } /// Check if buffer contains a complete HTTP request (headers + full body). /// Uses `httparse` for zero-copy, case-insensitive header parsing. fn has_complete_http_request(buf: &[u8]) -> bool { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = httparse::Request::new(&mut headers); let headers_end = match req.parse(buf) { Ok(httparse::Status::Complete(n)) => n, _ => return false, // Incomplete or parse error — need more data }; // Look for Content-Length for header in req.headers.iter() { if header.name.eq_ignore_ascii_case("content-length") { if let Ok(val) = std::str::from_utf8(header.value) { if let Ok(len) = val.trim().parse::() { return buf.len() >= headers_end + len; } } } if header.name.eq_ignore_ascii_case("transfer-encoding") { if let Ok(val) = std::str::from_utf8(header.value) { if val.trim().eq_ignore_ascii_case("chunked") { let body = &buf[headers_end..]; return body.len() >= 5 && body.ends_with(b"0\r\n\r\n"); } } } } // No Content-Length or Transfer-Encoding — no body expected (e.g., GET) true } /// Find the end of HTTP headers (position after \r\n\r\n). fn find_headers_end(buf: &[u8]) -> Option { buf.windows(4) .position(|w| w == b"\r\n\r\n") .map(|pos| pos + 4) } /// Parse HTTP request metadata from raw bytes using `httparse`. /// Returns (headers_end, content_length, is_streaming_request). fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) { let mut headers = [httparse::EMPTY_HEADER; 64]; let mut req = httparse::Request::new(&mut headers); let headers_end = match req.parse(buf) { Ok(httparse::Status::Complete(n)) => n, _ => { // Fallback if httparse can't parse return (find_headers_end(buf).unwrap_or(buf.len()), 0, false); } }; let mut content_length = 0usize; for header in req.headers.iter() { if header.name.eq_ignore_ascii_case("content-length") { if let Ok(val) = std::str::from_utf8(header.value) { content_length = val.trim().parse().unwrap_or(0); } } } // Check if request body asks for streaming let is_streaming = if headers_end < buf.len() { let body_str = String::from_utf8_lossy(&buf[headers_end..]); body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true") } else { false }; (headers_end, content_length, is_streaming) } /// Rewrite the Content-Length header in raw HTTP headers to match a new body size. /// If no Content-Length header is found, returns the headers unchanged. fn update_content_length(headers: &str, new_body_len: usize) -> String { use regex::Regex; let re = Regex::new(r"(?im)^content-length:\s*\d+").unwrap(); if re.is_match(headers) { re.replace(headers, format!("Content-Length: {new_body_len}")) .to_string() } else { headers.to_string() } }