feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling

This commit is contained in:
Nikketryhard
2026-02-14 02:24:35 -06:00
commit d5e7f09225
30 changed files with 9980 additions and 0 deletions

218
src/mitm/ca.rs Normal file
View File

@@ -0,0 +1,218 @@
//! Certificate Authority for MITM proxy.
//!
//! Generates a self-signed root CA at first run and caches it to disk.
//! Dynamically generates per-domain leaf certificates signed by this CA.
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
IsCa, KeyPair, KeyUsagePurpose, SanType,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::info;
/// MITM Certificate Authority.
pub struct MitmCa {
/// Root CA certificate (DER-encoded for rustls).
ca_cert_der: CertificateDer<'static>,
/// Root CA private key.
ca_key: KeyPair,
/// Signed root CA cert (needed by rcgen to sign leaf certs).
ca_signed: rcgen::Certificate,
/// Cache of per-domain TLS configs.
domain_cache: Arc<RwLock<HashMap<String, Arc<rustls::ServerConfig>>>>,
/// Path to the CA PEM file (for SSL_CERT_FILE combined bundle).
pub ca_pem_path: PathBuf,
}
impl MitmCa {
/// Load or generate the MITM CA.
///
/// The CA cert/key are stored at:
/// `<data_dir>/mitm-ca.pem` (cert, for NODE_EXTRA_CA_CERTS)
/// `<data_dir>/mitm-ca.key` (private key)
pub fn load_or_generate(data_dir: &Path) -> Result<Self, String> {
let cert_path = data_dir.join("mitm-ca.pem");
let key_path = data_dir.join("mitm-ca.key");
if cert_path.exists() && key_path.exists() {
info!("Loading existing MITM CA from {}", cert_path.display());
let cert_pem = std::fs::read_to_string(&cert_path)
.map_err(|e| format!("Failed to read CA cert: {e}"))?;
let key_pem = std::fs::read_to_string(&key_path)
.map_err(|e| format!("Failed to read CA key: {e}"))?;
let ca_key = KeyPair::from_pem(&key_pem)
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
// Re-create params and self-sign to get the rcgen Certificate object
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
// The re-signed cert will have a different serial/notBefore, but that's fine
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
// (via the combined CA bundle built by the wrapper script). Writing the
// re-signed cert back would desync the LS's trust anchor.
let ca_cert_der = Self::pem_to_der(&cert_pem)
.unwrap_or_else(|| CertificateDer::from(ca_signed.der().to_vec()));
Ok(Self {
ca_cert_der,
ca_key,
ca_signed,
domain_cache: Arc::new(RwLock::new(HashMap::new())),
ca_pem_path: cert_path,
})
} else {
info!("Generating new MITM CA at {}", cert_path.display());
// Ensure data dir exists
std::fs::create_dir_all(data_dir)
.map_err(|e| format!("Failed to create data dir: {e}"))?;
let ca_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Write cert and key to disk
std::fs::write(&cert_path, ca_signed.pem())
.map_err(|e| format!("Failed to write CA cert: {e}"))?;
std::fs::write(&key_path, ca_key.serialize_pem())
.map_err(|e| format!("Failed to write CA key: {e}"))?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600));
}
let ca_cert_der = CertificateDer::from(ca_signed.der().to_vec());
Ok(Self {
ca_cert_der,
ca_key,
ca_signed,
domain_cache: Arc::new(RwLock::new(HashMap::new())),
ca_pem_path: cert_path,
})
}
}
/// Build the CA certificate parameters (reusable for both generate and load).
fn ca_params() -> CertificateParams {
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "Antigravity MITM CA");
dn.push(DnType::OrganizationName, "Antigravity Proxy");
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
];
// Valid for 10 years
let now = time::OffsetDateTime::now_utc();
params.not_before = now;
params.not_after = now + time::Duration::days(3650);
params
}
/// Parse a PEM certificate into a DER-encoded CertificateDer.
fn pem_to_der(pem: &str) -> Option<CertificateDer<'static>> {
// Extract base64 content between BEGIN/END markers
let mut in_cert = false;
let mut b64 = String::new();
for line in pem.lines() {
if line.contains("BEGIN CERTIFICATE") {
in_cert = true;
continue;
}
if line.contains("END CERTIFICATE") {
break;
}
if in_cert {
b64.push_str(line.trim());
}
}
if b64.is_empty() {
return None;
}
use base64::Engine;
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
Some(CertificateDer::from(der))
}
/// Get or create a TLS ServerConfig for the given domain.
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
// Check cache first
{
let cache = self.domain_cache.read().await;
if let Some(config) = cache.get(domain) {
return Ok(config.clone());
}
}
// Generate leaf cert for this domain
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, domain);
params.distinguished_name = dn;
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
KeyUsagePurpose::KeyEncipherment,
];
// Valid for 1 year
let now = time::OffsetDateTime::now_utc();
params.not_before = now;
params.not_after = now + time::Duration::days(365);
let leaf_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
// Build rustls ServerConfig
let leaf_cert_der = CertificateDer::from(leaf_cert.der().to_vec());
let leaf_key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![leaf_cert_der, self.ca_cert_der.clone()],
leaf_key_der,
)
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
let config = Arc::new(config);
// Cache it
{
let mut cache = self.domain_cache.write().await;
cache.insert(domain.to_string(), config.clone());
}
Ok(config)
}
}

512
src/mitm/h2_handler.rs Normal file
View File

@@ -0,0 +1,512 @@
//! HTTP/2 handler for gRPC traffic interception.
//!
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
//! this module handles the bidirectional HTTP/2 connection:
//! 1. Accepts HTTP/2 frames from the client (LS)
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
//! 3. Forwards each request stream to upstream
//! 4. For non-streaming: buffers response, extracts usage, forwards
//! 5. For streaming: forwards response body chunks in real-time, tees to a
//! side buffer for usage extraction after stream completes
//!
//! ## Streaming vs Non-streaming
//!
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
//! We MUST forward streaming responses immediately — buffering would break
//! the LS's perception of real-time generation.
//!
//! For usage extraction: ModelUsageStats is typically in the LAST message
//! of a streaming response, so we tee the data and parse after stream ends.
use crate::mitm::proto::parse_grpc_response_for_usage;
use crate::mitm::store::{ApiUsage, MitmStore};
use bytes::Bytes;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::server::conn::http2::Builder as H2ServerBuilder;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, info, trace, warn};
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
///
/// gRPC multiplexes many requests over a single HTTP/2 connection.
/// We mirror this by maintaining a single upstream connection per domain.
struct UpstreamPool {
domain: String,
tls_config: Arc<rustls::ClientConfig>,
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
}
impl UpstreamPool {
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
Self {
domain,
tls_config,
sender: Mutex::new(None),
}
}
/// Get or create the upstream HTTP/2 sender.
async fn get_sender(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let mut guard = self.sender.lock().await;
// Check if existing sender is still usable
if let Some(ref sender) = *guard {
if !sender.is_closed() {
return Ok(sender.clone());
}
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
}
// Create new connection
let sender = self.connect().await?;
*guard = Some(sender.clone());
Ok(sender)
}
async fn connect(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
.await
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
let upstream_tls = connector
.connect(server_name, upstream_tcp)
.await
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
let upstream_io = TokioIo::new(upstream_tls);
let (sender, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(upstream_io)
.await
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let domain = self.domain.clone();
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
}
});
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
Ok(sender)
}
}
/// gRPC methods that carry ModelUsageStats in their responses.
const USAGE_METHODS: &[&str] = &[
// Unary methods
"GenerateContent",
"AsyncGenerateContent",
"GenerateChat",
"GenerateCode",
"CompleteCode",
"InternalAtomicAgenticChat",
"Predict",
"DirectPredict",
// Streaming methods
"StreamGenerateContent",
"StreamAsyncGenerateContent",
"StreamGenerateChat",
];
/// Handle an HTTP/2 connection from the LS after TLS termination.
///
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
/// HTTP/2 connection to forward them.
pub async fn handle_h2_connection<S>(
tls_stream: S,
domain: String,
store: MitmStore,
) -> Result<(), String>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
// Build TLS config for upstream connections
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let mut upstream_tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
// Shared upstream connection pool (single connection, multiplexed)
let pool = Arc::new(UpstreamPool::new(
domain.clone(),
Arc::new(upstream_tls_config),
));
let io = TokioIo::new(tls_stream);
let domain = Arc::new(domain);
let result = H2ServerBuilder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |req: Request<Incoming>| {
let domain = domain.clone();
let store = store.clone();
let pool = pool.clone();
async move { handle_h2_request(req, &domain, store, pool).await }
}),
)
.await;
match result {
Ok(()) => {
debug!("MITM H2: connection closed cleanly");
Ok(())
}
Err(e) => {
// Connection errors are expected on clean close
debug!(error = %e, "MITM H2: connection ended");
Ok(())
}
}
}
/// Response body type — either buffered or streaming.
type BoxBody = http_body_util::Either<
Full<Bytes>,
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
>;
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
///
/// For streaming responses, forwards chunks in real-time while teeing
/// data to a side buffer for post-stream usage extraction.
async fn handle_h2_request(
req: Request<Incoming>,
domain: &str,
store: MitmStore,
pool: Arc<UpstreamPool>,
) -> Result<Response<BoxBody>, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path().to_string();
// Identify gRPC method
let is_grpc = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.starts_with("application/grpc"))
.unwrap_or(false);
// Check if this method carries usage data
let is_usage_method = is_grpc
&& USAGE_METHODS.iter().any(|m| path.contains(m));
// Check if this is a streaming method
let is_streaming = is_grpc
&& (path.contains("Stream") || path.contains("stream"));
debug!(
domain,
%method,
path = %path,
grpc = is_grpc,
usage_method = is_usage_method,
streaming = is_streaming,
"MITM H2: forwarding request"
);
// Collect request body (we need it for cascade ID extraction)
let (parts, body) = req.into_parts();
let request_body = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect request body");
Bytes::new()
}
};
// Get upstream sender from pool
let mut upstream_sender = match pool.get_sender().await {
Ok(s) => s,
Err(e) => {
warn!(error = %e, domain, "MITM H2: upstream connect failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream connect failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Build the upstream request with proper authority
let upstream_uri = http::Uri::builder()
.scheme("https")
.authority(domain)
.path_and_query(
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/"),
)
.build()
.unwrap_or(uri);
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_uri);
// Copy headers, skip hop-by-hop
for (name, value) in &parts.headers {
let n = name.as_str();
if n == "host" || n == "connection" || n == "transfer-encoding" {
continue;
}
upstream_req = upstream_req.header(name, value);
}
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
Ok(r) => r,
Err(e) => {
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("build request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Send to upstream
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
let (resp_parts, resp_body) = upstream_resp.into_parts();
let status = resp_parts.status;
// ──────────────────────────────────────────────────────────────────
// Streaming path: forward chunks immediately, tee for usage parsing
// ──────────────────────────────────────────────────────────────────
if is_streaming && status.is_success() {
let should_track_usage = is_usage_method;
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
let store_clone = store.clone();
let path_clone = path.clone();
let request_body_clone = request_body.clone();
// Spawn a task to forward body chunks and tee for usage extraction
tokio::spawn(async move {
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
let mut body = resp_body;
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
buf.extend_from_slice(data);
}
if tx.send(Ok(frame)).await.is_err() {
break; // client disconnected
}
}
Some(Err(e)) => {
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
let _ = tx.send(Err(e)).await;
break;
}
None => break, // stream ended
}
}
// Stream completed — parse the tee buffer for usage
if let Some(tee_buffer) = tee_buffer {
if !tee_buffer.is_empty() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path_clone.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let stream_body = StreamBody::new(stream);
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Right(stream_body))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
return Ok(client_resp);
}
// ──────────────────────────────────────────────────────────────────
// Non-streaming path: buffer full response, extract usage, forward
// ──────────────────────────────────────────────────────────────────
let response_body = match resp_body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect response body");
Bytes::new()
}
};
trace!(
domain,
path = %path,
status = %status,
body_len = response_body.len(),
"MITM H2: got upstream response"
);
// Extract usage data from usage-carrying gRPC methods
if is_usage_method && !response_body.is_empty() && status.is_success() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
// Build response for the client
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Left(Full::new(response_body)))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
Ok(client_resp)
}
/// Try to extract a cascade ID from a gRPC request body.
///
/// Looks for UUID-formatted strings in the protobuf fields.
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
let messages = extract_grpc_messages(body);
for msg in &messages {
let fields = decode_proto(msg);
for field in &fields {
if let Some(id) = extract_uuid_from_field(field) {
return Some(id);
}
}
}
None
}
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
use crate::mitm::proto::ProtoValue;
match &field.value {
ProtoValue::Bytes(b) => {
if let Ok(s) = std::str::from_utf8(b) {
if is_uuid(s) {
return Some(s.to_string());
}
}
}
ProtoValue::Message(nested) => {
for nf in nested {
if let Some(id) = extract_uuid_from_field(nf) {
return Some(id);
}
}
}
_ => {}
}
None
}
fn is_uuid(s: &str) -> bool {
s.len() == 36
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
&& s.chars().filter(|&c| c == '-').count() == 4
}

271
src/mitm/intercept.rs Normal file
View File

@@ -0,0 +1,271 @@
//! API response interceptor: parses Anthropic/Google API responses to extract usage data.
//!
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
use super::store::ApiUsage;
use serde_json::Value;
use tracing::{debug, trace};
/// Parse a complete (non-streaming) Anthropic Messages API response body.
///
/// Response format:
/// ```json
/// {
/// "id": "msg_...",
/// "type": "message",
/// "model": "claude-sonnet-4-20250514",
/// "usage": {
/// "input_tokens": 1234,
/// "output_tokens": 567,
/// "cache_creation_input_tokens": 0,
/// "cache_read_input_tokens": 890
/// },
/// "stop_reason": "end_turn"
/// }
/// ```
pub fn parse_non_streaming_response(body: &[u8]) -> Option<ApiUsage> {
let json: Value = serde_json::from_slice(body).ok()?;
extract_usage_from_message(&json)
}
/// Parse SSE events from a streaming Anthropic response body chunk.
///
/// Events of interest:
/// - `message_start` — contains `message.usage.input_tokens` + cache tokens
/// - `message_delta` — contains `usage.output_tokens`
/// - `message_stop` — marks end (no usage data)
///
/// Returns accumulated usage across all events in this chunk.
pub fn parse_streaming_chunk(chunk: &str, accumulator: &mut StreamingAccumulator) {
for line in chunk.lines() {
if let Some(data) = line.strip_prefix("data: ") {
if data.trim() == "[DONE]" {
continue;
}
if let Ok(event) = serde_json::from_str::<Value>(data) {
accumulator.process_event(&event);
}
}
}
}
/// Accumulates usage data across streaming SSE events.
#[derive(Debug, Default)]
pub struct StreamingAccumulator {
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_creation_input_tokens: u64,
pub cache_read_input_tokens: u64,
pub model: Option<String>,
pub stop_reason: Option<String>,
pub is_complete: bool,
}
impl StreamingAccumulator {
pub fn new() -> Self {
Self::default()
}
/// Process a single SSE event.
pub fn process_event(&mut self, event: &Value) {
let event_type = event["type"].as_str().unwrap_or("");
match event_type {
"message_start" => {
// message_start contains the initial usage (input tokens + cache)
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
}
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
self.model = Some(model.to_string());
}
trace!(
input = self.input_tokens,
cache_read = self.cache_read_input_tokens,
cache_create = self.cache_creation_input_tokens,
"SSE message_start: captured input usage"
);
}
"message_delta" => {
// message_delta contains the output usage
if let Some(usage) = event.get("usage") {
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
}
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
self.stop_reason = Some(reason.to_string());
}
trace!(output = self.output_tokens, "SSE message_delta: updated output tokens");
}
"message_stop" => {
self.is_complete = true;
debug!(
input = self.input_tokens,
output = self.output_tokens,
cache_read = self.cache_read_input_tokens,
model = ?self.model,
"SSE message_stop: stream complete"
);
}
"content_block_start" | "content_block_delta" | "content_block_stop" | "ping" => {
// Content events — no usage data, just pass through
}
_ => {
trace!(event_type, "SSE: unknown event type");
}
}
}
/// Convert accumulated data to an ApiUsage.
pub fn into_usage(self) -> ApiUsage {
ApiUsage {
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
cache_creation_input_tokens: self.cache_creation_input_tokens,
cache_read_input_tokens: self.cache_read_input_tokens,
thinking_output_tokens: 0,
response_output_tokens: 0,
total_cost_usd: None,
model: self.model,
stop_reason: self.stop_reason,
api_provider: Some("anthropic".to_string()),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
}
/// Extract usage from a complete Message JSON object.
fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
let usage = msg.get("usage")?;
Some(ApiUsage {
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
cache_creation_input_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
cache_read_input_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
thinking_output_tokens: 0,
response_output_tokens: 0,
total_cost_usd: None,
model: msg["model"].as_str().map(|s| s.to_string()),
stop_reason: msg["stop_reason"].as_str().map(|s| s.to_string()),
api_provider: Some("anthropic".to_string()),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
})
}
/// Try to identify a cascade ID from the request body.
///
/// The LS includes cascade-related metadata in its API requests (as part of
/// the system prompt or metadata field). We try to find it.
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
let json: Value = serde_json::from_slice(request_body).ok()?;
// Check for metadata field (some API configurations include it)
if let Some(metadata) = json.get("metadata") {
if let Some(user_id) = metadata["user_id"].as_str() {
// The LS often sets user_id to the cascadeId
return Some(user_id.to_string());
}
}
// Check system prompt for cascade/workspace markers
if let Some(system) = json.get("system") {
let system_str = match system {
Value::String(s) => s.clone(),
Value::Array(arr) => {
// Array of content blocks
arr.iter()
.filter_map(|b| b["text"].as_str())
.collect::<Vec<_>>()
.join(" ")
}
_ => return None,
};
// Look for workspace_id or cascade_id patterns
if let Some(pos) = system_str.find("workspace_id") {
let rest = &system_str[pos..];
// Extract the value after workspace_id
if let Some(val) = rest.split_whitespace().nth(1) {
return Some(val.to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_non_streaming() {
let body = r#"{
"id": "msg_123",
"type": "message",
"model": "claude-sonnet-4-20250514",
"usage": {
"input_tokens": 100,
"output_tokens": 50,
"cache_creation_input_tokens": 10,
"cache_read_input_tokens": 30
},
"stop_reason": "end_turn"
}"#;
let usage = parse_non_streaming_response(body.as_bytes()).unwrap();
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.cache_creation_input_tokens, 10);
assert_eq!(usage.cache_read_input_tokens, 30);
assert_eq!(usage.model.as_deref(), Some("claude-sonnet-4-20250514"));
}
#[test]
fn test_streaming_accumulator() {
let mut acc = StreamingAccumulator::new();
// message_start
let start = serde_json::json!({
"type": "message_start",
"message": {
"model": "claude-sonnet-4-20250514",
"usage": {
"input_tokens": 200,
"cache_creation_input_tokens": 5,
"cache_read_input_tokens": 50
}
}
});
acc.process_event(&start);
assert_eq!(acc.input_tokens, 200);
assert_eq!(acc.cache_read_input_tokens, 50);
// message_delta
let delta = serde_json::json!({
"type": "message_delta",
"delta": { "stop_reason": "end_turn" },
"usage": { "output_tokens": 75 }
});
acc.process_event(&delta);
assert_eq!(acc.output_tokens, 75);
// message_stop
let stop = serde_json::json!({ "type": "message_stop" });
acc.process_event(&stop);
assert!(acc.is_complete);
let usage = acc.into_usage();
assert_eq!(usage.input_tokens, 200);
assert_eq!(usage.output_tokens, 75);
}
}

19
src/mitm/mod.rs Normal file
View File

@@ -0,0 +1,19 @@
//! MITM proxy module: intercepts LS ↔ Google/Anthropic API traffic.
//!
//! The LS (Go binary with BoringCrypto) respects `HTTPS_PROXY` and `SSL_CERT_FILE`.
//! By setting these env vars via the wrapper script, we route all outbound HTTPS
//! traffic through our local MITM proxy, which:
//!
//! 1. Terminates TLS using dynamically-generated per-domain certificates
//! 2. Detects protocol: HTTP/1.1 (REST) or HTTP/2 (gRPC)
//! 3. For HTTP/1.1: parses JSON/SSE responses (Anthropic format)
//! 4. For HTTP/2: decodes gRPC protobuf responses (Google format)
//! 5. Captures token usage data (input, output, thinking, cache)
//! 6. Forwards everything transparently to real upstream servers
pub mod ca;
pub mod h2_handler;
pub mod intercept;
pub mod proto;
pub mod proxy;
pub mod store;

584
src/mitm/proto.rs Normal file
View File

@@ -0,0 +1,584 @@
//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
//!
//! We don't have the .proto schema, so we decode protobuf messages generically
//! and search for usage-like structures by matching field patterns.
//!
//! gRPC wire format:
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
//! - 4 bytes: message length (big-endian u32)
//! - N bytes: protobuf message
//!
//! Protobuf wire format:
//! - Each field: (field_number << 3 | wire_type) as varint, then value
//! - Wire type 0: varint
//! - Wire type 1: 64-bit fixed
//! - Wire type 2: length-delimited (string, bytes, embedded message)
//! - Wire type 5: 32-bit fixed
//!
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
//!
//! ```protobuf
//! message ModelUsageStats {
//! Model model = 1; // enum (varint)
//! uint64 input_tokens = 2;
//! uint64 output_tokens = 3;
//! uint64 cache_write_tokens = 4;
//! uint64 cache_read_tokens = 5;
//! APIProvider api_provider = 6; // enum (varint)
//! string message_id = 7;
//! map<string,string> response_header = 8; // repeated message
//! uint64 thinking_output_tokens = 9;
//! uint64 response_output_tokens = 10;
//! string response_id = 11;
//! }
//! ```
use flate2::read::GzDecoder;
use std::io::Read;
use tracing::{debug, trace, warn};
/// A decoded protobuf field.
#[derive(Debug, Clone)]
pub enum ProtoValue {
Varint(u64),
#[allow(dead_code)]
Fixed64(u64),
#[allow(dead_code)]
Fixed32(u32),
Bytes(Vec<u8>),
/// Nested message (parsed recursively)
Message(Vec<ProtoField>),
}
/// A single protobuf field with its number and value.
#[derive(Debug, Clone)]
pub struct ProtoField {
pub number: u32,
pub value: ProtoValue,
}
/// Extracted usage data from a gRPC response.
#[derive(Debug, Default)]
pub struct GrpcUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub thinking_output_tokens: u64,
pub response_output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_write_tokens: u64,
pub model: Option<String>,
pub api_provider: Option<String>,
pub message_id: Option<String>,
pub response_id: Option<String>,
}
/// Extract gRPC message frames from a buffer.
///
/// A gRPC message is:
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
///
/// Multiple messages can be concatenated in a single buffer.
/// If compressed flag is 1, the message is gzip-decompressed.
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
let mut messages = Vec::new();
let mut offset = 0;
while offset + 5 <= data.len() {
let compressed = data[offset];
let length = u32::from_be_bytes([
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
]) as usize;
offset += 5;
if offset + length > data.len() {
break;
}
let payload = &data[offset..offset + length];
if compressed == 1 {
// gzip-compressed frame
let mut decoder = GzDecoder::new(payload);
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => messages.push(decompressed),
Err(e) => {
warn!(error = %e, "Proto: failed to decompress gRPC frame");
}
}
} else {
messages.push(payload.to_vec());
}
offset += length;
}
messages
}
/// Decode a protobuf message into a list of fields.
///
/// This is a best-effort decoder that handles the common wire types.
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
let mut fields = Vec::new();
let mut offset = 0;
while offset < data.len() {
// Read tag (varint)
let (tag, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let field_number = (tag >> 3) as u32;
let wire_type = (tag & 0x07) as u8;
if field_number == 0 {
break; // invalid
}
let value = match wire_type {
0 => {
// Varint
let (val, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
ProtoValue::Varint(val)
}
1 => {
// 64-bit fixed
if offset + 8 > data.len() {
break;
}
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
offset += 8;
ProtoValue::Fixed64(val)
}
2 => {
// Length-delimited
let (len, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let len = len as usize;
if offset + len > data.len() {
break;
}
let payload = &data[offset..offset + len];
offset += len;
// Try to parse as a nested message
let nested = decode_proto(payload);
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
ProtoValue::Message(nested)
} else {
ProtoValue::Bytes(payload.to_vec())
}
}
5 => {
// 32-bit fixed
if offset + 4 > data.len() {
break;
}
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
offset += 4;
ProtoValue::Fixed32(val)
}
_ => {
// Unknown wire type — stop parsing
break;
}
};
fields.push(ProtoField {
number: field_number,
value,
});
}
fields
}
/// Heuristic: does this list of fields look like a valid protobuf message?
/// (vs. a random string that happened to partially decode)
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
if fields.is_empty() {
return false;
}
// Check that field numbers are reasonable (< 10000)
let valid_numbers = fields.iter().all(|f| f.number < 10000);
if !valid_numbers {
return false;
}
// If we have very few fields relative to the data size, it's probably not a message
// (e.g., a long string that happened to have a valid first-field prefix)
if fields.len() == 1 && original_len > 100 {
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
match &fields[0].value {
ProtoValue::Bytes(_) | ProtoValue::Message(_) => true,
_ => false,
}
} else {
true
}
}
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0;
for (i, &byte) in data.iter().enumerate() {
if i >= 10 {
return None; // Too many bytes for a varint
}
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Some((result, i + 1));
}
}
None
}
/// Search a decoded protobuf message tree for usage-like structures.
///
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
///
/// field 1: model (enum/varint)
/// field 2: input_tokens (uint64)
/// field 3: output_tokens (uint64)
/// field 4: cache_write_tokens (uint64)
/// field 5: cache_read_tokens (uint64)
/// field 6: api_provider (enum/varint)
/// field 7: message_id (string)
/// field 8: response_header (map, repeated message)
/// field 9: thinking_output_tokens (uint64)
/// field 10: response_output_tokens (uint64)
/// field 11: response_id (string)
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
// Strategy: recursively search for any sub-message that looks like usage data
// Try this level first
if let Some(usage) = try_extract_usage(fields) {
return Some(usage);
}
// Recurse into nested messages
for field in fields {
if let ProtoValue::Message(ref nested) = field.value {
if let Some(usage) = extract_usage_from_proto(nested) {
return Some(usage);
}
}
}
None
}
/// Try to extract usage from this specific set of fields.
///
/// Uses verified field numbers from the binary's embedded proto descriptor.
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// We need:
// - At least 2 varint fields with values in token range
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
let varint_fields: Vec<_> = fields
.iter()
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
.collect();
let string_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
if let ProtoValue::Bytes(ref b) = f.value {
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
} else {
None
}
})
.collect();
// Need at least 2 varint fields to be a candidate
if varint_fields.len() < 2 {
return None;
}
// Check if the varint values make sense as token counts
let plausible_token_count = |v: u64| v <= 10_000_000;
let plausible_varints = varint_fields
.iter()
.filter(|f| {
if let ProtoValue::Varint(v) = f.value {
plausible_token_count(v) && v > 0
} else {
false
}
})
.count();
// Need at least 2 non-zero plausible values
if plausible_varints < 2 {
return None;
}
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
// can contain model names, or model enum values map to known names)
let has_model_string = string_fields.iter().any(|(_, s)| {
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|| s.contains("flash") || s.contains("pro")
});
// Check for fields at the known ModelUsageStats field numbers
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
// Strong signal: has both input and output token fields
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
if !is_likely_usage && varint_fields.len() < 3 {
// Without strong signal, need more fields
return None;
}
// Build usage from exact field numbers (verified from binary)
let mut usage = GrpcUsage::default();
for field in fields {
match &field.value {
ProtoValue::Varint(v) => {
let v = *v;
if !plausible_token_count(v) {
continue;
}
match field.number {
// field 1 = model enum (varint, not string!)
2 => usage.input_tokens = v,
3 => usage.output_tokens = v,
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
// field 6 = api_provider enum (varint)
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
10 => usage.response_output_tokens = v, // VERIFIED: field 10
_ => {}
}
}
ProtoValue::Bytes(ref b) => {
if let Ok(s) = std::str::from_utf8(b) {
match field.number {
7 => usage.message_id = Some(s.to_string()),
11 => usage.response_id = Some(s.to_string()),
_ => {}
}
}
}
_ => {}
}
}
// Model and api_provider are enums (varints), not strings
// We can map known enum values later if needed
// For now, extract the enum value as a string representation
for field in fields {
if let ProtoValue::Varint(v) = &field.value {
match field.number {
1 => {
// Model enum — we don't have the mapping, store as number
usage.model = Some(format!("model_enum_{v}"));
}
6 => {
// APIProvider enum
usage.api_provider = Some(match *v {
0 => "unknown".to_string(),
1 => "google".to_string(),
2 => "anthropic".to_string(),
_ => format!("provider_{v}"),
});
}
_ => {}
}
}
}
// Validate — we should have at least input OR output tokens
if usage.input_tokens == 0 && usage.output_tokens == 0 {
return None;
}
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
cache_read = usage.cache_read_tokens,
cache_write = usage.cache_write_tokens,
model = ?usage.model,
api_provider = ?usage.api_provider,
"Proto: extracted ModelUsageStats from protobuf"
);
Some(usage)
}
/// Parse a gRPC response body (may contain multiple messages) for usage data.
///
/// Handles both compressed and uncompressed gRPC frames.
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
let messages = extract_grpc_messages(body);
trace!(count = messages.len(), "Proto: extracted gRPC messages");
// Check each message for usage data (last message usually has it)
for msg in messages.iter().rev() {
let fields = decode_proto(msg);
if let Some(usage) = extract_usage_from_proto(&fields) {
return Some(usage);
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_varint() {
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
}
#[test]
fn test_extract_grpc_messages_uncompressed() {
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
let mut buf = vec![0u8]; // not compressed
buf.extend_from_slice(&5u32.to_be_bytes());
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].len(), 5);
}
#[test]
fn test_extract_grpc_messages_compressed() {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
// Create a payload
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
// Compress it
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&payload).unwrap();
let compressed = encoder.finish().unwrap();
// Build gRPC frame with compressed flag
let mut buf = vec![1u8]; // compressed
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
buf.extend_from_slice(&compressed);
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0], payload);
}
#[test]
fn test_decode_proto_varints() {
// field 1 = 150, field 2 = 66
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
let fields = decode_proto(&data);
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].number, 1);
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
assert_eq!(fields[1].number, 2);
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
}
#[test]
fn test_decode_proto_with_string() {
// field 1 = "hello" (string), field 2 = varint 42
let mut data = Vec::new();
// field 1, wire type 2 (length-delimited)
data.push(0x0A); // (1 << 3) | 2
data.push(0x05); // length 5
data.extend_from_slice(b"hello");
// field 2, wire type 0 (varint)
data.push(0x10); // (2 << 3) | 0
data.push(0x2A); // 42
let fields = decode_proto(&data);
assert!(fields.len() >= 2);
assert_eq!(fields[0].number, 1);
}
#[test]
fn test_extract_usage_correct_field_numbers() {
// Build a mock ModelUsageStats with the correct field numbers:
// field 1 (model enum) = 5 (some model)
// field 2 (input_tokens) = 1000
// field 3 (output_tokens) = 500
// field 4 (cache_write_tokens) = 100
// field 5 (cache_read_tokens) = 200
// field 9 (thinking_output_tokens) = 300
// field 10 (response_output_tokens) = 200
let mut data = Vec::new();
// Helper: encode varint field
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
// Tag
let tag = (field_num << 3) | 0; // wire type 0
let mut t = tag;
while t >= 0x80 {
data.push((t as u8) | 0x80);
t >>= 7;
}
data.push(t as u8);
// Value
let mut v = value;
while v >= 0x80 {
data.push((v as u8) | 0x80);
v >>= 7;
}
data.push(v as u8);
}
encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens
let fields = decode_proto(&data);
let usage = try_extract_usage(&fields).expect("should extract usage");
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_write_tokens, 100);
assert_eq!(usage.cache_read_tokens, 200);
assert_eq!(usage.thinking_output_tokens, 300);
assert_eq!(usage.response_output_tokens, 200);
}
}

591
src/mitm/proxy.rs Normal file
View File

@@ -0,0 +1,591 @@
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
//!
//! Listens on a local port for HTTP CONNECT requests from the LS.
//! For intercepted domains, it terminates TLS with our CA-signed cert,
//! reads/modifies the request, forwards to the real upstream, and captures
//! the response (especially usage data).
//!
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::TlsAcceptor;
use tracing::{debug, error, info, trace, warn};
/// Domains we intercept (terminate TLS and inspect traffic).
/// This includes exact matches and suffix matches for regional endpoints
/// (e.g., us-central1-aiplatform.googleapis.com).
const INTERCEPT_DOMAINS: &[&str] = &[
"cloudcode-pa.googleapis.com",
"aiplatform.googleapis.com",
"api.anthropic.com",
"speech.googleapis.com",
"modelarmor.googleapis.com",
];
/// Domains we NEVER intercept (transparent tunnel).
const PASSTHROUGH_DOMAINS: &[&str] = &[
"oauth2.googleapis.com",
"accounts.google.com",
"storage.googleapis.com",
"www.googleapis.com",
"firebaseinstallations.googleapis.com",
"crashlyticsreports-pa.googleapis.com",
"play.googleapis.com",
"update.googleapis.com",
"dl.google.com",
];
/// Configuration for the MITM proxy.
pub struct MitmConfig {
/// Port to listen on (0 = auto-assign).
pub port: u16,
/// Whether to enable request modification.
pub modify_requests: bool,
}
impl Default for MitmConfig {
fn default() -> Self {
Self {
port: 0,
modify_requests: false,
}
}
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
pub async fn run(
ca: Arc<MitmCa>,
store: MitmStore,
config: MitmConfig,
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
.await
.map_err(|e| format!("MITM bind failed: {e}"))?;
let port = listener
.local_addr()
.map_err(|e| format!("MITM local_addr failed: {e}"))?
.port();
info!(port, "MITM proxy listening");
let modify_requests = config.modify_requests;
let handle = tokio::spawn(async move {
loop {
match listener.accept().await {
Ok((stream, addr)) => {
trace!(?addr, "MITM: new connection");
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
debug!(error = %e, "MITM connection error");
}
});
}
Err(e) => {
error!(error = %e, "MITM accept error");
}
}
}
});
Ok((port, handle))
}
/// Handle a single incoming connection from the LS.
///
/// The LS sends an HTTP CONNECT request to establish a tunnel.
/// We then decide whether to intercept or passthrough.
async fn handle_connection(
mut stream: TcpStream,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
// Read the CONNECT request
let mut buf = vec![0u8; 8192];
let n = stream
.read(&mut buf)
.await
.map_err(|e| format!("Read CONNECT: {e}"))?;
if n == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buf[..n]);
let first_line = request.lines().next().unwrap_or("");
// Parse "CONNECT host:port HTTP/1.1"
let parts: Vec<&str> = first_line.split_whitespace().collect();
if parts.len() < 3 || parts[0] != "CONNECT" {
// Not a CONNECT request — return 400
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await;
return Ok(());
}
let host_port = parts[1];
let (domain, _port) = match host_port.rsplit_once(':') {
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
None => (host_port, 443),
};
debug!(domain, "MITM: CONNECT request");
// Decide: intercept or passthrough
let should_intercept = should_intercept_domain(domain);
// Send 200 Connection Established
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
stream
.write_all(response.as_bytes())
.await
.map_err(|e| format!("Write 200: {e}"))?;
if should_intercept {
handle_intercepted(stream, domain, ca, store, modify_requests).await
} else {
handle_passthrough(stream, domain, _port).await
}
}
/// Check if a domain should be intercepted.
fn should_intercept_domain(domain: &str) -> bool {
// Never intercept passthrough domains
for &pt in PASSTHROUGH_DOMAINS {
if domain == pt {
return false;
}
}
// Intercept known API domains (exact match, subdomain, or regional prefix)
for &intercept in INTERCEPT_DOMAINS {
if domain == intercept
|| domain.ends_with(&format!(".{intercept}"))
|| domain.ends_with(&format!("-{intercept}"))
{
return true;
}
}
// Default: passthrough
false
}
/// Handle an intercepted connection: terminate TLS, inspect traffic.
///
/// After TLS termination, checks the negotiated ALPN protocol:
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
async fn handle_intercepted(
stream: TcpStream,
domain: &str,
ca: Arc<MitmCa>,
store: MitmStore,
modify_requests: bool,
) -> Result<(), String> {
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca
.server_config_for_domain(domain)
.await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS)
let tls_stream = acceptor
.accept(stream)
.await
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
match alpn.as_deref() {
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(
tls_stream,
domain.to_string(),
store,
)
.await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
debug!(domain, "MITM: routing to HTTP/1.1 handler");
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
}
}
}
/// Handle HTTP traffic over the decrypted TLS connection.
///
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
/// Reads full request, connects to upstream, forwards request, streams response
/// back to client while capturing usage data.
async fn handle_http_over_tls(
mut client: tokio_rustls::server::TlsStream<TcpStream>,
domain: &str,
store: MitmStore,
_modify_requests: bool,
) -> Result<(), String> {
let mut tmp = vec![0u8; 32768];
// Build upstream TLS connector once for this connection
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let upstream_config = Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
// Reusable upstream connection — created lazily, reconnected if stale
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
let tcp = TcpStream::connect(format!("{domain}:443"))
.await
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
connector
.connect(server_name, tcp)
.await
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
}
// Keep-alive loop: handle multiple requests on this connection
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
loop {
let n = match client.read(&mut tmp).await {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
// Connection reset / broken pipe is normal for keep-alive end
debug!(domain, error = %e, "MITM: client read finished");
return Ok(());
}
};
request_buf.extend_from_slice(&tmp[..n]);
// Check if we have the full request (headers + body)
if has_complete_http_request(&request_buf) {
break;
}
}
if request_buf.is_empty() {
return Ok(());
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, is_streaming_request) = parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
extract_cascade_hint(&request_buf[headers_end..])
} else {
None
};
debug!(
domain,
content_length,
streaming = is_streaming_request,
cascade = ?cascade_hint,
"MITM: forwarding request to upstream"
);
// ── Ensure upstream connection is alive ──────────────────────────────
// Lazily connect on first request, or reconnect if the previous connection died
let conn = match upstream.as_mut() {
Some(c) => c,
None => {
let c = connect_upstream(domain, &upstream_config).await?;
upstream.insert(c)
}
};
// Forward the request — if write fails, reconnect and retry once
if let Err(e) = conn.write_all(&request_buf).await {
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
let c = connect_upstream(domain, &upstream_config).await?;
let conn = upstream.insert(c);
conn.write_all(&request_buf)
.await
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
}
let conn = upstream.as_mut().unwrap();
// ── Stream response back to client ──────────────────────────────────
let mut streaming_acc = StreamingAccumulator::new();
let mut is_streaming_response = false;
let mut headers_parsed = false;
// Only buffer response body for non-streaming (for usage parsing)
let mut non_streaming_buf: Option<Vec<u8>> = None;
// Track if upstream connection is still usable after this response
let mut upstream_ok = true;
// Per-request timeout: 5 minutes (covers large context API calls)
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
loop {
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
Ok(Ok(0)) => {
// Upstream closed — connection is no longer reusable
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read finished");
upstream_ok = false;
break;
}
Err(_) => {
warn!(domain, "MITM: upstream read timed out after 5 minutes");
upstream_ok = false;
break;
}
};
let chunk = &tmp[..n];
// Check response headers for content-type
if !headers_parsed {
// We need to buffer until we see the end of headers
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
buf.extend_from_slice(chunk);
if let Some(_hdr_end) = find_headers_end(buf) {
// Use httparse for response header parsing
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
let mut resp = httparse::Response::new(&mut resp_headers);
let hdr_end = match resp.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => _hdr_end, // Fallback to manual detection
};
// Detect content type and connection handling from parsed headers
for header in resp.headers.iter() {
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
}
headers_parsed = true;
if is_streaming_response {
// For streaming, parse any SSE data already in the buffer
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
if !body_so_far.is_empty() {
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
}
// Forward the accumulated buffer to client
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
break;
}
non_streaming_buf = None;
continue;
}
// Non-streaming: keep buffering the response body for parsing
continue;
}
continue;
}
// If streaming, parse SSE events and forward immediately
if is_streaming_response {
let chunk_str = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
if let Err(e) = client.write_all(chunk).await {
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
break;
}
} else {
// Non-streaming: keep accumulating to parse usage at the end
if let Some(ref mut buf) = non_streaming_buf {
buf.extend_from_slice(chunk);
}
}
}
// Forward non-streaming response all at once
if !is_streaming_response {
if let Some(ref buf) = non_streaming_buf {
if let Err(e) = client.write_all(buf).await {
warn!(error = %e, "MITM: write to client failed");
}
}
}
// Capture usage data
if is_streaming_response {
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
let usage = streaming_acc.into_usage();
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if let Some(ref buf) = non_streaming_buf {
if let Some(body_start) = find_headers_end(buf) {
let body = &buf[body_start..];
if let Some(usage) = parse_non_streaming_response(body) {
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
// If upstream closed, drop the connection so next iteration reconnects
if !upstream_ok {
upstream = None;
}
} // end keep-alive loop
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
.await
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
}
}
Ok(())
}
/// Check if buffer contains a complete HTTP request (headers + full body).
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
fn has_complete_http_request(buf: &[u8]) -> bool {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => return false, // Incomplete or parse error — need more data
};
// Look for Content-Length
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
if let Ok(len) = val.trim().parse::<usize>() {
return buf.len() >= headers_end + len;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(val) = std::str::from_utf8(header.value) {
if val.trim().eq_ignore_ascii_case("chunked") {
let body = &buf[headers_end..];
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
}
}
}
}
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
true
}
/// Find the end of HTTP headers (position after \r\n\r\n).
fn find_headers_end(buf: &[u8]) -> Option<usize> {
buf.windows(4)
.position(|w| w == b"\r\n\r\n")
.map(|pos| pos + 4)
}
/// Parse HTTP request metadata from raw bytes using `httparse`.
/// Returns (headers_end, content_length, is_streaming_request).
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut req = httparse::Request::new(&mut headers);
let headers_end = match req.parse(buf) {
Ok(httparse::Status::Complete(n)) => n,
_ => {
// Fallback if httparse can't parse
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
}
};
let mut content_length = 0usize;
for header in req.headers.iter() {
if header.name.eq_ignore_ascii_case("content-length") {
if let Ok(val) = std::str::from_utf8(header.value) {
content_length = val.trim().parse().unwrap_or(0);
}
}
}
// Check if request body asks for streaming
let is_streaming = if headers_end < buf.len() {
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
} else {
false
};
(headers_end, content_length, is_streaming)
}

163
src/mitm/store.rs Normal file
View File

@@ -0,0 +1,163 @@
//! Shared store for intercepted API usage data.
//!
//! The MITM proxy writes usage data here; the API handlers read from it.
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use tracing::debug;
/// Token usage from an intercepted API response.
///
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ApiUsage {
pub input_tokens: u64,
pub output_tokens: u64,
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
pub cache_creation_input_tokens: u64,
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
pub cache_read_input_tokens: u64,
/// Google-specific: thinking/reasoning output tokens (extended thinking)
pub thinking_output_tokens: u64,
/// Google-specific: response output tokens (non-thinking portion)
pub response_output_tokens: u64,
/// Total cost in USD (if provided by the API).
pub total_cost_usd: Option<f64>,
/// The actual model that served the request.
pub model: Option<String>,
/// Stop reason / finish reason from the API.
pub stop_reason: Option<String>,
/// API provider (e.g. "anthropic", "google")
pub api_provider: Option<String>,
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
pub grpc_method: Option<String>,
/// Timestamp when this usage was captured.
pub captured_at: u64,
}
/// Thread-safe store for intercepted data.
///
/// Keyed by a unique request ID that we can correlate with cascade operations.
/// In practice, we use the cascade ID + a sequence number.
#[derive(Clone)]
pub struct MitmStore {
/// Most recent usage per cascade ID.
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
/// Global aggregate stats.
stats: Arc<RwLock<MitmStats>>,
}
/// Aggregate statistics across all intercepted traffic.
#[derive(Debug, Clone, Default, Serialize)]
pub struct MitmStats {
pub total_requests: u64,
pub total_input_tokens: u64,
pub total_output_tokens: u64,
pub total_cache_read_tokens: u64,
pub total_cache_creation_tokens: u64,
pub total_thinking_output_tokens: u64,
pub total_response_output_tokens: u64,
/// Per-model usage breakdown (model name → stats).
pub per_model: HashMap<String, ModelStats>,
}
/// Per-model usage counters.
#[derive(Debug, Clone, Default, Serialize)]
pub struct ModelStats {
pub requests: u64,
pub input_tokens: u64,
pub output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_creation_tokens: u64,
}
impl MitmStore {
pub fn new() -> Self {
Self {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
}
}
/// Record a completed API exchange with usage data.
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
cache_read = usage.cache_read_input_tokens,
cache_create = usage.cache_creation_input_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
model = ?usage.model,
provider = ?usage.api_provider,
grpc = ?usage.grpc_method,
"MITM captured API usage"
);
// Update aggregate stats
{
let mut stats = self.stats.write().await;
stats.total_requests += 1;
stats.total_input_tokens += usage.input_tokens;
stats.total_output_tokens += usage.output_tokens;
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
stats.total_response_output_tokens += usage.response_output_tokens;
// Per-model breakdown
if let Some(ref model_name) = usage.model {
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
model_stats.requests += 1;
model_stats.input_tokens += usage.input_tokens;
model_stats.output_tokens += usage.output_tokens;
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
}
}
// Store latest usage for the cascade (if we can identify it)
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
let mut latest = self.latest_usage.write().await;
latest.insert(key, usage);
// Evict old entries to prevent unbounded memory growth
const MAX_ENTRIES: usize = 500;
if latest.len() > MAX_ENTRIES {
// Find the oldest entry by captured_at and remove it
let oldest_key = latest
.iter()
.min_by_key(|(_, v)| v.captured_at)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
latest.remove(&key);
}
}
}
/// Get the latest usage for a cascade, consuming it (one-shot read).
///
/// Only returns exact cascade_id matches — no cross-cascade fallback.
/// The `_latest` key is only consumed when the caller explicitly requests it
/// (i.e., when the MITM couldn't identify the cascade).
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let mut latest = self.latest_usage.write().await;
latest.remove(cascade_id)
}
/// Peek at the latest usage without consuming it.
#[allow(dead_code)]
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let latest = self.latest_usage.read().await;
latest.get(cascade_id)
.cloned()
}
/// Get aggregate stats.
pub async fn stats(&self) -> MitmStats {
self.stats.read().await.clone()
}
}