feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling
This commit is contained in:
512
src/mitm/h2_handler.rs
Normal file
512
src/mitm/h2_handler.rs
Normal file
@@ -0,0 +1,512 @@
|
||||
//! HTTP/2 handler for gRPC traffic interception.
|
||||
//!
|
||||
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
|
||||
//! this module handles the bidirectional HTTP/2 connection:
|
||||
//! 1. Accepts HTTP/2 frames from the client (LS)
|
||||
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
|
||||
//! 3. Forwards each request stream to upstream
|
||||
//! 4. For non-streaming: buffers response, extracts usage, forwards
|
||||
//! 5. For streaming: forwards response body chunks in real-time, tees to a
|
||||
//! side buffer for usage extraction after stream completes
|
||||
//!
|
||||
//! ## Streaming vs Non-streaming
|
||||
//!
|
||||
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
|
||||
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
|
||||
//! We MUST forward streaming responses immediately — buffering would break
|
||||
//! the LS's perception of real-time generation.
|
||||
//!
|
||||
//! For usage extraction: ModelUsageStats is typically in the LAST message
|
||||
//! of a streaming response, so we tee the data and parse after stream ends.
|
||||
|
||||
use crate::mitm::proto::parse_grpc_response_for_usage;
|
||||
use crate::mitm::store::{ApiUsage, MitmStore};
|
||||
|
||||
use bytes::Bytes;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::{Frame, Incoming};
|
||||
use hyper::server::conn::http2::Builder as H2ServerBuilder;
|
||||
use hyper::service::service_fn;
|
||||
use hyper::{Request, Response};
|
||||
use hyper_util::rt::TokioExecutor;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncRead, AsyncWrite};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
|
||||
///
|
||||
/// gRPC multiplexes many requests over a single HTTP/2 connection.
|
||||
/// We mirror this by maintaining a single upstream connection per domain.
|
||||
struct UpstreamPool {
|
||||
domain: String,
|
||||
tls_config: Arc<rustls::ClientConfig>,
|
||||
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
|
||||
}
|
||||
|
||||
impl UpstreamPool {
|
||||
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
|
||||
Self {
|
||||
domain,
|
||||
tls_config,
|
||||
sender: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or create the upstream HTTP/2 sender.
|
||||
async fn get_sender(
|
||||
&self,
|
||||
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||
let mut guard = self.sender.lock().await;
|
||||
|
||||
// Check if existing sender is still usable
|
||||
if let Some(ref sender) = *guard {
|
||||
if !sender.is_closed() {
|
||||
return Ok(sender.clone());
|
||||
}
|
||||
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
|
||||
}
|
||||
|
||||
// Create new connection
|
||||
let sender = self.connect().await?;
|
||||
*guard = Some(sender.clone());
|
||||
Ok(sender)
|
||||
}
|
||||
|
||||
async fn connect(
|
||||
&self,
|
||||
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
||||
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
|
||||
.await
|
||||
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
|
||||
|
||||
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
|
||||
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
|
||||
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
|
||||
|
||||
let upstream_tls = connector
|
||||
.connect(server_name, upstream_tcp)
|
||||
.await
|
||||
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
||||
|
||||
let upstream_io = TokioIo::new(upstream_tls);
|
||||
let (sender, conn) =
|
||||
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.handshake(upstream_io)
|
||||
.await
|
||||
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||
|
||||
let domain = self.domain.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = conn.await {
|
||||
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
|
||||
}
|
||||
});
|
||||
|
||||
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
|
||||
Ok(sender)
|
||||
}
|
||||
}
|
||||
|
||||
/// gRPC methods that carry ModelUsageStats in their responses.
|
||||
const USAGE_METHODS: &[&str] = &[
|
||||
// Unary methods
|
||||
"GenerateContent",
|
||||
"AsyncGenerateContent",
|
||||
"GenerateChat",
|
||||
"GenerateCode",
|
||||
"CompleteCode",
|
||||
"InternalAtomicAgenticChat",
|
||||
"Predict",
|
||||
"DirectPredict",
|
||||
// Streaming methods
|
||||
"StreamGenerateContent",
|
||||
"StreamAsyncGenerateContent",
|
||||
"StreamGenerateChat",
|
||||
];
|
||||
|
||||
/// Handle an HTTP/2 connection from the LS after TLS termination.
|
||||
///
|
||||
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
|
||||
/// HTTP/2 connection to forward them.
|
||||
pub async fn handle_h2_connection<S>(
|
||||
tls_stream: S,
|
||||
domain: String,
|
||||
store: MitmStore,
|
||||
) -> Result<(), String>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
||||
{
|
||||
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
|
||||
|
||||
// Build TLS config for upstream connections
|
||||
let mut root_store = rustls::RootCertStore::empty();
|
||||
let native_certs = rustls_native_certs::load_native_certs();
|
||||
for cert in native_certs.certs {
|
||||
let _ = root_store.add(cert);
|
||||
}
|
||||
let mut upstream_tls_config = rustls::ClientConfig::builder()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
|
||||
|
||||
// Shared upstream connection pool (single connection, multiplexed)
|
||||
let pool = Arc::new(UpstreamPool::new(
|
||||
domain.clone(),
|
||||
Arc::new(upstream_tls_config),
|
||||
));
|
||||
|
||||
let io = TokioIo::new(tls_stream);
|
||||
let domain = Arc::new(domain);
|
||||
|
||||
let result = H2ServerBuilder::new(TokioExecutor::new())
|
||||
.serve_connection(
|
||||
io,
|
||||
service_fn(move |req: Request<Incoming>| {
|
||||
let domain = domain.clone();
|
||||
let store = store.clone();
|
||||
let pool = pool.clone();
|
||||
async move { handle_h2_request(req, &domain, store, pool).await }
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(()) => {
|
||||
debug!("MITM H2: connection closed cleanly");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
// Connection errors are expected on clean close
|
||||
debug!(error = %e, "MITM H2: connection ended");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Response body type — either buffered or streaming.
|
||||
type BoxBody = http_body_util::Either<
|
||||
Full<Bytes>,
|
||||
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
|
||||
>;
|
||||
|
||||
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
|
||||
///
|
||||
/// For streaming responses, forwards chunks in real-time while teeing
|
||||
/// data to a side buffer for post-stream usage extraction.
|
||||
async fn handle_h2_request(
|
||||
req: Request<Incoming>,
|
||||
domain: &str,
|
||||
store: MitmStore,
|
||||
pool: Arc<UpstreamPool>,
|
||||
) -> Result<Response<BoxBody>, hyper::Error> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
let path = uri.path().to_string();
|
||||
|
||||
// Identify gRPC method
|
||||
let is_grpc = req
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|ct| ct.starts_with("application/grpc"))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Check if this method carries usage data
|
||||
let is_usage_method = is_grpc
|
||||
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||
|
||||
// Check if this is a streaming method
|
||||
let is_streaming = is_grpc
|
||||
&& (path.contains("Stream") || path.contains("stream"));
|
||||
|
||||
debug!(
|
||||
domain,
|
||||
%method,
|
||||
path = %path,
|
||||
grpc = is_grpc,
|
||||
usage_method = is_usage_method,
|
||||
streaming = is_streaming,
|
||||
"MITM H2: forwarding request"
|
||||
);
|
||||
|
||||
// Collect request body (we need it for cascade ID extraction)
|
||||
let (parts, body) = req.into_parts();
|
||||
let request_body = match body.collect().await {
|
||||
Ok(collected) => collected.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "MITM H2: failed to collect request body");
|
||||
Bytes::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Get upstream sender from pool
|
||||
let mut upstream_sender = match pool.get_sender().await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream connect failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
// Build the upstream request with proper authority
|
||||
let upstream_uri = http::Uri::builder()
|
||||
.scheme("https")
|
||||
.authority(domain)
|
||||
.path_and_query(
|
||||
uri.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/"),
|
||||
)
|
||||
.build()
|
||||
.unwrap_or(uri);
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_uri);
|
||||
|
||||
// Copy headers, skip hop-by-hop
|
||||
for (name, value) in &parts.headers {
|
||||
let n = name.as_str();
|
||||
if n == "host" || n == "connection" || n == "transfer-encoding" {
|
||||
continue;
|
||||
}
|
||||
upstream_req = upstream_req.header(name, value);
|
||||
}
|
||||
|
||||
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("build request failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
// Send to upstream
|
||||
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream request failed: {e}")),
|
||||
)))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
};
|
||||
|
||||
let (resp_parts, resp_body) = upstream_resp.into_parts();
|
||||
let status = resp_parts.status;
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
// Streaming path: forward chunks immediately, tee for usage parsing
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
if is_streaming && status.is_success() {
|
||||
let should_track_usage = is_usage_method;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
|
||||
|
||||
let store_clone = store.clone();
|
||||
let path_clone = path.clone();
|
||||
let request_body_clone = request_body.clone();
|
||||
|
||||
// Spawn a task to forward body chunks and tee for usage extraction
|
||||
tokio::spawn(async move {
|
||||
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
||||
let mut body = resp_body;
|
||||
|
||||
loop {
|
||||
match body.frame().await {
|
||||
Some(Ok(frame)) => {
|
||||
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
||||
buf.extend_from_slice(data);
|
||||
}
|
||||
if tx.send(Ok(frame)).await.is_err() {
|
||||
break; // client disconnected
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
|
||||
let _ = tx.send(Err(e)).await;
|
||||
break;
|
||||
}
|
||||
None => break, // stream ended
|
||||
}
|
||||
}
|
||||
|
||||
// Stream completed — parse the tee buffer for usage
|
||||
if let Some(tee_buffer) = tee_buffer {
|
||||
if !tee_buffer.is_empty() {
|
||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
||||
let usage = ApiUsage {
|
||||
input_tokens: grpc_usage.input_tokens,
|
||||
output_tokens: grpc_usage.output_tokens,
|
||||
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||
response_output_tokens: grpc_usage.response_output_tokens,
|
||||
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||
model: grpc_usage.model,
|
||||
api_provider: grpc_usage.api_provider,
|
||||
grpc_method: Some(path_clone.clone()),
|
||||
stop_reason: None,
|
||||
total_cost_usd: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
||||
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
let stream_body = StreamBody::new(stream);
|
||||
|
||||
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||
for (name, value) in &resp_parts.headers {
|
||||
client_resp = client_resp.header(name, value);
|
||||
}
|
||||
|
||||
let client_resp = client_resp
|
||||
.body(http_body_util::Either::Right(stream_body))
|
||||
.unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(500)
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
"internal error",
|
||||
))))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
return Ok(client_resp);
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
// Non-streaming path: buffer full response, extract usage, forward
|
||||
// ──────────────────────────────────────────────────────────────────
|
||||
let response_body = match resp_body.collect().await {
|
||||
Ok(collected) => collected.to_bytes(),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "MITM H2: failed to collect response body");
|
||||
Bytes::new()
|
||||
}
|
||||
};
|
||||
|
||||
trace!(
|
||||
domain,
|
||||
path = %path,
|
||||
status = %status,
|
||||
body_len = response_body.len(),
|
||||
"MITM H2: got upstream response"
|
||||
);
|
||||
|
||||
// Extract usage data from usage-carrying gRPC methods
|
||||
if is_usage_method && !response_body.is_empty() && status.is_success() {
|
||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
|
||||
let usage = ApiUsage {
|
||||
input_tokens: grpc_usage.input_tokens,
|
||||
output_tokens: grpc_usage.output_tokens,
|
||||
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
||||
response_output_tokens: grpc_usage.response_output_tokens,
|
||||
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
||||
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
||||
model: grpc_usage.model,
|
||||
api_provider: grpc_usage.api_provider,
|
||||
grpc_method: Some(path.clone()),
|
||||
stop_reason: None,
|
||||
total_cost_usd: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
|
||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Build response for the client
|
||||
let mut client_resp = Response::builder().status(resp_parts.status);
|
||||
for (name, value) in &resp_parts.headers {
|
||||
client_resp = client_resp.header(name, value);
|
||||
}
|
||||
|
||||
let client_resp = client_resp
|
||||
.body(http_body_util::Either::Left(Full::new(response_body)))
|
||||
.unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(500)
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
"internal error",
|
||||
))))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
Ok(client_resp)
|
||||
}
|
||||
|
||||
/// Try to extract a cascade ID from a gRPC request body.
|
||||
///
|
||||
/// Looks for UUID-formatted strings in the protobuf fields.
|
||||
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
|
||||
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
|
||||
|
||||
let messages = extract_grpc_messages(body);
|
||||
for msg in &messages {
|
||||
let fields = decode_proto(msg);
|
||||
for field in &fields {
|
||||
if let Some(id) = extract_uuid_from_field(field) {
|
||||
return Some(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
|
||||
use crate::mitm::proto::ProtoValue;
|
||||
|
||||
match &field.value {
|
||||
ProtoValue::Bytes(b) => {
|
||||
if let Ok(s) = std::str::from_utf8(b) {
|
||||
if is_uuid(s) {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
ProtoValue::Message(nested) => {
|
||||
for nf in nested {
|
||||
if let Some(id) = extract_uuid_from_field(nf) {
|
||||
return Some(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_uuid(s: &str) -> bool {
|
||||
s.len() == 36
|
||||
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||
&& s.chars().filter(|&c| c == '-').count() == 4
|
||||
}
|
||||
Reference in New Issue
Block a user