feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling
This commit is contained in:
218
src/mitm/ca.rs
Normal file
218
src/mitm/ca.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Certificate Authority for MITM proxy.
|
||||
//!
|
||||
//! Generates a self-signed root CA at first run and caches it to disk.
|
||||
//! Dynamically generates per-domain leaf certificates signed by this CA.
|
||||
|
||||
use rcgen::{
|
||||
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
|
||||
IsCa, KeyPair, KeyUsagePurpose, SanType,
|
||||
};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
/// MITM Certificate Authority.
|
||||
pub struct MitmCa {
|
||||
/// Root CA certificate (DER-encoded for rustls).
|
||||
ca_cert_der: CertificateDer<'static>,
|
||||
/// Root CA private key.
|
||||
ca_key: KeyPair,
|
||||
/// Signed root CA cert (needed by rcgen to sign leaf certs).
|
||||
ca_signed: rcgen::Certificate,
|
||||
/// Cache of per-domain TLS configs.
|
||||
domain_cache: Arc<RwLock<HashMap<String, Arc<rustls::ServerConfig>>>>,
|
||||
/// Path to the CA PEM file (for SSL_CERT_FILE combined bundle).
|
||||
pub ca_pem_path: PathBuf,
|
||||
}
|
||||
|
||||
impl MitmCa {
|
||||
/// Load or generate the MITM CA.
|
||||
///
|
||||
/// The CA cert/key are stored at:
|
||||
/// `<data_dir>/mitm-ca.pem` (cert, for NODE_EXTRA_CA_CERTS)
|
||||
/// `<data_dir>/mitm-ca.key` (private key)
|
||||
pub fn load_or_generate(data_dir: &Path) -> Result<Self, String> {
|
||||
let cert_path = data_dir.join("mitm-ca.pem");
|
||||
let key_path = data_dir.join("mitm-ca.key");
|
||||
|
||||
if cert_path.exists() && key_path.exists() {
|
||||
info!("Loading existing MITM CA from {}", cert_path.display());
|
||||
let cert_pem = std::fs::read_to_string(&cert_path)
|
||||
.map_err(|e| format!("Failed to read CA cert: {e}"))?;
|
||||
let key_pem = std::fs::read_to_string(&key_path)
|
||||
.map_err(|e| format!("Failed to read CA key: {e}"))?;
|
||||
|
||||
let ca_key = KeyPair::from_pem(&key_pem)
|
||||
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
||||
|
||||
// Re-create params and self-sign to get the rcgen Certificate object
|
||||
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
|
||||
// The re-signed cert will have a different serial/notBefore, but that's fine
|
||||
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
|
||||
let params = Self::ca_params();
|
||||
let ca_signed = params.self_signed(&ca_key)
|
||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||
|
||||
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
|
||||
// (via the combined CA bundle built by the wrapper script). Writing the
|
||||
// re-signed cert back would desync the LS's trust anchor.
|
||||
let ca_cert_der = Self::pem_to_der(&cert_pem)
|
||||
.unwrap_or_else(|| CertificateDer::from(ca_signed.der().to_vec()));
|
||||
|
||||
Ok(Self {
|
||||
ca_cert_der,
|
||||
ca_key,
|
||||
ca_signed,
|
||||
domain_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
ca_pem_path: cert_path,
|
||||
})
|
||||
} else {
|
||||
info!("Generating new MITM CA at {}", cert_path.display());
|
||||
|
||||
// Ensure data dir exists
|
||||
std::fs::create_dir_all(data_dir)
|
||||
.map_err(|e| format!("Failed to create data dir: {e}"))?;
|
||||
|
||||
let ca_key = KeyPair::generate()
|
||||
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
||||
|
||||
let params = Self::ca_params();
|
||||
let ca_signed = params.self_signed(&ca_key)
|
||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||
|
||||
// Write cert and key to disk
|
||||
std::fs::write(&cert_path, ca_signed.pem())
|
||||
.map_err(|e| format!("Failed to write CA cert: {e}"))?;
|
||||
std::fs::write(&key_path, ca_key.serialize_pem())
|
||||
.map_err(|e| format!("Failed to write CA key: {e}"))?;
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(&key_path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
let ca_cert_der = CertificateDer::from(ca_signed.der().to_vec());
|
||||
|
||||
Ok(Self {
|
||||
ca_cert_der,
|
||||
ca_key,
|
||||
ca_signed,
|
||||
domain_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
ca_pem_path: cert_path,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the CA certificate parameters (reusable for both generate and load).
|
||||
fn ca_params() -> CertificateParams {
|
||||
let mut params = CertificateParams::default();
|
||||
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, "Antigravity MITM CA");
|
||||
dn.push(DnType::OrganizationName, "Antigravity Proxy");
|
||||
params.distinguished_name = dn;
|
||||
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::KeyCertSign,
|
||||
KeyUsagePurpose::CrlSign,
|
||||
];
|
||||
|
||||
// Valid for 10 years
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
params.not_before = now;
|
||||
params.not_after = now + time::Duration::days(3650);
|
||||
|
||||
params
|
||||
}
|
||||
|
||||
/// Parse a PEM certificate into a DER-encoded CertificateDer.
|
||||
fn pem_to_der(pem: &str) -> Option<CertificateDer<'static>> {
|
||||
// Extract base64 content between BEGIN/END markers
|
||||
let mut in_cert = false;
|
||||
let mut b64 = String::new();
|
||||
for line in pem.lines() {
|
||||
if line.contains("BEGIN CERTIFICATE") {
|
||||
in_cert = true;
|
||||
continue;
|
||||
}
|
||||
if line.contains("END CERTIFICATE") {
|
||||
break;
|
||||
}
|
||||
if in_cert {
|
||||
b64.push_str(line.trim());
|
||||
}
|
||||
}
|
||||
if b64.is_empty() {
|
||||
return None;
|
||||
}
|
||||
use base64::Engine;
|
||||
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
|
||||
Some(CertificateDer::from(der))
|
||||
}
|
||||
|
||||
/// Get or create a TLS ServerConfig for the given domain.
|
||||
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.domain_cache.read().await;
|
||||
if let Some(config) = cache.get(domain) {
|
||||
return Ok(config.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate leaf cert for this domain
|
||||
let mut params = CertificateParams::default();
|
||||
|
||||
let mut dn = DistinguishedName::new();
|
||||
dn.push(DnType::CommonName, domain);
|
||||
params.distinguished_name = dn;
|
||||
|
||||
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
KeyUsagePurpose::KeyEncipherment,
|
||||
];
|
||||
|
||||
// Valid for 1 year
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
params.not_before = now;
|
||||
params.not_after = now + time::Duration::days(365);
|
||||
|
||||
let leaf_key = KeyPair::generate()
|
||||
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
||||
|
||||
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
||||
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
|
||||
|
||||
// Build rustls ServerConfig
|
||||
let leaf_cert_der = CertificateDer::from(leaf_cert.der().to_vec());
|
||||
let leaf_key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(leaf_key.serialize_der()));
|
||||
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
vec![leaf_cert_der, self.ca_cert_der.clone()],
|
||||
leaf_key_der,
|
||||
)
|
||||
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
|
||||
|
||||
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
|
||||
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
|
||||
let config = Arc::new(config);
|
||||
|
||||
// Cache it
|
||||
{
|
||||
let mut cache = self.domain_cache.write().await;
|
||||
cache.insert(domain.to_string(), config.clone());
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
512
src/mitm/h2_handler.rs
Normal file
512
src/mitm/h2_handler.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
//! HTTP/2 handler for gRPC traffic interception.
|
||||
//!
|
||||
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
|
||||
//! this module handles the bidirectional HTTP/2 connection:
|
||||
//! 1. Accepts HTTP/2 frames from the client (LS)
|
||||
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
|
||||
//! 3. Forwards each request stream to upstream
|
||||
//! 4. For non-streaming: buffers response, extracts usage, forwards
|
||||
//! 5. For streaming: forwards response body chunks in real-time, tees to a
|
||||
//! side buffer for usage extraction after stream completes
|
||||
//!
|
||||
//! ## Streaming vs Non-streaming
|
||||
//!
|
||||
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
|
||||
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
|
||||
//! We MUST forward streaming responses immediately — buffering would break
|
||||
//! the LS's perception of real-time generation.
|
||||
//!
|
||||
//! For usage extraction: ModelUsageStats is typically in the LAST message
|
||||
//! of a streaming response, so we tee the data and parse after stream ends.
|
||||
|
||||
use crate::mitm::proto::parse_grpc_response_for_usage;
|
||||
use crate::mitm::store::{ApiUsage, MitmStore};
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::{Frame, Incoming};
|
||||
use hyper::server::conn::http2::Builder as H2ServerBuilder;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Request, Response};
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
|
||||
///
|
||||
/// gRPC multiplexes many requests over a single HTTP/2 connection.
|
||||
/// We mirror this by maintaining a single upstream connection per domain.
|
||||
struct UpstreamPool {
|
||||
domain: String,
|
||||
tls_config: Arc<rustls::ClientConfig>,
|
||||
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
|
||||
}
|
||||
|
||||
impl UpstreamPool {
|
||||
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
|
||||
Self {
|
||||
domain,
|
||||
tls_config,
|
||||
sender: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create the upstream HTTP/2 sender.
|
||||
async fn get_sender(
|
||||
&self,
|
||||
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||
let mut guard = self.sender.lock().await;
|
||||
|
||||
// Check if existing sender is still usable
|
||||
if let Some(ref sender) = *guard {
|
||||
if !sender.is_closed() {
|
||||
return Ok(sender.clone());
|
||||
}
|
||||
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
let sender = self.connect().await?;
|
||||
*guard = Some(sender.clone());
|
||||
Ok(sender)
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
&self,
|
||||
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
|
||||
.await
|
||||
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
|
||||
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
|
||||
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
|
||||
|
||||
let upstream_tls = connector
|
||||
.connect(server_name, upstream_tcp)
|
||||
.await
|
||||
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
||||
|
||||
let upstream_io = TokioIo::new(upstream_tls);
|
||||
let (sender, conn) =
|
||||
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.handshake(upstream_io)
|
||||
.await
|
||||
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||
|
||||
let domain = self.domain.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
|
||||
}
|
||||
});
|
||||
|
||||
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
|
||||
Ok(sender)
|
||||
}
|
||||
}
|
||||
|
||||
/// gRPC methods that carry ModelUsageStats in their responses.
|
||||
const USAGE_METHODS: &[&str] = &[
|
||||
// Unary methods
|
||||
"GenerateContent",
|
||||
"AsyncGenerateContent",
|
||||
"GenerateChat",
|
||||
"GenerateCode",
|
||||
"CompleteCode",
|
||||
"InternalAtomicAgenticChat",
|
||||
"Predict",
|
||||
"DirectPredict",
|
||||
// Streaming methods
|
||||
"StreamGenerateContent",
|
||||
"StreamAsyncGenerateContent",
|
||||
"StreamGenerateChat",
|
||||
];
|
||||
|
||||
/// Handle an HTTP/2 connection from the LS after TLS termination.
|
||||
///
|
||||
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
|
||||
/// HTTP/2 connection to forward them.
|
||||
pub async fn handle_h2_connection<S>(
|
||||
tls_stream: S,
|
||||
domain: String,
|
||||
store: MitmStore,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
|
||||
|
||||
// Build TLS config for upstream connections
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
let native_certs = rustls_native_certs::load_native_certs();
|
||||
for cert in native_certs.certs {
|
||||
let _ = root_store.add(cert);
|
||||
}
|
||||
let mut upstream_tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
|
||||
|
||||
// Shared upstream connection pool (single connection, multiplexed)
|
||||
let pool = Arc::new(UpstreamPool::new(
|
||||
domain.clone(),
|
||||
Arc::new(upstream_tls_config),
|
||||
));
|
||||
|
||||
let io = TokioIo::new(tls_stream);
|
||||
let domain = Arc::new(domain);
|
||||
|
||||
let result = H2ServerBuilder::new(TokioExecutor::new())
|
||||
.serve_connection(
|
||||
io,
|
||||
service_fn(move |req: Request<Incoming>| {
|
||||
let domain = domain.clone();
|
||||
let store = store.clone();
|
||||
let pool = pool.clone();
|
||||
async move { handle_h2_request(req, &domain, store, pool).await }
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
debug!("MITM H2: connection closed cleanly");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// Connection errors are expected on clean close
|
||||
debug!(error = %e, "MITM H2: connection ended");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response body type — either buffered or streaming.
|
||||
type BoxBody = http_body_util::Either<
|
||||
Full<Bytes>,
|
||||
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
|
||||
>;
|
||||
|
||||
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
|
||||
///
|
||||
/// For streaming responses, forwards chunks in real-time while teeing
|
||||
/// data to a side buffer for post-stream usage extraction.
|
||||
async fn handle_h2_request(
|
||||
req: Request<Incoming>,
|
||||
domain: &str,
|
||||
store: MitmStore,
|
||||
pool: Arc<UpstreamPool>,
|
||||
) -> Result<Response<BoxBody>, hyper::Error> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let path = uri.path().to_string();
|
||||
|
||||
// Identify gRPC method
|
||||
let is_grpc = req
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|ct| ct.starts_with("application/grpc"))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Check if this method carries usage data
|
||||
let is_usage_method = is_grpc
|
||||
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||
|
||||
// Check if this is a streaming method
|
||||
let is_streaming = is_grpc
|
||||
&& (path.contains("Stream") || path.contains("stream"));
|
||||
|
||||
debug!(
|
||||
domain,
|
||||
%method,
|
||||
path = %path,
|
||||
grpc = is_grpc,
|
||||
usage_method = is_usage_method,
|
||||
streaming = is_streaming,
|
||||
"MITM H2: forwarding request"
|
||||
);
|
||||
|
||||
// Collect request body (we need it for cascade ID extraction)
|
||||
let (parts, body) = req.into_parts();
|
||||
let request_body = match body.collect().await {
|
||||
Ok(collected) => collected.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "MITM H2: failed to collect request body");
|
||||
Bytes::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Get upstream sender from pool
|
||||
let mut upstream_sender = match pool.get_sender().await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream connect failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
// Build the upstream request with proper authority
|
||||
let upstream_uri = http::Uri::builder()
|
||||
.scheme("https")
|
||||
.authority(domain)
|
||||
.path_and_query(
|
||||
uri.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/"),
|
||||
)
|
||||
.build()
|
||||
.unwrap_or(uri);
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_uri);
|
||||
|
||||
// Copy headers, skip hop-by-hop
|
||||
for (name, value) in &parts.headers {
|
||||
let n = name.as_str();
|
||||
if n == "host" || n == "connection" || n == "transfer-encoding" {
|
||||
continue;
|
||||
}
|
||||
upstream_req = upstream_req.header(name, value);
|
||||
}
|
||||
|
||||
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("build request failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
// Send to upstream
|
||||
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream request failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
let (resp_parts, resp_body) = upstream_resp.into_parts();
|
||||
let status = resp_parts.status;
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
// Streaming path: forward chunks immediately, tee for usage parsing
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
if is_streaming && status.is_success() {
|
||||
let should_track_usage = is_usage_method;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
|
||||
|
||||
let store_clone = store.clone();
|
||||
let path_clone = path.clone();
|
||||
let request_body_clone = request_body.clone();
|
||||
|
||||
// Spawn a task to forward body chunks and tee for usage extraction
|
||||
tokio::spawn(async move {
|
||||
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
||||
let mut body = resp_body;
|
||||
|
||||
loop {
|
||||
match body.frame().await {
|
||||
Some(Ok(frame)) => {
|
||||
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
||||
buf.extend_from_slice(data);
|
||||
}
|
||||
if tx.send(Ok(frame)).await.is_err() {
|
||||
break; // client disconnected
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
|
||||
let _ = tx.send(Err(e)).await;
|
||||
break;
|
||||
}
|
||||
None => break, // stream ended
|
||||
}
|
||||
}
|
||||
|
||||
// Stream completed — parse the tee buffer for usage
|
||||
if let Some(tee_buffer) = tee_buffer {
|
||||
if !tee_buffer.is_empty() {
|
||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
||||
let usage = ApiUsage {
|
||||
input_tokens: grpc_usage.input_tokens,
|
||||
output_tokens: grpc_usage.output_tokens,
|
||||
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||
response_output_tokens: grpc_usage.response_output_tokens,
|
||||
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||
model: grpc_usage.model,
|
||||
api_provider: grpc_usage.api_provider,
|
||||
grpc_method: Some(path_clone.clone()),
|
||||
stop_reason: None,
|
||||
total_cost_usd: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
||||
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
let stream_body = StreamBody::new(stream);
|
||||
|
||||
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||
for (name, value) in &resp_parts.headers {
|
||||
client_resp = client_resp.header(name, value);
|
||||
}
|
||||
|
||||
let client_resp = client_resp
|
||||
.body(http_body_util::Either::Right(stream_body))
|
||||
.unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(500)
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
"internal error",
|
||||
))))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
return Ok(client_resp);
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
// Non-streaming path: buffer full response, extract usage, forward
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
let response_body = match resp_body.collect().await {
|
||||
Ok(collected) => collected.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "MITM H2: failed to collect response body");
|
||||
Bytes::new()
|
||||
}
|
||||
};
|
||||
|
||||
trace!(
|
||||
domain,
|
||||
path = %path,
|
||||
status = %status,
|
||||
body_len = response_body.len(),
|
||||
"MITM H2: got upstream response"
|
||||
);
|
||||
|
||||
// Extract usage data from usage-carrying gRPC methods
|
||||
if is_usage_method && !response_body.is_empty() && status.is_success() {
|
||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
|
||||
let usage = ApiUsage {
|
||||
input_tokens: grpc_usage.input_tokens,
|
||||
output_tokens: grpc_usage.output_tokens,
|
||||
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||
response_output_tokens: grpc_usage.response_output_tokens,
|
||||
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||
model: grpc_usage.model,
|
||||
api_provider: grpc_usage.api_provider,
|
||||
grpc_method: Some(path.clone()),
|
||||
stop_reason: None,
|
||||
total_cost_usd: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Build response for the client
|
||||
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||
for (name, value) in &resp_parts.headers {
|
||||
client_resp = client_resp.header(name, value);
|
||||
}
|
||||
|
||||
let client_resp = client_resp
|
||||
.body(http_body_util::Either::Left(Full::new(response_body)))
|
||||
.unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(500)
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
"internal error",
|
||||
))))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
Ok(client_resp)
|
||||
}
|
||||
|
||||
/// Try to extract a cascade ID from a gRPC request body.
|
||||
///
|
||||
/// Looks for UUID-formatted strings in the protobuf fields.
|
||||
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
|
||||
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
|
||||
|
||||
let messages = extract_grpc_messages(body);
|
||||
for msg in &messages {
|
||||
let fields = decode_proto(msg);
|
||||
for field in &fields {
|
||||
if let Some(id) = extract_uuid_from_field(field) {
|
||||
return Some(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
|
||||
use crate::mitm::proto::ProtoValue;
|
||||
|
||||
match &field.value {
|
||||
ProtoValue::Bytes(b) => {
|
||||
if let Ok(s) = std::str::from_utf8(b) {
|
||||
if is_uuid(s) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
ProtoValue::Message(nested) => {
|
||||
for nf in nested {
|
||||
if let Some(id) = extract_uuid_from_field(nf) {
|
||||
return Some(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_uuid(s: &str) -> bool {
|
||||
s.len() == 36
|
||||
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||
&& s.chars().filter(|&c| c == '-').count() == 4
|
||||
}
|
||||
271
src/mitm/intercept.rs
Normal file
271
src/mitm/intercept.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
//! API response interceptor: parses Anthropic/Google API responses to extract usage data.
|
||||
//!
|
||||
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
||||
|
||||
use super::store::ApiUsage;
|
||||
use serde_json::Value;
|
||||
use tracing::{debug, trace};
|
||||
|
||||
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
||||
///
|
||||
/// Response format:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "id": "msg_...",
|
||||
/// "type": "message",
|
||||
/// "model": "claude-sonnet-4-20250514",
|
||||
/// "usage": {
|
||||
/// "input_tokens": 1234,
|
||||
/// "output_tokens": 567,
|
||||
/// "cache_creation_input_tokens": 0,
|
||||
/// "cache_read_input_tokens": 890
|
||||
/// },
|
||||
/// "stop_reason": "end_turn"
|
||||
/// }
|
||||
/// ```
|
||||
pub fn parse_non_streaming_response(body: &[u8]) -> Option<ApiUsage> {
|
||||
let json: Value = serde_json::from_slice(body).ok()?;
|
||||
extract_usage_from_message(&json)
|
||||
}
|
||||
|
||||
/// Parse SSE events from a streaming Anthropic response body chunk.
|
||||
///
|
||||
/// Events of interest:
|
||||
/// - `message_start` — contains `message.usage.input_tokens` + cache tokens
|
||||
/// - `message_delta` — contains `usage.output_tokens`
|
||||
/// - `message_stop` — marks end (no usage data)
|
||||
///
|
||||
/// Returns accumulated usage across all events in this chunk.
|
||||
pub fn parse_streaming_chunk(chunk: &str, accumulator: &mut StreamingAccumulator) {
|
||||
for line in chunk.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data.trim() == "[DONE]" {
|
||||
continue;
|
||||
}
|
||||
if let Ok(event) = serde_json::from_str::<Value>(data) {
|
||||
accumulator.process_event(&event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Accumulates usage data across streaming SSE events.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct StreamingAccumulator {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub cache_creation_input_tokens: u64,
|
||||
pub cache_read_input_tokens: u64,
|
||||
pub model: Option<String>,
|
||||
pub stop_reason: Option<String>,
|
||||
pub is_complete: bool,
|
||||
}
|
||||
|
||||
impl StreamingAccumulator {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Process a single SSE event.
|
||||
pub fn process_event(&mut self, event: &Value) {
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"message_start" => {
|
||||
// message_start contains the initial usage (input tokens + cache)
|
||||
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
||||
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
||||
}
|
||||
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
||||
self.model = Some(model.to_string());
|
||||
}
|
||||
trace!(
|
||||
input = self.input_tokens,
|
||||
cache_read = self.cache_read_input_tokens,
|
||||
cache_create = self.cache_creation_input_tokens,
|
||||
"SSE message_start: captured input usage"
|
||||
);
|
||||
}
|
||||
"message_delta" => {
|
||||
// message_delta contains the output usage
|
||||
if let Some(usage) = event.get("usage") {
|
||||
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
|
||||
}
|
||||
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
||||
self.stop_reason = Some(reason.to_string());
|
||||
}
|
||||
trace!(output = self.output_tokens, "SSE message_delta: updated output tokens");
|
||||
}
|
||||
"message_stop" => {
|
||||
self.is_complete = true;
|
||||
debug!(
|
||||
input = self.input_tokens,
|
||||
output = self.output_tokens,
|
||||
cache_read = self.cache_read_input_tokens,
|
||||
model = ?self.model,
|
||||
"SSE message_stop: stream complete"
|
||||
);
|
||||
}
|
||||
"content_block_start" | "content_block_delta" | "content_block_stop" | "ping" => {
|
||||
// Content events — no usage data, just pass through
|
||||
}
|
||||
_ => {
|
||||
trace!(event_type, "SSE: unknown event type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert accumulated data to an ApiUsage.
|
||||
pub fn into_usage(self) -> ApiUsage {
|
||||
ApiUsage {
|
||||
input_tokens: self.input_tokens,
|
||||
output_tokens: self.output_tokens,
|
||||
cache_creation_input_tokens: self.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: self.cache_read_input_tokens,
|
||||
thinking_output_tokens: 0,
|
||||
response_output_tokens: 0,
|
||||
total_cost_usd: None,
|
||||
model: self.model,
|
||||
stop_reason: self.stop_reason,
|
||||
api_provider: Some("anthropic".to_string()),
|
||||
grpc_method: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract usage from a complete Message JSON object.
|
||||
fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
|
||||
let usage = msg.get("usage")?;
|
||||
|
||||
Some(ApiUsage {
|
||||
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
|
||||
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
|
||||
cache_creation_input_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
|
||||
cache_read_input_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
|
||||
thinking_output_tokens: 0,
|
||||
response_output_tokens: 0,
|
||||
total_cost_usd: None,
|
||||
model: msg["model"].as_str().map(|s| s.to_string()),
|
||||
stop_reason: msg["stop_reason"].as_str().map(|s| s.to_string()),
|
||||
api_provider: Some("anthropic".to_string()),
|
||||
grpc_method: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Try to identify a cascade ID from the request body.
|
||||
///
|
||||
/// The LS includes cascade-related metadata in its API requests (as part of
|
||||
/// the system prompt or metadata field). We try to find it.
|
||||
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
|
||||
let json: Value = serde_json::from_slice(request_body).ok()?;
|
||||
|
||||
// Check for metadata field (some API configurations include it)
|
||||
if let Some(metadata) = json.get("metadata") {
|
||||
if let Some(user_id) = metadata["user_id"].as_str() {
|
||||
// The LS often sets user_id to the cascadeId
|
||||
return Some(user_id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Check system prompt for cascade/workspace markers
|
||||
if let Some(system) = json.get("system") {
|
||||
let system_str = match system {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Array(arr) => {
|
||||
// Array of content blocks
|
||||
arr.iter()
|
||||
.filter_map(|b| b["text"].as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
// Look for workspace_id or cascade_id patterns
|
||||
if let Some(pos) = system_str.find("workspace_id") {
|
||||
let rest = &system_str[pos..];
|
||||
// Extract the value after workspace_id
|
||||
if let Some(val) = rest.split_whitespace().nth(1) {
|
||||
return Some(val.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_non_streaming() {
|
||||
let body = r#"{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 10,
|
||||
"cache_read_input_tokens": 30
|
||||
},
|
||||
"stop_reason": "end_turn"
|
||||
}"#;
|
||||
|
||||
let usage = parse_non_streaming_response(body.as_bytes()).unwrap();
|
||||
assert_eq!(usage.input_tokens, 100);
|
||||
assert_eq!(usage.output_tokens, 50);
|
||||
assert_eq!(usage.cache_creation_input_tokens, 10);
|
||||
assert_eq!(usage.cache_read_input_tokens, 30);
|
||||
assert_eq!(usage.model.as_deref(), Some("claude-sonnet-4-20250514"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_accumulator() {
|
||||
let mut acc = StreamingAccumulator::new();
|
||||
|
||||
// message_start
|
||||
let start = serde_json::json!({
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"usage": {
|
||||
"input_tokens": 200,
|
||||
"cache_creation_input_tokens": 5,
|
||||
"cache_read_input_tokens": 50
|
||||
}
|
||||
}
|
||||
});
|
||||
acc.process_event(&start);
|
||||
assert_eq!(acc.input_tokens, 200);
|
||||
assert_eq!(acc.cache_read_input_tokens, 50);
|
||||
|
||||
// message_delta
|
||||
let delta = serde_json::json!({
|
||||
"type": "message_delta",
|
||||
"delta": { "stop_reason": "end_turn" },
|
||||
"usage": { "output_tokens": 75 }
|
||||
});
|
||||
acc.process_event(&delta);
|
||||
assert_eq!(acc.output_tokens, 75);
|
||||
|
||||
// message_stop
|
||||
let stop = serde_json::json!({ "type": "message_stop" });
|
||||
acc.process_event(&stop);
|
||||
assert!(acc.is_complete);
|
||||
|
||||
let usage = acc.into_usage();
|
||||
assert_eq!(usage.input_tokens, 200);
|
||||
assert_eq!(usage.output_tokens, 75);
|
||||
}
|
||||
}
|
||||
19
src/mitm/mod.rs
Normal file
19
src/mitm/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! MITM proxy module: intercepts LS ↔ Google/Anthropic API traffic.
|
||||
//!
|
||||
//! The LS (Go binary with BoringCrypto) respects `HTTPS_PROXY` and `SSL_CERT_FILE`.
|
||||
//! By setting these env vars via the wrapper script, we route all outbound HTTPS
|
||||
//! traffic through our local MITM proxy, which:
|
||||
//!
|
||||
//! 1. Terminates TLS using dynamically-generated per-domain certificates
|
||||
//! 2. Detects protocol: HTTP/1.1 (REST) or HTTP/2 (gRPC)
|
||||
//! 3. For HTTP/1.1: parses JSON/SSE responses (Anthropic format)
|
||||
//! 4. For HTTP/2: decodes gRPC protobuf responses (Google format)
|
||||
//! 5. Captures token usage data (input, output, thinking, cache)
|
||||
//! 6. Forwards everything transparently to real upstream servers
|
||||
|
||||
pub mod ca;
|
||||
pub mod h2_handler;
|
||||
pub mod intercept;
|
||||
pub mod proto;
|
||||
pub mod proxy;
|
||||
pub mod store;
|
||||
584
src/mitm/proto.rs
Normal file
584
src/mitm/proto.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
|
||||
//!
|
||||
//! We don't have the .proto schema, so we decode protobuf messages generically
|
||||
//! and search for usage-like structures by matching field patterns.
|
||||
//!
|
||||
//! gRPC wire format:
|
||||
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
|
||||
//! - 4 bytes: message length (big-endian u32)
|
||||
//! - N bytes: protobuf message
|
||||
//!
|
||||
//! Protobuf wire format:
|
||||
//! - Each field: (field_number << 3 | wire_type) as varint, then value
|
||||
//! - Wire type 0: varint
|
||||
//! - Wire type 1: 64-bit fixed
|
||||
//! - Wire type 2: length-delimited (string, bytes, embedded message)
|
||||
//! - Wire type 5: 32-bit fixed
|
||||
//!
|
||||
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
|
||||
//!
|
||||
//! ```protobuf
|
||||
//! message ModelUsageStats {
|
||||
//! Model model = 1; // enum (varint)
|
||||
//! uint64 input_tokens = 2;
|
||||
//! uint64 output_tokens = 3;
|
||||
//! uint64 cache_write_tokens = 4;
|
||||
//! uint64 cache_read_tokens = 5;
|
||||
//! APIProvider api_provider = 6; // enum (varint)
|
||||
//! string message_id = 7;
|
||||
//! map<string,string> response_header = 8; // repeated message
|
||||
//! uint64 thinking_output_tokens = 9;
|
||||
//! uint64 response_output_tokens = 10;
|
||||
//! string response_id = 11;
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
/// A decoded protobuf field.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProtoValue {
|
||||
Varint(u64),
|
||||
#[allow(dead_code)]
|
||||
Fixed64(u64),
|
||||
#[allow(dead_code)]
|
||||
Fixed32(u32),
|
||||
Bytes(Vec<u8>),
|
||||
/// Nested message (parsed recursively)
|
||||
Message(Vec<ProtoField>),
|
||||
}
|
||||
|
||||
/// A single protobuf field with its number and value.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProtoField {
|
||||
pub number: u32,
|
||||
pub value: ProtoValue,
|
||||
}
|
||||
|
||||
/// Extracted usage data from a gRPC response.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct GrpcUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub thinking_output_tokens: u64,
|
||||
pub response_output_tokens: u64,
|
||||
pub cache_read_tokens: u64,
|
||||
pub cache_write_tokens: u64,
|
||||
pub model: Option<String>,
|
||||
pub api_provider: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub response_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract gRPC message frames from a buffer.
|
||||
///
|
||||
/// A gRPC message is:
|
||||
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
|
||||
///
|
||||
/// Multiple messages can be concatenated in a single buffer.
|
||||
/// If compressed flag is 1, the message is gzip-decompressed.
|
||||
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
|
||||
let mut messages = Vec::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset + 5 <= data.len() {
|
||||
let compressed = data[offset];
|
||||
let length = u32::from_be_bytes([
|
||||
data[offset + 1],
|
||||
data[offset + 2],
|
||||
data[offset + 3],
|
||||
data[offset + 4],
|
||||
]) as usize;
|
||||
|
||||
offset += 5;
|
||||
|
||||
if offset + length > data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let payload = &data[offset..offset + length];
|
||||
|
||||
if compressed == 1 {
|
||||
// gzip-compressed frame
|
||||
let mut decoder = GzDecoder::new(payload);
|
||||
let mut decompressed = Vec::new();
|
||||
match decoder.read_to_end(&mut decompressed) {
|
||||
Ok(_) => messages.push(decompressed),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Proto: failed to decompress gRPC frame");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
messages.push(payload.to_vec());
|
||||
}
|
||||
|
||||
offset += length;
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Decode a protobuf message into a list of fields.
|
||||
///
|
||||
/// This is a best-effort decoder that handles the common wire types.
|
||||
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
|
||||
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
|
||||
let mut fields = Vec::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < data.len() {
|
||||
// Read tag (varint)
|
||||
let (tag, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
|
||||
let field_number = (tag >> 3) as u32;
|
||||
let wire_type = (tag & 0x07) as u8;
|
||||
|
||||
if field_number == 0 {
|
||||
break; // invalid
|
||||
}
|
||||
|
||||
let value = match wire_type {
|
||||
0 => {
|
||||
// Varint
|
||||
let (val, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
ProtoValue::Varint(val)
|
||||
}
|
||||
1 => {
|
||||
// 64-bit fixed
|
||||
if offset + 8 > data.len() {
|
||||
break;
|
||||
}
|
||||
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
|
||||
offset += 8;
|
||||
ProtoValue::Fixed64(val)
|
||||
}
|
||||
2 => {
|
||||
// Length-delimited
|
||||
let (len, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
let len = len as usize;
|
||||
|
||||
if offset + len > data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let payload = &data[offset..offset + len];
|
||||
offset += len;
|
||||
|
||||
// Try to parse as a nested message
|
||||
let nested = decode_proto(payload);
|
||||
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
|
||||
ProtoValue::Message(nested)
|
||||
} else {
|
||||
ProtoValue::Bytes(payload.to_vec())
|
||||
}
|
||||
}
|
||||
5 => {
|
||||
// 32-bit fixed
|
||||
if offset + 4 > data.len() {
|
||||
break;
|
||||
}
|
||||
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
|
||||
offset += 4;
|
||||
ProtoValue::Fixed32(val)
|
||||
}
|
||||
_ => {
|
||||
// Unknown wire type — stop parsing
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
fields.push(ProtoField {
|
||||
number: field_number,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
fields
|
||||
}
|
||||
|
||||
/// Heuristic: does this list of fields look like a valid protobuf message?
|
||||
/// (vs. a random string that happened to partially decode)
|
||||
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
|
||||
if fields.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that field numbers are reasonable (< 10000)
|
||||
let valid_numbers = fields.iter().all(|f| f.number < 10000);
|
||||
if !valid_numbers {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we have very few fields relative to the data size, it's probably not a message
|
||||
// (e.g., a long string that happened to have a valid first-field prefix)
|
||||
if fields.len() == 1 && original_len > 100 {
|
||||
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
||||
match &fields[0].value {
|
||||
ProtoValue::Bytes(_) | ProtoValue::Message(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
|
||||
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
|
||||
let mut result: u64 = 0;
|
||||
let mut shift = 0;
|
||||
|
||||
for (i, &byte) in data.iter().enumerate() {
|
||||
if i >= 10 {
|
||||
return None; // Too many bytes for a varint
|
||||
}
|
||||
|
||||
result |= ((byte & 0x7F) as u64) << shift;
|
||||
shift += 7;
|
||||
|
||||
if byte & 0x80 == 0 {
|
||||
return Some((result, i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Search a decoded protobuf message tree for usage-like structures.
|
||||
///
|
||||
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
|
||||
///
|
||||
/// field 1: model (enum/varint)
|
||||
/// field 2: input_tokens (uint64)
|
||||
/// field 3: output_tokens (uint64)
|
||||
/// field 4: cache_write_tokens (uint64)
|
||||
/// field 5: cache_read_tokens (uint64)
|
||||
/// field 6: api_provider (enum/varint)
|
||||
/// field 7: message_id (string)
|
||||
/// field 8: response_header (map, repeated message)
|
||||
/// field 9: thinking_output_tokens (uint64)
|
||||
/// field 10: response_output_tokens (uint64)
|
||||
/// field 11: response_id (string)
|
||||
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// Strategy: recursively search for any sub-message that looks like usage data
|
||||
// Try this level first
|
||||
if let Some(usage) = try_extract_usage(fields) {
|
||||
return Some(usage);
|
||||
}
|
||||
|
||||
// Recurse into nested messages
|
||||
for field in fields {
|
||||
if let ProtoValue::Message(ref nested) = field.value {
|
||||
if let Some(usage) = extract_usage_from_proto(nested) {
|
||||
return Some(usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Try to extract usage from this specific set of fields.
|
||||
///
|
||||
/// Uses verified field numbers from the binary's embedded proto descriptor.
|
||||
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// We need:
|
||||
// - At least 2 varint fields with values in token range
|
||||
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
|
||||
let varint_fields: Vec<_> = fields
|
||||
.iter()
|
||||
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
|
||||
.collect();
|
||||
|
||||
let string_fields: Vec<_> = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
if let ProtoValue::Bytes(ref b) = f.value {
|
||||
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Need at least 2 varint fields to be a candidate
|
||||
if varint_fields.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if the varint values make sense as token counts
|
||||
let plausible_token_count = |v: u64| v <= 10_000_000;
|
||||
let plausible_varints = varint_fields
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
if let ProtoValue::Varint(v) = f.value {
|
||||
plausible_token_count(v) && v > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.count();
|
||||
|
||||
// Need at least 2 non-zero plausible values
|
||||
if plausible_varints < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
||||
// can contain model names, or model enum values map to known names)
|
||||
let has_model_string = string_fields.iter().any(|(_, s)| {
|
||||
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|
||||
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|
||||
|| s.contains("flash") || s.contains("pro")
|
||||
});
|
||||
|
||||
// Check for fields at the known ModelUsageStats field numbers
|
||||
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
|
||||
// Strong signal: has both input and output token fields
|
||||
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
||||
|
||||
if !is_likely_usage && varint_fields.len() < 3 {
|
||||
// Without strong signal, need more fields
|
||||
return None;
|
||||
}
|
||||
|
||||
// Build usage from exact field numbers (verified from binary)
|
||||
let mut usage = GrpcUsage::default();
|
||||
|
||||
for field in fields {
|
||||
match &field.value {
|
||||
ProtoValue::Varint(v) => {
|
||||
let v = *v;
|
||||
if !plausible_token_count(v) {
|
||||
continue;
|
||||
}
|
||||
match field.number {
|
||||
// field 1 = model enum (varint, not string!)
|
||||
2 => usage.input_tokens = v,
|
||||
3 => usage.output_tokens = v,
|
||||
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||
// field 6 = api_provider enum (varint)
|
||||
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
||||
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
ProtoValue::Bytes(ref b) => {
|
||||
if let Ok(s) = std::str::from_utf8(b) {
|
||||
match field.number {
|
||||
7 => usage.message_id = Some(s.to_string()),
|
||||
11 => usage.response_id = Some(s.to_string()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Model and api_provider are enums (varints), not strings
|
||||
// We can map known enum values later if needed
|
||||
// For now, extract the enum value as a string representation
|
||||
for field in fields {
|
||||
if let ProtoValue::Varint(v) = &field.value {
|
||||
match field.number {
|
||||
1 => {
|
||||
// Model enum — we don't have the mapping, store as number
|
||||
usage.model = Some(format!("model_enum_{v}"));
|
||||
}
|
||||
6 => {
|
||||
// APIProvider enum
|
||||
usage.api_provider = Some(match *v {
|
||||
0 => "unknown".to_string(),
|
||||
1 => "google".to_string(),
|
||||
2 => "anthropic".to_string(),
|
||||
_ => format!("provider_{v}"),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate — we should have at least input OR output tokens
|
||||
if usage.input_tokens == 0 && usage.output_tokens == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
debug!(
|
||||
input = usage.input_tokens,
|
||||
output = usage.output_tokens,
|
||||
thinking = usage.thinking_output_tokens,
|
||||
response = usage.response_output_tokens,
|
||||
cache_read = usage.cache_read_tokens,
|
||||
cache_write = usage.cache_write_tokens,
|
||||
model = ?usage.model,
|
||||
api_provider = ?usage.api_provider,
|
||||
"Proto: extracted ModelUsageStats from protobuf"
|
||||
);
|
||||
|
||||
Some(usage)
|
||||
}
|
||||
|
||||
/// Parse a gRPC response body (may contain multiple messages) for usage data.
|
||||
///
|
||||
/// Handles both compressed and uncompressed gRPC frames.
|
||||
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
|
||||
let messages = extract_grpc_messages(body);
|
||||
|
||||
trace!(count = messages.len(), "Proto: extracted gRPC messages");
|
||||
|
||||
// Check each message for usage data (last message usually has it)
|
||||
for msg in messages.iter().rev() {
|
||||
let fields = decode_proto(msg);
|
||||
if let Some(usage) = extract_usage_from_proto(&fields) {
|
||||
return Some(usage);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_read_varint() {
|
||||
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
|
||||
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
|
||||
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
|
||||
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_grpc_messages_uncompressed() {
|
||||
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
|
||||
let mut buf = vec![0u8]; // not compressed
|
||||
buf.extend_from_slice(&5u32.to_be_bytes());
|
||||
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
|
||||
|
||||
let messages = extract_grpc_messages(&buf);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_grpc_messages_compressed() {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
// Create a payload
|
||||
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
|
||||
|
||||
// Compress it
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(&payload).unwrap();
|
||||
let compressed = encoder.finish().unwrap();
|
||||
|
||||
// Build gRPC frame with compressed flag
|
||||
let mut buf = vec![1u8]; // compressed
|
||||
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
|
||||
buf.extend_from_slice(&compressed);
|
||||
|
||||
let messages = extract_grpc_messages(&buf);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0], payload);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_proto_varints() {
|
||||
// field 1 = 150, field 2 = 66
|
||||
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
|
||||
let fields = decode_proto(&data);
|
||||
assert_eq!(fields.len(), 2);
|
||||
assert_eq!(fields[0].number, 1);
|
||||
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
|
||||
assert_eq!(fields[1].number, 2);
|
||||
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_proto_with_string() {
|
||||
// field 1 = "hello" (string), field 2 = varint 42
|
||||
let mut data = Vec::new();
|
||||
// field 1, wire type 2 (length-delimited)
|
||||
data.push(0x0A); // (1 << 3) | 2
|
||||
data.push(0x05); // length 5
|
||||
data.extend_from_slice(b"hello");
|
||||
// field 2, wire type 0 (varint)
|
||||
data.push(0x10); // (2 << 3) | 0
|
||||
data.push(0x2A); // 42
|
||||
|
||||
let fields = decode_proto(&data);
|
||||
assert!(fields.len() >= 2);
|
||||
assert_eq!(fields[0].number, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_usage_correct_field_numbers() {
|
||||
// Build a mock ModelUsageStats with the correct field numbers:
|
||||
// field 1 (model enum) = 5 (some model)
|
||||
// field 2 (input_tokens) = 1000
|
||||
// field 3 (output_tokens) = 500
|
||||
// field 4 (cache_write_tokens) = 100
|
||||
// field 5 (cache_read_tokens) = 200
|
||||
// field 9 (thinking_output_tokens) = 300
|
||||
// field 10 (response_output_tokens) = 200
|
||||
let mut data = Vec::new();
|
||||
|
||||
// Helper: encode varint field
|
||||
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
|
||||
// Tag
|
||||
let tag = (field_num << 3) | 0; // wire type 0
|
||||
let mut t = tag;
|
||||
while t >= 0x80 {
|
||||
data.push((t as u8) | 0x80);
|
||||
t >>= 7;
|
||||
}
|
||||
data.push(t as u8);
|
||||
// Value
|
||||
let mut v = value;
|
||||
while v >= 0x80 {
|
||||
data.push((v as u8) | 0x80);
|
||||
v >>= 7;
|
||||
}
|
||||
data.push(v as u8);
|
||||
}
|
||||
|
||||
encode_varint_field(&mut data, 1, 5); // model enum
|
||||
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||
|
||||
let fields = decode_proto(&data);
|
||||
let usage = try_extract_usage(&fields).expect("should extract usage");
|
||||
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_write_tokens, 100);
|
||||
assert_eq!(usage.cache_read_tokens, 200);
|
||||
assert_eq!(usage.thinking_output_tokens, 300);
|
||||
assert_eq!(usage.response_output_tokens, 200);
|
||||
}
|
||||
}
|
||||
591
src/mitm/proxy.rs
Normal file
591
src/mitm/proxy.rs
Normal file
@@ -0,0 +1,591 @@
|
||||
//! MITM proxy server: handles CONNECT tunnels and TLS interception.
|
||||
//!
|
||||
//! Listens on a local port for HTTP CONNECT requests from the LS.
|
||||
//! For intercepted domains, it terminates TLS with our CA-signed cert,
|
||||
//! reads/modifies the request, forwards to the real upstream, and captures
|
||||
//! the response (especially usage data).
|
||||
//!
|
||||
//! For non-intercepted domains, it acts as a transparent TCP tunnel.
|
||||
|
||||
use super::ca::MitmCa;
|
||||
use super::intercept::{
|
||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
|
||||
StreamingAccumulator,
|
||||
};
|
||||
use super::store::MitmStore;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
/// Domains we intercept (terminate TLS and inspect traffic).
|
||||
/// This includes exact matches and suffix matches for regional endpoints
|
||||
/// (e.g., us-central1-aiplatform.googleapis.com).
|
||||
const INTERCEPT_DOMAINS: &[&str] = &[
|
||||
"cloudcode-pa.googleapis.com",
|
||||
"aiplatform.googleapis.com",
|
||||
"api.anthropic.com",
|
||||
"speech.googleapis.com",
|
||||
"modelarmor.googleapis.com",
|
||||
];
|
||||
|
||||
/// Domains we NEVER intercept (transparent tunnel).
|
||||
const PASSTHROUGH_DOMAINS: &[&str] = &[
|
||||
"oauth2.googleapis.com",
|
||||
"accounts.google.com",
|
||||
"storage.googleapis.com",
|
||||
"www.googleapis.com",
|
||||
"firebaseinstallations.googleapis.com",
|
||||
"crashlyticsreports-pa.googleapis.com",
|
||||
"play.googleapis.com",
|
||||
"update.googleapis.com",
|
||||
"dl.google.com",
|
||||
];
|
||||
|
||||
/// Configuration for the MITM proxy.
|
||||
pub struct MitmConfig {
|
||||
/// Port to listen on (0 = auto-assign).
|
||||
pub port: u16,
|
||||
/// Whether to enable request modification.
|
||||
pub modify_requests: bool,
|
||||
}
|
||||
|
||||
impl Default for MitmConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
port: 0,
|
||||
modify_requests: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the MITM proxy server.
|
||||
///
|
||||
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
||||
pub async fn run(
|
||||
ca: Arc<MitmCa>,
|
||||
store: MitmStore,
|
||||
config: MitmConfig,
|
||||
) -> Result<(u16, tokio::task::JoinHandle<()>), String> {
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{}", config.port))
|
||||
.await
|
||||
.map_err(|e| format!("MITM bind failed: {e}"))?;
|
||||
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.map_err(|e| format!("MITM local_addr failed: {e}"))?
|
||||
.port();
|
||||
|
||||
info!(port, "MITM proxy listening");
|
||||
|
||||
let modify_requests = config.modify_requests;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
match listener.accept().await {
|
||||
Ok((stream, addr)) => {
|
||||
trace!(?addr, "MITM: new connection");
|
||||
let ca = ca.clone();
|
||||
let store = store.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
|
||||
debug!(error = %e, "MITM connection error");
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
error!(error = %e, "MITM accept error");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok((port, handle))
|
||||
}
|
||||
|
||||
/// Handle a single incoming connection from the LS.
|
||||
///
|
||||
/// The LS sends an HTTP CONNECT request to establish a tunnel.
|
||||
/// We then decide whether to intercept or passthrough.
|
||||
async fn handle_connection(
|
||||
mut stream: TcpStream,
|
||||
ca: Arc<MitmCa>,
|
||||
store: MitmStore,
|
||||
modify_requests: bool,
|
||||
) -> Result<(), String> {
|
||||
// Read the CONNECT request
|
||||
let mut buf = vec![0u8; 8192];
|
||||
let n = stream
|
||||
.read(&mut buf)
|
||||
.await
|
||||
.map_err(|e| format!("Read CONNECT: {e}"))?;
|
||||
|
||||
if n == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let request = String::from_utf8_lossy(&buf[..n]);
|
||||
let first_line = request.lines().next().unwrap_or("");
|
||||
|
||||
// Parse "CONNECT host:port HTTP/1.1"
|
||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
||||
if parts.len() < 3 || parts[0] != "CONNECT" {
|
||||
// Not a CONNECT request — return 400
|
||||
let resp = "HTTP/1.1 400 Bad Request\r\n\r\n";
|
||||
let _ = stream.write_all(resp.as_bytes()).await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let host_port = parts[1];
|
||||
let (domain, _port) = match host_port.rsplit_once(':') {
|
||||
Some((h, p)) => (h, p.parse::<u16>().unwrap_or(443)),
|
||||
None => (host_port, 443),
|
||||
};
|
||||
|
||||
debug!(domain, "MITM: CONNECT request");
|
||||
|
||||
// Decide: intercept or passthrough
|
||||
let should_intercept = should_intercept_domain(domain);
|
||||
|
||||
// Send 200 Connection Established
|
||||
let response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
stream
|
||||
.write_all(response.as_bytes())
|
||||
.await
|
||||
.map_err(|e| format!("Write 200: {e}"))?;
|
||||
|
||||
if should_intercept {
|
||||
handle_intercepted(stream, domain, ca, store, modify_requests).await
|
||||
} else {
|
||||
handle_passthrough(stream, domain, _port).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a domain should be intercepted.
|
||||
fn should_intercept_domain(domain: &str) -> bool {
|
||||
// Never intercept passthrough domains
|
||||
for &pt in PASSTHROUGH_DOMAINS {
|
||||
if domain == pt {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Intercept known API domains (exact match, subdomain, or regional prefix)
|
||||
for &intercept in INTERCEPT_DOMAINS {
|
||||
if domain == intercept
|
||||
|| domain.ends_with(&format!(".{intercept}"))
|
||||
|| domain.ends_with(&format!("-{intercept}"))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Default: passthrough
|
||||
false
|
||||
}
|
||||
|
||||
/// Handle an intercepted connection: terminate TLS, inspect traffic.
|
||||
///
|
||||
/// After TLS termination, checks the negotiated ALPN protocol:
|
||||
/// - `h2` → HTTP/2 handler (for gRPC traffic to Google APIs)
|
||||
/// - `http/1.1` or none → HTTP/1.1 handler (for REST/SSE traffic)
|
||||
async fn handle_intercepted(
|
||||
stream: TcpStream,
|
||||
domain: &str,
|
||||
ca: Arc<MitmCa>,
|
||||
store: MitmStore,
|
||||
modify_requests: bool,
|
||||
) -> Result<(), String> {
|
||||
info!(domain, "MITM: intercepting TLS");
|
||||
|
||||
// Get or create server TLS config for this domain
|
||||
let server_config = ca
|
||||
.server_config_for_domain(domain)
|
||||
.await?;
|
||||
|
||||
let acceptor = TlsAcceptor::from(server_config);
|
||||
|
||||
// Perform TLS handshake with the client (LS)
|
||||
let tls_stream = acceptor
|
||||
.accept(stream)
|
||||
.await
|
||||
.map_err(|e| format!("TLS handshake with client failed for {domain}: {e}"))?;
|
||||
|
||||
// Check negotiated ALPN protocol
|
||||
let alpn = tls_stream.get_ref().1
|
||||
.alpn_protocol()
|
||||
.map(|p| String::from_utf8_lossy(p).to_string());
|
||||
|
||||
debug!(domain, alpn = ?alpn, "MITM: TLS handshake successful");
|
||||
|
||||
match alpn.as_deref() {
|
||||
Some("h2") => {
|
||||
// HTTP/2 — use the hyper-based gRPC handler
|
||||
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
||||
super::h2_handler::handle_h2_connection(
|
||||
tls_stream,
|
||||
domain.to_string(),
|
||||
store,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
// HTTP/1.1 or no ALPN — use the existing handler
|
||||
debug!(domain, "MITM: routing to HTTP/1.1 handler");
|
||||
handle_http_over_tls(tls_stream, domain, store, modify_requests).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle HTTP traffic over the decrypted TLS connection.
|
||||
///
|
||||
/// Loops to handle multiple requests on the same connection (HTTP keep-alive).
|
||||
/// Reads full request, connects to upstream, forwards request, streams response
|
||||
/// back to client while capturing usage data.
|
||||
async fn handle_http_over_tls(
|
||||
mut client: tokio_rustls::server::TlsStream<TcpStream>,
|
||||
domain: &str,
|
||||
store: MitmStore,
|
||||
_modify_requests: bool,
|
||||
) -> Result<(), String> {
|
||||
let mut tmp = vec![0u8; 32768];
|
||||
|
||||
// Build upstream TLS connector once for this connection
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
let native_certs = rustls_native_certs::load_native_certs();
|
||||
for cert in native_certs.certs {
|
||||
let _ = root_store.add(cert);
|
||||
}
|
||||
let upstream_config = Arc::new(
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth(),
|
||||
);
|
||||
|
||||
// Reusable upstream connection — created lazily, reconnected if stale
|
||||
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
|
||||
|
||||
/// Connect (or reconnect) to the real upstream via TLS.
|
||||
async fn connect_upstream(
|
||||
domain: &str,
|
||||
config: &Arc<rustls::ClientConfig>,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
||||
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
||||
let tcp = TcpStream::connect(format!("{domain}:443"))
|
||||
.await
|
||||
.map_err(|e| format!("Connect to upstream {domain}: {e}"))?;
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
||||
.map_err(|e| format!("Invalid server name: {e}"))?;
|
||||
connector
|
||||
.connect(server_name, tcp)
|
||||
.await
|
||||
.map_err(|e| format!("TLS connect to upstream {domain}: {e}"))
|
||||
}
|
||||
|
||||
// Keep-alive loop: handle multiple requests on this connection
|
||||
loop {
|
||||
// ── Read the HTTP request from the client ─────────────────────────
|
||||
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||
|
||||
loop {
|
||||
let n = match client.read(&mut tmp).await {
|
||||
Ok(0) => return Ok(()), // Client closed connection cleanly
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
// Connection reset / broken pipe is normal for keep-alive end
|
||||
debug!(domain, error = %e, "MITM: client read finished");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
request_buf.extend_from_slice(&tmp[..n]);
|
||||
|
||||
// Check if we have the full request (headers + body)
|
||||
if has_complete_http_request(&request_buf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if request_buf.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Parse the HTTP request to find headers and body
|
||||
let (headers_end, content_length, is_streaming_request) = parse_http_request_meta(&request_buf);
|
||||
|
||||
// Try to extract cascade hint from request body
|
||||
let cascade_hint = if headers_end < request_buf.len() {
|
||||
extract_cascade_hint(&request_buf[headers_end..])
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
debug!(
|
||||
domain,
|
||||
content_length,
|
||||
streaming = is_streaming_request,
|
||||
cascade = ?cascade_hint,
|
||||
"MITM: forwarding request to upstream"
|
||||
);
|
||||
|
||||
// ── Ensure upstream connection is alive ──────────────────────────────
|
||||
// Lazily connect on first request, or reconnect if the previous connection died
|
||||
let conn = match upstream.as_mut() {
|
||||
Some(c) => c,
|
||||
None => {
|
||||
let c = connect_upstream(domain, &upstream_config).await?;
|
||||
upstream.insert(c)
|
||||
}
|
||||
};
|
||||
|
||||
// Forward the request — if write fails, reconnect and retry once
|
||||
if let Err(e) = conn.write_all(&request_buf).await {
|
||||
debug!(domain, error = %e, "MITM: upstream write failed, reconnecting");
|
||||
let c = connect_upstream(domain, &upstream_config).await?;
|
||||
let conn = upstream.insert(c);
|
||||
conn.write_all(&request_buf)
|
||||
.await
|
||||
.map_err(|e| format!("Write to upstream (retry): {e}"))?;
|
||||
}
|
||||
|
||||
let conn = upstream.as_mut().unwrap();
|
||||
|
||||
// ── Stream response back to client ──────────────────────────────────
|
||||
let mut streaming_acc = StreamingAccumulator::new();
|
||||
let mut is_streaming_response = false;
|
||||
let mut headers_parsed = false;
|
||||
// Only buffer response body for non-streaming (for usage parsing)
|
||||
let mut non_streaming_buf: Option<Vec<u8>> = None;
|
||||
// Track if upstream connection is still usable after this response
|
||||
let mut upstream_ok = true;
|
||||
|
||||
// Per-request timeout: 5 minutes (covers large context API calls)
|
||||
const READ_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300);
|
||||
|
||||
loop {
|
||||
let n = match tokio::time::timeout(READ_TIMEOUT, conn.read(&mut tmp)).await {
|
||||
Ok(Ok(0)) => {
|
||||
// Upstream closed — connection is no longer reusable
|
||||
upstream_ok = false;
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => {
|
||||
debug!(domain, error = %e, "MITM: upstream read finished");
|
||||
upstream_ok = false;
|
||||
break;
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: upstream read timed out after 5 minutes");
|
||||
upstream_ok = false;
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let chunk = &tmp[..n];
|
||||
|
||||
// Check response headers for content-type
|
||||
if !headers_parsed {
|
||||
// We need to buffer until we see the end of headers
|
||||
let buf = non_streaming_buf.get_or_insert_with(|| Vec::with_capacity(1024 * 64));
|
||||
buf.extend_from_slice(chunk);
|
||||
if let Some(_hdr_end) = find_headers_end(buf) {
|
||||
// Use httparse for response header parsing
|
||||
let mut resp_headers = [httparse::EMPTY_HEADER; 64];
|
||||
let mut resp = httparse::Response::new(&mut resp_headers);
|
||||
let hdr_end = match resp.parse(buf) {
|
||||
Ok(httparse::Status::Complete(n)) => n,
|
||||
_ => _hdr_end, // Fallback to manual detection
|
||||
};
|
||||
|
||||
// Detect content type and connection handling from parsed headers
|
||||
for header in resp.headers.iter() {
|
||||
if header.name.eq_ignore_ascii_case("content-type") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
if val.contains("text/event-stream") {
|
||||
is_streaming_response = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("connection") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
if val.trim().eq_ignore_ascii_case("close") {
|
||||
upstream_ok = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
headers_parsed = true;
|
||||
|
||||
if is_streaming_response {
|
||||
// For streaming, parse any SSE data already in the buffer
|
||||
let body_so_far = String::from_utf8_lossy(&buf[hdr_end..]);
|
||||
if !body_so_far.is_empty() {
|
||||
parse_streaming_chunk(&body_so_far, &mut streaming_acc);
|
||||
}
|
||||
// Forward the accumulated buffer to client
|
||||
if let Err(e) = client.write_all(buf).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
break;
|
||||
}
|
||||
non_streaming_buf = None;
|
||||
continue;
|
||||
}
|
||||
// Non-streaming: keep buffering the response body for parsing
|
||||
continue;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// If streaming, parse SSE events and forward immediately
|
||||
if is_streaming_response {
|
||||
let chunk_str = String::from_utf8_lossy(chunk);
|
||||
parse_streaming_chunk(&chunk_str, &mut streaming_acc);
|
||||
|
||||
if let Err(e) = client.write_all(chunk).await {
|
||||
warn!(error = %e, "MITM: write to client failed (client disconnected?)");
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Non-streaming: keep accumulating to parse usage at the end
|
||||
if let Some(ref mut buf) = non_streaming_buf {
|
||||
buf.extend_from_slice(chunk);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Forward non-streaming response all at once
|
||||
if !is_streaming_response {
|
||||
if let Some(ref buf) = non_streaming_buf {
|
||||
if let Err(e) = client.write_all(buf).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Capture usage data
|
||||
if is_streaming_response {
|
||||
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
||||
let usage = streaming_acc.into_usage();
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
} else if let Some(ref buf) = non_streaming_buf {
|
||||
if let Some(body_start) = find_headers_end(buf) {
|
||||
let body = &buf[body_start..];
|
||||
if let Some(usage) = parse_non_streaming_response(body) {
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If upstream closed, drop the connection so next iteration reconnects
|
||||
if !upstream_ok {
|
||||
upstream = None;
|
||||
}
|
||||
} // end keep-alive loop
|
||||
}
|
||||
|
||||
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||
async fn handle_passthrough(
|
||||
mut client: TcpStream,
|
||||
domain: &str,
|
||||
port: u16,
|
||||
) -> Result<(), String> {
|
||||
trace!(domain, port, "MITM: transparent tunnel");
|
||||
|
||||
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
||||
.await
|
||||
.map_err(|e| format!("Connect to {domain}:{port}: {e}"))?;
|
||||
|
||||
// Bidirectional copy
|
||||
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
||||
Ok((client_to_server, server_to_client)) => {
|
||||
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
|
||||
}
|
||||
Err(e) => {
|
||||
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if buffer contains a complete HTTP request (headers + full body).
|
||||
/// Uses `httparse` for zero-copy, case-insensitive header parsing.
|
||||
fn has_complete_http_request(buf: &[u8]) -> bool {
|
||||
let mut headers = [httparse::EMPTY_HEADER; 64];
|
||||
let mut req = httparse::Request::new(&mut headers);
|
||||
|
||||
let headers_end = match req.parse(buf) {
|
||||
Ok(httparse::Status::Complete(n)) => n,
|
||||
_ => return false, // Incomplete or parse error — need more data
|
||||
};
|
||||
|
||||
// Look for Content-Length
|
||||
for header in req.headers.iter() {
|
||||
if header.name.eq_ignore_ascii_case("content-length") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
if let Ok(len) = val.trim().parse::<usize>() {
|
||||
return buf.len() >= headers_end + len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
if val.trim().eq_ignore_ascii_case("chunked") {
|
||||
let body = &buf[headers_end..];
|
||||
return body.len() >= 5 && body.ends_with(b"0\r\n\r\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No Content-Length or Transfer-Encoding — no body expected (e.g., GET)
|
||||
true
|
||||
}
|
||||
|
||||
/// Find the end of HTTP headers (position after \r\n\r\n).
|
||||
fn find_headers_end(buf: &[u8]) -> Option<usize> {
|
||||
buf.windows(4)
|
||||
.position(|w| w == b"\r\n\r\n")
|
||||
.map(|pos| pos + 4)
|
||||
}
|
||||
|
||||
/// Parse HTTP request metadata from raw bytes using `httparse`.
|
||||
/// Returns (headers_end, content_length, is_streaming_request).
|
||||
fn parse_http_request_meta(buf: &[u8]) -> (usize, usize, bool) {
|
||||
let mut headers = [httparse::EMPTY_HEADER; 64];
|
||||
let mut req = httparse::Request::new(&mut headers);
|
||||
|
||||
let headers_end = match req.parse(buf) {
|
||||
Ok(httparse::Status::Complete(n)) => n,
|
||||
_ => {
|
||||
// Fallback if httparse can't parse
|
||||
return (find_headers_end(buf).unwrap_or(buf.len()), 0, false);
|
||||
}
|
||||
};
|
||||
|
||||
let mut content_length = 0usize;
|
||||
|
||||
for header in req.headers.iter() {
|
||||
if header.name.eq_ignore_ascii_case("content-length") {
|
||||
if let Ok(val) = std::str::from_utf8(header.value) {
|
||||
content_length = val.trim().parse().unwrap_or(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if request body asks for streaming
|
||||
let is_streaming = if headers_end < buf.len() {
|
||||
let body_str = String::from_utf8_lossy(&buf[headers_end..]);
|
||||
body_str.contains("\"stream\":true") || body_str.contains("\"stream\": true")
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
(headers_end, content_length, is_streaming)
|
||||
}
|
||||
|
||||
163
src/mitm/store.rs
Normal file
163
src/mitm/store.rs
Normal file
@@ -0,0 +1,163 @@
|
||||
//! Shared store for intercepted API usage data.
|
||||
//!
|
||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
|
||||
/// Token usage from an intercepted API response.
|
||||
///
|
||||
/// Covers both Anthropic JSON/SSE responses and Google gRPC protobuf responses.
|
||||
/// Fields map to the superset of Anthropic's `usage` object and Google's `ModelUsageStats` proto.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ApiUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
/// Anthropic: cache_creation_input_tokens / Google: cache_write_tokens
|
||||
pub cache_creation_input_tokens: u64,
|
||||
/// Anthropic: cache_read_input_tokens / Google: cache_read_tokens
|
||||
pub cache_read_input_tokens: u64,
|
||||
/// Google-specific: thinking/reasoning output tokens (extended thinking)
|
||||
pub thinking_output_tokens: u64,
|
||||
/// Google-specific: response output tokens (non-thinking portion)
|
||||
pub response_output_tokens: u64,
|
||||
/// Total cost in USD (if provided by the API).
|
||||
pub total_cost_usd: Option<f64>,
|
||||
/// The actual model that served the request.
|
||||
pub model: Option<String>,
|
||||
/// Stop reason / finish reason from the API.
|
||||
pub stop_reason: Option<String>,
|
||||
/// API provider (e.g. "anthropic", "google")
|
||||
pub api_provider: Option<String>,
|
||||
/// gRPC method path (e.g. "/google.internal.cloud.code.v1internal.PredictionService/GenerateContent")
|
||||
pub grpc_method: Option<String>,
|
||||
/// Timestamp when this usage was captured.
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
/// In practice, we use the cascade ID + a sequence number.
|
||||
#[derive(Clone)]
|
||||
pub struct MitmStore {
|
||||
/// Most recent usage per cascade ID.
|
||||
latest_usage: Arc<RwLock<HashMap<String, ApiUsage>>>,
|
||||
/// Global aggregate stats.
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct MitmStats {
|
||||
pub total_requests: u64,
|
||||
pub total_input_tokens: u64,
|
||||
pub total_output_tokens: u64,
|
||||
pub total_cache_read_tokens: u64,
|
||||
pub total_cache_creation_tokens: u64,
|
||||
pub total_thinking_output_tokens: u64,
|
||||
pub total_response_output_tokens: u64,
|
||||
/// Per-model usage breakdown (model name → stats).
|
||||
pub per_model: HashMap<String, ModelStats>,
|
||||
}
|
||||
|
||||
/// Per-model usage counters.
|
||||
#[derive(Debug, Clone, Default, Serialize)]
|
||||
pub struct ModelStats {
|
||||
pub requests: u64,
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub cache_read_tokens: u64,
|
||||
pub cache_creation_tokens: u64,
|
||||
}
|
||||
|
||||
impl MitmStore {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record a completed API exchange with usage data.
|
||||
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
|
||||
debug!(
|
||||
input = usage.input_tokens,
|
||||
output = usage.output_tokens,
|
||||
cache_read = usage.cache_read_input_tokens,
|
||||
cache_create = usage.cache_creation_input_tokens,
|
||||
thinking = usage.thinking_output_tokens,
|
||||
response = usage.response_output_tokens,
|
||||
model = ?usage.model,
|
||||
provider = ?usage.api_provider,
|
||||
grpc = ?usage.grpc_method,
|
||||
"MITM captured API usage"
|
||||
);
|
||||
|
||||
// Update aggregate stats
|
||||
{
|
||||
let mut stats = self.stats.write().await;
|
||||
stats.total_requests += 1;
|
||||
stats.total_input_tokens += usage.input_tokens;
|
||||
stats.total_output_tokens += usage.output_tokens;
|
||||
stats.total_cache_read_tokens += usage.cache_read_input_tokens;
|
||||
stats.total_cache_creation_tokens += usage.cache_creation_input_tokens;
|
||||
stats.total_thinking_output_tokens += usage.thinking_output_tokens;
|
||||
stats.total_response_output_tokens += usage.response_output_tokens;
|
||||
|
||||
// Per-model breakdown
|
||||
if let Some(ref model_name) = usage.model {
|
||||
let model_stats = stats.per_model.entry(model_name.clone()).or_default();
|
||||
model_stats.requests += 1;
|
||||
model_stats.input_tokens += usage.input_tokens;
|
||||
model_stats.output_tokens += usage.output_tokens;
|
||||
model_stats.cache_read_tokens += usage.cache_read_input_tokens;
|
||||
model_stats.cache_creation_tokens += usage.cache_creation_input_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
// Store latest usage for the cascade (if we can identify it)
|
||||
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
latest.insert(key, usage);
|
||||
|
||||
// Evict old entries to prevent unbounded memory growth
|
||||
const MAX_ENTRIES: usize = 500;
|
||||
if latest.len() > MAX_ENTRIES {
|
||||
// Find the oldest entry by captured_at and remove it
|
||||
let oldest_key = latest
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.captured_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(key) = oldest_key {
|
||||
latest.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the latest usage for a cascade, consuming it (one-shot read).
|
||||
///
|
||||
/// Only returns exact cascade_id matches — no cross-cascade fallback.
|
||||
/// The `_latest` key is only consumed when the caller explicitly requests it
|
||||
/// (i.e., when the MITM couldn't identify the cascade).
|
||||
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
latest.remove(cascade_id)
|
||||
}
|
||||
|
||||
/// Peek at the latest usage without consuming it.
|
||||
#[allow(dead_code)]
|
||||
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let latest = self.latest_usage.read().await;
|
||||
latest.get(cascade_id)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Get aggregate stats.
|
||||
pub async fn stats(&self) -> MitmStats {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user