feat: MITM interception for standalone LS with UID isolation

- Spawn standalone LS as dedicated 'antigravity-ls' user via sudo
- UID-scoped iptables redirect (port 443 → MITM proxy) via mitm-redirect.sh
- Combined CA bundle (system CAs + MITM CA) for Go TLS trust
- Transparent TLS interception with chunked response detection
- Google SSE parser for streamGenerateContent usage extraction
- Timeouts on all MITM operations (TLS handshake, upstream, idle)
- Forward response data immediately (no buffering)
- Per-model token usage capture (input, output, thinking)
- Update docs and known issues to reflect resolved TLS blocker
This commit is contained in:
Nikketryhard
2026-02-14 17:50:12 -06:00
parent 6842bfeaa5
commit d4de436856
10 changed files with 1156 additions and 478 deletions

View File

@@ -85,7 +85,7 @@ pub async fn run(
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
debug!(error = %e, "MITM connection error");
warn!(error = %e, "MITM connection error");
}
});
}
@@ -310,18 +310,30 @@ async fn handle_intercepted(
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS)
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
// Perform TLS handshake with the client (LS) — 10s timeout
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_secs(10),
acceptor.accept(stream),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
info!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
match alpn.as_deref() {
Some("h2") => {
@@ -336,7 +348,7 @@ async fn handle_intercepted(
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
debug!(domain, "MITM: routing to HTTP/1.1 handler");
info!(domain, "MITM: routing to HTTP/1.1 handler");
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
}
}
@@ -382,16 +394,35 @@ async fn handle_http_over_tls(
// Try to resolve the real IP, bypassing /etc/hosts
let addr = resolve_upstream(domain).await;
info!(domain, addr = %addr, "MITM: connecting upstream");
let tcp = match tokio::time::timeout(
std::time::Duration::from_secs(15),
TcpStream::connect(&addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
};
let tcp = TcpStream::connect(addr)
.await
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
connector
.connect(server_name, tcp)
.await
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
match tokio::time::timeout(
std::time::Duration::from_secs(15),
connector.connect(server_name, tcp),
)
.await
{
Ok(Ok(s)) => {
info!(domain, "MITM: upstream TLS connected ✓");
Ok(s)
}
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
}
}
/// Resolve upstream IP bypassing /etc/hosts.
@@ -428,8 +459,37 @@ async fn handle_http_over_tls(
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
// 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
loop {
let n = match client.read(&mut tmp).await {
let read_result = if request_buf.is_empty() {
// First read — apply idle timeout
match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await {
Ok(r) => r,
Err(_) => {
// Idle timeout — connection pool warmup, no data sent
debug!(domain, "MITM: client idle timeout (60s), closing");
return Ok(());
}
}
} else {
// Subsequent reads — wait up to 30s for rest of request
match tokio::time::timeout(
std::time::Duration::from_secs(30),
client.read(&mut tmp),
)
.await
{
Ok(r) => r,
Err(_) => {
warn!(domain, "MITM: partial request read timed out");
return Err("Partial request read timed out".into());
}
}
};
let n = match read_result {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
@@ -461,12 +521,25 @@ async fn handle_http_over_tls(
None
};
debug!(
// Extract request method and path for logging
let req_path = {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
match req.parse(&request_buf) {
Ok(httparse::Status::Complete(_)) => {
format!("{} {}", req.method.unwrap_or("?"), req.path.unwrap_or("?"))
}
_ => "?".to_string(),
}
};
info!(
domain,
req_path = %req_path,
content_length,
streaming = is_streaming_request,
cascade = ?cascade_hint,
"MITM: forwarding request to upstream"
"MITM: forwarding request"
);
// ── Ensure upstream connection is alive ──────────────────────────────
@@ -492,118 +565,139 @@ async fn handle_http_over_tls(
let conn = upstream.as_mut().unwrap();
// ── Stream response back to client ──────────────────────────────────
// ALWAYS forward data to client immediately (no buffering).
// Buffer body on the side for usage parsing.
let mut streaming_acc = StreamingAccumulator::new();
let mut is_streaming_response = false;
let mut headers_parsed = false;
// Only buffer response body for non-streaming (for usage parsing)
let mut non_streaming_buf: Option<Vec<u8>> = None;
// Track if upstream connection is still usable after this response
let mut upstream_ok = true;
// Per-request timeout: 5 minutes (covers large context API calls)
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
let mut response_body_buf = Vec::new();
let mut response_content_length: Option<usize> = None;
let mut is_chunked = false;
let mut got_first_byte = false;
let mut header_buf = Vec::with_capacity(8192);
loop {
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
Ok(Ok(0)) => {
// Upstream closed — connection is no longer reusable
upstream_ok = false;
break;
}
// 15s idle timeout after first byte, 60s for initial response
let timeout = if got_first_byte {
std::time::Duration::from_secs(15)
} else {
std::time::Duration::from_secs(60)
};
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
Ok(Ok(0)) => { upstream_ok = false; break; }
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read finished");
debug!(domain, error = %e, "MITM: upstream read ended");
upstream_ok = false;
break;
}
Err(_) => {
warn!(domain, "MITM: upstream read timed out after 5 minutes");
if got_first_byte {
debug!(domain, "MITM: response idle timeout (complete)");
} else {
warn!(domain, "MITM: no upstream response in 60s");
}
upstream_ok = false;
break;
}
};
got_first_byte = true;
let chunk = &tmp[..n];
// Check response headers for content-type
if !headers_parsed {
// We need to buffer until we see the end of headers
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(buf) {
// Use httparse for response header parsing
header_buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(&header_buf) {
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
let mut resp = httparse::Response::new(&mut resp_headers);
let hdr_end = match resp.parse(buf) {
let hdr_end = match resp.parse(&header_buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => _hdr_end, // Fallback to manual detection
_ => _hdr_end,
};
// Detect content type and connection handling from parsed headers
let mut content_type = String::new();
for header in resp.headers.iter() {
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.contains("text/event-stream") {
is_streaming_response = true;
}
if let Ok(v) = std::str::from_utf8(header.value) {
content_type = v.to_string();
if v.contains("text/event-stream") { is_streaming_response = true; }
}
}
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(v) = std::str::from_utf8(header.value) {
response_content_length = v.trim().parse().ok();
}
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
}
}
}
info!(domain, streaming = is_streaming_response,
content_length = ?response_content_length,
content_type = %content_type,
status = resp.code, "MITM: got response headers");
headers_parsed = true;
if is_streaming_response {
// For streaming, parse any SSE data already in the buffer
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
if !body_so_far.is_empty() {
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
}
// Forward the accumulated buffer to client
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
non_streaming_buf = None;
continue;
// Save body for usage parsing
response_body_buf.extend_from_slice(&header_buf[hdr_end..]);
// Forward to client immediately
if let Err(e) = client.write_all(&header_buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
if is_streaming_response && hdr_end < header_buf.len() {
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc);
}
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
}
// Check chunked terminator in initial body
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete (initial)");
break;
}
// Non-streaming: keep buffering the response body for parsing
continue;
}
continue;
}
// If streaming, parse SSE events and forward immediately
// Forward to client immediately
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
response_body_buf.extend_from_slice(chunk);
if is_streaming_response {
let chunk_str = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
break;
}
} else {
// Non-streaming: keep accumulating to parse usage at the end
if let Some(ref mut buf) = non_streaming_buf {
buf.extend_from_slice(chunk);
}
let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc);
}
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete");
break;
}
}
// Forward non-streaming response all at once
if !is_streaming_response {
if let Some(ref buf) = non_streaming_buf {
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
}
}
}
// Flush client
let _ = client.flush().await;
// Capture usage data
if is_streaming_response {
@@ -611,12 +705,9 @@ async fn handle_http_over_tls(
let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if let Some(ref buf) = non_streaming_buf {
if let Some(body_start) = find_headers_end(buf) {
let body = &buf[body_start..];
if let Some(usage) = parse_non_streaming_response(body) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if !response_body_buf.is_empty() {
if let Some(usage) = parse_non_streaming_response(&response_body_buf) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
@@ -652,6 +743,20 @@ async fn handle_passthrough(
Ok(())
}
/// Detect end of HTTP chunked transfer encoding.
/// A chunked response ends with "0\r\n\r\n" (zero-length chunk + empty trailer).
/// We check the tail of the buffer for this pattern.
fn has_chunked_terminator(body: &[u8]) -> bool {
// The minimal terminator is "0\r\n\r\n" (5 bytes)
if body.len() < 5 {
return false;
}
// Check last 7 bytes to account for possible trailing whitespace
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
// Look for \r\n0\r\n\r\n anywhere in the tail
tail.windows(5).any(|w| w == b"0\r\n\r\n")
}
/// Check if buffer contains a complete HTTP request (headers + full body).
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
fn has_complete_http_request(buf: &[u8]) -> bool {