Files
zerogravity/src/mitm/proxy.rs

1120 lines
44 KiB
Rust

//! MITM proxy server: handles CONNECT tunnels and transparent TLS interception.
//!
//! Supports two modes:
//! 1. **HTTP CONNECT** — standard proxy mode, LS sends `CONNECT host:port`
//! 2. **Transparent (iptables)** — raw TLS arrives via REDIRECT, SNI extracted from ClientHello
//!
//! For intercepted domains, terminates TLS with our CA-signed cert,
//! reads/modifies the request, forwards to the real upstream, and captures usage.
//!
//! For non-intercepted domains, acts as a transparent TCP tunnel.
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace, warn};
/// Domains we intercept (terminate TLS and inspect traffic).
/// This includes exact matches and suffix matches for regional endpoints
/// (e.g., us-central1-aiplatform.googleapis.com).
const INTERCEPT_DOMAINS: &[&str] = &[
"cloudcode-pa.googleapis.com",
"aiplatform.googleapis.com",
"api.anthropic.com",
"speech.googleapis.com",
"modelarmor.googleapis.com",
];
/// Domains we NEVER intercept (transparent tunnel).
const PASSTHROUGH_DOMAINS: &[&str] = &[
"oauth2.googleapis.com",
"accounts.google.com",
"storage.googleapis.com",
"www.googleapis.com",
"firebaseinstallations.googleapis.com",
"crashlyticsreports-pa.googleapis.com",
"play.googleapis.com",
"update.googleapis.com",
"dl.google.com",
];
/// Configuration for the MITM proxy.
#[derive(Default)]
pub struct MitmConfig {
/// Port to listen on (0 = auto-assign).
pub port: u16,
/// Whether to enable request modification.
pub modify_requests: bool,
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
pub async fn run(
ca: Arc<MitmCa>,
store: MitmStore,
config: MitmConfig,
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
.await
.map_err(|e| format!("MITM bind failed: {e}"))?;
let port = listener
.local_addr()
.map_err(|e| format!("MITM local_addr failed: {e}"))?
.port();
info!(port, "MITM proxy listening");
let modify_requests = config.modify_requests;
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, addr)) => {
trace!(?addr, "MITM: new connection");
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
{
warn!(error = %e, "MITM connection error");
}
});
}
Err(e) => {
error!(error = %e, "MITM accept error");
}
}
}
});
Ok((port, handle))
}
/// Handle a single incoming connection.
///
/// Supports two modes:
/// 1. **HTTP CONNECT** (standard proxy) — LS sends `CONNECT host:port HTTP/1.1`
/// 2. **Transparent/iptables redirect** — raw TLS ClientHello arrives directly
/// (first byte is 0x16). We extract the domain from SNI and intercept.
async fn handle_connection(
mut stream: TcpStream,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
// Peek at the first byte to distinguish CONNECT vs raw TLS
let mut peek = [0u8; 1];
let n = stream
.peek(&mut peek)
.await
.map_err(|e| format!("Peek failed: {e}"))?;
if n == 0 {
return Ok(());
}
if peek[0] == 0x16 {
// TLS ClientHello — transparent/iptables redirect mode.
// Peek enough bytes to extract SNI from the ClientHello.
let mut hello_buf = vec![0u8; 16384];
let n = stream
.peek(&mut hello_buf)
.await
.map_err(|e| format!("Peek ClientHello: {e}"))?;
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
info!(domain, "MITM: transparent redirect (iptables)");
let should_intercept = should_intercept_domain(&domain);
if should_intercept {
handle_intercepted(stream, &domain, ca, store, modify_requests).await
} else {
// For non-intercepted domains via iptables, we need the original dest.
// Since we only have SNI, resolve and passthrough.
handle_passthrough(stream, &domain, 443).await
}
} else {
// Standard HTTP CONNECT proxy mode
let mut buf = vec![0u8; 8192];
let n = stream
.read(&mut buf)
.await
.map_err(|e| format!("Read CONNECT: {e}"))?;
if n == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buf[..n]);
let first_line = request.lines().next().unwrap_or("");
// Parse "CONNECT host:port HTTP/1.1"
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "CONNECT" {
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await;
return Ok(());
}
let host_port = parts[1];
let (domain, _port) = match host_port.rsplit_once(':') {
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
None => (host_port, 443),
};
debug!(domain, "MITM: CONNECT request");
let should_intercept = should_intercept_domain(domain);
// Send 200 Connection Established
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
stream
.write_all(response.as_bytes())
.await
.map_err(|e| format!("Write 200: {e}"))?;
if should_intercept {
handle_intercepted(stream, domain, ca, store, modify_requests).await
} else {
handle_passthrough(stream, domain, _port).await
}
}
}
/// Extract SNI (Server Name Indication) from a TLS ClientHello message.
///
/// Parses the raw TLS record to find the `server_name` extension (type 0x0000).
/// Returns `None` if the SNI can't be found (not TLS, no SNI extension, etc.).
fn extract_sni(buf: &[u8]) -> Option<String> {
// TLS record: type(1) + version(2) + length(2) + handshake
if buf.len() < 5 || buf[0] != 0x16 {
return None;
}
let record_len = u16::from_be_bytes([buf[3], buf[4]]) as usize;
let handshake = &buf[5..5 + record_len.min(buf.len() - 5)];
// Handshake: type(1) + length(3) + body
if handshake.is_empty() || handshake[0] != 0x01 {
return None; // Not ClientHello
}
if handshake.len() < 4 {
return None;
}
let hs_len = u32::from_be_bytes([0, handshake[1], handshake[2], handshake[3]]) as usize;
let body = &handshake[4..4 + hs_len.min(handshake.len() - 4)];
// ClientHello: version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
if body.len() < 34 {
return None;
}
let mut pos = 34; // skip version + random
// Session ID
if pos >= body.len() {
return None;
}
let sid_len = body[pos] as usize;
pos += 1 + sid_len;
// Cipher suites
if pos + 2 > body.len() {
return None;
}
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2 + cs_len;
// Compression methods
if pos >= body.len() {
return None;
}
let cm_len = body[pos] as usize;
pos += 1 + cm_len;
// Extensions
if pos + 2 > body.len() {
return None;
}
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len.min(body.len() - pos);
while pos + 4 <= ext_end {
let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]);
let ext_data_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize;
pos += 4;
if ext_type == 0x0000 {
// SNI extension — server_name_list_len(2) + type(1) + name_len(2) + name
if ext_data_len >= 5 && pos + ext_data_len <= ext_end {
let name_len = u16::from_be_bytes([body[pos + 3], body[pos + 4]]) as usize;
if pos + 5 + name_len <= ext_end {
return String::from_utf8(body[pos + 5..pos + 5 + name_len].to_vec()).ok();
}
}
return None;
}
pos += ext_data_len;
}
None
}
/// Check if a domain should be intercepted.
fn should_intercept_domain(domain: &str) -> bool {
// Never intercept passthrough domains
for &pt in PASSTHROUGH_DOMAINS {
if domain == pt {
return false;
}
}
// Intercept known API domains (exact match, subdomain, or regional prefix)
for &intercept in INTERCEPT_DOMAINS {
if domain == intercept
|| domain.ends_with(&format!(".{intercept}"))
|| domain.ends_with(&format!("-{intercept}"))
{
return true;
}
}
// Default: passthrough
false
}
/// Handle an intercepted connection: terminate TLS, inspect traffic.
///
/// After TLS termination, checks the negotiated ALPN protocol:
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
async fn handle_intercepted(
stream: TcpStream,
domain: &str,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca.server_config_for_domain(domain).await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS) — 10s timeout
let tls_stream =
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!(
"TLS handshake with client failed for {domain}: {e}"
));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
// Check negotiated ALPN protocol
let alpn = tls_stream
.get_ref()
.1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
info!(domain, alpn = ?alpn, "MITM: TLS handshake successful ✓");
match alpn.as_deref() {
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
info!(domain, "MITM: routing to HTTP/1.1 handler");
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
}
}
}
/// Handle HTTP traffic over the decrypted TLS connection.
///
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
/// Reads full request, connects to upstream, forwards request, streams response
/// back to client while capturing usage data.
async fn handle_http_over_tls(
mut client: tokio_rustls::server::TlsStream<TcpStream>,
domain: &str,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
let mut tmp = vec![0u8; 32768];
// Build upstream TLS connector once for this connection
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let upstream_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
// Reusable upstream connection — created lazily, reconnected if stale
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS.
///
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
/// then falls back to cached IPs file, then to normal system resolution.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
// Try to resolve the real IP, bypassing /etc/hosts
let addr = resolve_upstream(domain).await;
info!(domain, addr = %addr, "MITM: connecting upstream");
let tcp = match tokio::time::timeout(
std::time::Duration::from_secs(15),
TcpStream::connect(&addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
};
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
match tokio::time::timeout(
std::time::Duration::from_secs(15),
connector.connect(server_name, tcp),
)
.await
{
Ok(Ok(s)) => {
info!(domain, "MITM: upstream TLS connected ✓");
Ok(s)
}
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
}
}
/// Resolve upstream IP bypassing /etc/hosts.
async fn resolve_upstream(domain: &str) -> String {
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
if let Ok(output) = tokio::process::Command::new("dig")
.args(["+short", "@8.8.8.8", domain])
.output()
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out
.lines()
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
{
return format!("{ip}:443");
}
}
// 2. Try cached IPs file (written by dns-redirect.sh install)
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
for line in contents.lines() {
if let Some((d, ip)) = line.split_once('=') {
if d == domain {
return format!("{ip}:443");
}
}
}
}
// 3. Fallback to normal resolution (may hit /etc/hosts)
format!("{domain}:443")
}
// Keep-alive loop: handle multiple requests on this connection
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
// 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
loop {
let read_result = if request_buf.is_empty() {
// First read — apply idle timeout
match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await {
Ok(r) => r,
Err(_) => {
// Idle timeout — connection pool warmup, no data sent
debug!(domain, "MITM: client idle timeout (60s), closing");
return Ok(());
}
}
} else {
// Subsequent reads — wait up to 30s for rest of request
match tokio::time::timeout(
std::time::Duration::from_secs(30),
client.read(&mut tmp),
)
.await
{
Ok(r) => r,
Err(_) => {
warn!(domain, "MITM: partial request read timed out");
return Err("Partial request read timed out".into());
}
}
};
let n = match read_result {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
// Connection reset / broken pipe is normal for keep-alive end
debug!(domain, error = %e, "MITM: client read finished");
return Ok(());
}
};
request_buf.extend_from_slice(&tmp[..n]);
// Check if we have the full request (headers + body)
if has_complete_http_request(&request_buf) {
break;
}
}
if request_buf.is_empty() {
return Ok(());
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, _is_streaming_request) =
parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
extract_cascade_hint(&request_buf[headers_end..])
} else {
None
};
// Extract request method and path for logging
let req_path = {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
match req.parse(&request_buf) {
Ok(httparse::Status::Complete(_)) => {
format!("{} {}", req.method.unwrap_or("?"), req.path.unwrap_or("?"))
}
_ => "?".to_string(),
}
};
// Generation tracking for store write guards
let mut won_gate = false;
let mut conn_generation = store.current_generation();
// Log LLM calls at info, everything else at debug
if req_path.contains("streamGenerateContent") {
let body_len = request_buf.len() - headers_end;
info!(
domain,
req_path = %req_path,
body_len,
cascade = ?cascade_hint,
"MITM: forwarding LLM request"
);
// ── Atomic in-flight gate ─────────────────────────────────
// The LS opens multiple connections and sends parallel requests.
// When custom tools are active, only the FIRST request wins the
// atomic compare_exchange. All others get fake STOP responses.
let has_tools = store.get_tools().await.is_some();
won_gate = if has_tools {
if !store.try_mark_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue;
}
true
} else {
false
};
// Snapshot the generation at gate-win time. If it changes later,
// another completions turn started and our data is stale.
conn_generation = store.current_generation();
// ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk
if modify_requests && body_len > 0 {
let body_slice = &request_buf[headers_end..];
let raw_body = super::modify::dechunk(body_slice);
// Only modify "agent" requests, not "checkpoint" (LS internal)
let body_str = String::from_utf8_lossy(&raw_body);
let is_agent = body_str.contains("\"requestType\":\"agent\"")
|| body_str.contains("\"requestType\": \"agent\"");
if is_agent {
// Build ToolContext from store
let tools = store.get_tools().await;
let tool_config = store.get_tool_config().await;
let pending_results = store.take_tool_results().await;
let last_calls = store.get_last_function_calls().await;
let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await;
let tool_ctx = if tools.is_some()
|| !pending_results.is_empty()
|| generation_params.is_some()
|| pending_image.is_some()
{
Some(super::modify::ToolContext {
tools,
tool_config,
pending_results,
last_calls,
generation_params,
pending_image,
})
} else {
None
};
if let Some(modified_body) =
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
{
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
let new_chunked = super::modify::rechunk(&modified_body);
// Fix Content-Length header to match new body size
let header_str = String::from_utf8_lossy(&request_buf[..headers_end]);
let updated_headers = update_content_length(&header_str, new_chunked.len());
let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf;
// In-flight already marked atomically above
}
}
}
} else {
debug!(
domain,
req_path = %req_path,
content_length,
"MITM: forwarding request"
);
}
// ── Ensure upstream connection is alive ──────────────────────────────
// Lazily connect on first request, or reconnect if the previous connection died
let conn = match upstream.as_mut() {
Some(c) => c,
None => {
let c = connect_upstream(domain, &upstream_config).await?;
upstream.insert(c)
}
};
// Forward the request — if write fails, reconnect and retry once
if let Err(e) = conn.write_all(&request_buf).await {
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
let c = connect_upstream(domain, &upstream_config).await?;
let conn = upstream.insert(c);
conn.write_all(&request_buf)
.await
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
}
let conn = upstream.as_mut().unwrap();
// ── Stream response back to client ──────────────────────────────────
// ALWAYS forward data to client immediately (no buffering).
// Buffer body on the side for usage parsing.
let mut streaming_acc = StreamingAccumulator::new();
let mut is_streaming_response = false;
let mut headers_parsed = false;
let mut upstream_ok = true;
let mut response_body_buf = Vec::new();
let mut response_content_length: Option<usize> = None;
let mut is_chunked = false;
let mut got_first_byte = false;
let mut header_buf = Vec::with_capacity(8192);
loop {
// 15s idle timeout after first byte, 60s for initial response
let timeout = if got_first_byte {
std::time::Duration::from_secs(15)
} else {
std::time::Duration::from_secs(60)
};
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
Ok(Ok(0)) => {
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read ended");
upstream_ok = false;
break;
}
Err(_) => {
if got_first_byte {
debug!(domain, "MITM: response idle timeout (complete)");
} else {
warn!(domain, "MITM: no upstream response in 60s");
}
upstream_ok = false;
break;
}
};
got_first_byte = true;
let chunk = &tmp[..n];
if !headers_parsed {
header_buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(&header_buf) {
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
let mut resp = httparse::Response::new(&mut resp_headers);
let hdr_end = match resp.parse(&header_buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => _hdr_end,
};
let mut content_type = String::new();
for header in resp.headers.iter() {
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(v) = std::str::from_utf8(header.value) {
content_type = v.to_string();
if v.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(v) = std::str::from_utf8(header.value) {
response_content_length = v.trim().parse().ok();
}
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("chunked") {
is_chunked = true;
}
}
}
}
if is_streaming_response {
info!(domain,
content_type = %content_type,
status = resp.code, "MITM: streaming response");
} else {
debug!(domain,
content_type = %content_type,
status = resp.code, "MITM: response headers");
}
headers_parsed = true;
// Capture upstream errors for forwarding to client
let http_status = resp.code.unwrap_or(0) as u16;
if http_status >= 400 {
let body_str = String::from_utf8_lossy(&header_buf[hdr_end..]).to_string();
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
let (message, error_status) =
serde_json::from_str::<serde_json::Value>(&body_str)
.ok()
.and_then(|v| {
let err = v.get("error")?;
let msg = err
.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let status = err
.get("status")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
store
.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
})
.await;
}
// Save body for usage parsing
response_body_buf.extend_from_slice(&header_buf[hdr_end..]);
// Parse ORIGINAL initial body for MITM interception
if is_streaming_response && hdr_end < header_buf.len() {
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc);
// Only write to store if our generation is still current.
// If another completions turn started, our data is stale.
let gen_valid = !won_gate || store.current_generation() == conn_generation;
if gen_valid {
// Store captured function calls (drain to avoid re-storing on next chunk)
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> =
streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!(
"MITM: stored {} function call(s) from initial body",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
"MITM: response complete (initial body) — marking store"
);
store.mark_response_complete();
}
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (initial body)");
}
}
// Forward to client — rewrite function calls if custom tools are injected
let forward_buf = if modify_requests {
if let Some(modified) = super::modify::modify_response_chunk(&header_buf) {
modified
} else {
header_buf.clone()
}
} else {
header_buf.clone()
};
if let Err(e) = client.write_all(&forward_buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl {
break;
}
}
// Check chunked terminator in initial body
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete (initial)");
break;
}
continue;
}
continue;
}
// ── Response body interception ────────────────────────────────
if is_streaming_response {
let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc);
// Only write to store if our generation is still current.
let gen_valid = !won_gate || store.current_generation() == conn_generation;
if gen_valid {
// Store captured function calls (drain to avoid re-storing on next chunk)
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!(
"MITM: stored {} function call(s) from body chunk",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
function_calls = streaming_acc.function_calls.len(),
"MITM: response complete — marking store"
);
store.mark_response_complete();
}
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (body chunk)");
}
}
// Forward chunk to client (LS) — rewrite function calls if custom tools
let forward_chunk = if modify_requests {
if let Some(modified) = super::modify::modify_response_chunk(chunk) {
modified
} else {
chunk.to_vec()
}
} else {
chunk.to_vec()
};
if let Err(e) = client.write_all(&forward_chunk).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl {
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete");
break;
}
}
// Flush client
let _ = client.flush().await;
// Capture usage data
if is_streaming_response {
// Store grounding metadata before consuming the accumulator
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
// Function calls are stored immediately when detected (above),
// so no need to store them again here.
let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if !response_body_buf.is_empty() {
if let Some(usage) = parse_non_streaming_response(&response_body_buf) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
// If upstream closed, drop the connection so next iteration reconnects
if !upstream_ok {
upstream = None;
}
} // end keep-alive loop
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
.await
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(
domain,
client_to_server,
server_to_client,
"MITM: tunnel closed"
);
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
}
}
Ok(())
}
/// Detect end of HTTP chunked transfer encoding.
/// A chunked response ends with "0\r\n\r\n" (zero-length chunk + empty trailer).
/// We check the tail of the buffer for this pattern.
fn has_chunked_terminator(body: &[u8]) -> bool {
// The minimal terminator is "0\r\n\r\n" (5 bytes)
if body.len() < 5 {
return false;
}
// Check last 7 bytes to account for possible trailing whitespace
let tail = if body.len() > 7 {
&body[body.len() - 7..]
} else {
body
};
// Look for \r\n0\r\n\r\n anywhere in the tail
tail.windows(5).any(|w| w == b"0\r\n\r\n")
}
/// Check if buffer contains a complete HTTP request (headers + full body).
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
fn has_complete_http_request(buf: &[u8]) -> bool {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => return false, // Incomplete or parse error — need more data
};
// Look for Content-Length
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
if let Ok(len) = val.trim().parse::<usize>() {
return buf.len() >= headers_end + len;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("chunked") {
let body = &buf[headers_end..];
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
}
}
}
}
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
true
}
/// Find the end of HTTP headers (position after \r\n\r\n).
fn find_headers_end(buf: &[u8]) -> Option<usize> {
buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|pos| pos + 4)
}
/// Parse HTTP request metadata from raw bytes using `httparse`.
/// Returns (headers_end, content_length, is_streaming_request).
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => {
// Fallback if httparse can't parse
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
}
};
let mut content_length = 0usize;
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
content_length = val.trim().parse().unwrap_or(0);
}
}
}
// Check if request body asks for streaming
let is_streaming = if headers_end < buf.len() {
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
} else {
false
};
(headers_end, content_length, is_streaming)
}
/// Rewrite the Content-Length header in raw HTTP headers to match a new body size.
/// If no Content-Length header is found, returns the headers unchanged.
fn update_content_length(headers: &str, new_body_len: usize) -> String {
use regex::Regex;
let re = Regex::new(r"(?im)^content-length:\s*\d+").unwrap();
if re.is_match(headers) {
re.replace(headers, format!("Content-Length: {new_body_len}"))
.to_string()
} else {
headers.to_string()
}
}