//! HTTP/2 handler for gRPC traffic interception. //! //! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do), //! this module handles the bidirectional HTTP/2 connection: //! 1. Accepts HTTP/2 frames from the client (LS) //! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused) //! 3. Forwards each request stream to upstream //! 4. For non-streaming: buffers response, extracts usage, forwards //! 5. For streaming: forwards response body chunks in real-time, tees to a //! side buffer for usage extraction after stream completes //! //! ## Streaming vs Non-streaming //! //! gRPC has both unary (non-streaming) and server-streaming RPCs. //! The LS uses server-streaming for methods like `StreamGenerateContent`. //! We MUST forward streaming responses immediately — buffering would break //! the LS's perception of real-time generation. //! //! For usage extraction: ModelUsageStats is typically in the LAST message //! of a streaming response, so we tee the data and parse after stream ends. use crate::mitm::proto::parse_grpc_response_for_usage; use crate::mitm::store::{ApiUsage, MitmStore}; use bytes::Bytes; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::{Frame, Incoming}; use hyper::server::conn::http2::Builder as H2ServerBuilder; use hyper::service::service_fn; use hyper::{Request, Response}; use hyper_util::rt::TokioExecutor; use hyper_util::rt::TokioIo; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::sync::Mutex; use tracing::{debug, info, trace, warn}; /// A lazily-initialized, shared HTTP/2 connection to the upstream server. /// /// gRPC multiplexes many requests over a single HTTP/2 connection. /// We mirror this by maintaining a single upstream connection per domain. struct UpstreamPool { domain: String, tls_config: Arc, sender: Mutex>>>, } impl UpstreamPool { fn new(domain: String, tls_config: Arc) -> Self { Self { domain, tls_config, sender: Mutex::new(None), } } /// Get or create the upstream HTTP/2 sender. async fn get_sender( &self, ) -> Result>, String> { let mut guard = self.sender.lock().await; // Check if existing sender is still usable if let Some(ref sender) = *guard { if !sender.is_closed() { return Ok(sender.clone()); } debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting"); } // Create new connection let sender = self.connect().await?; *guard = Some(sender.clone()); Ok(sender) } async fn connect( &self, ) -> Result>, String> { let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain)) .await .map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?; let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone()); let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone()) .map_err(|e| format!("invalid domain {}: {e}", self.domain))?; let upstream_tls = connector .connect(server_name, upstream_tcp) .await .map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?; let upstream_io = TokioIo::new(upstream_tls); let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .handshake(upstream_io) .await .map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?; let domain = self.domain.clone(); tokio::spawn(async move { if let Err(e) = conn.await { debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended"); } }); info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection"); Ok(sender) } } /// gRPC methods that carry ModelUsageStats in their responses. const USAGE_METHODS: &[&str] = &[ // Unary methods "GenerateContent", "AsyncGenerateContent", "GenerateChat", "GenerateCode", "CompleteCode", "InternalAtomicAgenticChat", "Predict", "DirectPredict", // Streaming methods "StreamGenerateContent", "StreamAsyncGenerateContent", "StreamGenerateChat", ]; /// Handle an HTTP/2 connection from the LS after TLS termination. /// /// Uses hyper's HTTP/2 server to accept requests and a shared upstream /// HTTP/2 connection to forward them. pub async fn handle_h2_connection( tls_stream: S, domain: String, store: MitmStore, ) -> Result<(), String> where S: AsyncRead + AsyncWrite + Unpin + Send + 'static, { info!(domain = %domain, "MITM H2: handling HTTP/2 connection"); // Build TLS config for upstream connections let mut root_store = rustls::RootCertStore::empty(); let native_certs = rustls_native_certs::load_native_certs(); for cert in native_certs.certs { let _ = root_store.add(cert); } let mut upstream_tls_config = rustls::ClientConfig::builder() .with_root_certificates(root_store) .with_no_client_auth(); upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()]; // Shared upstream connection pool (single connection, multiplexed) let pool = Arc::new(UpstreamPool::new( domain.clone(), Arc::new(upstream_tls_config), )); let io = TokioIo::new(tls_stream); let domain = Arc::new(domain); let result = H2ServerBuilder::new(TokioExecutor::new()) .serve_connection( io, service_fn(move |req: Request| { let domain = domain.clone(); let store = store.clone(); let pool = pool.clone(); async move { handle_h2_request(req, &domain, store, pool).await } }), ) .await; match result { Ok(()) => { debug!("MITM H2: connection closed cleanly"); Ok(()) } Err(e) => { // Connection errors are expected on clean close debug!(error = %e, "MITM H2: connection ended"); Ok(()) } } } /// Response body type — either buffered or streaming. type BoxBody = http_body_util::Either< Full, StreamBody, hyper::Error>>>, >; /// Handle a single HTTP/2 request: forward to upstream, capture usage. /// /// For streaming responses, forwards chunks in real-time while teeing /// data to a side buffer for post-stream usage extraction. async fn handle_h2_request( req: Request, domain: &str, store: MitmStore, pool: Arc, ) -> Result, hyper::Error> { let method = req.method().clone(); let uri = req.uri().clone(); let path = uri.path().to_string(); // Identify gRPC method let is_grpc = req .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .map(|ct| ct.starts_with("application/grpc")) .unwrap_or(false); // Check if this method carries usage data let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m)); // Check if this is a streaming method let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream")); debug!( domain, %method, path = %path, grpc = is_grpc, usage_method = is_usage_method, streaming = is_streaming, "MITM H2: forwarding request" ); // Collect request body (we need it for cascade ID extraction) let (parts, body) = req.into_parts(); let request_body = match body.collect().await { Ok(collected) => collected.to_bytes(), Err(e) => { warn!(error = %e, "MITM H2: failed to collect request body"); Bytes::new() } }; // Get upstream sender from pool let mut upstream_sender = match pool.get_sender().await { Ok(s) => s, Err(e) => { warn!(error = %e, domain, "MITM H2: upstream connect failed"); let resp = Response::builder() .status(502) .body(http_body_util::Either::Left(Full::new( Bytes::from(format!("upstream connect failed: {e}")), ))) .unwrap(); return Ok(resp); } }; // Build the upstream request with proper authority let upstream_uri = http::Uri::builder() .scheme("https") .authority(domain) .path_and_query( uri.path_and_query() .map(|pq| pq.as_str()) .unwrap_or("/"), ) .build() .unwrap_or(uri); let mut upstream_req = Request::builder() .method(parts.method) .uri(upstream_uri); // Copy headers, skip hop-by-hop for (name, value) in &parts.headers { let n = name.as_str(); if n == "host" || n == "connection" || n == "transfer-encoding" { continue; } upstream_req = upstream_req.header(name, value); } let upstream_req = match upstream_req.body(Full::new(request_body.clone())) { Ok(r) => r, Err(e) => { let resp = Response::builder() .status(502) .body(http_body_util::Either::Left(Full::new( Bytes::from(format!("build request failed: {e}")), ))) .unwrap(); return Ok(resp); } }; // Send to upstream let upstream_resp = match upstream_sender.send_request(upstream_req).await { Ok(r) => r, Err(e) => { warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed"); let resp = Response::builder() .status(502) .body(http_body_util::Either::Left(Full::new( Bytes::from(format!("upstream request failed: {e}")), ))) .unwrap(); return Ok(resp); } }; let (resp_parts, resp_body) = upstream_resp.into_parts(); let status = resp_parts.status; // ────────────────────────────────────────────────────────────────── // Streaming path: forward chunks immediately, tee for usage parsing // ────────────────────────────────────────────────────────────────── if is_streaming && status.is_success() { let should_track_usage = is_usage_method; let (tx, rx) = tokio::sync::mpsc::channel::, hyper::Error>>(32); let store_clone = store.clone(); let path_clone = path.clone(); let request_body_clone = request_body.clone(); // Spawn a task to forward body chunks and tee for usage extraction tokio::spawn(async move { let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None }; let mut body = resp_body; loop { match body.frame().await { Some(Ok(frame)) => { if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) { buf.extend_from_slice(data); } if tx.send(Ok(frame)).await.is_err() { break; // client disconnected } } Some(Err(e)) => { warn!(error = %e, path = %path_clone, "MITM H2: streaming error"); let _ = tx.send(Err(e)).await; break; } None => break, // stream ended } } // Stream completed — parse the tee buffer for usage if let Some(tee_buffer) = tee_buffer { if !tee_buffer.is_empty() { if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) { let usage = ApiUsage { input_tokens: grpc_usage.input_tokens, output_tokens: grpc_usage.output_tokens, thinking_output_tokens: grpc_usage.thinking_output_tokens, response_output_tokens: grpc_usage.response_output_tokens, cache_creation_input_tokens: grpc_usage.cache_write_tokens, cache_read_input_tokens: grpc_usage.cache_read_tokens, model: grpc_usage.model, api_provider: grpc_usage.api_provider, grpc_method: Some(path_clone.clone()), stop_reason: None, total_cost_usd: None, captured_at: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), }; let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone); store_clone.record_usage(cascade_hint.as_deref(), usage).await; } } } }); let stream = tokio_stream::wrappers::ReceiverStream::new(rx); let stream_body = StreamBody::new(stream); let mut client_resp = Response::builder().status(resp_parts.status); for (name, value) in &resp_parts.headers { client_resp = client_resp.header(name, value); } let client_resp = client_resp .body(http_body_util::Either::Right(stream_body)) .unwrap_or_else(|_| { Response::builder() .status(500) .body(http_body_util::Either::Left(Full::new(Bytes::from( "internal error", )))) .unwrap() }); return Ok(client_resp); } // ────────────────────────────────────────────────────────────────── // Non-streaming path: buffer full response, extract usage, forward // ────────────────────────────────────────────────────────────────── let response_body = match resp_body.collect().await { Ok(collected) => collected.to_bytes(), Err(e) => { warn!(error = %e, "MITM H2: failed to collect response body"); Bytes::new() } }; trace!( domain, path = %path, status = %status, body_len = response_body.len(), "MITM H2: got upstream response" ); // Extract usage data from usage-carrying gRPC methods if is_usage_method && !response_body.is_empty() && status.is_success() { if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) { let usage = ApiUsage { input_tokens: grpc_usage.input_tokens, output_tokens: grpc_usage.output_tokens, thinking_output_tokens: grpc_usage.thinking_output_tokens, response_output_tokens: grpc_usage.response_output_tokens, cache_creation_input_tokens: grpc_usage.cache_write_tokens, cache_read_input_tokens: grpc_usage.cache_read_tokens, model: grpc_usage.model, api_provider: grpc_usage.api_provider, grpc_method: Some(path.clone()), stop_reason: None, total_cost_usd: None, captured_at: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), }; let cascade_hint = extract_cascade_from_grpc_request(&request_body); store.record_usage(cascade_hint.as_deref(), usage).await; } } // Build response for the client let mut client_resp = Response::builder().status(resp_parts.status); for (name, value) in &resp_parts.headers { client_resp = client_resp.header(name, value); } let client_resp = client_resp .body(http_body_util::Either::Left(Full::new(response_body))) .unwrap_or_else(|_| { Response::builder() .status(500) .body(http_body_util::Either::Left(Full::new(Bytes::from( "internal error", )))) .unwrap() }); Ok(client_resp) } /// Try to extract a cascade ID from a gRPC request body. /// /// Looks for UUID-formatted strings in the protobuf fields. fn extract_cascade_from_grpc_request(body: &[u8]) -> Option { use crate::mitm::proto::{decode_proto, extract_grpc_messages}; let messages = extract_grpc_messages(body); for msg in &messages { let fields = decode_proto(msg); for field in &fields { if let Some(id) = extract_uuid_from_field(field) { return Some(id); } } } None } fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option { use crate::mitm::proto::ProtoValue; match &field.value { ProtoValue::Bytes(b) => { if let Ok(s) = std::str::from_utf8(b) { if is_uuid(s) { return Some(s.to_string()); } } } ProtoValue::Message(nested) => { for nf in nested { if let Some(id) = extract_uuid_from_field(nf) { return Some(id); } } } _ => {} } None } fn is_uuid(s: &str) -> bool { s.len() == 36 && s.chars().all(|c| c.is_ascii_hexdigit() || c == '-') && s.chars().filter(|&c| c == '-').count() == 4 }