513 lines
18 KiB
Rust
513 lines
18 KiB
Rust
//! HTTP/2 handler for gRPC traffic interception.
|
|
//!
|
|
//! When the LS negotiates HTTP/2 via ALPN (which all gRPC connections do),
|
|
//! this module handles the bidirectional HTTP/2 connection:
|
|
//! 1. Accepts HTTP/2 frames from the client (LS)
|
|
//! 2. Connects to the real upstream via TLS + HTTP/2 (single connection reused)
|
|
//! 3. Forwards each request stream to upstream
|
|
//! 4. For non-streaming: buffers response, extracts usage, forwards
|
|
//! 5. For streaming: forwards response body chunks in real-time, tees to a
|
|
//! side buffer for usage extraction after stream completes
|
|
//!
|
|
//! ## Streaming vs Non-streaming
|
|
//!
|
|
//! gRPC has both unary (non-streaming) and server-streaming RPCs.
|
|
//! The LS uses server-streaming for methods like `StreamGenerateContent`.
|
|
//! We MUST forward streaming responses immediately — buffering would break
|
|
//! the LS's perception of real-time generation.
|
|
//!
|
|
//! For usage extraction: ModelUsageStats is typically in the LAST message
|
|
//! of a streaming response, so we tee the data and parse after stream ends.
|
|
|
|
use crate::mitm::proto::parse_grpc_response_for_usage;
|
|
use crate::mitm::store::{ApiUsage, MitmStore};
|
|
|
|
use bytes::Bytes;
|
|
use http_body_util::{BodyExt, Full, StreamBody};
|
|
use hyper::body::{Frame, Incoming};
|
|
use hyper::server::conn::http2::Builder as H2ServerBuilder;
|
|
use hyper::service::service_fn;
|
|
use hyper::{Request, Response};
|
|
use hyper_util::rt::TokioExecutor;
|
|
use hyper_util::rt::TokioIo;
|
|
use std::sync::Arc;
|
|
use tokio::io::{AsyncRead, AsyncWrite};
|
|
use tokio::net::TcpStream;
|
|
use tokio::sync::Mutex;
|
|
use tracing::{debug, info, trace, warn};
|
|
|
|
/// A lazily-initialized, shared HTTP/2 connection to the upstream server.
|
|
///
|
|
/// gRPC multiplexes many requests over a single HTTP/2 connection.
|
|
/// We mirror this by maintaining a single upstream connection per domain.
|
|
struct UpstreamPool {
|
|
domain: String,
|
|
tls_config: Arc<rustls::ClientConfig>,
|
|
sender: Mutex<Option<hyper::client::conn::http2::SendRequest<Full<Bytes>>>>,
|
|
}
|
|
|
|
impl UpstreamPool {
|
|
fn new(domain: String, tls_config: Arc<rustls::ClientConfig>) -> Self {
|
|
Self {
|
|
domain,
|
|
tls_config,
|
|
sender: Mutex::new(None),
|
|
}
|
|
}
|
|
|
|
/// Get or create the upstream HTTP/2 sender.
|
|
async fn get_sender(
|
|
&self,
|
|
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
|
let mut guard = self.sender.lock().await;
|
|
|
|
// Check if existing sender is still usable
|
|
if let Some(ref sender) = *guard {
|
|
if !sender.is_closed() {
|
|
return Ok(sender.clone());
|
|
}
|
|
debug!(domain = %self.domain, "MITM H2: upstream connection closed, reconnecting");
|
|
}
|
|
|
|
// Create new connection
|
|
let sender = self.connect().await?;
|
|
*guard = Some(sender.clone());
|
|
Ok(sender)
|
|
}
|
|
|
|
async fn connect(
|
|
&self,
|
|
) -> Result<hyper::client::conn::http2::SendRequest<Full<Bytes>>, String> {
|
|
let upstream_tcp = TcpStream::connect(format!("{}:443", self.domain))
|
|
.await
|
|
.map_err(|e| format!("upstream TCP connect to {} failed: {e}", self.domain))?;
|
|
|
|
let connector = tokio_rustls::TlsConnector::from(self.tls_config.clone());
|
|
let server_name = rustls::pki_types::ServerName::try_from(self.domain.clone())
|
|
.map_err(|e| format!("invalid domain {}: {e}", self.domain))?;
|
|
|
|
let upstream_tls = connector
|
|
.connect(server_name, upstream_tcp)
|
|
.await
|
|
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
|
|
|
let upstream_io = TokioIo::new(upstream_tls);
|
|
let (sender, conn) =
|
|
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
|
.handshake(upstream_io)
|
|
.await
|
|
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
|
|
|
let domain = self.domain.clone();
|
|
tokio::spawn(async move {
|
|
if let Err(e) = conn.await {
|
|
debug!(domain = %domain, error = %e, "MITM H2: upstream connection driver ended");
|
|
}
|
|
});
|
|
|
|
info!(domain = %self.domain, "MITM H2: established upstream HTTP/2 connection");
|
|
Ok(sender)
|
|
}
|
|
}
|
|
|
|
/// gRPC methods that carry ModelUsageStats in their responses.
|
|
const USAGE_METHODS: &[&str] = &[
|
|
// Unary methods
|
|
"GenerateContent",
|
|
"AsyncGenerateContent",
|
|
"GenerateChat",
|
|
"GenerateCode",
|
|
"CompleteCode",
|
|
"InternalAtomicAgenticChat",
|
|
"Predict",
|
|
"DirectPredict",
|
|
// Streaming methods
|
|
"StreamGenerateContent",
|
|
"StreamAsyncGenerateContent",
|
|
"StreamGenerateChat",
|
|
];
|
|
|
|
/// Handle an HTTP/2 connection from the LS after TLS termination.
|
|
///
|
|
/// Uses hyper's HTTP/2 server to accept requests and a shared upstream
|
|
/// HTTP/2 connection to forward them.
|
|
pub async fn handle_h2_connection<S>(
|
|
tls_stream: S,
|
|
domain: String,
|
|
store: MitmStore,
|
|
) -> Result<(), String>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
|
|
{
|
|
info!(domain = %domain, "MITM H2: handling HTTP/2 connection");
|
|
|
|
// Build TLS config for upstream connections
|
|
let mut root_store = rustls::RootCertStore::empty();
|
|
let native_certs = rustls_native_certs::load_native_certs();
|
|
for cert in native_certs.certs {
|
|
let _ = root_store.add(cert);
|
|
}
|
|
let mut upstream_tls_config = rustls::ClientConfig::builder()
|
|
.with_root_certificates(root_store)
|
|
.with_no_client_auth();
|
|
upstream_tls_config.alpn_protocols = vec![b"h2".to_vec()];
|
|
|
|
// Shared upstream connection pool (single connection, multiplexed)
|
|
let pool = Arc::new(UpstreamPool::new(
|
|
domain.clone(),
|
|
Arc::new(upstream_tls_config),
|
|
));
|
|
|
|
let io = TokioIo::new(tls_stream);
|
|
let domain = Arc::new(domain);
|
|
|
|
let result = H2ServerBuilder::new(TokioExecutor::new())
|
|
.serve_connection(
|
|
io,
|
|
service_fn(move |req: Request<Incoming>| {
|
|
let domain = domain.clone();
|
|
let store = store.clone();
|
|
let pool = pool.clone();
|
|
async move { handle_h2_request(req, &domain, store, pool).await }
|
|
}),
|
|
)
|
|
.await;
|
|
|
|
match result {
|
|
Ok(()) => {
|
|
debug!("MITM H2: connection closed cleanly");
|
|
Ok(())
|
|
}
|
|
Err(e) => {
|
|
// Connection errors are expected on clean close
|
|
debug!(error = %e, "MITM H2: connection ended");
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Response body type — either buffered or streaming.
|
|
type BoxBody = http_body_util::Either<
|
|
Full<Bytes>,
|
|
StreamBody<tokio_stream::wrappers::ReceiverStream<Result<Frame<Bytes>, hyper::Error>>>,
|
|
>;
|
|
|
|
/// Handle a single HTTP/2 request: forward to upstream, capture usage.
|
|
///
|
|
/// For streaming responses, forwards chunks in real-time while teeing
|
|
/// data to a side buffer for post-stream usage extraction.
|
|
async fn handle_h2_request(
|
|
req: Request<Incoming>,
|
|
domain: &str,
|
|
store: MitmStore,
|
|
pool: Arc<UpstreamPool>,
|
|
) -> Result<Response<BoxBody>, hyper::Error> {
|
|
let method = req.method().clone();
|
|
let uri = req.uri().clone();
|
|
let path = uri.path().to_string();
|
|
|
|
// Identify gRPC method
|
|
let is_grpc = req
|
|
.headers()
|
|
.get("content-type")
|
|
.and_then(|v| v.to_str().ok())
|
|
.map(|ct| ct.starts_with("application/grpc"))
|
|
.unwrap_or(false);
|
|
|
|
// Check if this method carries usage data
|
|
let is_usage_method = is_grpc
|
|
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
|
|
|
// Check if this is a streaming method
|
|
let is_streaming = is_grpc
|
|
&& (path.contains("Stream") || path.contains("stream"));
|
|
|
|
debug!(
|
|
domain,
|
|
%method,
|
|
path = %path,
|
|
grpc = is_grpc,
|
|
usage_method = is_usage_method,
|
|
streaming = is_streaming,
|
|
"MITM H2: forwarding request"
|
|
);
|
|
|
|
// Collect request body (we need it for cascade ID extraction)
|
|
let (parts, body) = req.into_parts();
|
|
let request_body = match body.collect().await {
|
|
Ok(collected) => collected.to_bytes(),
|
|
Err(e) => {
|
|
warn!(error = %e, "MITM H2: failed to collect request body");
|
|
Bytes::new()
|
|
}
|
|
};
|
|
|
|
// Get upstream sender from pool
|
|
let mut upstream_sender = match pool.get_sender().await {
|
|
Ok(s) => s,
|
|
Err(e) => {
|
|
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
|
let resp = Response::builder()
|
|
.status(502)
|
|
.body(http_body_util::Either::Left(Full::new(
|
|
Bytes::from(format!("upstream connect failed: {e}")),
|
|
)))
|
|
.unwrap();
|
|
return Ok(resp);
|
|
}
|
|
};
|
|
|
|
// Build the upstream request with proper authority
|
|
let upstream_uri = http::Uri::builder()
|
|
.scheme("https")
|
|
.authority(domain)
|
|
.path_and_query(
|
|
uri.path_and_query()
|
|
.map(|pq| pq.as_str())
|
|
.unwrap_or("/"),
|
|
)
|
|
.build()
|
|
.unwrap_or(uri);
|
|
|
|
let mut upstream_req = Request::builder()
|
|
.method(parts.method)
|
|
.uri(upstream_uri);
|
|
|
|
// Copy headers, skip hop-by-hop
|
|
for (name, value) in &parts.headers {
|
|
let n = name.as_str();
|
|
if n == "host" || n == "connection" || n == "transfer-encoding" {
|
|
continue;
|
|
}
|
|
upstream_req = upstream_req.header(name, value);
|
|
}
|
|
|
|
let upstream_req = match upstream_req.body(Full::new(request_body.clone())) {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
let resp = Response::builder()
|
|
.status(502)
|
|
.body(http_body_util::Either::Left(Full::new(
|
|
Bytes::from(format!("build request failed: {e}")),
|
|
)))
|
|
.unwrap();
|
|
return Ok(resp);
|
|
}
|
|
};
|
|
|
|
// Send to upstream
|
|
let upstream_resp = match upstream_sender.send_request(upstream_req).await {
|
|
Ok(r) => r,
|
|
Err(e) => {
|
|
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
|
let resp = Response::builder()
|
|
.status(502)
|
|
.body(http_body_util::Either::Left(Full::new(
|
|
Bytes::from(format!("upstream request failed: {e}")),
|
|
)))
|
|
.unwrap();
|
|
return Ok(resp);
|
|
}
|
|
};
|
|
|
|
let (resp_parts, resp_body) = upstream_resp.into_parts();
|
|
let status = resp_parts.status;
|
|
|
|
// ──────────────────────────────────────────────────────────────────
|
|
// Streaming path: forward chunks immediately, tee for usage parsing
|
|
// ──────────────────────────────────────────────────────────────────
|
|
if is_streaming && status.is_success() {
|
|
let should_track_usage = is_usage_method;
|
|
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Frame<Bytes>, hyper::Error>>(32);
|
|
|
|
let store_clone = store.clone();
|
|
let path_clone = path.clone();
|
|
let request_body_clone = request_body.clone();
|
|
|
|
// Spawn a task to forward body chunks and tee for usage extraction
|
|
tokio::spawn(async move {
|
|
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
|
let mut body = resp_body;
|
|
|
|
loop {
|
|
match body.frame().await {
|
|
Some(Ok(frame)) => {
|
|
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
|
buf.extend_from_slice(data);
|
|
}
|
|
if tx.send(Ok(frame)).await.is_err() {
|
|
break; // client disconnected
|
|
}
|
|
}
|
|
Some(Err(e)) => {
|
|
warn!(error = %e, path = %path_clone, "MITM H2: streaming error");
|
|
let _ = tx.send(Err(e)).await;
|
|
break;
|
|
}
|
|
None => break, // stream ended
|
|
}
|
|
}
|
|
|
|
// Stream completed — parse the tee buffer for usage
|
|
if let Some(tee_buffer) = tee_buffer {
|
|
if !tee_buffer.is_empty() {
|
|
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
|
let usage = ApiUsage {
|
|
input_tokens: grpc_usage.input_tokens,
|
|
output_tokens: grpc_usage.output_tokens,
|
|
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
|
response_output_tokens: grpc_usage.response_output_tokens,
|
|
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
|
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
|
model: grpc_usage.model,
|
|
api_provider: grpc_usage.api_provider,
|
|
grpc_method: Some(path_clone.clone()),
|
|
stop_reason: None,
|
|
total_cost_usd: None,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
};
|
|
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
|
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
|
let stream_body = StreamBody::new(stream);
|
|
|
|
let mut client_resp = Response::builder().status(resp_parts.status);
|
|
for (name, value) in &resp_parts.headers {
|
|
client_resp = client_resp.header(name, value);
|
|
}
|
|
|
|
let client_resp = client_resp
|
|
.body(http_body_util::Either::Right(stream_body))
|
|
.unwrap_or_else(|_| {
|
|
Response::builder()
|
|
.status(500)
|
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
|
"internal error",
|
|
))))
|
|
.unwrap()
|
|
});
|
|
|
|
return Ok(client_resp);
|
|
}
|
|
|
|
// ──────────────────────────────────────────────────────────────────
|
|
// Non-streaming path: buffer full response, extract usage, forward
|
|
// ──────────────────────────────────────────────────────────────────
|
|
let response_body = match resp_body.collect().await {
|
|
Ok(collected) => collected.to_bytes(),
|
|
Err(e) => {
|
|
warn!(error = %e, "MITM H2: failed to collect response body");
|
|
Bytes::new()
|
|
}
|
|
};
|
|
|
|
trace!(
|
|
domain,
|
|
path = %path,
|
|
status = %status,
|
|
body_len = response_body.len(),
|
|
"MITM H2: got upstream response"
|
|
);
|
|
|
|
// Extract usage data from usage-carrying gRPC methods
|
|
if is_usage_method && !response_body.is_empty() && status.is_success() {
|
|
if let Some(grpc_usage) = parse_grpc_response_for_usage(&response_body) {
|
|
let usage = ApiUsage {
|
|
input_tokens: grpc_usage.input_tokens,
|
|
output_tokens: grpc_usage.output_tokens,
|
|
thinking_output_tokens: grpc_usage.thinking_output_tokens,
|
|
response_output_tokens: grpc_usage.response_output_tokens,
|
|
cache_creation_input_tokens: grpc_usage.cache_write_tokens,
|
|
cache_read_input_tokens: grpc_usage.cache_read_tokens,
|
|
model: grpc_usage.model,
|
|
api_provider: grpc_usage.api_provider,
|
|
grpc_method: Some(path.clone()),
|
|
stop_reason: None,
|
|
total_cost_usd: None,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
};
|
|
|
|
let cascade_hint = extract_cascade_from_grpc_request(&request_body);
|
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
|
}
|
|
}
|
|
|
|
// Build response for the client
|
|
let mut client_resp = Response::builder().status(resp_parts.status);
|
|
for (name, value) in &resp_parts.headers {
|
|
client_resp = client_resp.header(name, value);
|
|
}
|
|
|
|
let client_resp = client_resp
|
|
.body(http_body_util::Either::Left(Full::new(response_body)))
|
|
.unwrap_or_else(|_| {
|
|
Response::builder()
|
|
.status(500)
|
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
|
"internal error",
|
|
))))
|
|
.unwrap()
|
|
});
|
|
|
|
Ok(client_resp)
|
|
}
|
|
|
|
/// Try to extract a cascade ID from a gRPC request body.
|
|
///
|
|
/// Looks for UUID-formatted strings in the protobuf fields.
|
|
fn extract_cascade_from_grpc_request(body: &[u8]) -> Option<String> {
|
|
use crate::mitm::proto::{decode_proto, extract_grpc_messages};
|
|
|
|
let messages = extract_grpc_messages(body);
|
|
for msg in &messages {
|
|
let fields = decode_proto(msg);
|
|
for field in &fields {
|
|
if let Some(id) = extract_uuid_from_field(field) {
|
|
return Some(id);
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn extract_uuid_from_field(field: &crate::mitm::proto::ProtoField) -> Option<String> {
|
|
use crate::mitm::proto::ProtoValue;
|
|
|
|
match &field.value {
|
|
ProtoValue::Bytes(b) => {
|
|
if let Ok(s) = std::str::from_utf8(b) {
|
|
if is_uuid(s) {
|
|
return Some(s.to_string());
|
|
}
|
|
}
|
|
}
|
|
ProtoValue::Message(nested) => {
|
|
for nf in nested {
|
|
if let Some(id) = extract_uuid_from_field(nf) {
|
|
return Some(id);
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
None
|
|
}
|
|
|
|
fn is_uuid(s: &str) -> bool {
|
|
s.len() == 36
|
|
&& s.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
|
&& s.chars().filter(|&c| c == '-').count() == 4
|
|
}
|