feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling

This commit is contained in:
Nikketryhard
2026-02-14 02:24:35 -06:00
commit d5e7f09225
30 changed files with 9980 additions and 0 deletions

591
src/mitm/proxy.rs Normal file
View File

@@ -0,0 +1,591 @@
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
//!
//! Listens on a local port for HTTP CONNECT requests from the LS.
//! For intercepted domains, it terminates TLS with our CA-signed cert,
//! reads/modifies the request, forwards to the real upstream, and captures
//! the response (especially usage data).
//!
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace, warn};
/// Domains we intercept (terminate TLS and inspect traffic).
/// This includes exact matches and suffix matches for regional endpoints
/// (e.g., us-central1-aiplatform.googleapis.com).
const INTERCEPT_DOMAINS: &[&str] = &[
"cloudcode-pa.googleapis.com",
"aiplatform.googleapis.com",
"api.anthropic.com",
"speech.googleapis.com",
"modelarmor.googleapis.com",
];
/// Domains we NEVER intercept (transparent tunnel).
const PASSTHROUGH_DOMAINS: &[&str] = &[
"oauth2.googleapis.com",
"accounts.google.com",
"storage.googleapis.com",
"www.googleapis.com",
"firebaseinstallations.googleapis.com",
"crashlyticsreports-pa.googleapis.com",
"play.googleapis.com",
"update.googleapis.com",
"dl.google.com",
];
/// Configuration for the MITM proxy.
pub struct MitmConfig {
/// Port to listen on (0 = auto-assign).
pub port: u16,
/// Whether to enable request modification.
pub modify_requests: bool,
}
impl Default for MitmConfig {
fn default() -> Self {
Self {
port: 0,
modify_requests: false,
}
}
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
pub async fn run(
ca: Arc<MitmCa>,
store: MitmStore,
config: MitmConfig,
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
.await
.map_err(|e| format!("MITM bind failed: {e}"))?;
let port = listener
.local_addr()
.map_err(|e| format!("MITM local_addr failed: {e}"))?
.port();
info!(port, "MITM proxy listening");
let modify_requests = config.modify_requests;
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, addr)) => {
trace!(?addr, "MITM: new connection");
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
debug!(error = %e, "MITM connection error");
}
});
}
Err(e) => {
error!(error = %e, "MITM accept error");
}
}
}
});
Ok((port, handle))
}
/// Handle a single incoming connection from the LS.
///
/// The LS sends an HTTP CONNECT request to establish a tunnel.
/// We then decide whether to intercept or passthrough.
async fn handle_connection(
mut stream: TcpStream,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
// Read the CONNECT request
let mut buf = vec![0u8; 8192];
let n = stream
.read(&mut buf)
.await
.map_err(|e| format!("Read CONNECT: {e}"))?;
if n == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buf[..n]);
let first_line = request.lines().next().unwrap_or("");
// Parse "CONNECT host:port HTTP/1.1"
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "CONNECT" {
// Not a CONNECT request — return 400
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await;
return Ok(());
}
let host_port = parts[1];
let (domain, _port) = match host_port.rsplit_once(':') {
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
None => (host_port, 443),
};
debug!(domain, "MITM: CONNECT request");
// Decide: intercept or passthrough
let should_intercept = should_intercept_domain(domain);
// Send 200 Connection Established
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
stream
.write_all(response.as_bytes())
.await
.map_err(|e| format!("Write 200: {e}"))?;
if should_intercept {
handle_intercepted(stream, domain, ca, store, modify_requests).await
} else {
handle_passthrough(stream, domain, _port).await
}
}
/// Check if a domain should be intercepted.
fn should_intercept_domain(domain: &str) -> bool {
// Never intercept passthrough domains
for &pt in PASSTHROUGH_DOMAINS {
if domain == pt {
return false;
}
}
// Intercept known API domains (exact match, subdomain, or regional prefix)
for &intercept in INTERCEPT_DOMAINS {
if domain == intercept
|| domain.ends_with(&format!(".{intercept}"))
|| domain.ends_with(&format!("-{intercept}"))
{
return true;
}
}
// Default: passthrough
false
}
/// Handle an intercepted connection: terminate TLS, inspect traffic.
///
/// After TLS termination, checks the negotiated ALPN protocol:
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
async fn handle_intercepted(
stream: TcpStream,
domain: &str,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca
.server_config_for_domain(domain)
.await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS)
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
match alpn.as_deref() {
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(
tls_stream,
domain.to_string(),
store,
)
.await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
debug!(domain, "MITM: routing to HTTP/1.1 handler");
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
}
}
}
/// Handle HTTP traffic over the decrypted TLS connection.
///
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
/// Reads full request, connects to upstream, forwards request, streams response
/// back to client while capturing usage data.
async fn handle_http_over_tls(
mut client: tokio_rustls::server::TlsStream<TcpStream>,
domain: &str,
store: MitmStore,
_modify_requests: bool,
) -> Result<(), String> {
let mut tmp = vec![0u8; 32768];
// Build upstream TLS connector once for this connection
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let upstream_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
// Reusable upstream connection — created lazily, reconnected if stale
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
let tcp = TcpStream::connect(format!("{domain}:443"))
.await
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
connector
.connect(server_name, tcp)
.await
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
}
// Keep-alive loop: handle multiple requests on this connection
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
loop {
let n = match client.read(&mut tmp).await {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
// Connection reset / broken pipe is normal for keep-alive end
debug!(domain, error = %e, "MITM: client read finished");
return Ok(());
}
};
request_buf.extend_from_slice(&tmp[..n]);
// Check if we have the full request (headers + body)
if has_complete_http_request(&request_buf) {
break;
}
}
if request_buf.is_empty() {
return Ok(());
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, is_streaming_request) = parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
extract_cascade_hint(&request_buf[headers_end..])
} else {
None
};
debug!(
domain,
content_length,
streaming = is_streaming_request,
cascade = ?cascade_hint,
"MITM: forwarding request to upstream"
);
// ── Ensure upstream connection is alive ──────────────────────────────
// Lazily connect on first request, or reconnect if the previous connection died
let conn = match upstream.as_mut() {
Some(c) => c,
None => {
let c = connect_upstream(domain, &upstream_config).await?;
upstream.insert(c)
}
};
// Forward the request — if write fails, reconnect and retry once
if let Err(e) = conn.write_all(&request_buf).await {
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
let c = connect_upstream(domain, &upstream_config).await?;
let conn = upstream.insert(c);
conn.write_all(&request_buf)
.await
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
}
let conn = upstream.as_mut().unwrap();
// ── Stream response back to client ──────────────────────────────────
let mut streaming_acc = StreamingAccumulator::new();
let mut is_streaming_response = false;
let mut headers_parsed = false;
// Only buffer response body for non-streaming (for usage parsing)
let mut non_streaming_buf: Option<Vec<u8>> = None;
// Track if upstream connection is still usable after this response
let mut upstream_ok = true;
// Per-request timeout: 5 minutes (covers large context API calls)
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
loop {
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
Ok(Ok(0)) => {
// Upstream closed — connection is no longer reusable
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read finished");
upstream_ok = false;
break;
}
Err(_) => {
warn!(domain, "MITM: upstream read timed out after 5 minutes");
upstream_ok = false;
break;
}
};
let chunk = &tmp[..n];
// Check response headers for content-type
if !headers_parsed {
// We need to buffer until we see the end of headers
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(buf) {
// Use httparse for response header parsing
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
let mut resp = httparse::Response::new(&mut resp_headers);
let hdr_end = match resp.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => _hdr_end, // Fallback to manual detection
};
// Detect content type and connection handling from parsed headers
for header in resp.headers.iter() {
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
}
headers_parsed = true;
if is_streaming_response {
// For streaming, parse any SSE data already in the buffer
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
if !body_so_far.is_empty() {
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
}
// Forward the accumulated buffer to client
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
non_streaming_buf = None;
continue;
}
// Non-streaming: keep buffering the response body for parsing
continue;
}
continue;
}
// If streaming, parse SSE events and forward immediately
if is_streaming_response {
let chunk_str = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
break;
}
} else {
// Non-streaming: keep accumulating to parse usage at the end
if let Some(ref mut buf) = non_streaming_buf {
buf.extend_from_slice(chunk);
}
}
}
// Forward non-streaming response all at once
if !is_streaming_response {
if let Some(ref buf) = non_streaming_buf {
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
}
}
}
// Capture usage data
if is_streaming_response {
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if let Some(ref buf) = non_streaming_buf {
if let Some(body_start) = find_headers_end(buf) {
let body = &buf[body_start..];
if let Some(usage) = parse_non_streaming_response(body) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
// If upstream closed, drop the connection so next iteration reconnects
if !upstream_ok {
upstream = None;
}
} // end keep-alive loop
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
.await
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
}
}
Ok(())
}
/// Check if buffer contains a complete HTTP request (headers + full body).
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
fn has_complete_http_request(buf: &[u8]) -> bool {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => return false, // Incomplete or parse error — need more data
};
// Look for Content-Length
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
if let Ok(len) = val.trim().parse::<usize>() {
return buf.len() >= headers_end + len;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("chunked") {
let body = &buf[headers_end..];
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
}
}
}
}
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
true
}
/// Find the end of HTTP headers (position after \r\n\r\n).
fn find_headers_end(buf: &[u8]) -> Option<usize> {
buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|pos| pos + 4)
}
/// Parse HTTP request metadata from raw bytes using `httparse`.
/// Returns (headers_end, content_length, is_streaming_request).
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => {
// Fallback if httparse can't parse
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
}
};
let mut content_length = 0usize;
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
content_length = val.trim().parse().unwrap_or(0);
}
}
}
// Check if request body asks for streaming
let is_streaming = if headers_end < buf.len() {
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
} else {
false
};
(headers_end, content_length, is_streaming)
}