1120 lines
44 KiB
Rust
1120 lines
44 KiB
Rust
//! MITM proxy server: handles CONNECT tunnels and transparent TLS interception.
|
|
//!
|
|
//! Supports two modes:
|
|
//! 1. **HTTP CONNECT** — standard proxy mode, LS sends `CONNECT host:port`
|
|
//! 2. **Transparent (iptables)** — raw TLS arrives via REDIRECT, SNI extracted from ClientHello
|
|
//!
|
|
//! For intercepted domains, terminates TLS with our CA-signed cert,
|
|
//! reads/modifies the request, forwards to the real upstream, and captures usage.
|
|
//!
|
|
//! For non-intercepted domains, acts as a transparent TCP tunnel.
|
|
|
|
use super::ca::MitmCa;
|
|
use super::intercept::{
|
|
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
|
|
};
|
|
use super::store::MitmStore;
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
use tokio_rustls::TlsAcceptor;
|
|
use tracing::{debug, error, info, trace, warn};
|
|
|
|
/// Domains we intercept (terminate TLS and inspect traffic).
|
|
/// This includes exact matches and suffix matches for regional endpoints
|
|
/// (e.g., us-central1-aiplatform.googleapis.com).
|
|
const INTERCEPT_DOMAINS: &[&str] = &[
|
|
"cloudcode-pa.googleapis.com",
|
|
"aiplatform.googleapis.com",
|
|
"api.anthropic.com",
|
|
"speech.googleapis.com",
|
|
"modelarmor.googleapis.com",
|
|
];
|
|
|
|
/// Domains we NEVER intercept (transparent tunnel).
|
|
const PASSTHROUGH_DOMAINS: &[&str] = &[
|
|
"oauth2.googleapis.com",
|
|
"accounts.google.com",
|
|
"storage.googleapis.com",
|
|
"www.googleapis.com",
|
|
"firebaseinstallations.googleapis.com",
|
|
"crashlyticsreports-pa.googleapis.com",
|
|
"play.googleapis.com",
|
|
"update.googleapis.com",
|
|
"dl.google.com",
|
|
];
|
|
|
|
/// Configuration for the MITM proxy.
|
|
#[derive(Default)]
|
|
pub struct MitmConfig {
|
|
/// Port to listen on (0 = auto-assign).
|
|
pub port: u16,
|
|
/// Whether to enable request modification.
|
|
pub modify_requests: bool,
|
|
}
|
|
|
|
/// Run the MITM proxy server.
|
|
///
|
|
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
|
pub async fn run(
|
|
ca: Arc<MitmCa>,
|
|
store: MitmStore,
|
|
config: MitmConfig,
|
|
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
|
|
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
|
|
.await
|
|
.map_err(|e| format!("MITM bind failed: {e}"))?;
|
|
|
|
let port = listener
|
|
.local_addr()
|
|
.map_err(|e| format!("MITM local_addr failed: {e}"))?
|
|
.port();
|
|
|
|
info!(port, "MITM proxy listening");
|
|
|
|
let modify_requests = config.modify_requests;
|
|
|
|
let handle = tokio::spawn(async move {
|
|
loop {
|
|
match listener.accept().await {
|
|
Ok((stream, addr)) => {
|
|
trace!(?addr, "MITM: new connection");
|
|
let ca = ca.clone();
|
|
let store = store.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
|
|
{
|
|
warn!(error = %e, "MITM connection error");
|
|
}
|
|
});
|
|
}
|
|
Err(e) => {
|
|
error!(error = %e, "MITM accept error");
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok((port, handle))
|
|
}
|
|
|
|
/// Handle a single incoming connection.
|
|
///
|
|
/// Supports two modes:
|
|
/// 1. **HTTP CONNECT** (standard proxy) — LS sends `CONNECT host:port HTTP/1.1`
|
|
/// 2. **Transparent/iptables redirect** — raw TLS ClientHello arrives directly
|
|
/// (first byte is 0x16). We extract the domain from SNI and intercept.
|
|
async fn handle_connection(
|
|
mut stream: TcpStream,
|
|
ca: Arc<MitmCa>,
|
|
store: MitmStore,
|
|
modify_requests: bool,
|
|
) -> Result<(), String> {
|
|
// Peek at the first byte to distinguish CONNECT vs raw TLS
|
|
let mut peek = [0u8; 1];
|
|
let n = stream
|
|
.peek(&mut peek)
|
|
.await
|
|
.map_err(|e| format!("Peek failed: {e}"))?;
|
|
|
|
if n == 0 {
|
|
return Ok(());
|
|
}
|
|
|
|
if peek[0] == 0x16 {
|
|
// TLS ClientHello — transparent/iptables redirect mode.
|
|
// Peek enough bytes to extract SNI from the ClientHello.
|
|
let mut hello_buf = vec![0u8; 16384];
|
|
let n = stream
|
|
.peek(&mut hello_buf)
|
|
.await
|
|
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
|
|
|
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
|
|
|
|
info!(domain, "MITM: transparent redirect (iptables)");
|
|
|
|
let should_intercept = should_intercept_domain(&domain);
|
|
if should_intercept {
|
|
handle_intercepted(stream, &domain, ca, store, modify_requests).await
|
|
} else {
|
|
// For non-intercepted domains via iptables, we need the original dest.
|
|
// Since we only have SNI, resolve and passthrough.
|
|
handle_passthrough(stream, &domain, 443).await
|
|
}
|
|
} else {
|
|
// Standard HTTP CONNECT proxy mode
|
|
let mut buf = vec![0u8; 8192];
|
|
let n = stream
|
|
.read(&mut buf)
|
|
.await
|
|
.map_err(|e| format!("Read CONNECT: {e}"))?;
|
|
|
|
if n == 0 {
|
|
return Ok(());
|
|
}
|
|
|
|
let request = String::from_utf8_lossy(&buf[..n]);
|
|
let first_line = request.lines().next().unwrap_or("");
|
|
|
|
// Parse "CONNECT host:port HTTP/1.1"
|
|
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
|
if parts.len() < 3 || parts[0] != "CONNECT" {
|
|
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
|
|
let _ = stream.write_all(resp.as_bytes()).await;
|
|
return Ok(());
|
|
}
|
|
|
|
let host_port = parts[1];
|
|
let (domain, _port) = match host_port.rsplit_once(':') {
|
|
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
|
|
None => (host_port, 443),
|
|
};
|
|
|
|
debug!(domain, "MITM: CONNECT request");
|
|
|
|
let should_intercept = should_intercept_domain(domain);
|
|
|
|
// Send 200 Connection Established
|
|
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
|
stream
|
|
.write_all(response.as_bytes())
|
|
.await
|
|
.map_err(|e| format!("Write 200: {e}"))?;
|
|
|
|
if should_intercept {
|
|
handle_intercepted(stream, domain, ca, store, modify_requests).await
|
|
} else {
|
|
handle_passthrough(stream, domain, _port).await
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Extract SNI (Server Name Indication) from a TLS ClientHello message.
|
|
///
|
|
/// Parses the raw TLS record to find the `server_name` extension (type 0x0000).
|
|
/// Returns `None` if the SNI can't be found (not TLS, no SNI extension, etc.).
|
|
fn extract_sni(buf: &[u8]) -> Option<String> {
|
|
// TLS record: type(1) + version(2) + length(2) + handshake
|
|
if buf.len() < 5 || buf[0] != 0x16 {
|
|
return None;
|
|
}
|
|
let record_len = u16::from_be_bytes([buf[3], buf[4]]) as usize;
|
|
let handshake = &buf[5..5 + record_len.min(buf.len() - 5)];
|
|
|
|
// Handshake: type(1) + length(3) + body
|
|
if handshake.is_empty() || handshake[0] != 0x01 {
|
|
return None; // Not ClientHello
|
|
}
|
|
if handshake.len() < 4 {
|
|
return None;
|
|
}
|
|
let hs_len = u32::from_be_bytes([0, handshake[1], handshake[2], handshake[3]]) as usize;
|
|
let body = &handshake[4..4 + hs_len.min(handshake.len() - 4)];
|
|
|
|
// ClientHello: version(2) + random(32) + session_id_len(1) + session_id(var)
|
|
// + cipher_suites_len(2) + cipher_suites(var)
|
|
// + compression_len(1) + compression(var)
|
|
// + extensions_len(2) + extensions(var)
|
|
if body.len() < 34 {
|
|
return None;
|
|
}
|
|
let mut pos = 34; // skip version + random
|
|
|
|
// Session ID
|
|
if pos >= body.len() {
|
|
return None;
|
|
}
|
|
let sid_len = body[pos] as usize;
|
|
pos += 1 + sid_len;
|
|
|
|
// Cipher suites
|
|
if pos + 2 > body.len() {
|
|
return None;
|
|
}
|
|
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
|
pos += 2 + cs_len;
|
|
|
|
// Compression methods
|
|
if pos >= body.len() {
|
|
return None;
|
|
}
|
|
let cm_len = body[pos] as usize;
|
|
pos += 1 + cm_len;
|
|
|
|
// Extensions
|
|
if pos + 2 > body.len() {
|
|
return None;
|
|
}
|
|
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
|
pos += 2;
|
|
let ext_end = pos + ext_len.min(body.len() - pos);
|
|
|
|
while pos + 4 <= ext_end {
|
|
let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]);
|
|
let ext_data_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize;
|
|
pos += 4;
|
|
|
|
if ext_type == 0x0000 {
|
|
// SNI extension — server_name_list_len(2) + type(1) + name_len(2) + name
|
|
if ext_data_len >= 5 && pos + ext_data_len <= ext_end {
|
|
let name_len = u16::from_be_bytes([body[pos + 3], body[pos + 4]]) as usize;
|
|
if pos + 5 + name_len <= ext_end {
|
|
return String::from_utf8(body[pos + 5..pos + 5 + name_len].to_vec()).ok();
|
|
}
|
|
}
|
|
return None;
|
|
}
|
|
|
|
pos += ext_data_len;
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
/// Check if a domain should be intercepted.
|
|
fn should_intercept_domain(domain: &str) -> bool {
|
|
// Never intercept passthrough domains
|
|
for &pt in PASSTHROUGH_DOMAINS {
|
|
if domain == pt {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Intercept known API domains (exact match, subdomain, or regional prefix)
|
|
for &intercept in INTERCEPT_DOMAINS {
|
|
if domain == intercept
|
|
|| domain.ends_with(&format!(".{intercept}"))
|
|
|| domain.ends_with(&format!("-{intercept}"))
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// Default: passthrough
|
|
false
|
|
}
|
|
|
|
/// Handle an intercepted connection: terminate TLS, inspect traffic.
|
|
///
|
|
/// After TLS termination, checks the negotiated ALPN protocol:
|
|
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
|
|
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
|
|
async fn handle_intercepted(
|
|
stream: TcpStream,
|
|
domain: &str,
|
|
ca: Arc<MitmCa>,
|
|
store: MitmStore,
|
|
modify_requests: bool,
|
|
) -> Result<(), String> {
|
|
info!(domain, "MITM: intercepting TLS");
|
|
|
|
// Get or create server TLS config for this domain
|
|
let server_config = ca.server_config_for_domain(domain).await?;
|
|
|
|
let acceptor = TlsAcceptor::from(server_config);
|
|
|
|
// Perform TLS handshake with the client (LS) — 10s timeout
|
|
let tls_stream =
|
|
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
|
|
.await
|
|
{
|
|
Ok(Ok(s)) => s,
|
|
Ok(Err(e)) => {
|
|
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
|
return Err(format!(
|
|
"TLS handshake with client failed for {domain}: {e}"
|
|
));
|
|
}
|
|
Err(_) => {
|
|
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
|
return Err(format!("TLS handshake timed out for {domain}"));
|
|
}
|
|
};
|
|
|
|
// Check negotiated ALPN protocol
|
|
let alpn = tls_stream
|
|
.get_ref()
|
|
.1
|
|
.alpn_protocol()
|
|
.map(|p| String::from_utf8_lossy(p).to_string());
|
|
|
|
info!(domain, alpn = ?alpn, "MITM: TLS handshake successful ✓");
|
|
|
|
match alpn.as_deref() {
|
|
Some("h2") => {
|
|
// HTTP/2 — use the hyper-based gRPC handler
|
|
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
|
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
|
|
}
|
|
_ => {
|
|
// HTTP/1.1 or no ALPN — use the existing handler
|
|
info!(domain, "MITM: routing to HTTP/1.1 handler");
|
|
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Handle HTTP traffic over the decrypted TLS connection.
|
|
///
|
|
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
|
|
/// Reads full request, connects to upstream, forwards request, streams response
|
|
/// back to client while capturing usage data.
|
|
async fn handle_http_over_tls(
|
|
mut client: tokio_rustls::server::TlsStream<TcpStream>,
|
|
domain: &str,
|
|
store: MitmStore,
|
|
modify_requests: bool,
|
|
) -> Result<(), String> {
|
|
let mut tmp = vec![0u8; 32768];
|
|
|
|
// Build upstream TLS connector once for this connection
|
|
let mut root_store = rustls::RootCertStore::empty();
|
|
let native_certs = rustls_native_certs::load_native_certs();
|
|
for cert in native_certs.certs {
|
|
let _ = root_store.add(cert);
|
|
}
|
|
let upstream_config = Arc::new(
|
|
rustls::ClientConfig::builder()
|
|
.with_root_certificates(root_store)
|
|
.with_no_client_auth(),
|
|
);
|
|
|
|
// Reusable upstream connection — created lazily, reconnected if stale
|
|
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
|
|
|
|
/// Connect (or reconnect) to the real upstream via TLS.
|
|
///
|
|
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
|
|
/// then falls back to cached IPs file, then to normal system resolution.
|
|
async fn connect_upstream(
|
|
domain: &str,
|
|
config: &Arc<rustls::ClientConfig>,
|
|
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
|
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
|
|
|
// Try to resolve the real IP, bypassing /etc/hosts
|
|
let addr = resolve_upstream(domain).await;
|
|
info!(domain, addr = %addr, "MITM: connecting upstream");
|
|
|
|
let tcp = match tokio::time::timeout(
|
|
std::time::Duration::from_secs(15),
|
|
TcpStream::connect(&addr),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(s)) => s,
|
|
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
|
|
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
|
|
};
|
|
|
|
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
|
.map_err(|e| format!("Invalid server name: {e}"))?;
|
|
|
|
match tokio::time::timeout(
|
|
std::time::Duration::from_secs(15),
|
|
connector.connect(server_name, tcp),
|
|
)
|
|
.await
|
|
{
|
|
Ok(Ok(s)) => {
|
|
info!(domain, "MITM: upstream TLS connected ✓");
|
|
Ok(s)
|
|
}
|
|
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
|
|
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
|
|
}
|
|
}
|
|
|
|
/// Resolve upstream IP bypassing /etc/hosts.
|
|
async fn resolve_upstream(domain: &str) -> String {
|
|
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
|
|
if let Ok(output) = tokio::process::Command::new("dig")
|
|
.args(["+short", "@8.8.8.8", domain])
|
|
.output()
|
|
.await
|
|
{
|
|
let out = String::from_utf8_lossy(&output.stdout);
|
|
if let Some(ip) = out
|
|
.lines()
|
|
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
|
|
{
|
|
return format!("{ip}:443");
|
|
}
|
|
}
|
|
|
|
// 2. Try cached IPs file (written by dns-redirect.sh install)
|
|
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
|
|
for line in contents.lines() {
|
|
if let Some((d, ip)) = line.split_once('=') {
|
|
if d == domain {
|
|
return format!("{ip}:443");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// 3. Fallback to normal resolution (may hit /etc/hosts)
|
|
format!("{domain}:443")
|
|
}
|
|
|
|
// Keep-alive loop: handle multiple requests on this connection
|
|
loop {
|
|
// ── Read the HTTP request from the client ─────────────────────────
|
|
let mut request_buf = Vec::with_capacity(1024 * 64);
|
|
|
|
// 60s timeout on initial read (LS may open connection without sending immediately)
|
|
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
|
|
|
loop {
|
|
let read_result = if request_buf.is_empty() {
|
|
// First read — apply idle timeout
|
|
match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await {
|
|
Ok(r) => r,
|
|
Err(_) => {
|
|
// Idle timeout — connection pool warmup, no data sent
|
|
debug!(domain, "MITM: client idle timeout (60s), closing");
|
|
return Ok(());
|
|
}
|
|
}
|
|
} else {
|
|
// Subsequent reads — wait up to 30s for rest of request
|
|
match tokio::time::timeout(
|
|
std::time::Duration::from_secs(30),
|
|
client.read(&mut tmp),
|
|
)
|
|
.await
|
|
{
|
|
Ok(r) => r,
|
|
Err(_) => {
|
|
warn!(domain, "MITM: partial request read timed out");
|
|
return Err("Partial request read timed out".into());
|
|
}
|
|
}
|
|
};
|
|
|
|
let n = match read_result {
|
|
Ok(0) => return Ok(()), // Client closed connection cleanly
|
|
Ok(n) => n,
|
|
Err(e) => {
|
|
// Connection reset / broken pipe is normal for keep-alive end
|
|
debug!(domain, error = %e, "MITM: client read finished");
|
|
return Ok(());
|
|
}
|
|
};
|
|
|
|
request_buf.extend_from_slice(&tmp[..n]);
|
|
|
|
// Check if we have the full request (headers + body)
|
|
if has_complete_http_request(&request_buf) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if request_buf.is_empty() {
|
|
return Ok(());
|
|
}
|
|
|
|
// Parse the HTTP request to find headers and body
|
|
let (headers_end, content_length, _is_streaming_request) =
|
|
parse_http_request_meta(&request_buf);
|
|
|
|
// Try to extract cascade hint from request body
|
|
let cascade_hint = if headers_end < request_buf.len() {
|
|
extract_cascade_hint(&request_buf[headers_end..])
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Extract request method and path for logging
|
|
let req_path = {
|
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
|
let mut req = httparse::Request::new(&mut headers);
|
|
match req.parse(&request_buf) {
|
|
Ok(httparse::Status::Complete(_)) => {
|
|
format!("{} {}", req.method.unwrap_or("?"), req.path.unwrap_or("?"))
|
|
}
|
|
_ => "?".to_string(),
|
|
}
|
|
};
|
|
|
|
// Generation tracking for store write guards
|
|
let mut won_gate = false;
|
|
let mut conn_generation = store.current_generation();
|
|
|
|
// Log LLM calls at info, everything else at debug
|
|
if req_path.contains("streamGenerateContent") {
|
|
let body_len = request_buf.len() - headers_end;
|
|
info!(
|
|
domain,
|
|
req_path = %req_path,
|
|
body_len,
|
|
cascade = ?cascade_hint,
|
|
"MITM: forwarding LLM request"
|
|
);
|
|
|
|
// ── Atomic in-flight gate ─────────────────────────────────
|
|
// The LS opens multiple connections and sends parallel requests.
|
|
// When custom tools are active, only the FIRST request wins the
|
|
// atomic compare_exchange. All others get fake STOP responses.
|
|
let has_tools = store.get_tools().await.is_some();
|
|
won_gate = if has_tools {
|
|
if !store.try_mark_request_in_flight() {
|
|
info!("MITM: blocking LS request — another request already in-flight");
|
|
let fake_response = "HTTP/1.1 200 OK\r\n\
|
|
Content-Type: text/event-stream\r\n\
|
|
Transfer-Encoding: chunked\r\n\
|
|
\r\n";
|
|
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
|
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
|
let mut response = fake_response.as_bytes().to_vec();
|
|
response.extend_from_slice(&chunked_body);
|
|
if let Err(e) = client.write_all(&response).await {
|
|
warn!(error = %e, "MITM: failed to write fake response");
|
|
}
|
|
let _ = client.flush().await;
|
|
continue;
|
|
}
|
|
true
|
|
} else {
|
|
false
|
|
};
|
|
// Snapshot the generation at gate-win time. If it changes later,
|
|
// another completions turn started and our data is stale.
|
|
conn_generation = store.current_generation();
|
|
|
|
// ── Request modification ─────────────────────────────────────
|
|
// Dechunk body → check if agent request → modify → rechunk
|
|
if modify_requests && body_len > 0 {
|
|
let body_slice = &request_buf[headers_end..];
|
|
let raw_body = super::modify::dechunk(body_slice);
|
|
|
|
// Only modify "agent" requests, not "checkpoint" (LS internal)
|
|
let body_str = String::from_utf8_lossy(&raw_body);
|
|
let is_agent = body_str.contains("\"requestType\":\"agent\"")
|
|
|| body_str.contains("\"requestType\": \"agent\"");
|
|
|
|
if is_agent {
|
|
// Build ToolContext from store
|
|
let tools = store.get_tools().await;
|
|
let tool_config = store.get_tool_config().await;
|
|
let pending_results = store.take_tool_results().await;
|
|
let last_calls = store.get_last_function_calls().await;
|
|
let generation_params = store.get_generation_params().await;
|
|
let pending_image = store.take_pending_image().await;
|
|
|
|
let tool_ctx = if tools.is_some()
|
|
|| !pending_results.is_empty()
|
|
|| generation_params.is_some()
|
|
|| pending_image.is_some()
|
|
{
|
|
Some(super::modify::ToolContext {
|
|
tools,
|
|
tool_config,
|
|
pending_results,
|
|
last_calls,
|
|
generation_params,
|
|
pending_image,
|
|
})
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if let Some(modified_body) =
|
|
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
|
|
{
|
|
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
|
|
let new_chunked = super::modify::rechunk(&modified_body);
|
|
|
|
// Fix Content-Length header to match new body size
|
|
let header_str = String::from_utf8_lossy(&request_buf[..headers_end]);
|
|
let updated_headers = update_content_length(&header_str, new_chunked.len());
|
|
let mut new_buf = updated_headers.into_bytes();
|
|
new_buf.extend_from_slice(&new_chunked);
|
|
request_buf = new_buf;
|
|
|
|
// In-flight already marked atomically above
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
debug!(
|
|
domain,
|
|
req_path = %req_path,
|
|
content_length,
|
|
"MITM: forwarding request"
|
|
);
|
|
}
|
|
|
|
// ── Ensure upstream connection is alive ──────────────────────────────
|
|
// Lazily connect on first request, or reconnect if the previous connection died
|
|
let conn = match upstream.as_mut() {
|
|
Some(c) => c,
|
|
None => {
|
|
let c = connect_upstream(domain, &upstream_config).await?;
|
|
upstream.insert(c)
|
|
}
|
|
};
|
|
|
|
// Forward the request — if write fails, reconnect and retry once
|
|
if let Err(e) = conn.write_all(&request_buf).await {
|
|
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
|
|
let c = connect_upstream(domain, &upstream_config).await?;
|
|
let conn = upstream.insert(c);
|
|
conn.write_all(&request_buf)
|
|
.await
|
|
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
|
|
}
|
|
|
|
let conn = upstream.as_mut().unwrap();
|
|
|
|
// ── Stream response back to client ──────────────────────────────────
|
|
// ALWAYS forward data to client immediately (no buffering).
|
|
// Buffer body on the side for usage parsing.
|
|
let mut streaming_acc = StreamingAccumulator::new();
|
|
let mut is_streaming_response = false;
|
|
let mut headers_parsed = false;
|
|
let mut upstream_ok = true;
|
|
let mut response_body_buf = Vec::new();
|
|
let mut response_content_length: Option<usize> = None;
|
|
let mut is_chunked = false;
|
|
let mut got_first_byte = false;
|
|
let mut header_buf = Vec::with_capacity(8192);
|
|
|
|
loop {
|
|
// 15s idle timeout after first byte, 60s for initial response
|
|
let timeout = if got_first_byte {
|
|
std::time::Duration::from_secs(15)
|
|
} else {
|
|
std::time::Duration::from_secs(60)
|
|
};
|
|
|
|
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
|
|
Ok(Ok(0)) => {
|
|
upstream_ok = false;
|
|
break;
|
|
}
|
|
Ok(Ok(n)) => n,
|
|
Ok(Err(e)) => {
|
|
debug!(domain, error = %e, "MITM: upstream read ended");
|
|
upstream_ok = false;
|
|
break;
|
|
}
|
|
Err(_) => {
|
|
if got_first_byte {
|
|
debug!(domain, "MITM: response idle timeout (complete)");
|
|
} else {
|
|
warn!(domain, "MITM: no upstream response in 60s");
|
|
}
|
|
upstream_ok = false;
|
|
break;
|
|
}
|
|
};
|
|
|
|
got_first_byte = true;
|
|
let chunk = &tmp[..n];
|
|
|
|
if !headers_parsed {
|
|
header_buf.extend_from_slice(chunk);
|
|
if let Some(_hdr_end) = find_headers_end(&header_buf) {
|
|
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
|
|
let mut resp = httparse::Response::new(&mut resp_headers);
|
|
let hdr_end = match resp.parse(&header_buf) {
|
|
Ok(httparse::Status::Complete(n)) => n,
|
|
_ => _hdr_end,
|
|
};
|
|
|
|
let mut content_type = String::new();
|
|
|
|
for header in resp.headers.iter() {
|
|
if header.name.eq_ignore_ascii_case("content-type") {
|
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
|
content_type = v.to_string();
|
|
if v.contains("text/event-stream") {
|
|
is_streaming_response = true;
|
|
}
|
|
}
|
|
}
|
|
if header.name.eq_ignore_ascii_case("content-length") {
|
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
|
response_content_length = v.trim().parse().ok();
|
|
}
|
|
}
|
|
if header.name.eq_ignore_ascii_case("connection") {
|
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
|
if v.trim().eq_ignore_ascii_case("close") {
|
|
upstream_ok = false;
|
|
}
|
|
}
|
|
}
|
|
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
|
if v.trim().eq_ignore_ascii_case("chunked") {
|
|
is_chunked = true;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if is_streaming_response {
|
|
info!(domain,
|
|
content_type = %content_type,
|
|
status = resp.code, "MITM: streaming response");
|
|
} else {
|
|
debug!(domain,
|
|
content_type = %content_type,
|
|
status = resp.code, "MITM: response headers");
|
|
}
|
|
headers_parsed = true;
|
|
|
|
// Capture upstream errors for forwarding to client
|
|
let http_status = resp.code.unwrap_or(0) as u16;
|
|
if http_status >= 400 {
|
|
let body_str = String::from_utf8_lossy(&header_buf[hdr_end..]).to_string();
|
|
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
|
|
|
|
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
|
|
let (message, error_status) =
|
|
serde_json::from_str::<serde_json::Value>(&body_str)
|
|
.ok()
|
|
.and_then(|v| {
|
|
let err = v.get("error")?;
|
|
let msg = err
|
|
.get("message")
|
|
.and_then(|m| m.as_str())
|
|
.map(|s| s.to_string());
|
|
let status = err
|
|
.get("status")
|
|
.and_then(|s| s.as_str())
|
|
.map(|s| s.to_string());
|
|
Some((msg, status))
|
|
})
|
|
.unwrap_or((None, None));
|
|
|
|
store
|
|
.set_upstream_error(super::store::UpstreamError {
|
|
status: http_status,
|
|
body: body_str,
|
|
message,
|
|
error_status,
|
|
})
|
|
.await;
|
|
}
|
|
|
|
// Save body for usage parsing
|
|
response_body_buf.extend_from_slice(&header_buf[hdr_end..]);
|
|
|
|
// Parse ORIGINAL initial body for MITM interception
|
|
if is_streaming_response && hdr_end < header_buf.len() {
|
|
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
|
parse_streaming_chunk(&body, &mut streaming_acc);
|
|
|
|
// Only write to store if our generation is still current.
|
|
// If another completions turn started, our data is stale.
|
|
let gen_valid = !won_gate || store.current_generation() == conn_generation;
|
|
if gen_valid {
|
|
// Store captured function calls (drain to avoid re-storing on next chunk)
|
|
if !streaming_acc.function_calls.is_empty() {
|
|
let calls: Vec<_> =
|
|
streaming_acc.function_calls.drain(..).collect();
|
|
for fc in &calls {
|
|
store
|
|
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
|
.await;
|
|
}
|
|
store.set_last_function_calls(calls.clone()).await;
|
|
info!(
|
|
"MITM: stored {} function call(s) from initial body",
|
|
calls.len()
|
|
);
|
|
}
|
|
|
|
// Capture response + thinking text + grounding into MitmStore
|
|
if !streaming_acc.response_text.is_empty() {
|
|
store.set_response_text(&streaming_acc.response_text).await;
|
|
}
|
|
if !streaming_acc.thinking_text.is_empty() {
|
|
store.set_thinking_text(&streaming_acc.thinking_text).await;
|
|
}
|
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
|
store.set_grounding(gm.clone()).await;
|
|
}
|
|
if streaming_acc.is_complete {
|
|
info!(
|
|
response_text_len = streaming_acc.response_text.len(),
|
|
thinking_text_len = streaming_acc.thinking_text.len(),
|
|
"MITM: response complete (initial body) — marking store"
|
|
);
|
|
store.mark_response_complete();
|
|
}
|
|
} else if streaming_acc.is_complete {
|
|
debug!("MITM: skipping store write — generation stale (initial body)");
|
|
}
|
|
}
|
|
|
|
// Forward to client — rewrite function calls if custom tools are injected
|
|
let forward_buf = if modify_requests {
|
|
if let Some(modified) = super::modify::modify_response_chunk(&header_buf) {
|
|
modified
|
|
} else {
|
|
header_buf.clone()
|
|
}
|
|
} else {
|
|
header_buf.clone()
|
|
};
|
|
if let Err(e) = client.write_all(&forward_buf).await {
|
|
warn!(error = %e, "MITM: write to client failed");
|
|
break;
|
|
}
|
|
|
|
if let Some(cl) = response_content_length {
|
|
if response_body_buf.len() >= cl {
|
|
break;
|
|
}
|
|
}
|
|
// Check chunked terminator in initial body
|
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
|
debug!(domain, "MITM: chunked response complete (initial)");
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
continue;
|
|
}
|
|
|
|
// ── Response body interception ────────────────────────────────
|
|
if is_streaming_response {
|
|
let s = String::from_utf8_lossy(chunk);
|
|
parse_streaming_chunk(&s, &mut streaming_acc);
|
|
|
|
// Only write to store if our generation is still current.
|
|
let gen_valid = !won_gate || store.current_generation() == conn_generation;
|
|
if gen_valid {
|
|
// Store captured function calls (drain to avoid re-storing on next chunk)
|
|
if !streaming_acc.function_calls.is_empty() {
|
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
|
for fc in &calls {
|
|
store
|
|
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
|
.await;
|
|
}
|
|
store.set_last_function_calls(calls.clone()).await;
|
|
info!(
|
|
"MITM: stored {} function call(s) from body chunk",
|
|
calls.len()
|
|
);
|
|
}
|
|
|
|
// Capture response + thinking text + grounding into MitmStore
|
|
if !streaming_acc.response_text.is_empty() {
|
|
store.set_response_text(&streaming_acc.response_text).await;
|
|
}
|
|
if !streaming_acc.thinking_text.is_empty() {
|
|
store.set_thinking_text(&streaming_acc.thinking_text).await;
|
|
}
|
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
|
store.set_grounding(gm.clone()).await;
|
|
}
|
|
if streaming_acc.is_complete {
|
|
info!(
|
|
response_text_len = streaming_acc.response_text.len(),
|
|
thinking_text_len = streaming_acc.thinking_text.len(),
|
|
function_calls = streaming_acc.function_calls.len(),
|
|
"MITM: response complete — marking store"
|
|
);
|
|
store.mark_response_complete();
|
|
}
|
|
} else if streaming_acc.is_complete {
|
|
debug!("MITM: skipping store write — generation stale (body chunk)");
|
|
}
|
|
}
|
|
|
|
// Forward chunk to client (LS) — rewrite function calls if custom tools
|
|
let forward_chunk = if modify_requests {
|
|
if let Some(modified) = super::modify::modify_response_chunk(chunk) {
|
|
modified
|
|
} else {
|
|
chunk.to_vec()
|
|
}
|
|
} else {
|
|
chunk.to_vec()
|
|
};
|
|
if let Err(e) = client.write_all(&forward_chunk).await {
|
|
warn!(error = %e, "MITM: write to client failed");
|
|
break;
|
|
}
|
|
response_body_buf.extend_from_slice(chunk);
|
|
|
|
if let Some(cl) = response_content_length {
|
|
if response_body_buf.len() >= cl {
|
|
break;
|
|
}
|
|
}
|
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
|
debug!(domain, "MITM: chunked response complete");
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Flush client
|
|
let _ = client.flush().await;
|
|
|
|
// Capture usage data
|
|
if is_streaming_response {
|
|
// Store grounding metadata before consuming the accumulator
|
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
|
store.set_grounding(gm.clone()).await;
|
|
}
|
|
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
|
// Function calls are stored immediately when detected (above),
|
|
// so no need to store them again here.
|
|
let usage = streaming_acc.into_usage();
|
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
|
}
|
|
} else if !response_body_buf.is_empty() {
|
|
if let Some(usage) = parse_non_streaming_response(&response_body_buf) {
|
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
|
}
|
|
}
|
|
|
|
// If upstream closed, drop the connection so next iteration reconnects
|
|
if !upstream_ok {
|
|
upstream = None;
|
|
}
|
|
} // end keep-alive loop
|
|
}
|
|
|
|
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
|
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
|
|
trace!(domain, port, "MITM: transparent tunnel");
|
|
|
|
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
|
.await
|
|
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
|
|
|
|
// Bidirectional copy
|
|
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
|
Ok((client_to_server, server_to_client)) => {
|
|
trace!(
|
|
domain,
|
|
client_to_server,
|
|
server_to_client,
|
|
"MITM: tunnel closed"
|
|
);
|
|
}
|
|
Err(e) => {
|
|
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Detect end of HTTP chunked transfer encoding.
|
|
/// A chunked response ends with "0\r\n\r\n" (zero-length chunk + empty trailer).
|
|
/// We check the tail of the buffer for this pattern.
|
|
fn has_chunked_terminator(body: &[u8]) -> bool {
|
|
// The minimal terminator is "0\r\n\r\n" (5 bytes)
|
|
if body.len() < 5 {
|
|
return false;
|
|
}
|
|
// Check last 7 bytes to account for possible trailing whitespace
|
|
let tail = if body.len() > 7 {
|
|
&body[body.len() - 7..]
|
|
} else {
|
|
body
|
|
};
|
|
// Look for \r\n0\r\n\r\n anywhere in the tail
|
|
tail.windows(5).any(|w| w == b"0\r\n\r\n")
|
|
}
|
|
|
|
/// Check if buffer contains a complete HTTP request (headers + full body).
|
|
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
|
|
fn has_complete_http_request(buf: &[u8]) -> bool {
|
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
|
let mut req = httparse::Request::new(&mut headers);
|
|
|
|
let headers_end = match req.parse(buf) {
|
|
Ok(httparse::Status::Complete(n)) => n,
|
|
_ => return false, // Incomplete or parse error — need more data
|
|
};
|
|
|
|
// Look for Content-Length
|
|
for header in req.headers.iter() {
|
|
if header.name.eq_ignore_ascii_case("content-length") {
|
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
|
if let Ok(len) = val.trim().parse::<usize>() {
|
|
return buf.len() >= headers_end + len;
|
|
}
|
|
}
|
|
}
|
|
|
|
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
|
if val.trim().eq_ignore_ascii_case("chunked") {
|
|
let body = &buf[headers_end..];
|
|
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
|
|
true
|
|
}
|
|
|
|
/// Find the end of HTTP headers (position after \r\n\r\n).
|
|
fn find_headers_end(buf: &[u8]) -> Option<usize> {
|
|
buf.windows(4)
|
|
.position(|w| w == b"\r\n\r\n")
|
|
.map(|pos| pos + 4)
|
|
}
|
|
|
|
/// Parse HTTP request metadata from raw bytes using `httparse`.
|
|
/// Returns (headers_end, content_length, is_streaming_request).
|
|
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
|
|
let mut headers = [httparse::EMPTY_HEADER; 64];
|
|
let mut req = httparse::Request::new(&mut headers);
|
|
|
|
let headers_end = match req.parse(buf) {
|
|
Ok(httparse::Status::Complete(n)) => n,
|
|
_ => {
|
|
// Fallback if httparse can't parse
|
|
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
|
|
}
|
|
};
|
|
|
|
let mut content_length = 0usize;
|
|
|
|
for header in req.headers.iter() {
|
|
if header.name.eq_ignore_ascii_case("content-length") {
|
|
if let Ok(val) = std::str::from_utf8(header.value) {
|
|
content_length = val.trim().parse().unwrap_or(0);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check if request body asks for streaming
|
|
let is_streaming = if headers_end < buf.len() {
|
|
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
|
|
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
|
|
} else {
|
|
false
|
|
};
|
|
|
|
(headers_end, content_length, is_streaming)
|
|
}
|
|
|
|
/// Rewrite the Content-Length header in raw HTTP headers to match a new body size.
|
|
/// If no Content-Length header is found, returns the headers unchanged.
|
|
fn update_content_length(headers: &str, new_body_len: usize) -> String {
|
|
use regex::Regex;
|
|
let re = Regex::new(r"(?im)^content-length:\s*\d+").unwrap();
|
|
if re.is_match(headers) {
|
|
re.replace(headers, format!("Content-Length: {new_body_len}"))
|
|
.to_string()
|
|
} else {
|
|
headers.to_string()
|
|
}
|
|
}
|