Files
zerogravity/src/mitm/h2_handler.rs

513 lines
18 KiB
Rust

//! HTTP/2 handler for gRPC traffic interception.
//!
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
//! this module handles the bidirectional HTTP/2 connection:
//! 1. Accepts HTTP/2 frames from the client (LS)
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
//! 3. Forwards each request stream to upstream
//! 4. For non-streaming: buffers response, extracts usage, forwards
//! 5. For streaming: forwards response body chunks in real-time, tees to a
//! side buffer for usage extraction after stream completes
//!
//! ## Streaming vs Non-streaming
//!
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
//! We MUST forward streaming responses immediately — buffering would break
//! the LS's perception of real-time generation.
//!
//! For usage extraction: ModelUsageStats is typically in the LAST message
//! of a streaming response, so we tee the data and parse after stream ends.
use crate::mitm::proto::parse_grpc_response_for_usage;
use crate::mitm::store::{ApiUsage, MitmStore};
use bytes::Bytes;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::{Frame, Incoming};
use hyper::server::conn::http2::Builder as H2ServerBuilder;
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use std::sync::Arc;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tracing::{debug, info, trace, warn};
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
///
/// gRPC multiplexes many requests over a single HTTP/2 connection.
/// We mirror this by maintaining a single upstream connection per domain.
struct UpstreamPool {
domain: String,
tls_config: Arc<rustls::ClientConfig>,
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
}
impl UpstreamPool {
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
Self {
domain,
tls_config,
sender: Mutex::new(None),
}
}
/// Get or create the upstream HTTP/2 sender.
async fn get_sender(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let mut guard = self.sender.lock().await;
// Check if existing sender is still usable
if let Some(ref sender) = *guard {
if !sender.is_closed() {
return Ok(sender.clone());
}
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
}
// Create new connection
let sender = self.connect().await?;
*guard = Some(sender.clone());
Ok(sender)
}
async fn connect(
&self,
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
.await
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
let upstream_tls = connector
.connect(server_name, upstream_tcp)
.await
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
let upstream_io = TokioIo::new(upstream_tls);
let (sender, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(upstream_io)
.await
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let domain = self.domain.clone();
tokio::spawn(async move {
if let Err(e) = conn.await {
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
}
});
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
Ok(sender)
}
}
/// gRPC methods that carry ModelUsageStats in their responses.
const USAGE_METHODS: &[&str] = &[
// Unary methods
"GenerateContent",
"AsyncGenerateContent",
"GenerateChat",
"GenerateCode",
"CompleteCode",
"InternalAtomicAgenticChat",
"Predict",
"DirectPredict",
// Streaming methods
"StreamGenerateContent",
"StreamAsyncGenerateContent",
"StreamGenerateChat",
];
/// Handle an HTTP/2 connection from the LS after TLS termination.
///
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
/// HTTP/2 connection to forward them.
pub async fn handle_h2_connection<S>(
tls_stream: S,
domain: String,
store: MitmStore,
) -> Result<(), String>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
// Build TLS config for upstream connections
let mut root_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in native_certs.certs {
let _ = root_store.add(cert);
}
let mut upstream_tls_config = rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth();
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
// Shared upstream connection pool (single connection, multiplexed)
let pool = Arc::new(UpstreamPool::new(
domain.clone(),
Arc::new(upstream_tls_config),
));
let io = TokioIo::new(tls_stream);
let domain = Arc::new(domain);
let result = H2ServerBuilder::new(TokioExecutor::new())
.serve_connection(
io,
service_fn(move |req: Request<Incoming>| {
let domain = domain.clone();
let store = store.clone();
let pool = pool.clone();
async move { handle_h2_request(req, &domain, store, pool).await }
}),
)
.await;
match result {
Ok(()) => {
debug!("MITM H2: connection closed cleanly");
Ok(())
}
Err(e) => {
// Connection errors are expected on clean close
debug!(error = %e, "MITM H2: connection ended");
Ok(())
}
}
}
/// Response body type — either buffered or streaming.
type BoxBody = http_body_util::Either<
Full<Bytes>,
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
>;
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
///
/// For streaming responses, forwards chunks in real-time while teeing
/// data to a side buffer for post-stream usage extraction.
async fn handle_h2_request(
req: Request<Incoming>,
domain: &str,
store: MitmStore,
pool: Arc<UpstreamPool>,
) -> Result<Response<BoxBody>, hyper::Error> {
let method = req.method().clone();
let uri = req.uri().clone();
let path = uri.path().to_string();
// Identify gRPC method
let is_grpc = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|ct| ct.starts_with("application/grpc"))
.unwrap_or(false);
// Check if this method carries usage data
let is_usage_method = is_grpc
&& USAGE_METHODS.iter().any(|m| path.contains(m));
// Check if this is a streaming method
let is_streaming = is_grpc
&& (path.contains("Stream") || path.contains("stream"));
debug!(
domain,
%method,
path = %path,
grpc = is_grpc,
usage_method = is_usage_method,
streaming = is_streaming,
"MITM H2: forwarding request"
);
// Collect request body (we need it for cascade ID extraction)
let (parts, body) = req.into_parts();
let request_body = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect request body");
Bytes::new()
}
};
// Get upstream sender from pool
let mut upstream_sender = match pool.get_sender().await {
Ok(s) => s,
Err(e) => {
warn!(error = %e, domain, "MITM H2: upstream connect failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream connect failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Build the upstream request with proper authority
let upstream_uri = http::Uri::builder()
.scheme("https")
.authority(domain)
.path_and_query(
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/"),
)
.build()
.unwrap_or(uri);
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_uri);
// Copy headers, skip hop-by-hop
for (name, value) in &parts.headers {
let n = name.as_str();
if n == "host" || n == "connection" || n == "transfer-encoding" {
continue;
}
upstream_req = upstream_req.header(name, value);
}
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
Ok(r) => r,
Err(e) => {
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("build request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
// Send to upstream
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
Ok(r) => r,
Err(e) => {
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream request failed: {e}")),
)))
.unwrap();
return Ok(resp);
}
};
let (resp_parts, resp_body) = upstream_resp.into_parts();
let status = resp_parts.status;
// ──────────────────────────────────────────────────────────────────
// Streaming path: forward chunks immediately, tee for usage parsing
// ──────────────────────────────────────────────────────────────────
if is_streaming && status.is_success() {
let should_track_usage = is_usage_method;
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
let store_clone = store.clone();
let path_clone = path.clone();
let request_body_clone = request_body.clone();
// Spawn a task to forward body chunks and tee for usage extraction
tokio::spawn(async move {
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
let mut body = resp_body;
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
buf.extend_from_slice(data);
}
if tx.send(Ok(frame)).await.is_err() {
break; // client disconnected
}
}
Some(Err(e)) => {
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
let _ = tx.send(Err(e)).await;
break;
}
None => break, // stream ended
}
}
// Stream completed — parse the tee buffer for usage
if let Some(tee_buffer) = tee_buffer {
if !tee_buffer.is_empty() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path_clone.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
}
}
}
});
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let stream_body = StreamBody::new(stream);
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Right(stream_body))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
return Ok(client_resp);
}
// ──────────────────────────────────────────────────────────────────
// Non-streaming path: buffer full response, extract usage, forward
// ──────────────────────────────────────────────────────────────────
let response_body = match resp_body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(e) => {
warn!(error = %e, "MITM H2: failed to collect response body");
Bytes::new()
}
};
trace!(
domain,
path = %path,
status = %status,
body_len = response_body.len(),
"MITM H2: got upstream response"
);
// Extract usage data from usage-carrying gRPC methods
if is_usage_method && !response_body.is_empty() && status.is_success() {
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
let usage = ApiUsage {
input_tokens: grpc_usage.input_tokens,
output_tokens: grpc_usage.output_tokens,
thinking_output_tokens: grpc_usage.thinking_output_tokens,
response_output_tokens: grpc_usage.response_output_tokens,
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
cache_read_input_tokens: grpc_usage.cache_read_tokens,
model: grpc_usage.model,
api_provider: grpc_usage.api_provider,
grpc_method: Some(path.clone()),
stop_reason: None,
total_cost_usd: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
store.record_usage(cascade_hint.as_deref(), usage).await;
}
}
// Build response for the client
let mut client_resp = Response::builder().status(resp_parts.status);
for (name, value) in &resp_parts.headers {
client_resp = client_resp.header(name, value);
}
let client_resp = client_resp
.body(http_body_util::Either::Left(Full::new(response_body)))
.unwrap_or_else(|_| {
Response::builder()
.status(500)
.body(http_body_util::Either::Left(Full::new(Bytes::from(
"internal error",
))))
.unwrap()
});
Ok(client_resp)
}
/// Try to extract a cascade ID from a gRPC request body.
///
/// Looks for UUID-formatted strings in the protobuf fields.
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
let messages = extract_grpc_messages(body);
for msg in &messages {
let fields = decode_proto(msg);
for field in &fields {
if let Some(id) = extract_uuid_from_field(field) {
return Some(id);
}
}
}
None
}
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
use crate::mitm::proto::ProtoValue;
match &field.value {
ProtoValue::Bytes(b) => {
if let Ok(s) = std::str::from_utf8(b) {
if is_uuid(s) {
return Some(s.to_string());
}
}
}
ProtoValue::Message(nested) => {
for nf in nested {
if let Some(id) = extract_uuid_from_field(nf) {
return Some(id);
}
}
}
_ => {}
}
None
}
fn is_uuid(s: &str) -> bool {
s.len() == 36
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
&& s.chars().filter(|&c| c == '-').count() == 4
}