feat: initial commit — antigravity proxy with MITM, standalone LS, and snapshot tooling
This commit is contained in:
584
src/mitm/proto.rs
Normal file
584
src/mitm/proto.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
//! Raw protobuf decoder for extracting ModelUsageStats from gRPC responses.
|
||||
//!
|
||||
//! We don't have the .proto schema, so we decode protobuf messages generically
|
||||
//! and search for usage-like structures by matching field patterns.
|
||||
//!
|
||||
//! gRPC wire format:
|
||||
//! - 1 byte: compression flag (0 = uncompressed, 1 = compressed)
|
||||
//! - 4 bytes: message length (big-endian u32)
|
||||
//! - N bytes: protobuf message
|
||||
//!
|
||||
//! Protobuf wire format:
|
||||
//! - Each field: (field_number << 3 | wire_type) as varint, then value
|
||||
//! - Wire type 0: varint
|
||||
//! - Wire type 1: 64-bit fixed
|
||||
//! - Wire type 2: length-delimited (string, bytes, embedded message)
|
||||
//! - Wire type 5: 32-bit fixed
|
||||
//!
|
||||
//! ## ModelUsageStats schema (reverse-engineered from LS binary):
|
||||
//!
|
||||
//! ```protobuf
|
||||
//! message ModelUsageStats {
|
||||
//! Model model = 1; // enum (varint)
|
||||
//! uint64 input_tokens = 2;
|
||||
//! uint64 output_tokens = 3;
|
||||
//! uint64 cache_write_tokens = 4;
|
||||
//! uint64 cache_read_tokens = 5;
|
||||
//! APIProvider api_provider = 6; // enum (varint)
|
||||
//! string message_id = 7;
|
||||
//! map<string,string> response_header = 8; // repeated message
|
||||
//! uint64 thinking_output_tokens = 9;
|
||||
//! uint64 response_output_tokens = 10;
|
||||
//! string response_id = 11;
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
/// A decoded protobuf field.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProtoValue {
|
||||
Varint(u64),
|
||||
#[allow(dead_code)]
|
||||
Fixed64(u64),
|
||||
#[allow(dead_code)]
|
||||
Fixed32(u32),
|
||||
Bytes(Vec<u8>),
|
||||
/// Nested message (parsed recursively)
|
||||
Message(Vec<ProtoField>),
|
||||
}
|
||||
|
||||
/// A single protobuf field with its number and value.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProtoField {
|
||||
pub number: u32,
|
||||
pub value: ProtoValue,
|
||||
}
|
||||
|
||||
/// Extracted usage data from a gRPC response.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct GrpcUsage {
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub thinking_output_tokens: u64,
|
||||
pub response_output_tokens: u64,
|
||||
pub cache_read_tokens: u64,
|
||||
pub cache_write_tokens: u64,
|
||||
pub model: Option<String>,
|
||||
pub api_provider: Option<String>,
|
||||
pub message_id: Option<String>,
|
||||
pub response_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Extract gRPC message frames from a buffer.
|
||||
///
|
||||
/// A gRPC message is:
|
||||
/// [1 byte compressed flag] [4 bytes length BE] [N bytes protobuf]
|
||||
///
|
||||
/// Multiple messages can be concatenated in a single buffer.
|
||||
/// If compressed flag is 1, the message is gzip-decompressed.
|
||||
pub fn extract_grpc_messages(data: &[u8]) -> Vec<Vec<u8>> {
|
||||
let mut messages = Vec::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset + 5 <= data.len() {
|
||||
let compressed = data[offset];
|
||||
let length = u32::from_be_bytes([
|
||||
data[offset + 1],
|
||||
data[offset + 2],
|
||||
data[offset + 3],
|
||||
data[offset + 4],
|
||||
]) as usize;
|
||||
|
||||
offset += 5;
|
||||
|
||||
if offset + length > data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let payload = &data[offset..offset + length];
|
||||
|
||||
if compressed == 1 {
|
||||
// gzip-compressed frame
|
||||
let mut decoder = GzDecoder::new(payload);
|
||||
let mut decompressed = Vec::new();
|
||||
match decoder.read_to_end(&mut decompressed) {
|
||||
Ok(_) => messages.push(decompressed),
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Proto: failed to decompress gRPC frame");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
messages.push(payload.to_vec());
|
||||
}
|
||||
|
||||
offset += length;
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
/// Decode a protobuf message into a list of fields.
|
||||
///
|
||||
/// This is a best-effort decoder that handles the common wire types.
|
||||
/// Embedded messages (wire type 2) are attempted to be parsed recursively.
|
||||
pub fn decode_proto(data: &[u8]) -> Vec<ProtoField> {
|
||||
let mut fields = Vec::new();
|
||||
let mut offset = 0;
|
||||
|
||||
while offset < data.len() {
|
||||
// Read tag (varint)
|
||||
let (tag, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
|
||||
let field_number = (tag >> 3) as u32;
|
||||
let wire_type = (tag & 0x07) as u8;
|
||||
|
||||
if field_number == 0 {
|
||||
break; // invalid
|
||||
}
|
||||
|
||||
let value = match wire_type {
|
||||
0 => {
|
||||
// Varint
|
||||
let (val, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
ProtoValue::Varint(val)
|
||||
}
|
||||
1 => {
|
||||
// 64-bit fixed
|
||||
if offset + 8 > data.len() {
|
||||
break;
|
||||
}
|
||||
let val = u64::from_le_bytes(data[offset..offset + 8].try_into().unwrap());
|
||||
offset += 8;
|
||||
ProtoValue::Fixed64(val)
|
||||
}
|
||||
2 => {
|
||||
// Length-delimited
|
||||
let (len, bytes_read) = match read_varint(&data[offset..]) {
|
||||
Some(v) => v,
|
||||
None => break,
|
||||
};
|
||||
offset += bytes_read;
|
||||
let len = len as usize;
|
||||
|
||||
if offset + len > data.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let payload = &data[offset..offset + len];
|
||||
offset += len;
|
||||
|
||||
// Try to parse as a nested message
|
||||
let nested = decode_proto(payload);
|
||||
if !nested.is_empty() && looks_like_valid_message(&nested, payload.len()) {
|
||||
ProtoValue::Message(nested)
|
||||
} else {
|
||||
ProtoValue::Bytes(payload.to_vec())
|
||||
}
|
||||
}
|
||||
5 => {
|
||||
// 32-bit fixed
|
||||
if offset + 4 > data.len() {
|
||||
break;
|
||||
}
|
||||
let val = u32::from_le_bytes(data[offset..offset + 4].try_into().unwrap());
|
||||
offset += 4;
|
||||
ProtoValue::Fixed32(val)
|
||||
}
|
||||
_ => {
|
||||
// Unknown wire type — stop parsing
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
fields.push(ProtoField {
|
||||
number: field_number,
|
||||
value,
|
||||
});
|
||||
}
|
||||
|
||||
fields
|
||||
}
|
||||
|
||||
/// Heuristic: does this list of fields look like a valid protobuf message?
|
||||
/// (vs. a random string that happened to partially decode)
|
||||
fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool {
|
||||
if fields.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that field numbers are reasonable (< 10000)
|
||||
let valid_numbers = fields.iter().all(|f| f.number < 10000);
|
||||
if !valid_numbers {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If we have very few fields relative to the data size, it's probably not a message
|
||||
// (e.g., a long string that happened to have a valid first-field prefix)
|
||||
if fields.len() == 1 && original_len > 100 {
|
||||
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
||||
match &fields[0].value {
|
||||
ProtoValue::Bytes(_) | ProtoValue::Message(_) => true,
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
|
||||
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
|
||||
let mut result: u64 = 0;
|
||||
let mut shift = 0;
|
||||
|
||||
for (i, &byte) in data.iter().enumerate() {
|
||||
if i >= 10 {
|
||||
return None; // Too many bytes for a varint
|
||||
}
|
||||
|
||||
result |= ((byte & 0x7F) as u64) << shift;
|
||||
shift += 7;
|
||||
|
||||
if byte & 0x80 == 0 {
|
||||
return Some((result, i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Search a decoded protobuf message tree for usage-like structures.
|
||||
///
|
||||
/// Uses the exact field numbers from the reverse-engineered ModelUsageStats schema:
|
||||
///
|
||||
/// field 1: model (enum/varint)
|
||||
/// field 2: input_tokens (uint64)
|
||||
/// field 3: output_tokens (uint64)
|
||||
/// field 4: cache_write_tokens (uint64)
|
||||
/// field 5: cache_read_tokens (uint64)
|
||||
/// field 6: api_provider (enum/varint)
|
||||
/// field 7: message_id (string)
|
||||
/// field 8: response_header (map, repeated message)
|
||||
/// field 9: thinking_output_tokens (uint64)
|
||||
/// field 10: response_output_tokens (uint64)
|
||||
/// field 11: response_id (string)
|
||||
pub fn extract_usage_from_proto(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// Strategy: recursively search for any sub-message that looks like usage data
|
||||
// Try this level first
|
||||
if let Some(usage) = try_extract_usage(fields) {
|
||||
return Some(usage);
|
||||
}
|
||||
|
||||
// Recurse into nested messages
|
||||
for field in fields {
|
||||
if let ProtoValue::Message(ref nested) = field.value {
|
||||
if let Some(usage) = extract_usage_from_proto(nested) {
|
||||
return Some(usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Try to extract usage from this specific set of fields.
|
||||
///
|
||||
/// Uses verified field numbers from the binary's embedded proto descriptor.
|
||||
fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// We need:
|
||||
// - At least 2 varint fields with values in token range
|
||||
// - Ideally field 2 (input_tokens) or field 3 (output_tokens) present
|
||||
let varint_fields: Vec<_> = fields
|
||||
.iter()
|
||||
.filter(|f| matches!(f.value, ProtoValue::Varint(_)))
|
||||
.collect();
|
||||
|
||||
let string_fields: Vec<_> = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
if let ProtoValue::Bytes(ref b) = f.value {
|
||||
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Need at least 2 varint fields to be a candidate
|
||||
if varint_fields.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if the varint values make sense as token counts
|
||||
let plausible_token_count = |v: u64| v <= 10_000_000;
|
||||
let plausible_varints = varint_fields
|
||||
.iter()
|
||||
.filter(|f| {
|
||||
if let ProtoValue::Varint(v) = f.value {
|
||||
plausible_token_count(v) && v > 0
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.count();
|
||||
|
||||
// Need at least 2 non-zero plausible values
|
||||
if plausible_varints < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
||||
// can contain model names, or model enum values map to known names)
|
||||
let has_model_string = string_fields.iter().any(|(_, s)| {
|
||||
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|
||||
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|
||||
|| s.contains("flash") || s.contains("pro")
|
||||
});
|
||||
|
||||
// Check for fields at the known ModelUsageStats field numbers
|
||||
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
|
||||
// Strong signal: has both input and output token fields
|
||||
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
||||
|
||||
if !is_likely_usage && varint_fields.len() < 3 {
|
||||
// Without strong signal, need more fields
|
||||
return None;
|
||||
}
|
||||
|
||||
// Build usage from exact field numbers (verified from binary)
|
||||
let mut usage = GrpcUsage::default();
|
||||
|
||||
for field in fields {
|
||||
match &field.value {
|
||||
ProtoValue::Varint(v) => {
|
||||
let v = *v;
|
||||
if !plausible_token_count(v) {
|
||||
continue;
|
||||
}
|
||||
match field.number {
|
||||
// field 1 = model enum (varint, not string!)
|
||||
2 => usage.input_tokens = v,
|
||||
3 => usage.output_tokens = v,
|
||||
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||
// field 6 = api_provider enum (varint)
|
||||
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
||||
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
ProtoValue::Bytes(ref b) => {
|
||||
if let Ok(s) = std::str::from_utf8(b) {
|
||||
match field.number {
|
||||
7 => usage.message_id = Some(s.to_string()),
|
||||
11 => usage.response_id = Some(s.to_string()),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Model and api_provider are enums (varints), not strings
|
||||
// We can map known enum values later if needed
|
||||
// For now, extract the enum value as a string representation
|
||||
for field in fields {
|
||||
if let ProtoValue::Varint(v) = &field.value {
|
||||
match field.number {
|
||||
1 => {
|
||||
// Model enum — we don't have the mapping, store as number
|
||||
usage.model = Some(format!("model_enum_{v}"));
|
||||
}
|
||||
6 => {
|
||||
// APIProvider enum
|
||||
usage.api_provider = Some(match *v {
|
||||
0 => "unknown".to_string(),
|
||||
1 => "google".to_string(),
|
||||
2 => "anthropic".to_string(),
|
||||
_ => format!("provider_{v}"),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate — we should have at least input OR output tokens
|
||||
if usage.input_tokens == 0 && usage.output_tokens == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
debug!(
|
||||
input = usage.input_tokens,
|
||||
output = usage.output_tokens,
|
||||
thinking = usage.thinking_output_tokens,
|
||||
response = usage.response_output_tokens,
|
||||
cache_read = usage.cache_read_tokens,
|
||||
cache_write = usage.cache_write_tokens,
|
||||
model = ?usage.model,
|
||||
api_provider = ?usage.api_provider,
|
||||
"Proto: extracted ModelUsageStats from protobuf"
|
||||
);
|
||||
|
||||
Some(usage)
|
||||
}
|
||||
|
||||
/// Parse a gRPC response body (may contain multiple messages) for usage data.
|
||||
///
|
||||
/// Handles both compressed and uncompressed gRPC frames.
|
||||
pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
|
||||
let messages = extract_grpc_messages(body);
|
||||
|
||||
trace!(count = messages.len(), "Proto: extracted gRPC messages");
|
||||
|
||||
// Check each message for usage data (last message usually has it)
|
||||
for msg in messages.iter().rev() {
|
||||
let fields = decode_proto(msg);
|
||||
if let Some(usage) = extract_usage_from_proto(&fields) {
|
||||
return Some(usage);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_read_varint() {
|
||||
assert_eq!(read_varint(&[0x00]), Some((0, 1)));
|
||||
assert_eq!(read_varint(&[0x01]), Some((1, 1)));
|
||||
assert_eq!(read_varint(&[0x96, 0x01]), Some((150, 2)));
|
||||
assert_eq!(read_varint(&[0xAC, 0x02]), Some((300, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_grpc_messages_uncompressed() {
|
||||
// Construct a test gRPC frame: [0x00] [0x00, 0x00, 0x00, 0x05] [5 bytes data]
|
||||
let mut buf = vec![0u8]; // not compressed
|
||||
buf.extend_from_slice(&5u32.to_be_bytes());
|
||||
buf.extend_from_slice(&[0x08, 0x96, 0x01, 0x10, 0x42]); // field 1 varint 150, field 2 varint 66
|
||||
|
||||
let messages = extract_grpc_messages(&buf);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0].len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_grpc_messages_compressed() {
|
||||
use flate2::write::GzEncoder;
|
||||
use flate2::Compression;
|
||||
use std::io::Write;
|
||||
|
||||
// Create a payload
|
||||
let payload = vec![0x08, 0x96, 0x01, 0x10, 0x42];
|
||||
|
||||
// Compress it
|
||||
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
|
||||
encoder.write_all(&payload).unwrap();
|
||||
let compressed = encoder.finish().unwrap();
|
||||
|
||||
// Build gRPC frame with compressed flag
|
||||
let mut buf = vec![1u8]; // compressed
|
||||
buf.extend_from_slice(&(compressed.len() as u32).to_be_bytes());
|
||||
buf.extend_from_slice(&compressed);
|
||||
|
||||
let messages = extract_grpc_messages(&buf);
|
||||
assert_eq!(messages.len(), 1);
|
||||
assert_eq!(messages[0], payload);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_proto_varints() {
|
||||
// field 1 = 150, field 2 = 66
|
||||
let data = [0x08, 0x96, 0x01, 0x10, 0x42];
|
||||
let fields = decode_proto(&data);
|
||||
assert_eq!(fields.len(), 2);
|
||||
assert_eq!(fields[0].number, 1);
|
||||
assert!(matches!(fields[0].value, ProtoValue::Varint(150)));
|
||||
assert_eq!(fields[1].number, 2);
|
||||
assert!(matches!(fields[1].value, ProtoValue::Varint(66)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_proto_with_string() {
|
||||
// field 1 = "hello" (string), field 2 = varint 42
|
||||
let mut data = Vec::new();
|
||||
// field 1, wire type 2 (length-delimited)
|
||||
data.push(0x0A); // (1 << 3) | 2
|
||||
data.push(0x05); // length 5
|
||||
data.extend_from_slice(b"hello");
|
||||
// field 2, wire type 0 (varint)
|
||||
data.push(0x10); // (2 << 3) | 0
|
||||
data.push(0x2A); // 42
|
||||
|
||||
let fields = decode_proto(&data);
|
||||
assert!(fields.len() >= 2);
|
||||
assert_eq!(fields[0].number, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_usage_correct_field_numbers() {
|
||||
// Build a mock ModelUsageStats with the correct field numbers:
|
||||
// field 1 (model enum) = 5 (some model)
|
||||
// field 2 (input_tokens) = 1000
|
||||
// field 3 (output_tokens) = 500
|
||||
// field 4 (cache_write_tokens) = 100
|
||||
// field 5 (cache_read_tokens) = 200
|
||||
// field 9 (thinking_output_tokens) = 300
|
||||
// field 10 (response_output_tokens) = 200
|
||||
let mut data = Vec::new();
|
||||
|
||||
// Helper: encode varint field
|
||||
fn encode_varint_field(data: &mut Vec<u8>, field_num: u32, value: u64) {
|
||||
// Tag
|
||||
let tag = (field_num << 3) | 0; // wire type 0
|
||||
let mut t = tag;
|
||||
while t >= 0x80 {
|
||||
data.push((t as u8) | 0x80);
|
||||
t >>= 7;
|
||||
}
|
||||
data.push(t as u8);
|
||||
// Value
|
||||
let mut v = value;
|
||||
while v >= 0x80 {
|
||||
data.push((v as u8) | 0x80);
|
||||
v >>= 7;
|
||||
}
|
||||
data.push(v as u8);
|
||||
}
|
||||
|
||||
encode_varint_field(&mut data, 1, 5); // model enum
|
||||
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||
|
||||
let fields = decode_proto(&data);
|
||||
let usage = try_extract_usage(&fields).expect("should extract usage");
|
||||
|
||||
assert_eq!(usage.input_tokens, 1000);
|
||||
assert_eq!(usage.output_tokens, 500);
|
||||
assert_eq!(usage.cache_write_tokens, 100);
|
||||
assert_eq!(usage.cache_read_tokens, 200);
|
||||
assert_eq!(usage.thinking_output_tokens, 300);
|
||||
assert_eq!(usage.response_output_tokens, 200);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user