feat: transparent proxy mode with SNI extraction and DNS bypass for upstream

This commit is contained in:
Nikketryhard
2026-02-14 04:03:19 -06:00
parent df7dcc96db
commit 4fa8775b61

View File

@@ -1,11 +1,13 @@
//! MITM proxy server: handles CONNECT tunnels and TLS interception. //! MITM proxy server: handles CONNECT tunnels and transparent TLS interception.
//! //!
//! Listens on a local port for HTTP CONNECT requests from the LS. //! Supports two modes:
//! For intercepted domains, it terminates TLS with our CA-signed cert, //! 1. **HTTP CONNECT** — standard proxy mode, LS sends `CONNECT host:port`
//! reads/modifies the request, forwards to the real upstream, and captures //! 2. **Transparent (iptables)** — raw TLS arrives via REDIRECT, SNI extracted from ClientHello
//! the response (especially usage data).
//! //!
//! For non-intercepted domains, it acts as a transparent TCP tunnel. //! For intercepted domains, terminates TLS with our CA-signed cert,
//! reads/modifies the request, forwards to the real upstream, and captures usage.
//!
//! For non-intercepted domains, acts as a transparent TCP tunnel.
use super::ca::MitmCa; use super::ca::MitmCa;
use super::intercept::{ use super::intercept::{
@@ -104,17 +106,53 @@ pub async fn run(
Ok((port, handle)) Ok((port, handle))
} }
/// Handle a single incoming connection from the LS. /// Handle a single incoming connection.
/// ///
/// The LS sends an HTTP CONNECT request to establish a tunnel. /// Supports two modes:
/// We then decide whether to intercept or passthrough. /// 1. **HTTP CONNECT** (standard proxy) — LS sends `CONNECT host:port HTTP/1.1`
/// 2. **Transparent/iptables redirect** — raw TLS ClientHello arrives directly
/// (first byte is 0x16). We extract the domain from SNI and intercept.
async fn handle_connection( async fn handle_connection(
mut stream: TcpStream, mut stream: TcpStream,
ca: Arc<MitmCa>, ca: Arc<MitmCa>,
store: MitmStore, store: MitmStore,
modify_requests: bool, modify_requests: bool,
) -> Result<(), String> { ) -> Result<(), String> {
// Read the CONNECT request // Peek at the first byte to distinguish CONNECT vs raw TLS
let mut peek = [0u8; 1];
let n = stream
.peek(&mut peek)
.await
.map_err(|e| format!("Peek failed: {e}"))?;
if n == 0 {
return Ok(());
}
if peek[0] == 0x16 {
// TLS ClientHello — transparent/iptables redirect mode.
// Peek enough bytes to extract SNI from the ClientHello.
let mut hello_buf = vec![0u8; 16384];
let n = stream
.peek(&mut hello_buf)
.await
.map_err(|e| format!("Peek ClientHello: {e}"))?;
let domain = extract_sni(&hello_buf[..n])
.unwrap_or_else(|| "unknown".to_string());
info!(domain, "MITM: transparent redirect (iptables)");
let should_intercept = should_intercept_domain(&domain);
if should_intercept {
handle_intercepted(stream, &domain, ca, store, modify_requests).await
} else {
// For non-intercepted domains via iptables, we need the original dest.
// Since we only have SNI, resolve and passthrough.
handle_passthrough(stream, &domain, 443).await
}
} else {
// Standard HTTP CONNECT proxy mode
let mut buf = vec![0u8; 8192]; let mut buf = vec![0u8; 8192];
let n = stream let n = stream
.read(&mut buf) .read(&mut buf)
@@ -131,7 +169,6 @@ async fn handle_connection(
// Parse "CONNECT host:port HTTP/1.1" // Parse "CONNECT host:port HTTP/1.1"
let parts: Vec<&str> = first_line.split_whitespace().collect(); let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "CONNECT" { if parts.len() < 3 || parts[0] != "CONNECT" {
// Not a CONNECT request — return 400
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n"; let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await; let _ = stream.write_all(resp.as_bytes()).await;
return Ok(()); return Ok(());
@@ -145,7 +182,6 @@ async fn handle_connection(
debug!(domain, "MITM: CONNECT request"); debug!(domain, "MITM: CONNECT request");
// Decide: intercept or passthrough
let should_intercept = should_intercept_domain(domain); let should_intercept = should_intercept_domain(domain);
// Send 200 Connection Established // Send 200 Connection Established
@@ -160,6 +196,81 @@ async fn handle_connection(
} else { } else {
handle_passthrough(stream, domain, _port).await handle_passthrough(stream, domain, _port).await
} }
}
}
/// Extract SNI (Server Name Indication) from a TLS ClientHello message.
///
/// Parses the raw TLS record to find the `server_name` extension (type 0x0000).
/// Returns `None` if the SNI can't be found (not TLS, no SNI extension, etc.).
fn extract_sni(buf: &[u8]) -> Option<String> {
// TLS record: type(1) + version(2) + length(2) + handshake
if buf.len() < 5 || buf[0] != 0x16 {
return None;
}
let record_len = u16::from_be_bytes([buf[3], buf[4]]) as usize;
let handshake = &buf[5..5 + record_len.min(buf.len() - 5)];
// Handshake: type(1) + length(3) + body
if handshake.is_empty() || handshake[0] != 0x01 {
return None; // Not ClientHello
}
if handshake.len() < 4 {
return None;
}
let hs_len = u32::from_be_bytes([0, handshake[1], handshake[2], handshake[3]]) as usize;
let body = &handshake[4..4 + hs_len.min(handshake.len() - 4)];
// ClientHello: version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
if body.len() < 34 {
return None;
}
let mut pos = 34; // skip version + random
// Session ID
if pos >= body.len() { return None; }
let sid_len = body[pos] as usize;
pos += 1 + sid_len;
// Cipher suites
if pos + 2 > body.len() { return None; }
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2 + cs_len;
// Compression methods
if pos >= body.len() { return None; }
let cm_len = body[pos] as usize;
pos += 1 + cm_len;
// Extensions
if pos + 2 > body.len() { return None; }
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len.min(body.len() - pos);
while pos + 4 <= ext_end {
let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]);
let ext_data_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize;
pos += 4;
if ext_type == 0x0000 {
// SNI extension — server_name_list_len(2) + type(1) + name_len(2) + name
if ext_data_len >= 5 && pos + ext_data_len <= ext_end {
let name_len = u16::from_be_bytes([body[pos + 3], body[pos + 4]]) as usize;
if pos + 5 + name_len <= ext_end {
return String::from_utf8(body[pos + 5..pos + 5 + name_len].to_vec()).ok();
}
}
return None;
}
pos += ext_data_len;
}
None
} }
/// Check if a domain should be intercepted. /// Check if a domain should be intercepted.
@@ -267,12 +378,19 @@ async fn handle_http_over_tls(
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None; let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS. /// Connect (or reconnect) to the real upstream via TLS.
///
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
/// then falls back to cached IPs file, then to normal system resolution.
async fn connect_upstream( async fn connect_upstream(
domain: &str, domain: &str,
config: &Arc<rustls::ClientConfig>, config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> { ) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone()); let connector = tokio_rustls::TlsConnector::from(config.clone());
let tcp = TcpStream::connect(format!("{domain}:443"))
// Try to resolve the real IP, bypassing /etc/hosts
let addr = resolve_upstream(domain).await;
let tcp = TcpStream::connect(addr)
.await .await
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?; .map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string()) let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
@@ -283,6 +401,35 @@ async fn handle_http_over_tls(
.map_err(|e| format!("TLS connect to upstream {domain}: {e}")) .map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
} }
/// Resolve upstream IP bypassing /etc/hosts.
async fn resolve_upstream(domain: &str) -> String {
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
if let Ok(output) = tokio::process::Command::new("dig")
.args(["+short", &format!("@8.8.8.8"), domain])
.output()
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
return format!("{ip}:443");
}
}
// 2. Try cached IPs file (written by dns-redirect.sh install)
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
for line in contents.lines() {
if let Some((d, ip)) = line.split_once('=') {
if d == domain {
return format!("{ip}:443");
}
}
}
}
// 3. Fallback to normal resolution (may hit /etc/hosts)
format!("{domain}:443")
}
// Keep-alive loop: handle multiple requests on this connection // Keep-alive loop: handle multiple requests on this connection
loop { loop {
// ── Read the HTTP request from the client ───────────────────────── // ── Read the HTTP request from the client ─────────────────────────