feat: transparent proxy mode with SNI extraction and DNS bypass for upstream
This commit is contained in:
@@ -1,11 +1,13 @@
|
||||
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
|
||||
//! MITM proxy server: handles CONNECT tunnels and transparent TLS interception.
|
||||
//!
|
||||
//! Listens on a local port for HTTP CONNECT requests from the LS.
|
||||
//! For intercepted domains, it terminates TLS with our CA-signed cert,
|
||||
//! reads/modifies the request, forwards to the real upstream, and captures
|
||||
//! the response (especially usage data).
|
||||
//! Supports two modes:
|
||||
//! 1. **HTTP CONNECT** — standard proxy mode, LS sends `CONNECT host:port`
|
||||
//! 2. **Transparent (iptables)** — raw TLS arrives via REDIRECT, SNI extracted from ClientHello
|
||||
//!
|
||||
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
|
||||
//! For intercepted domains, terminates TLS with our CA-signed cert,
|
||||
//! reads/modifies the request, forwards to the real upstream, and captures usage.
|
||||
//!
|
||||
//! For non-intercepted domains, acts as a transparent TCP tunnel.
|
||||
|
||||
use super::ca::MitmCa;
|
||||
use super::intercept::{
|
||||
@@ -104,64 +106,173 @@ pub async fn run(
|
||||
Ok((port, handle))
|
||||
}
|
||||
|
||||
/// Handle a single incoming connection from the LS.
|
||||
/// Handle a single incoming connection.
|
||||
///
|
||||
/// The LS sends an HTTP CONNECT request to establish a tunnel.
|
||||
/// We then decide whether to intercept or passthrough.
|
||||
/// Supports two modes:
|
||||
/// 1. **HTTP CONNECT** (standard proxy) — LS sends `CONNECT host:port HTTP/1.1`
|
||||
/// 2. **Transparent/iptables redirect** — raw TLS ClientHello arrives directly
|
||||
/// (first byte is 0x16). We extract the domain from SNI and intercept.
|
||||
async fn handle_connection(
|
||||
mut stream: TcpStream,
|
||||
ca: Arc<MitmCa>,
|
||||
store: MitmStore,
|
||||
modify_requests: bool,
|
||||
) -> Result<(), String> {
|
||||
// Read the CONNECT request
|
||||
let mut buf = vec![0u8; 8192];
|
||||
// Peek at the first byte to distinguish CONNECT vs raw TLS
|
||||
let mut peek = [0u8; 1];
|
||||
let n = stream
|
||||
.read(&mut buf)
|
||||
.peek(&mut peek)
|
||||
.await
|
||||
.map_err(|e| format!("Read CONNECT: {e}"))?;
|
||||
.map_err(|e| format!("Peek failed: {e}"))?;
|
||||
|
||||
if n == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let request = String::from_utf8_lossy(&buf[..n]);
|
||||
let first_line = request.lines().next().unwrap_or("");
|
||||
if peek[0] == 0x16 {
|
||||
// TLS ClientHello — transparent/iptables redirect mode.
|
||||
// Peek enough bytes to extract SNI from the ClientHello.
|
||||
let mut hello_buf = vec![0u8; 16384];
|
||||
let n = stream
|
||||
.peek(&mut hello_buf)
|
||||
.await
|
||||
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
||||
|
||||
// Parse "CONNECT host:port HTTP/1.1"
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
if parts.len() < 3 || parts[0] != "CONNECT" {
|
||||
// Not a CONNECT request — return 400
|
||||
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
|
||||
let _ = stream.write_all(resp.as_bytes()).await;
|
||||
return Ok(());
|
||||
}
|
||||
let domain = extract_sni(&hello_buf[..n])
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
let host_port = parts[1];
|
||||
let (domain, _port) = match host_port.rsplit_once(':') {
|
||||
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
|
||||
None => (host_port, 443),
|
||||
};
|
||||
info!(domain, "MITM: transparent redirect (iptables)");
|
||||
|
||||
debug!(domain, "MITM: CONNECT request");
|
||||
|
||||
// Decide: intercept or passthrough
|
||||
let should_intercept = should_intercept_domain(domain);
|
||||
|
||||
// Send 200 Connection Established
|
||||
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("Write 200: {e}"))?;
|
||||
|
||||
if should_intercept {
|
||||
handle_intercepted(stream, domain, ca, store, modify_requests).await
|
||||
let should_intercept = should_intercept_domain(&domain);
|
||||
if should_intercept {
|
||||
handle_intercepted(stream, &domain, ca, store, modify_requests).await
|
||||
} else {
|
||||
// For non-intercepted domains via iptables, we need the original dest.
|
||||
// Since we only have SNI, resolve and passthrough.
|
||||
handle_passthrough(stream, &domain, 443).await
|
||||
}
|
||||
} else {
|
||||
handle_passthrough(stream, domain, _port).await
|
||||
// Standard HTTP CONNECT proxy mode
|
||||
let mut buf = vec![0u8; 8192];
|
||||
let n = stream
|
||||
.read(&mut buf)
|
||||
.await
|
||||
.map_err(|e| format!("Read CONNECT: {e}"))?;
|
||||
|
||||
if n == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let request = String::from_utf8_lossy(&buf[..n]);
|
||||
let first_line = request.lines().next().unwrap_or("");
|
||||
|
||||
// Parse "CONNECT host:port HTTP/1.1"
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
if parts.len() < 3 || parts[0] != "CONNECT" {
|
||||
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
|
||||
let _ = stream.write_all(resp.as_bytes()).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let host_port = parts[1];
|
||||
let (domain, _port) = match host_port.rsplit_once(':') {
|
||||
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
|
||||
None => (host_port, 443),
|
||||
};
|
||||
|
||||
debug!(domain, "MITM: CONNECT request");
|
||||
|
||||
let should_intercept = should_intercept_domain(domain);
|
||||
|
||||
// Send 200 Connection Established
|
||||
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("Write 200: {e}"))?;
|
||||
|
||||
if should_intercept {
|
||||
handle_intercepted(stream, domain, ca, store, modify_requests).await
|
||||
} else {
|
||||
handle_passthrough(stream, domain, _port).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract SNI (Server Name Indication) from a TLS ClientHello message.
|
||||
///
|
||||
/// Parses the raw TLS record to find the `server_name` extension (type 0x0000).
|
||||
/// Returns `None` if the SNI can't be found (not TLS, no SNI extension, etc.).
|
||||
fn extract_sni(buf: &[u8]) -> Option<String> {
|
||||
// TLS record: type(1) + version(2) + length(2) + handshake
|
||||
if buf.len() < 5 || buf[0] != 0x16 {
|
||||
return None;
|
||||
}
|
||||
let record_len = u16::from_be_bytes([buf[3], buf[4]]) as usize;
|
||||
let handshake = &buf[5..5 + record_len.min(buf.len() - 5)];
|
||||
|
||||
// Handshake: type(1) + length(3) + body
|
||||
if handshake.is_empty() || handshake[0] != 0x01 {
|
||||
return None; // Not ClientHello
|
||||
}
|
||||
if handshake.len() < 4 {
|
||||
return None;
|
||||
}
|
||||
let hs_len = u32::from_be_bytes([0, handshake[1], handshake[2], handshake[3]]) as usize;
|
||||
let body = &handshake[4..4 + hs_len.min(handshake.len() - 4)];
|
||||
|
||||
// ClientHello: version(2) + random(32) + session_id_len(1) + session_id(var)
|
||||
// + cipher_suites_len(2) + cipher_suites(var)
|
||||
// + compression_len(1) + compression(var)
|
||||
// + extensions_len(2) + extensions(var)
|
||||
if body.len() < 34 {
|
||||
return None;
|
||||
}
|
||||
let mut pos = 34; // skip version + random
|
||||
|
||||
// Session ID
|
||||
if pos >= body.len() { return None; }
|
||||
let sid_len = body[pos] as usize;
|
||||
pos += 1 + sid_len;
|
||||
|
||||
// Cipher suites
|
||||
if pos + 2 > body.len() { return None; }
|
||||
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2 + cs_len;
|
||||
|
||||
// Compression methods
|
||||
if pos >= body.len() { return None; }
|
||||
let cm_len = body[pos] as usize;
|
||||
pos += 1 + cm_len;
|
||||
|
||||
// Extensions
|
||||
if pos + 2 > body.len() { return None; }
|
||||
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2;
|
||||
let ext_end = pos + ext_len.min(body.len() - pos);
|
||||
|
||||
while pos + 4 <= ext_end {
|
||||
let ext_type = u16::from_be_bytes([body[pos], body[pos + 1]]);
|
||||
let ext_data_len = u16::from_be_bytes([body[pos + 2], body[pos + 3]]) as usize;
|
||||
pos += 4;
|
||||
|
||||
if ext_type == 0x0000 {
|
||||
// SNI extension — server_name_list_len(2) + type(1) + name_len(2) + name
|
||||
if ext_data_len >= 5 && pos + ext_data_len <= ext_end {
|
||||
let name_len = u16::from_be_bytes([body[pos + 3], body[pos + 4]]) as usize;
|
||||
if pos + 5 + name_len <= ext_end {
|
||||
return String::from_utf8(body[pos + 5..pos + 5 + name_len].to_vec()).ok();
|
||||
}
|
||||
}
|
||||
return None;
|
||||
}
|
||||
|
||||
pos += ext_data_len;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a domain should be intercepted.
|
||||
fn should_intercept_domain(domain: &str) -> bool {
|
||||
// Never intercept passthrough domains
|
||||
@@ -267,12 +378,19 @@ async fn handle_http_over_tls(
|
||||
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
|
||||
|
||||
/// Connect (or reconnect) to the real upstream via TLS.
|
||||
///
|
||||
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
|
||||
/// then falls back to cached IPs file, then to normal system resolution.
|
||||
async fn connect_upstream(
|
||||
domain: &str,
|
||||
config: &Arc<rustls::ClientConfig>,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
||||
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
||||
let tcp = TcpStream::connect(format!("{domain}:443"))
|
||||
|
||||
// Try to resolve the real IP, bypassing /etc/hosts
|
||||
let addr = resolve_upstream(domain).await;
|
||||
|
||||
let tcp = TcpStream::connect(addr)
|
||||
.await
|
||||
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
||||
@@ -283,6 +401,35 @@ async fn handle_http_over_tls(
|
||||
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
|
||||
}
|
||||
|
||||
/// Resolve upstream IP bypassing /etc/hosts.
|
||||
async fn resolve_upstream(domain: &str) -> String {
|
||||
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
|
||||
if let Ok(output) = tokio::process::Command::new("dig")
|
||||
.args(["+short", &format!("@8.8.8.8"), domain])
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
let out = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Try cached IPs file (written by dns-redirect.sh install)
|
||||
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
|
||||
for line in contents.lines() {
|
||||
if let Some((d, ip)) = line.split_once('=') {
|
||||
if d == domain {
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to normal resolution (may hit /etc/hosts)
|
||||
format!("{domain}:443")
|
||||
}
|
||||
|
||||
// Keep-alive loop: handle multiple requests on this connection
|
||||
loop {
|
||||
// ── Read the HTTP request from the client ─────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user