650 lines
26 KiB
Rust
650 lines
26 KiB
Rust
//! API response interceptor: parses Anthropic/Google API responses to extract usage data.
|
|
//!
|
|
//! Handles both streaming (SSE) and non-streaming (JSON) responses.
|
|
|
|
use super::store::{ApiUsage, CapturedFunctionCall};
|
|
use serde_json::Value;
|
|
use tracing::{debug, info, trace};
|
|
|
|
/// Parse a complete (non-streaming) Anthropic Messages API response body.
|
|
///
|
|
/// Response format:
|
|
/// ```json
|
|
/// {
|
|
/// "id": "msg_...",
|
|
/// "type": "message",
|
|
/// "model": "claude-sonnet-4-20250514",
|
|
/// "usage": {
|
|
/// "input_tokens": 1234,
|
|
/// "output_tokens": 567,
|
|
/// "cache_creation_input_tokens": 0,
|
|
/// "cache_read_input_tokens": 890
|
|
/// },
|
|
/// "stop_reason": "end_turn"
|
|
/// }
|
|
/// ```
|
|
pub fn parse_non_streaming_response(body: &[u8]) -> Option<ApiUsage> {
|
|
let json: Value = serde_json::from_slice(body).ok()?;
|
|
extract_usage_from_message(&json)
|
|
}
|
|
|
|
/// Parse SSE events from a streaming response body chunk.
|
|
///
|
|
/// Handles chunked transfer encoding where JSON data may be split across
|
|
/// TCP reads. Buffers raw data in the accumulator and only parses
|
|
/// complete newline-terminated lines.
|
|
pub fn parse_streaming_chunk(chunk: &str, accumulator: &mut StreamingAccumulator) {
|
|
accumulator.pending_data.push_str(chunk);
|
|
|
|
// Extract and process all complete lines (terminated by \n).
|
|
// Leave any trailing partial line in the buffer for the next read.
|
|
loop {
|
|
let pos = match accumulator.pending_data.find('\n') {
|
|
Some(p) => p,
|
|
None => break,
|
|
};
|
|
|
|
let line = accumulator.pending_data[..pos]
|
|
.trim_end_matches('\r')
|
|
.to_string();
|
|
accumulator.pending_data = accumulator.pending_data[pos + 1..].to_string();
|
|
|
|
// Skip empty lines and chunked TE size lines (pure hex)
|
|
let t = line.trim();
|
|
if t.is_empty() || t.chars().all(|c| c.is_ascii_hexdigit()) {
|
|
continue;
|
|
}
|
|
|
|
if let Some(data) = t.strip_prefix("data: ") {
|
|
if data.trim() == "[DONE]" {
|
|
continue;
|
|
}
|
|
if let Ok(event) = serde_json::from_str::<Value>(data) {
|
|
accumulator.process_event(&event);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Accumulates usage data across streaming SSE events.
|
|
#[derive(Debug, Default)]
|
|
pub struct StreamingAccumulator {
|
|
pub input_tokens: u64,
|
|
pub output_tokens: u64,
|
|
pub cache_creation_input_tokens: u64,
|
|
pub cache_read_input_tokens: u64,
|
|
pub thinking_tokens: u64,
|
|
/// Accumulated thinking/reasoning text from the model.
|
|
pub thinking_text: String,
|
|
/// Accumulated response text (non-thinking parts).
|
|
/// Used to identify "thinking summary" calls in the v1internal API.
|
|
pub response_text: String,
|
|
pub model: Option<String>,
|
|
pub stop_reason: Option<String>,
|
|
pub is_complete: bool,
|
|
pub api_provider: Option<String>,
|
|
/// Captured function calls from Google's response.
|
|
pub function_calls: Vec<CapturedFunctionCall>,
|
|
/// Captured grounding metadata from Google Search grounding.
|
|
pub grounding_metadata: Option<serde_json::Value>,
|
|
/// Buffer for reassembling lines split across TCP reads.
|
|
pub pending_data: String,
|
|
/// Thinking signature (base64 opaque blob) from non-function-call response parts.
|
|
pub thinking_signature: Option<String>,
|
|
}
|
|
|
|
impl StreamingAccumulator {
|
|
pub fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
/// Process a single SSE event.
|
|
pub fn process_event(&mut self, event: &Value) {
|
|
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
|
|
if let Some(response) = event.get("response") {
|
|
// Extract usage metadata (each event has cumulative counts)
|
|
if let Some(usage) = response.get("usageMetadata") {
|
|
self.input_tokens = usage["promptTokenCount"]
|
|
.as_u64()
|
|
.unwrap_or(self.input_tokens);
|
|
self.output_tokens = usage["candidatesTokenCount"]
|
|
.as_u64()
|
|
.unwrap_or(self.output_tokens);
|
|
self.thinking_tokens = usage["thoughtsTokenCount"]
|
|
.as_u64()
|
|
.unwrap_or(self.thinking_tokens);
|
|
}
|
|
if let Some(model) = response["modelVersion"].as_str() {
|
|
self.model = Some(model.to_string());
|
|
}
|
|
if let Some(candidates) = response.get("candidates").and_then(|c| c.as_array()) {
|
|
for candidate in candidates {
|
|
if let Some(parts) = candidate["content"]["parts"].as_array() {
|
|
for part in parts {
|
|
// Public Gemini API: explicit thought flag
|
|
if part["thought"].as_bool() == Some(true) {
|
|
if let Some(text) = part["text"].as_str() {
|
|
self.thinking_text.push_str(text);
|
|
}
|
|
}
|
|
// Detect functionCall from Google (tool call response)
|
|
else if let Some(fc) = part.get("functionCall") {
|
|
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
|
|
let args = fc["args"].clone();
|
|
// thoughtSignature is a SIBLING of functionCall in the part,
|
|
// not nested inside functionCall
|
|
let thought_signature = part
|
|
.get("thoughtSignature")
|
|
.and_then(|v| v.as_str())
|
|
.map(|s| s.to_string());
|
|
info!(
|
|
tool_name = %name,
|
|
tool_args = %args,
|
|
has_thought_sig = thought_signature.is_some(),
|
|
"MITM: Google returned functionCall!"
|
|
);
|
|
self.function_calls.push(CapturedFunctionCall {
|
|
name,
|
|
args,
|
|
thought_signature,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
});
|
|
}
|
|
// Capture non-thinking response text
|
|
else {
|
|
// Capture thoughtSignature from response parts (not function call parts)
|
|
if let Some(sig) =
|
|
part.get("thoughtSignature").and_then(|v| v.as_str())
|
|
{
|
|
self.thinking_signature = Some(sig.to_string());
|
|
}
|
|
if let Some(text) = part["text"].as_str() {
|
|
if !text.is_empty() {
|
|
self.response_text.push_str(text);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Check for completion — any finishReason means response is done
|
|
if let Some(reason) = candidate["finishReason"].as_str() {
|
|
self.stop_reason = Some(reason.to_string());
|
|
self.is_complete = true;
|
|
// Log non-STOP finish reasons
|
|
if reason != "STOP" {
|
|
info!(finish_reason = reason, "MITM: non-STOP finish reason");
|
|
}
|
|
}
|
|
// Capture grounding metadata (Google Search grounding results)
|
|
if let Some(gm) = candidate.get("groundingMetadata") {
|
|
self.grounding_metadata = Some(gm.clone());
|
|
debug!(
|
|
has_search_queries = gm.get("searchEntryPoint").is_some(),
|
|
has_web_results = gm.get("groundingChunks").is_some(),
|
|
"MITM: captured grounding metadata"
|
|
);
|
|
}
|
|
}
|
|
}
|
|
self.api_provider = Some("google".to_string());
|
|
trace!(
|
|
input = self.input_tokens,
|
|
output = self.output_tokens,
|
|
thinking = self.thinking_tokens,
|
|
thinking_text_len = self.thinking_text.len(),
|
|
complete = self.is_complete,
|
|
"SSE Google: usage update"
|
|
);
|
|
return;
|
|
}
|
|
|
|
// ── Anthropic format: {"type": "message_start"|"message_delta"|"message_stop"} ──
|
|
let event_type = event["type"].as_str().unwrap_or("");
|
|
|
|
match event_type {
|
|
"message_start" => {
|
|
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
|
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
|
self.cache_creation_input_tokens =
|
|
usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
|
self.cache_read_input_tokens =
|
|
usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
|
}
|
|
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
|
self.model = Some(model.to_string());
|
|
}
|
|
self.api_provider = Some("anthropic".to_string());
|
|
trace!(input = self.input_tokens, "SSE Anthropic: message_start");
|
|
}
|
|
"message_delta" => {
|
|
if let Some(usage) = event.get("usage") {
|
|
self.output_tokens = usage["output_tokens"]
|
|
.as_u64()
|
|
.unwrap_or(self.output_tokens);
|
|
}
|
|
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
|
self.stop_reason = Some(reason.to_string());
|
|
}
|
|
}
|
|
"message_stop" => {
|
|
self.is_complete = true;
|
|
debug!(
|
|
input = self.input_tokens,
|
|
output = self.output_tokens,
|
|
model = ?self.model,
|
|
"SSE Anthropic: stream complete"
|
|
);
|
|
}
|
|
// Anthropic thinking content blocks
|
|
"content_block_delta" => {
|
|
// type: "thinking" delta contains thinking text
|
|
if event["delta"]["type"].as_str() == Some("thinking_delta") {
|
|
if let Some(text) = event["delta"]["thinking"].as_str() {
|
|
self.thinking_text.push_str(text);
|
|
}
|
|
}
|
|
}
|
|
"content_block_start" | "content_block_stop" | "ping" => {}
|
|
_ => {
|
|
trace!(event_type, "SSE: unknown event type");
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Convert accumulated data to an ApiUsage.
|
|
pub fn into_usage(self) -> ApiUsage {
|
|
let thinking_text = if self.thinking_text.is_empty() {
|
|
None
|
|
} else {
|
|
Some(self.thinking_text)
|
|
};
|
|
let response_text = if self.response_text.is_empty() {
|
|
None
|
|
} else {
|
|
Some(self.response_text)
|
|
};
|
|
ApiUsage {
|
|
input_tokens: self.input_tokens,
|
|
output_tokens: self.output_tokens,
|
|
cache_creation_input_tokens: self.cache_creation_input_tokens,
|
|
cache_read_input_tokens: self.cache_read_input_tokens,
|
|
thinking_output_tokens: self.thinking_tokens,
|
|
thinking_text,
|
|
response_text,
|
|
response_output_tokens: 0,
|
|
model: self.model,
|
|
stop_reason: self.stop_reason,
|
|
api_provider: self
|
|
.api_provider
|
|
.unwrap_or_else(|| "unknown".to_string())
|
|
.into(),
|
|
grpc_method: None,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
thinking_signature: self.thinking_signature,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Extract usage from a complete Message JSON object.
|
|
fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
|
|
let usage = msg.get("usage")?;
|
|
|
|
Some(ApiUsage {
|
|
input_tokens: usage["input_tokens"].as_u64().unwrap_or(0),
|
|
output_tokens: usage["output_tokens"].as_u64().unwrap_or(0),
|
|
cache_creation_input_tokens: usage["cache_creation_input_tokens"].as_u64().unwrap_or(0),
|
|
cache_read_input_tokens: usage["cache_read_input_tokens"].as_u64().unwrap_or(0),
|
|
thinking_output_tokens: 0,
|
|
thinking_text: None,
|
|
response_text: None,
|
|
response_output_tokens: 0,
|
|
model: msg["model"].as_str().map(|s| s.to_string()),
|
|
stop_reason: msg["stop_reason"].as_str().map(|s| s.to_string()),
|
|
api_provider: Some("anthropic".to_string()),
|
|
grpc_method: None,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
thinking_signature: None,
|
|
})
|
|
}
|
|
|
|
/// Try to identify a cascade ID from the request body.
|
|
///
|
|
/// Priority:
|
|
/// 1. `<cid:UUID>` marker embedded by our proxy in the user message content
|
|
/// 2. `requestId` field: `agent/{timestamp}/{cascade_uuid}/{sequence}` format
|
|
/// 3. `metadata.user_id` fallback
|
|
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
|
|
// Fast path: look for <cid:UUID> marker in raw bytes (avoid JSON parse)
|
|
let body_str = std::str::from_utf8(request_body).ok()?;
|
|
if let Some(start) = body_str.find("<cid:") {
|
|
let rest = &body_str[start + 5..];
|
|
if let Some(end) = rest.find('>') {
|
|
let candidate = &rest[..end];
|
|
// Validate UUID format
|
|
if candidate.len() == 36
|
|
&& candidate.chars().filter(|c| *c == '-').count() == 4
|
|
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
|
{
|
|
return Some(candidate.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
let json: Value = serde_json::from_slice(request_body).ok()?;
|
|
|
|
// Secondary: extract cascade UUID from requestId field
|
|
// Format: "agent/{timestamp}/{cascade_uuid}/{sequence}"
|
|
if let Some(request_id) = json.get("requestId").and_then(|v| v.as_str()) {
|
|
let parts: Vec<&str> = request_id.split('/').collect();
|
|
if parts.len() >= 3 {
|
|
let candidate = parts[2];
|
|
if candidate.len() == 36
|
|
&& candidate.chars().filter(|c| *c == '-').count() == 4
|
|
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
|
{
|
|
return Some(candidate.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
// Fallback: check metadata.user_id
|
|
if let Some(metadata) = json.get("metadata") {
|
|
if let Some(user_id) = metadata["user_id"].as_str() {
|
|
return Some(user_id.to_string());
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_parse_non_streaming() {
|
|
let body = r#"{
|
|
"id": "msg_123",
|
|
"type": "message",
|
|
"model": "claude-sonnet-4-20250514",
|
|
"usage": {
|
|
"input_tokens": 100,
|
|
"output_tokens": 50,
|
|
"cache_creation_input_tokens": 10,
|
|
"cache_read_input_tokens": 30
|
|
},
|
|
"stop_reason": "end_turn"
|
|
}"#;
|
|
|
|
let usage = parse_non_streaming_response(body.as_bytes()).unwrap();
|
|
assert_eq!(usage.input_tokens, 100);
|
|
assert_eq!(usage.output_tokens, 50);
|
|
assert_eq!(usage.cache_creation_input_tokens, 10);
|
|
assert_eq!(usage.cache_read_input_tokens, 30);
|
|
assert_eq!(usage.model.as_deref(), Some("claude-sonnet-4-20250514"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_streaming_accumulator() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
// message_start
|
|
let start = serde_json::json!({
|
|
"type": "message_start",
|
|
"message": {
|
|
"model": "claude-sonnet-4-20250514",
|
|
"usage": {
|
|
"input_tokens": 200,
|
|
"cache_creation_input_tokens": 5,
|
|
"cache_read_input_tokens": 50
|
|
}
|
|
}
|
|
});
|
|
acc.process_event(&start);
|
|
assert_eq!(acc.input_tokens, 200);
|
|
assert_eq!(acc.cache_read_input_tokens, 50);
|
|
|
|
// message_delta
|
|
let delta = serde_json::json!({
|
|
"type": "message_delta",
|
|
"delta": { "stop_reason": "end_turn" },
|
|
"usage": { "output_tokens": 75 }
|
|
});
|
|
acc.process_event(&delta);
|
|
assert_eq!(acc.output_tokens, 75);
|
|
|
|
// message_stop
|
|
let stop = serde_json::json!({ "type": "message_stop" });
|
|
acc.process_event(&stop);
|
|
assert!(acc.is_complete);
|
|
|
|
let usage = acc.into_usage();
|
|
assert_eq!(usage.input_tokens, 200);
|
|
assert_eq!(usage.output_tokens, 75);
|
|
assert_eq!(usage.api_provider, Some("anthropic".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_google_sse_single_event() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
let event = serde_json::json!({
|
|
"response": {
|
|
"candidates": [{"content": {"role": "model", "parts": [{"text": "4"}]}}],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 1514,
|
|
"candidatesTokenCount": 25,
|
|
"totalTokenCount": 1539,
|
|
"thoughtsTokenCount": 52
|
|
},
|
|
"modelVersion": "gemini-3-flash",
|
|
"responseId": "abc123"
|
|
},
|
|
"traceId": "trace456",
|
|
"metadata": {}
|
|
});
|
|
|
|
acc.process_event(&event);
|
|
assert_eq!(acc.input_tokens, 1514);
|
|
assert_eq!(acc.output_tokens, 25);
|
|
assert_eq!(acc.thinking_tokens, 52);
|
|
assert_eq!(acc.model, Some("gemini-3-flash".to_string()));
|
|
assert!(!acc.is_complete); // no finishReason yet
|
|
assert_eq!(acc.api_provider, Some("google".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_google_sse_multi_event_accumulation() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
// First event — partial response
|
|
let event1 = serde_json::json!({
|
|
"response": {
|
|
"candidates": [{"content": {"role": "model", "parts": [{"text": "Hello"}]}}],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 1514,
|
|
"candidatesTokenCount": 6,
|
|
"totalTokenCount": 1520
|
|
},
|
|
"modelVersion": "gemini-2.5-flash-lite"
|
|
},
|
|
"traceId": "t1",
|
|
"metadata": {}
|
|
});
|
|
acc.process_event(&event1);
|
|
assert_eq!(acc.output_tokens, 6);
|
|
assert!(!acc.is_complete);
|
|
|
|
// Second event — more output
|
|
let event2 = serde_json::json!({
|
|
"response": {
|
|
"candidates": [{"content": {"role": "model", "parts": [{"text": " world"}]}}],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 1514,
|
|
"candidatesTokenCount": 22,
|
|
"totalTokenCount": 1536
|
|
},
|
|
"modelVersion": "gemini-2.5-flash-lite"
|
|
},
|
|
"traceId": "t1",
|
|
"metadata": {}
|
|
});
|
|
acc.process_event(&event2);
|
|
assert_eq!(acc.output_tokens, 22); // cumulative, not additive
|
|
|
|
// Third event — completion
|
|
let event3 = serde_json::json!({
|
|
"response": {
|
|
"candidates": [{"content": {"role": "model", "parts": [{"text": "!"}]},
|
|
"finishReason": "STOP"}],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 1514,
|
|
"candidatesTokenCount": 25,
|
|
"totalTokenCount": 1539,
|
|
"thoughtsTokenCount": 52
|
|
},
|
|
"modelVersion": "gemini-2.5-flash-lite"
|
|
},
|
|
"traceId": "t1",
|
|
"metadata": {}
|
|
});
|
|
acc.process_event(&event3);
|
|
assert!(acc.is_complete);
|
|
assert_eq!(acc.output_tokens, 25);
|
|
assert_eq!(acc.thinking_tokens, 52);
|
|
assert_eq!(acc.stop_reason, Some("STOP".to_string()));
|
|
|
|
let usage = acc.into_usage();
|
|
assert_eq!(usage.input_tokens, 1514);
|
|
assert_eq!(usage.output_tokens, 25);
|
|
assert_eq!(usage.thinking_output_tokens, 52);
|
|
assert_eq!(usage.model, Some("gemini-2.5-flash-lite".to_string()));
|
|
assert_eq!(usage.api_provider, Some("google".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_google_sse_parse_streaming_chunk() {
|
|
// Simulates real SSE data with HTTP chunked framing (hex sizes on their own lines)
|
|
let chunk = r#"150
|
|
data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text": "4"}]}}],"usageMetadata": {"promptTokenCount": 14615,"candidatesTokenCount": 1,"totalTokenCount": 14668,"thoughtsTokenCount": 52},"modelVersion": "gemini-3-flash","responseId": "agaRacPLC4WHz7IPreOl8QM"},"traceId": "8145be7112baf823","metadata": {}}
|
|
|
|
|
|
2f1
|
|
data: {"response": {"candidates": [{"content": {"role": "model","parts": [{"text": ""}]},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 14615,"candidatesTokenCount": 1,"totalTokenCount": 14668,"thoughtsTokenCount": 52},"modelVersion": "gemini-3-flash","responseId": "agaRacPLC4WHz7IPreOl8QM"},"traceId": "8145be7112baf823","metadata": {}}
|
|
|
|
|
|
0
|
|
"#;
|
|
|
|
let mut acc = StreamingAccumulator::new();
|
|
parse_streaming_chunk(chunk, &mut acc);
|
|
|
|
assert_eq!(acc.input_tokens, 14615);
|
|
assert_eq!(acc.output_tokens, 1);
|
|
assert_eq!(acc.thinking_tokens, 52);
|
|
assert!(acc.is_complete);
|
|
assert_eq!(acc.model, Some("gemini-3-flash".to_string()));
|
|
assert_eq!(acc.stop_reason, Some("STOP".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_google_sse_no_thinking_tokens() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
let event = serde_json::json!({
|
|
"response": {
|
|
"candidates": [{"content": {"role": "model", "parts": [{"text": "hi"}]},
|
|
"finishReason": "STOP"}],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 100,
|
|
"candidatesTokenCount": 5,
|
|
"totalTokenCount": 105
|
|
},
|
|
"modelVersion": "gemini-2.5-flash-lite"
|
|
},
|
|
"traceId": "t1",
|
|
"metadata": {}
|
|
});
|
|
|
|
acc.process_event(&event);
|
|
assert_eq!(acc.thinking_tokens, 0); // no thoughtsTokenCount field
|
|
assert!(acc.is_complete);
|
|
|
|
let usage = acc.into_usage();
|
|
assert_eq!(usage.thinking_output_tokens, 0);
|
|
}
|
|
|
|
/// Regression test: reproduces the exact TCP fragmentation from the SSE dump.
|
|
/// The `data:` line containing `finishReason: STOP` is split across two reads.
|
|
#[test]
|
|
fn test_split_tcp_reads() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
// TCP read 1: complete first event
|
|
let chunk1 = "164\r\ndata: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\",\"parts\": [{\"text\": \"yo\"}]}}],\"usageMetadata\": {\"promptTokenCount\": 100,\"candidatesTokenCount\": 1,\"totalTokenCount\": 101},\"modelVersion\": \"gemini-3-flash\"},\"traceId\": \"abc\",\"metadata\": {}}\r\n\r\n\r\n";
|
|
parse_streaming_chunk(chunk1, &mut acc);
|
|
assert_eq!(acc.response_text, "yo");
|
|
assert!(!acc.is_complete); // no finishReason yet
|
|
|
|
// TCP read 2: PARTIAL second event — JSON cut mid-traceId
|
|
let chunk2 = "200\r\ndata: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\",\"parts\": [{\"text\": \"\"}]},\"finishReason\": \"STOP\"}],\"usageMetadata\": {\"promptTokenCount\": 100,\"candidatesTokenCount\": 1,\"totalTokenCount\": 101},\"modelVersion\": \"gemini-3-flash\"},\"traceId\": \"abc123";
|
|
parse_streaming_chunk(chunk2, &mut acc);
|
|
// Still not complete — the line hasn't ended yet (no \n)
|
|
assert!(
|
|
!acc.is_complete,
|
|
"should NOT be complete yet — JSON line is still partial"
|
|
);
|
|
|
|
// TCP read 3: rest of the JSON + chunked TE terminator
|
|
let chunk3 = "def\",\"metadata\": {}}\r\n\r\n\r\n0\r\n\r\n";
|
|
parse_streaming_chunk(chunk3, &mut acc);
|
|
// NOW the line is complete and should be parsed
|
|
assert!(
|
|
acc.is_complete,
|
|
"finishReason: STOP should be detected after reassembly"
|
|
);
|
|
assert_eq!(acc.stop_reason, Some("STOP".to_string()));
|
|
}
|
|
|
|
#[test]
|
|
fn test_function_call_finish_reason_sets_complete() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"functionCall\": {\"name\": \"read_file\", \"args\": {\"path\": \"/foo\"}}}]}, \"finishReason\": \"FUNCTION_CALL\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 5, \"totalTokenCount\": 55}, \"modelVersion\": \"gemini-3-flash\"}}\n";
|
|
parse_streaming_chunk(event, &mut acc);
|
|
|
|
assert!(
|
|
acc.is_complete,
|
|
"FUNCTION_CALL finishReason should set is_complete"
|
|
);
|
|
assert_eq!(acc.stop_reason, Some("FUNCTION_CALL".to_string()));
|
|
assert_eq!(acc.function_calls.len(), 1);
|
|
assert_eq!(acc.function_calls[0].name, "read_file");
|
|
assert_eq!(acc.output_tokens, 5);
|
|
}
|
|
|
|
#[test]
|
|
fn test_max_tokens_finish_reason_sets_complete() {
|
|
let mut acc = StreamingAccumulator::new();
|
|
|
|
let event = "data: {\"response\": {\"candidates\": [{\"content\": {\"role\": \"model\", \"parts\": [{\"text\": \"truncated...\"}]}, \"finishReason\": \"MAX_TOKENS\"}], \"usageMetadata\": {\"promptTokenCount\": 50, \"candidatesTokenCount\": 100, \"totalTokenCount\": 150}}}\n";
|
|
parse_streaming_chunk(event, &mut acc);
|
|
|
|
assert!(
|
|
acc.is_complete,
|
|
"MAX_TOKENS finishReason should set is_complete"
|
|
);
|
|
assert_eq!(acc.stop_reason, Some("MAX_TOKENS".to_string()));
|
|
assert_eq!(acc.response_text, "truncated...");
|
|
}
|
|
}
|