Files
zerogravity/src/mitm/proto.rs
Nikketryhard 48674f65da refactor: decompose large functions and remove dead code
- Decompose modify_request() into 7 single-responsibility helpers
- Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events
- Promote connect_upstream/resolve_upstream to module-level functions
- Split standalone.rs (1238 lines) into 4 submodules:
  standalone/mod.rs, spawn.rs, discovery.rs, stub.rs
- Extract proto wire primitives into proto/wire.rs
- Remove 6 dead MitmStore methods
- Remove dead SessionResult, DEFAULT_SESSION, get_or_create
- Remove dead decode_varint_at, extract_conversation_id
- Clean all unused imports across 10 files
- Suppress structural dead_code warnings on deserialization fields

Warnings: 20 -> 0. All 43 tests pass.
2026-02-17 22:27:26 -06:00

649 lines
22 KiB
Rust

//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
//!
//! We don't have the .proto schema, so we decode protobuf messages generically
//! and search for usage-like structures by matching field patterns.
//!
//! gRPC wire format:
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
//! - 4 bytes: message length (big-endian u32)
//! - N bytes: protobuf message
//!
//! Protobuf wire format:
//! - Each field: (field_number << 3 | wire_type) as varint, then value
//! - Wire type 0: varint
//! - Wire type 1: 64-bit fixed
//! - Wire type 2: length-delimited (string, bytes, embedded message)
//! - Wire type 5: 32-bit fixed
//!
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
//!
//! ```protobuf
//! message ModelUsageStats {
//! Model model = 1; // enum (varint)
//! uint64 input_tokens = 2;
//! uint64 output_tokens = 3;
//! uint64 cache_write_tokens = 4;
//! uint64 cache_read_tokens = 5;
//! APIProvider api_provider = 6; // enum (varint)
//! string message_id = 7;
//! map<string,string> response_header = 8; // repeated message
//! uint64 thinking_output_tokens = 9;
//! uint64 response_output_tokens = 10;
//! string response_id = 11;
//! }
//! ```
use flate2::read::GzDecoder;
use std::io::Read;
use tracing::{debug, trace, warn};
// Re-import the shared varint decoder under the name used throughout this module
use crate::proto::wire::decode_varint as read_varint;
/// A decoded protobuf field.
#[derive(Debug, Clone)]
pub enum ProtoValue {
Varint(u64),
#[allow(dead_code)]
Fixed64(u64),
#[allow(dead_code)]
Fixed32(u32),
Bytes(Vec<u8>),
/// Nested message (parsed recursively)
Message(Vec<ProtoField>),
}
/// A single protobuf field with its number and value.
#[derive(Debug, Clone)]
pub struct ProtoField {
pub number: u32,
pub value: ProtoValue,
}
/// Extracted usage data from a gRPC response.
#[derive(Debug, Default)]
pub struct GrpcUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub thinking_output_tokens: u64,
pub response_output_tokens: u64,
pub cache_read_tokens: u64,
pub cache_write_tokens: u64,
pub model: Option<String>,
pub api_provider: Option<String>,
pub message_id: Option<String>,
pub response_id: Option<String>,
}
impl GrpcUsage {
/// Convert to a full `ApiUsage` record, attaching the gRPC method path.
pub fn into_api_usage(self, grpc_method: String) -> super::store::ApiUsage {
super::store::ApiUsage {
input_tokens: self.input_tokens,
output_tokens: self.output_tokens,
thinking_output_tokens: self.thinking_output_tokens,
thinking_text: None, // gRPC proto doesn't carry thinking text
response_text: None,
response_output_tokens: self.response_output_tokens,
cache_creation_input_tokens: self.cache_write_tokens,
cache_read_input_tokens: self.cache_read_tokens,
model: self.model,
api_provider: self.api_provider,
grpc_method: Some(grpc_method),
stop_reason: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
}
}
}
/// Extract gRPC message frames from a buffer.
///
/// A gRPC message is:
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
///
/// Multiple messages can be concatenated in a single buffer.
/// If compressed flag is 1, the message is gzip-decompressed.
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
let mut messages = Vec::new();
let mut offset = 0;
while offset + 5 <= data.len() {
let compressed = data[offset];
let length = u32::from_be_bytes([
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
]) as usize;
offset += 5;
if offset + length > data.len() {
break;
}
let payload = &data[offset..offset + length];
if compressed == 1 {
// gzip-compressed frame
let mut decoder = GzDecoder::new(payload);
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => messages.push(decompressed),
Err(e) => {
warn!(error = %e, "Proto: failed to decompress gRPC frame");
}
}
} else {
messages.push(payload.to_vec());
}
offset += length;
}
messages
}
/// Decode a protobuf message into a list of fields.
///
/// This is a best-effort decoder that handles the common wire types.
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
let mut fields = Vec::new();
let mut offset = 0;
while offset < data.len() {
// Read tag (varint)
let (tag, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let field_number = (tag >> 3) as u32;
let wire_type = (tag & 0x07) as u8;
if field_number == 0 {
break; // invalid
}
let value = match wire_type {
0 => {
// Varint
let (val, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
ProtoValue::Varint(val)
}
1 => {
// 64-bit fixed
if offset + 8 > data.len() {
break;
}
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
offset += 8;
ProtoValue::Fixed64(val)
}
2 => {
// Length-delimited
let (len, bytes_read) = match read_varint(&data[offset..]) {
Some(v) => v,
None => break,
};
offset += bytes_read;
let len = len as usize;
if offset + len > data.len() {
break;
}
let payload = &data[offset..offset + len];
offset += len;
// Try to parse as a nested message
let nested = decode_proto(payload);
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
ProtoValue::Message(nested)
} else {
ProtoValue::Bytes(payload.to_vec())
}
}
5 => {
// 32-bit fixed
if offset + 4 > data.len() {
break;
}
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
offset += 4;
ProtoValue::Fixed32(val)
}
_ => {
// Unknown wire type — stop parsing
break;
}
};
fields.push(ProtoField {
number: field_number,
value,
});
}
fields
}
/// Heuristic: does this list of fields look like a valid protobuf message?
/// (vs. a random string that happened to partially decode)
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
if fields.is_empty() {
return false;
}
// Check that field numbers are reasonable (< 10000)
let valid_numbers = fields.iter().all(|f| f.number < 10000);
if !valid_numbers {
return false;
}
// If we have very few fields relative to the data size, it's probably not a message
// (e.g., a long string that happened to have a valid first-field prefix)
if fields.len() == 1 && original_len > 100 {
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
matches!(
&fields[0].value,
ProtoValue::Bytes(_) | ProtoValue::Message(_)
)
} else {
true
}
}
/// Search a decoded protobuf message tree for usage-like structures.
///
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
///
/// field 1: model (enum/varint)
/// field 2: input_tokens (uint64)
/// field 3: output_tokens (uint64)
/// field 4: cache_write_tokens (uint64)
/// field 5: cache_read_tokens (uint64)
/// field 6: api_provider (enum/varint)
/// field 7: message_id (string)
/// field 8: response_header (map, repeated message)
/// field 9: thinking_output_tokens (uint64)
/// field 10: response_output_tokens (uint64)
/// field 11: response_id (string)
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
// Strategy: recursively search for any sub-message that looks like usage data
// Try this level first
if let Some(usage) = try_extract_usage(fields) {
return Some(usage);
}
// Recurse into nested messages
for field in fields {
if let ProtoValue::Message(ref nested) = field.value {
if let Some(usage) = extract_usage_from_proto(nested) {
return Some(usage);
}
}
}
None
}
/// Try to extract usage from this specific set of fields.
///
/// Uses verified field numbers from the binary's embedded proto descriptor.
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// We need:
// - At least 2 varint fields with values in token range
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
let varint_fields: Vec<_> = fields
.iter()
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
.collect();
let string_fields: Vec<_> = fields
.iter()
.filter_map(|f| {
if let ProtoValue::Bytes(ref b) = f.value {
std::str::from_utf8(b)
.ok()
.map(|s| (f.number, s.to_string()))
} else {
None
}
})
.collect();
// Need at least 2 varint fields to be a candidate
if varint_fields.len() < 2 {
return None;
}
// Check if the varint values make sense as token counts
let plausible_token_count = |v: u64| v <= 10_000_000;
let plausible_varints = varint_fields
.iter()
.filter(|f| {
if let ProtoValue::Varint(v) = f.value {
plausible_token_count(v) && v > 0
} else {
false
}
})
.count();
// Need at least 2 non-zero plausible values
if plausible_varints < 2 {
return None;
}
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
// can contain model names, or model enum values map to known names)
let has_model_string = string_fields.iter().any(|(_, s)| {
s.contains("claude")
|| s.contains("gemini")
|| s.contains("gpt")
|| s.starts_with("models/")
|| s.contains("sonnet")
|| s.contains("opus")
|| s.contains("flash")
|| s.contains("pro")
});
// Check for fields at the known ModelUsageStats field numbers
let has_field_2 = fields
.iter()
.any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields
.iter()
.any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
// Strong signal: has both input and output token fields
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
if !is_likely_usage && varint_fields.len() < 3 {
// Without strong signal, need more fields
return None;
}
// Build usage from exact field numbers (verified from binary)
let mut usage = GrpcUsage::default();
for field in fields {
match &field.value {
ProtoValue::Varint(v) => {
let v = *v;
if !plausible_token_count(v) {
continue;
}
match field.number {
// field 1 = model enum (varint, not string!)
2 => usage.input_tokens = v,
3 => usage.output_tokens = v,
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
// field 6 = api_provider enum (varint)
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
10 => usage.response_output_tokens = v, // VERIFIED: field 10
_ => {}
}
}
ProtoValue::Bytes(ref b) => {
if let Ok(s) = std::str::from_utf8(b) {
match field.number {
7 => usage.message_id = Some(s.to_string()),
11 => usage.response_id = Some(s.to_string()),
_ => {}
}
}
}
_ => {}
}
}
// Model and api_provider are enums (varints), not strings
// We can map known enum values later if needed
// For now, extract the enum value as a string representation
for field in fields {
if let ProtoValue::Varint(v) = &field.value {
match field.number {
1 => {
// Model proto enum → human-readable name
// See docs/ls-binary-analysis.md for full mapping
usage.model = Some(model_enum_name(*v).to_string());
}
6 => {
// APIProvider enum
usage.api_provider = Some(match *v {
0 => "unknown".to_string(),
1 => "google".to_string(),
2 => "anthropic".to_string(),
_ => format!("provider_{v}"),
});
}
_ => {}
}
}
}
// Validate — we should have at least input OR output tokens
if usage.input_tokens == 0 && usage.output_tokens == 0 {
return None;
}
debug!(
input = usage.input_tokens,
output = usage.output_tokens,
thinking = usage.thinking_output_tokens,
response = usage.response_output_tokens,
cache_read = usage.cache_read_tokens,
cache_write = usage.cache_write_tokens,
model = ?usage.model,
api_provider = ?usage.api_provider,
"Proto: extracted ModelUsageStats from protobuf"
);
Some(usage)
}
/// Parse a gRPC response body (may contain multiple messages) for usage data.
///
/// Handles both compressed and uncompressed gRPC frames.
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
let messages = extract_grpc_messages(body);
trace!(count = messages.len(), "Proto: extracted gRPC messages");
// Check each message for usage data (last message usually has it)
for msg in messages.iter().rev() {
let fields = decode_proto(msg);
if let Some(usage) = extract_usage_from_proto(&fields) {
return Some(usage);
}
}
None
}
// ─── Model enum → name mapping ──────────────────────────────────────────────
/// Map a proto model enum number to a human-readable name.
///
/// Numbers extracted from extension.js protobuf definitions.
/// See `docs/ls-binary-analysis.md` for full catalog.
fn model_enum_name(enum_val: u64) -> &'static str {
match enum_val {
// Placeholder models (1000 + N)
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
// Claude models (named)
281 => "claude-4-sonnet",
282 => "claude-4-sonnet-thinking",
290 => "claude-4-opus",
291 => "claude-4-opus-thinking",
333 => "claude-4.5-sonnet",
334 => "claude-4.5-sonnet-thinking",
340 => "claude-4.5-haiku",
341 => "claude-4.5-haiku-thinking",
// Google models (named)
246 => "gemini-2.5-pro",
312 => "gemini-2.5-flash",
313 => "gemini-2.5-flash-thinking",
329 => "gemini-2.5-flash-thinking-tools",
330 => "gemini-2.5-flash-lite",
335 => "gemini-computer-use-experimental",
342 => "openai-gpt-oss-120b",
346 => "jarvis-proxy",
348 => "gemini-riftrunner",
352 => "gemini-riftrunner-thinking-low",
353 => "gemini-riftrunner-thinking-high",
// Unknown — return a static leak to avoid format!() in a &'static str context
// This is fine because the match arm handles it
_ => Box::leak(format!("model_enum_{enum_val}").into_boxed_str()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_varint() {
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
}
#[test]
fn test_extract_grpc_messages_uncompressed() {
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
let mut buf = vec![0u8]; // not compressed
buf.extend_from_slice(&5u32.to_be_bytes());
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].len(), 5);
}
#[test]
fn test_extract_grpc_messages_compressed() {
use flate2::write::GzEncoder;
use flate2::Compression;
use std::io::Write;
// Create a payload
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
// Compress it
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(&payload).unwrap();
let compressed = encoder.finish().unwrap();
// Build gRPC frame with compressed flag
let mut buf = vec![1u8]; // compressed
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
buf.extend_from_slice(&compressed);
let messages = extract_grpc_messages(&buf);
assert_eq!(messages.len(), 1);
assert_eq!(messages[0], payload);
}
#[test]
fn test_decode_proto_varints() {
// field 1 = 150, field 2 = 66
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
let fields = decode_proto(&data);
assert_eq!(fields.len(), 2);
assert_eq!(fields[0].number, 1);
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
assert_eq!(fields[1].number, 2);
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
}
#[test]
fn test_decode_proto_with_string() {
// field 1 = "hello" (string), field 2 = varint 42
let mut data = Vec::new();
// field 1, wire type 2 (length-delimited)
data.push(0x0A); // (1 << 3) | 2
data.push(0x05); // length 5
data.extend_from_slice(b"hello");
// field 2, wire type 0 (varint)
data.push(0x10); // (2 << 3) | 0
data.push(0x2A); // 42
let fields = decode_proto(&data);
assert!(fields.len() >= 2);
assert_eq!(fields[0].number, 1);
}
#[test]
fn test_extract_usage_correct_field_numbers() {
// Build a mock ModelUsageStats with the correct field numbers:
// field 1 (model enum) = 5 (some model)
// field 2 (input_tokens) = 1000
// field 3 (output_tokens) = 500
// field 4 (cache_write_tokens) = 100
// field 5 (cache_read_tokens) = 200
// field 9 (thinking_output_tokens) = 300
// field 10 (response_output_tokens) = 200
let mut data = Vec::new();
// Helper: encode varint field
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
// Tag
let tag = (field_num << 3) | 0; // wire type 0
let mut t = tag;
while t >= 0x80 {
data.push((t as u8) | 0x80);
t >>= 7;
}
data.push(t as u8);
// Value
let mut v = value;
while v >= 0x80 {
data.push((v as u8) | 0x80);
v >>= 7;
}
data.push(v as u8);
}
encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens
let fields = decode_proto(&data);
let usage = try_extract_usage(&fields).expect("should extract usage");
assert_eq!(usage.input_tokens, 1000);
assert_eq!(usage.output_tokens, 500);
assert_eq!(usage.cache_write_tokens, 100);
assert_eq!(usage.cache_read_tokens, 200);
assert_eq!(usage.thinking_output_tokens, 300);
assert_eq!(usage.response_output_tokens, 200);
}
}