diff --git a/src/api/completions.rs b/src/api/completions.rs index 33a7c22..211ff65 100644 --- a/src/api/completions.rs +++ b/src/api/completions.rs @@ -10,9 +10,11 @@ use std::sync::Arc; use tracing::{debug, info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; -use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; +use super::polling::{ + extract_response_text, extract_thinking_content, is_response_done, poll_for_response, +}; use super::types::*; -use super::util::{err_response, upstream_err_response, now_unix}; +use super::util::{err_response, now_unix, upstream_err_response}; use super::AppState; /// Extract a conversation/session ID from a flexible JSON value. @@ -33,7 +35,8 @@ fn system_fingerprint() -> String { /// Build a streaming chunk JSON with all required OpenAI fields. /// Includes system_fingerprint, service_tier, and logprobs:null in choices. fn chunk_json( - id: &str, model: &str, + id: &str, + model: &str, choices: serde_json::Value, usage: Option, ) -> String { @@ -53,7 +56,11 @@ fn chunk_json( } /// Build a single choice for a streaming chunk (delta + finish_reason + logprobs). -fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str>) -> serde_json::Value { +fn chunk_choice( + index: u32, + delta: serde_json::Value, + finish_reason: Option<&str>, +) -> serde_json::Value { serde_json::json!({ "index": index, "delta": delta, @@ -70,7 +77,9 @@ fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str { match stop_reason { Some("MAX_TOKENS") => "length", - Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => "content_filter", + Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => { + "content_filter" + } _ => "stop", } } @@ -84,7 +93,9 @@ fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str { /// sends the entire messages array to the model. fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option) { // Extract image from last user message content array - let image = messages.iter().rev() + let image = messages + .iter() + .rev() .find(|m| m.role == "user") .and_then(|m| super::util::extract_first_image(&m.content)); // Always build the full conversation @@ -141,10 +152,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String { if let Some(func) = tc.get("function") { let name = func["name"].as_str().unwrap_or("unknown"); let args = func["arguments"].as_str().unwrap_or("{}"); - parts.push(format!( - "[Tool call: {}({})]", - name, args - )); + parts.push(format!("[Tool call: {}({})]", name, args)); } } } @@ -153,10 +161,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String { let text = extract_message_text(&msg.content); let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown"); if !text.is_empty() { - parts.push(format!( - "[Tool result ({})]:\n{}", - tool_id, text - )); + parts.push(format!("[Tool result ({})]:\n{}", tool_id, text)); } } _ => {} @@ -202,7 +207,10 @@ pub(crate) async fn handle_completions( let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice); state.mitm_store.set_tool_config(gemini_config).await; } - info!(count = tools.len(), "Completions: stored client tools for MITM injection"); + info!( + count = tools.len(), + "Completions: stored client tools for MITM injection" + ); } else { state.mitm_store.clear_tools().await; } @@ -239,10 +247,15 @@ pub(crate) async fn handle_completions( google_search: body.web_search, }; // Only store if at least one param is set - if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() - || gp.frequency_penalty.is_some() || gp.presence_penalty.is_some() - || gp.reasoning_effort.is_some() || gp.stop_sequences.is_some() - || gp.response_mime_type.is_some() || gp.response_schema.is_some() + if gp.temperature.is_some() + || gp.top_p.is_some() + || gp.max_output_tokens.is_some() + || gp.frequency_penalty.is_some() + || gp.presence_penalty.is_some() + || gp.reasoning_effort.is_some() + || gp.stop_sequences.is_some() + || gp.response_mime_type.is_some() + || gp.response_schema.is_some() || gp.google_search { state.mitm_store.set_generation_params(gp).await; @@ -306,12 +319,13 @@ pub(crate) async fn handle_completions( // Store image for MITM injection (LS doesn't forward images to Google API) if let Some(ref img) = image { use base64::Engine; - state.mitm_store.set_pending_image( - crate::mitm::store::PendingImage { + state + .mitm_store + .set_pending_image(crate::mitm::store::PendingImage { base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), mime_type: img.mime_type.clone(), - } - ).await; + }) + .await; } match state .backend @@ -346,7 +360,10 @@ pub(crate) async fn handle_completions( uuid::Uuid::new_v4().to_string().replace('-', "") ); - let include_usage = body.stream_options.as_ref().map_or(false, |o| o.include_usage); + let include_usage = body + .stream_options + .as_ref() + .map_or(false, |o| o.include_usage); if body.stream { chat_completions_stream( @@ -374,11 +391,17 @@ pub(crate) async fn handle_completions( match state.backend.create_cascade().await { Ok(cid) => { // Send the same message on each extra cascade - match state.backend.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()).await { + match state + .backend + .send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()) + .await + { Ok((200, _)) => { let bg = Arc::clone(&state.backend); let cid2 = cid.clone(); - tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; }); + tokio::spawn(async move { + let _ = bg.update_annotations(&cid2).await; + }); extra_cascade_ids.push(cid); } _ => {} // Skip failed cascades @@ -420,7 +443,12 @@ pub(crate) async fn handle_completions( mitm.as_ref().and_then(|u| u.stop_reason.as_deref()), ); let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm { - (mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens) + ( + mu.input_tokens, + mu.output_tokens, + mu.cache_read_input_tokens, + mu.thinking_output_tokens, + ) } else if let Some(u) = &result.usage { (u.input_tokens, u.output_tokens, 0, 0) } else { @@ -874,15 +902,22 @@ async fn chat_completions_sync( None => state.mitm_store.take_usage("_latest").await, }; - let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); + let finish_reason = + google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); - let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm { - (mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens) - } else if let Some(u) = &result.usage { - (u.input_tokens, u.output_tokens, 0, 0) - } else { - (0, 0, 0, 0) - }; + let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = + if let Some(ref mitm_usage) = mitm { + ( + mitm_usage.input_tokens, + mitm_usage.output_tokens, + mitm_usage.cache_read_input_tokens, + mitm_usage.thinking_output_tokens, + ) + } else if let Some(u) = &result.usage { + (u.input_tokens, u.output_tokens, 0, 0) + } else { + (0, 0, 0, 0) + }; // Build message object, including reasoning_content if thinking is present let mut message = serde_json::json!({ diff --git a/src/api/gemini.rs b/src/api/gemini.rs index 5f050ea..5eea228 100644 --- a/src/api/gemini.rs +++ b/src/api/gemini.rs @@ -15,7 +15,9 @@ use std::sync::Arc; use tracing::{info, warn}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; -use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; +use super::polling::{ + extract_response_text, extract_thinking_content, is_response_done, poll_for_response, +}; use super::util::{err_response, upstream_err_response}; use super::AppState; use crate::mitm::store::PendingToolResult; @@ -84,7 +86,9 @@ async fn build_usage_metadata( store: &crate::mitm::store::MitmStore, cascade_id: &str, ) -> serde_json::Value { - let usage = store.take_usage(cascade_id).await + let usage = store + .take_usage(cascade_id) + .await .or(store.take_usage("_latest").await); if let Some(usage) = usage { serde_json::json!({ @@ -152,13 +156,12 @@ pub(crate) async fn handle_gemini( // Gemini-native inlineData format if image.is_none() { if let Some(inline) = obj.get("inlineData") { - if let (Some(mime), Some(b64)) = ( - inline["mimeType"].as_str(), - inline["data"].as_str(), - ) { - if let Some(img) = super::util::parse_data_uri( - &format!("data:{mime};base64,{b64}") - ) { + if let (Some(mime), Some(b64)) = + (inline["mimeType"].as_str(), inline["data"].as_str()) + { + if let Some(img) = super::util::parse_data_uri(&format!( + "data:{mime};base64,{b64}" + )) { image = Some(img); } } @@ -194,7 +197,10 @@ pub(crate) async fn handle_gemini( if let Some(ref tools) = body.tools { if !tools.is_empty() { state.mitm_store.set_tools(tools.clone()).await; - info!(count = tools.len(), "Stored Gemini-native tools for MITM injection"); + info!( + count = tools.len(), + "Stored Gemini-native tools for MITM injection" + ); } } if let Some(ref config) = body.tool_config { @@ -207,13 +213,19 @@ pub(crate) async fn handle_gemini( if let Some(fr) = r.get("functionResponse") { let name = fr["name"].as_str().unwrap_or("unknown").to_string(); let response = fr.get("response").cloned().unwrap_or(serde_json::json!({})); - state.mitm_store.add_tool_result(PendingToolResult { - name, - result: response, - }).await; + state + .mitm_store + .add_tool_result(PendingToolResult { + name, + result: response, + }) + .await; } } - info!(count = results.len(), "Stored Gemini-native tool results for MITM injection"); + info!( + count = results.len(), + "Stored Gemini-native tool results for MITM injection" + ); } // Store generation parameters for MITM injection @@ -232,9 +244,13 @@ pub(crate) async fn handle_gemini( response_schema: None, google_search: body.google_search, }; - if gp.temperature.is_some() || gp.top_p.is_some() || gp.top_k.is_some() - || gp.max_output_tokens.is_some() || gp.stop_sequences.is_some() - || gp.reasoning_effort.is_some() || gp.google_search + if gp.temperature.is_some() + || gp.top_p.is_some() + || gp.top_k.is_some() + || gp.max_output_tokens.is_some() + || gp.stop_sequences.is_some() + || gp.reasoning_effort.is_some() + || gp.google_search { state.mitm_store.set_generation_params(gp).await; } else { @@ -277,12 +293,13 @@ pub(crate) async fn handle_gemini( // Store image for MITM injection (LS doesn't forward images to Google API) if let Some(ref img) = image { use base64::Engine; - state.mitm_store.set_pending_image( - crate::mitm::store::PendingImage { + state + .mitm_store + .set_pending_image(crate::mitm::store::PendingImage { base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), mime_type: img.mime_type.clone(), - } - ).await; + }) + .await; } match state .backend @@ -372,7 +389,11 @@ async fn gemini_sync( // Check for completed text response if state.mitm_store.is_response_complete() { - let text = state.mitm_store.take_response_text().await.unwrap_or_default(); + let text = state + .mitm_store + .take_response_text() + .await + .unwrap_or_default(); let thinking = state.mitm_store.take_thinking_text().await; // Guard against stale response_complete with no data diff --git a/src/api/mod.rs b/src/api/mod.rs index 6ca19b3..29a0f68 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -44,7 +44,6 @@ pub fn router(state: Arc) -> Router { post(completions::handle_completions), ) .route("/v1/gemini", post(gemini::handle_gemini)) - .route("/v1/models", get(handle_models)) .route("/v1/sessions", get(handle_list_sessions)) .route("/v1/sessions/{id}", delete(handle_delete_session)) @@ -106,9 +105,7 @@ async fn handle_models() -> Json { Json(serde_json::json!({"object": "list", "data": models})) } -async fn handle_list_sessions( - State(state): State>, -) -> Json { +async fn handle_list_sessions(State(state): State>) -> Json { let sessions = state.sessions.list_sessions().await; Json(serde_json::json!({"sessions": sessions})) } @@ -155,9 +152,7 @@ async fn handle_set_token( ) } -async fn handle_usage( - State(state): State>, -) -> Json { +async fn handle_usage(State(state): State>) -> Json { let stats = state.mitm_store.stats().await; Json(serde_json::json!({ "mitm": { @@ -174,9 +169,7 @@ async fn handle_usage( })) } -async fn handle_quota( - State(state): State>, -) -> Json { +async fn handle_quota(State(state): State>) -> Json { let snap = state.quota_store.snapshot().await; Json(serde_json::to_value(snap).unwrap_or_default()) } diff --git a/src/api/polling.rs b/src/api/polling.rs index 373d6c6..6050ba0 100644 --- a/src/api/polling.rs +++ b/src/api/polling.rs @@ -84,14 +84,8 @@ pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option 4 && step_count % 5 == 0 { - if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await - { + if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await { if ts == 200 { - let run_status = - td["status"].as_str().unwrap_or(""); + let run_status = td["status"].as_str().unwrap_or(""); if run_status.contains("IDLE") { let text = extract_response_text(steps); if !text.is_empty() { @@ -293,7 +300,14 @@ pub(crate) async fn poll_for_response( elapsed, text.len() ); - return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; + return PollResult { + text, + usage, + thinking_signature, + thinking, + thinking_duration, + upstream_error: None, + }; } } } diff --git a/src/api/responses.rs b/src/api/responses.rs index 6864dcd..38bf973 100644 --- a/src/api/responses.rs +++ b/src/api/responses.rs @@ -14,12 +14,15 @@ use std::sync::Arc; use tracing::{debug, info}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; -use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content}; +use super::polling::{ + extract_model_usage, extract_response_text, extract_thinking_content, + extract_thinking_signature, is_response_done, poll_for_response, +}; use super::types::*; -use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event}; +use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response}; use super::AppState; +use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini}; use crate::mitm::store::PendingToolResult; -use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini}; // ─── Input extraction ──────────────────────────────────────────────────────── @@ -35,7 +38,11 @@ struct ToolResultInput { fn extract_responses_input( input: &serde_json::Value, instructions: Option<&str>, -) -> (String, Vec, Option) { +) -> ( + String, + Vec, + Option, +) { let mut tool_results: Vec = Vec::new(); let mut image: Option = None; @@ -45,10 +52,9 @@ fn extract_responses_input( // Check for function_call_output items for item in items { if item["type"].as_str() == Some("function_call_output") { - if let (Some(call_id), Some(output)) = ( - item["call_id"].as_str(), - item["output"].as_str(), - ) { + if let (Some(call_id), Some(output)) = + (item["call_id"].as_str(), item["output"].as_str()) + { tool_results.push(ToolResultInput { call_id: call_id.to_string(), output: output.to_string(), @@ -230,24 +236,31 @@ pub(crate) async fn handle_responses( ); } - let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref()); + let (user_text, tool_results, image) = + extract_responses_input(&body.input, body.instructions.as_deref()); // Handle tool result submission (function_call_output in input) let is_tool_result_turn = !tool_results.is_empty(); if is_tool_result_turn { for tr in &tool_results { // Look up function name from call_id - let name = state.mitm_store.lookup_call_id(&tr.call_id).await + let name = state + .mitm_store + .lookup_call_id(&tr.call_id) + .await .unwrap_or_else(|| "unknown_function".to_string()); // Parse the output as JSON, fall back to string wrapper let result_value = serde_json::from_str::(&tr.output) .unwrap_or_else(|_| serde_json::json!({"result": tr.output})); - state.mitm_store.add_tool_result(PendingToolResult { - name, - result: result_value, - }).await; + state + .mitm_store + .add_tool_result(PendingToolResult { + name, + result: result_value, + }) + .await; } info!( count = tool_results.len(), @@ -275,7 +288,10 @@ pub(crate) async fn handle_responses( let gemini_tools = openai_tools_to_gemini(tools); if !gemini_tools.is_empty() { state.mitm_store.set_tools(gemini_tools).await; - info!(count = tools.len(), "Stored client tools for MITM injection"); + info!( + count = tools.len(), + "Stored client tools for MITM injection" + ); } } if let Some(ref choice) = body.tool_choice { @@ -289,7 +305,9 @@ pub(crate) async fn handle_responses( let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text"); if fmt_type == "json_schema" { let name = text_val["format"]["name"].as_str().map(|s| s.to_string()); - let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone())); + let schema = text_val["format"]["schema"] + .as_object() + .map(|o| serde_json::Value::Object(o.clone())); let strict = text_val["format"]["strict"].as_bool(); let tf = TextFormat { format: TextFormatInner { @@ -321,9 +339,13 @@ pub(crate) async fn handle_responses( response_schema, google_search: has_web_search, }; - if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() - || gp.reasoning_effort.is_some() || gp.response_mime_type.is_some() - || gp.response_schema.is_some() || gp.google_search + if gp.temperature.is_some() + || gp.top_p.is_some() + || gp.max_output_tokens.is_some() + || gp.reasoning_effort.is_some() + || gp.response_mime_type.is_some() + || gp.response_schema.is_some() + || gp.google_search { state.mitm_store.set_generation_params(gp).await; } else { @@ -331,10 +353,7 @@ pub(crate) async fn handle_responses( } } - let response_id = format!( - "resp_{}", - uuid::Uuid::new_v4().to_string().replace('-', "") - ); + let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); // Session/conversation management let session_id_str = extract_conversation_id(&body.conversation); @@ -371,12 +390,13 @@ pub(crate) async fn handle_responses( // Store image for MITM injection (LS doesn't forward images to Google API) if let Some(ref img) = image { use base64::Engine; - state.mitm_store.set_pending_image( - crate::mitm::store::PendingImage { + state + .mitm_store + .set_pending_image(crate::mitm::store::PendingImage { base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), mime_type: img.mime_type.clone(), - } - ).await; + }) + .await; } match state .backend @@ -419,21 +439,32 @@ pub(crate) async fn handle_responses( metadata: body.metadata.clone().unwrap_or(serde_json::json!({})), max_tool_calls: body.max_tool_calls, reasoning_effort: body.reasoning_effort.clone(), - tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")), + tool_choice: body + .tool_choice + .clone() + .unwrap_or(serde_json::json!("auto")), tools: body.tools.clone().unwrap_or_default(), text_format, }; if body.stream { handle_responses_stream( - state, response_id, model_name.to_string(), cascade_id, - body.timeout, req_params, + state, + response_id, + model_name.to_string(), + cascade_id, + body.timeout, + req_params, ) .await } else { handle_responses_sync( - state, response_id, model_name.to_string(), cascade_id, - body.timeout, req_params, + state, + response_id, + model_name.to_string(), + cascade_id, + body.timeout, + req_params, ) .await } @@ -485,7 +516,9 @@ async fn usage_from_poll( if let Some(u) = mitm_store.peek_usage(key).await { if u.thinking_output_tokens > 0 && u.thinking_text.is_none() { // Call 2 hasn't arrived yet — wait briefly for the merge - tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge..."); + tracing::debug!( + "MITM: thinking tokens found but no text, waiting for summary merge..." + ); for _ in 0..10 { tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; if let Some(u2) = mitm_store.peek_usage(key).await { @@ -526,13 +559,18 @@ async fn usage_from_poll( // Priority 2: LS trajectory data (from CHECKPOINT/metadata steps) if let Some(u) = model_usage { - return (Usage { - input_tokens: u.input_tokens, - input_tokens_details: InputTokensDetails { cached_tokens: 0 }, - output_tokens: u.output_tokens, - output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 }, - total_tokens: u.input_tokens + u.output_tokens, - }, None); + return ( + Usage { + input_tokens: u.input_tokens, + input_tokens_details: InputTokensDetails { cached_tokens: 0 }, + output_tokens: u.output_tokens, + output_tokens_details: OutputTokensDetails { + reasoning_tokens: 0, + }, + total_tokens: u.input_tokens + u.output_tokens, + }, + None, + ); } // Priority 3: Estimate from text lengths @@ -575,14 +613,22 @@ async fn handle_responses_sync( "call_{}", uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() ); - state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; + state + .mitm_store + .register_call_id(call_id.clone(), fc.name.clone()) + .await; let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); - output_items.push(build_function_call_output(&call_id, &fc.name, &arguments)); + output_items + .push(build_function_call_output(&call_id, &fc.name, &arguments)); } let (usage, _) = usage_from_poll( - &state.mitm_store, &cascade_id, &None, - ¶ms.user_text, "", - ).await; + &state.mitm_store, + &cascade_id, + &None, + ¶ms.user_text, + "", + ) + .await; let resp = build_response_object( ResponseData { id: response_id, @@ -602,12 +648,20 @@ async fn handle_responses_sync( // Check for completed text response if state.mitm_store.is_response_complete() { - let text = state.mitm_store.take_response_text().await.unwrap_or_default(); + let text = state + .mitm_store + .take_response_text() + .await + .unwrap_or_default(); let thinking = state.mitm_store.take_thinking_text().await; let (usage, _) = usage_from_poll( - &state.mitm_store, &cascade_id, &None, - ¶ms.user_text, &text, - ).await; + &state.mitm_store, + &cascade_id, + &None, + ¶ms.user_text, + &text, + ) + .await; let mut output_items: Vec = Vec::new(); if let Some(ref t) = thinking { @@ -658,10 +712,7 @@ async fn handle_responses_sync( return upstream_err_response(err); } let completed_at = now_unix(); - let msg_id = format!( - "msg_{}", - uuid::Uuid::new_v4().to_string().replace('-', "") - ); + let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); // Check for captured function calls from MITM (clears the active flag) let captured_tool_calls = state.mitm_store.take_any_function_calls().await; @@ -689,7 +740,10 @@ async fn handle_responses_sync( uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() ); // Register call_id → name mapping for tool result routing - state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; + state + .mitm_store + .register_call_id(call_id.clone(), fc.name.clone()) + .await; // Stringify args (OpenAI sends arguments as JSON string) let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); @@ -697,9 +751,13 @@ async fn handle_responses_sync( } let (usage, _) = usage_from_poll( - &state.mitm_store, &cascade_id, &poll_result.usage, - ¶ms.user_text, &poll_result.text, - ).await; + &state.mitm_store, + &cascade_id, + &poll_result.usage, + ¶ms.user_text, + &poll_result.text, + ) + .await; let resp = build_response_object( ResponseData { @@ -719,7 +777,14 @@ async fn handle_responses_sync( } // Normal text response (no tool calls) - let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await; + let (usage, mitm_thinking) = usage_from_poll( + &state.mitm_store, + &cascade_id, + &poll_result.usage, + ¶ms.user_text, + &poll_result.text, + ) + .await; // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) let thinking_text = mitm_thinking.or(poll_result.thinking); @@ -1560,4 +1625,3 @@ fn completion_events( events } - diff --git a/src/api/types.rs b/src/api/types.rs index 0c04959..a694089 100644 --- a/src/api/types.rs +++ b/src/api/types.rs @@ -126,7 +126,9 @@ pub(crate) struct CompletionRequest { pub web_search: bool, } -fn default_n() -> u32 { 1 } +fn default_n() -> u32 { + 1 +} /// Stop sequence can be a single string or array of strings (OpenAI accepts both). #[derive(Deserialize, Clone)] @@ -254,8 +256,7 @@ pub(crate) struct OutputTokensDetails { pub reasoning_tokens: u64, } -#[derive(Serialize, Clone)] -#[derive(Default)] +#[derive(Serialize, Clone, Default)] pub(crate) struct Reasoning { pub effort: Option, pub summary: Option, @@ -313,7 +314,6 @@ impl Default for Usage { } } - impl Default for TextFormat { fn default() -> Self { Self { diff --git a/src/api/util.rs b/src/api/util.rs index 3fe57c6..a489f63 100644 --- a/src/api/util.rs +++ b/src/api/util.rs @@ -27,7 +27,9 @@ pub(crate) fn err_response( /// Convert a MITM-captured upstream error from Google into an HTTP response. /// Maps Google's HTTP status codes and preserves the error message. -pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response { +pub(crate) fn upstream_err_response( + err: &crate::mitm::store::UpstreamError, +) -> axum::response::Response { // Map Google's status code to HTTP status let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY); @@ -41,7 +43,9 @@ pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> _ => "upstream_error", }; - let message = err.message.clone() + let message = err + .message + .clone() .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); err_response(status, message, error_type) @@ -99,7 +103,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option { - let url = item["image_url"].as_str() + let url = item["image_url"] + .as_str() .or_else(|| item["url"].as_str())?; parse_data_uri(url) } @@ -109,5 +114,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option Option { - content.as_array()?.iter().find_map(extract_image_from_content) + content + .as_array()? + .iter() + .find_map(extract_image_from_content) } diff --git a/src/backend.rs b/src/backend.rs index 8c232c3..dd0565b 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -48,10 +48,7 @@ static STATIC_HEADERS: LazyLock = LazyLock::new(|| { *CHROME_MAJOR, )), ); - h.insert( - HeaderName::from_static("sec-ch-ua-mobile"), - hv("?0"), - ); + h.insert(HeaderName::from_static("sec-ch-ua-mobile"), hv("?0")); h.insert( HeaderName::from_static("sec-ch-ua-platform"), hv("\"Linux\""), @@ -72,7 +69,7 @@ impl Backend { // wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint let client = wreq::Client::builder() .emulation(wreq_util::Emulation::Chrome142) - .cert_verification(false) // LS uses self-signed cert + .cert_verification(false) // LS uses self-signed cert .verify_hostname(false) .build() .map_err(|e| format!("wreq client build failed: {e}"))?; @@ -86,11 +83,7 @@ impl Backend { /// Create a Backend with known connection details (for standalone LS). /// /// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token. - pub fn new_with_config( - port: u16, - csrf: String, - oauth_token: String, - ) -> Result { + pub fn new_with_config(port: u16, csrf: String, oauth_token: String) -> Result { let inner = BackendInner { pid: "standalone".to_string(), csrf, @@ -212,10 +205,7 @@ impl Backend { fn common_headers(csrf: &str) -> HeaderMap { let mut h = STATIC_HEADERS.clone(); if let Ok(val) = HeaderValue::from_str(csrf) { - h.insert( - HeaderName::from_static("x-codeium-csrf-token"), - val, - ); + h.insert(HeaderName::from_static("x-codeium-csrf-token"), val); } else { warn!("CSRF token contains invalid header characters, omitting"); } @@ -239,8 +229,8 @@ impl Backend { let mut headers = Self::common_headers(&csrf); headers.insert("Content-Type", HeaderValue::from_static("application/json")); - let body_bytes = serde_json::to_vec(body) - .map_err(|e| format!("JSON serialize error: {e}"))?; + let body_bytes = + serde_json::to_vec(body).map_err(|e| format!("JSON serialize error: {e}"))?; let resp = self .client @@ -258,7 +248,9 @@ impl Backend { .and_then(|v| v.to_str().ok()) .unwrap_or("") .to_string(); - let raw = resp.bytes().await + let raw = resp + .bytes() + .await .map_err(|e| format!("Read body error: {e}"))?; let resp_bytes = decompress(method, &raw, &encoding); // High-frequency polling methods → trace; everything else → debug @@ -288,11 +280,7 @@ impl Backend { } /// Call a binary protobuf RPC method. - pub async fn call_proto( - &self, - method: &str, - body: Vec, - ) -> Result<(u16, Vec), String> { + pub async fn call_proto(&self, method: &str, body: Vec) -> Result<(u16, Vec), String> { let (base, csrf) = { let guard = self.inner.read().await; ( @@ -302,7 +290,10 @@ impl Backend { }; let url = format!("{base}/{LS_SERVICE}/{method}"); let mut headers = Self::common_headers(&csrf); - headers.insert("Content-Type", HeaderValue::from_static("application/proto")); + headers.insert( + "Content-Type", + HeaderValue::from_static("application/proto"), + ); let resp = self .client @@ -350,7 +341,8 @@ impl Backend { text: &str, model_enum: u32, ) -> Result<(u16, Vec), String> { - self.send_message_with_image(cascade_id, text, model_enum, None).await + self.send_message_with_image(cascade_id, text, model_enum, None) + .await } /// SendUserCascadeMessage with optional image attachment. @@ -365,7 +357,8 @@ impl Backend { if token.is_empty() { return Err("No OAuth token available".to_string()); } - let proto = crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image); + let proto = + crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image); if image.is_some() { tracing::info!( proto_size = proto.len(), @@ -376,10 +369,7 @@ impl Backend { } /// GetCascadeTrajectorySteps → JSON with steps array. - pub async fn get_steps( - &self, - cascade_id: &str, - ) -> Result<(u16, serde_json::Value), String> { + pub async fn get_steps(&self, cascade_id: &str) -> Result<(u16, serde_json::Value), String> { let body = serde_json::json!({"cascadeId": cascade_id}); self.call_json("GetCascadeTrajectorySteps", &body).await } @@ -415,7 +405,10 @@ impl Backend { }); let mut headers = Self::common_headers(&csrf); - headers.insert("Content-Type", HeaderValue::from_static("application/connect+json")); + headers.insert( + "Content-Type", + HeaderValue::from_static("application/connect+json"), + ); headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1")); // Connect protocol envelope: [flags:1][length:4][payload] @@ -441,7 +434,8 @@ impl Backend { return Err(format!("{rpc_method} failed: {status} — {err_text}")); } - let resp_ct = resp.headers() + let resp_ct = resp + .headers() .get("content-type") .and_then(|v| v.to_str().ok()) .unwrap_or("unknown") @@ -495,7 +489,8 @@ impl Backend { &self, cascade_id: &str, ) -> Result, String> { - self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id).await + self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id) + .await } } @@ -506,7 +501,10 @@ fn discover() -> Result { // the wrapper is a shell script named language_server_linux_x64, while // the real binary is language_server_linux_x64.real) let pid_output = Command::new("sh") - .args(["-c", "pgrep -f 'language_server_linux_x64\\.real' | head -1"]) + .args([ + "-c", + "pgrep -f 'language_server_linux_x64\\.real' | head -1", + ]) .output() .map_err(|e| format!("pgrep failed: {e}"))?; @@ -564,9 +562,8 @@ fn discover() -> Result { LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap()); for d in &dirs { - let log_path = format!( - "{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log" - ); + let log_path = + format!("{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"); if let Ok(contents) = fs::read_to_string(&log_path) { for line in contents.lines() { if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") { @@ -584,10 +581,7 @@ fn discover() -> Result { if https_port.is_empty() { // Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up) - if let Ok(output) = std::process::Command::new("ss") - .args(["-tlnp"]) - .output() - { + if let Ok(output) = std::process::Command::new("ss").args(["-tlnp"]).output() { let ss_out = String::from_utf8_lossy(&output.stdout); // Find listening ports for this PID — typically the first is HTTPS for line in ss_out.lines() { @@ -653,7 +647,11 @@ fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec { Err(e) => { if !encoding.is_empty() { let preview = String::from_utf8_lossy(&data[..data.len().min(100)]); - warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview); + warn!( + "{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", + data.len(), + preview + ); } data.to_vec() } diff --git a/src/constants.rs b/src/constants.rs index 044e5d6..4f55673 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -115,9 +115,7 @@ fn detect_versions() -> DetectedVersions { const FALLBACK_CLIENT: &str = "1.16.5"; let Some(install_dir) = find_install_dir() else { - tracing::warn!( - "Could not find Antigravity install — using fallback versions" - ); + tracing::warn!("Could not find Antigravity install — using fallback versions"); return DetectedVersions { antigravity: FALLBACK_ANTIGRAVITY.to_string(), chrome: FALLBACK_CHROME.to_string(), diff --git a/src/main.rs b/src/main.rs index 7f0e580..ca017dd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -24,7 +24,10 @@ use tracing::{info, warn}; use mitm::store::MitmStore; #[derive(Parser)] -#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")] +#[command( + name = "antigravity-proxy", + about = "Antigravity OpenAI Proxy (stealth)" +)] struct Cli { /// Port to listen on #[arg(long, default_value_t = 8741)] @@ -93,15 +96,12 @@ async fn main() { }; let filter = if log_level.is_empty() { - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "warn".into()) + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into()) } else { tracing_subscriber::EnvFilter::new(log_level) }; - tracing_subscriber::fmt() - .with_env_filter(filter) - .init(); + tracing_subscriber::fmt().with_env_filter(filter).init(); // ── Step 1: Bind main port (auto-kill stale process if needed) ───────────── let addr = format!("127.0.0.1:{}", cli.port); @@ -111,7 +111,10 @@ async fn main() { // Port in use — try to kill whatever's holding it eprintln!(" Port {} in use, killing stale process...", cli.port); let _ = std::process::Command::new("sh") - .args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)]) + .args([ + "-c", + &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port), + ]) .status(); // Also kill any leftover standalone LS processes let _ = std::process::Command::new("pkill") @@ -180,7 +183,9 @@ async fn main() { Ok(c) => c, Err(e) => { eprintln!("Fatal: {e}"); - eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode"); + eprintln!( + "Hint: start Antigravity first, or remove --classic to use headless mode" + ); std::process::exit(1); } } @@ -199,13 +204,14 @@ async fn main() { None }; - let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) { - Ok(ls) => ls, - Err(e) => { - eprintln!("Fatal: failed to spawn standalone LS: {e}"); - std::process::exit(1); - } - }; + let mut ls = + match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) { + Ok(ls) => ls, + Err(e) => { + eprintln!("Fatal: failed to spawn standalone LS: {e}"); + std::process::exit(1); + } + }; // Wait for it to be ready let rt_ls_port = ls.port; let rt_ls_csrf = ls.csrf.clone(); @@ -294,7 +300,15 @@ async fn main() { // ── Step 5: Start serving ───────────────────────────────────────────────── let app = api::router(state.clone()); - print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual, is_standalone); + print_banner( + cli.port, + &pid, + &https_port, + &csrf, + &token, + &mitm_port_actual, + is_standalone, + ); info!("Listening on http://{addr}"); axum::serve(listener, app) @@ -349,7 +363,15 @@ async fn shutdown_signal() { } } -fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>, is_standalone: bool) { +fn print_banner( + port: u16, + pid: &str, + https_port: &str, + csrf: &str, + token: &str, + mitm: &Option<(u16, String)>, + is_standalone: bool, +) { let chrome_major = &*constants::CHROME_MAJOR; let ver = crate::constants::antigravity_version(); @@ -401,7 +423,11 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, println!(); // Status line - let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" }; + let mitm_tag = if mitm.is_some() { + "\x1b[32mmitm\x1b[0m" + } else { + "\x1b[31mmitm\x1b[0m" + }; println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}"); println!(); @@ -421,7 +447,9 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, if token == "NOT SET" { println!(" \x1b[1;33m[!]\x1b[0m no oauth token"); println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx"); - println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"); + println!( + " curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'" + ); println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token"); println!(); } @@ -476,5 +504,7 @@ fn find_ls_binary_path() -> Option { /// Get the data directory for storing MITM CA cert/key. fn dirs_data_dir() -> std::path::PathBuf { let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); - std::path::PathBuf::from(home).join(".config").join("antigravity-proxy") + std::path::PathBuf::from(home) + .join(".config") + .join("antigravity-proxy") } diff --git a/src/mitm/ca.rs b/src/mitm/ca.rs index 000b656..778bda4 100644 --- a/src/mitm/ca.rs +++ b/src/mitm/ca.rs @@ -4,8 +4,8 @@ //! Dynamically generates per-domain leaf certificates signed by this CA. use rcgen::{ - BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, - IsCa, KeyPair, KeyUsagePurpose, SanType, + BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, IsCa, + KeyPair, KeyUsagePurpose, SanType, }; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use std::collections::HashMap; @@ -45,15 +45,16 @@ impl MitmCa { let key_pem = std::fs::read_to_string(&key_path) .map_err(|e| format!("Failed to read CA key: {e}"))?; - let ca_key = KeyPair::from_pem(&key_pem) - .map_err(|e| format!("Failed to parse CA key: {e}"))?; + let ca_key = + KeyPair::from_pem(&key_pem).map_err(|e| format!("Failed to parse CA key: {e}"))?; // Re-create params and self-sign to get the rcgen Certificate object // (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem). // The re-signed cert will have a different serial/notBefore, but that's fine // because we only use it for the rcgen signing API, NOT for the on-disk PEM. let params = Self::ca_params(); - let ca_signed = params.self_signed(&ca_key) + let ca_signed = params + .self_signed(&ca_key) .map_err(|e| format!("Failed to self-sign CA: {e}"))?; // Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts @@ -76,11 +77,12 @@ impl MitmCa { std::fs::create_dir_all(data_dir) .map_err(|e| format!("Failed to create data dir: {e}"))?; - let ca_key = KeyPair::generate() - .map_err(|e| format!("Failed to generate CA key: {e}"))?; + let ca_key = + KeyPair::generate().map_err(|e| format!("Failed to generate CA key: {e}"))?; let params = Self::ca_params(); - let ca_signed = params.self_signed(&ca_key) + let ca_signed = params + .self_signed(&ca_key) .map_err(|e| format!("Failed to self-sign CA: {e}"))?; // Write cert and key to disk @@ -117,10 +119,7 @@ impl MitmCa { params.distinguished_name = dn; params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); - params.key_usages = vec![ - KeyUsagePurpose::KeyCertSign, - KeyUsagePurpose::CrlSign, - ]; + params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign]; // Valid for 10 years let now = time::OffsetDateTime::now_utc(); @@ -151,12 +150,17 @@ impl MitmCa { return None; } use base64::Engine; - let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?; + let der = base64::engine::general_purpose::STANDARD + .decode(&b64) + .ok()?; Some(CertificateDer::from(der)) } /// Get or create a TLS ServerConfig for the given domain. - pub async fn server_config_for_domain(&self, domain: &str) -> Result, String> { + pub async fn server_config_for_domain( + &self, + domain: &str, + ) -> Result, String> { // Check cache first { let cache = self.domain_cache.read().await; @@ -172,7 +176,11 @@ impl MitmCa { dn.push(DnType::CommonName, domain); params.distinguished_name = dn; - params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)]; + params.subject_alt_names = vec![SanType::DnsName( + domain + .try_into() + .map_err(|e| format!("Invalid domain: {e}"))?, + )]; params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; params.key_usages = vec![ KeyUsagePurpose::DigitalSignature, @@ -184,10 +192,11 @@ impl MitmCa { params.not_before = now; params.not_after = now + time::Duration::days(365); - let leaf_key = KeyPair::generate() - .map_err(|e| format!("Failed to generate leaf key: {e}"))?; + let leaf_key = + KeyPair::generate().map_err(|e| format!("Failed to generate leaf key: {e}"))?; - let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key) + let leaf_cert = params + .signed_by(&leaf_key, &self.ca_signed, &self.ca_key) .map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?; // Build rustls ServerConfig @@ -196,10 +205,7 @@ impl MitmCa { let mut config = rustls::ServerConfig::builder() .with_no_client_auth() - .with_single_cert( - vec![leaf_cert_der, self.ca_cert_der.clone()], - leaf_key_der, - ) + .with_single_cert(vec![leaf_cert_der, self.ca_cert_der.clone()], leaf_key_der) .map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?; // Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2 diff --git a/src/mitm/h2_handler.rs b/src/mitm/h2_handler.rs index abff754..aaed8ca 100644 --- a/src/mitm/h2_handler.rs +++ b/src/mitm/h2_handler.rs @@ -92,11 +92,10 @@ impl UpstreamPool { .map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?; let upstream_io = TokioIo::new(upstream_tls); - let (sender, conn) = - hyper::client::conn::http2::Builder::new(TokioExecutor::new()) - .handshake(upstream_io) - .await - .map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?; + let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(upstream_io) + .await + .map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?; let domain = self.domain.clone(); tokio::spawn(async move { @@ -215,12 +214,10 @@ async fn handle_h2_request( .unwrap_or(false); // Check if this method carries usage data - let is_usage_method = is_grpc - && USAGE_METHODS.iter().any(|m| path.contains(m)); + let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m)); // Check if this is a streaming method - let is_streaming = is_grpc - && (path.contains("Stream") || path.contains("stream")); + let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream")); debug!( domain, @@ -249,9 +246,9 @@ async fn handle_h2_request( warn!(error = %e, domain, "MITM H2: upstream connect failed"); let resp = Response::builder() .status(502) - .body(http_body_util::Either::Left(Full::new( - Bytes::from(format!("upstream connect failed: {e}")), - ))) + .body(http_body_util::Either::Left(Full::new(Bytes::from( + format!("upstream connect failed: {e}"), + )))) .unwrap(); return Ok(resp); } @@ -261,17 +258,11 @@ async fn handle_h2_request( let upstream_uri = http::Uri::builder() .scheme("https") .authority(domain) - .path_and_query( - uri.path_and_query() - .map(|pq| pq.as_str()) - .unwrap_or("/"), - ) + .path_and_query(uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/")) .build() .unwrap_or(uri); - let mut upstream_req = Request::builder() - .method(parts.method) - .uri(upstream_uri); + let mut upstream_req = Request::builder().method(parts.method).uri(upstream_uri); // Copy headers, skip hop-by-hop for (name, value) in &parts.headers { @@ -287,9 +278,9 @@ async fn handle_h2_request( Err(e) => { let resp = Response::builder() .status(502) - .body(http_body_util::Either::Left(Full::new( - Bytes::from(format!("build request failed: {e}")), - ))) + .body(http_body_util::Either::Left(Full::new(Bytes::from( + format!("build request failed: {e}"), + )))) .unwrap(); return Ok(resp); } @@ -302,9 +293,9 @@ async fn handle_h2_request( warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed"); let resp = Response::builder() .status(502) - .body(http_body_util::Either::Left(Full::new( - Bytes::from(format!("upstream request failed: {e}")), - ))) + .body(http_body_util::Either::Left(Full::new(Bytes::from( + format!("upstream request failed: {e}"), + )))) .unwrap(); return Ok(resp); } @@ -326,13 +317,18 @@ async fn handle_h2_request( // Spawn a task to forward body chunks and tee for usage extraction tokio::spawn(async move { - let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None }; + let mut tee_buffer = if should_track_usage { + Some(Vec::new()) + } else { + None + }; let mut body = resp_body; loop { match body.frame().await { Some(Ok(frame)) => { - if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) { + if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) + { buf.extend_from_slice(data); } if tx.send(Ok(frame)).await.is_err() { @@ -354,7 +350,9 @@ async fn handle_h2_request( if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) { let usage = grpc_usage.into_api_usage(path_clone.clone()); let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone); - store_clone.record_usage(cascade_hint.as_deref(), usage).await; + store_clone + .record_usage(cascade_hint.as_deref(), usage) + .await; } } } diff --git a/src/mitm/intercept.rs b/src/mitm/intercept.rs index 90711e5..e0cefb1 100644 --- a/src/mitm/intercept.rs +++ b/src/mitm/intercept.rs @@ -78,15 +78,21 @@ impl StreamingAccumulator { Self::default() } -/// Process a single SSE event. + /// Process a single SSE event. pub fn process_event(&mut self, event: &Value) { // ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ── if let Some(response) = event.get("response") { // Extract usage metadata (each event has cumulative counts) if let Some(usage) = response.get("usageMetadata") { - self.input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(self.input_tokens); - self.output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(self.output_tokens); - self.thinking_tokens = usage["thoughtsTokenCount"].as_u64().unwrap_or(self.thinking_tokens); + self.input_tokens = usage["promptTokenCount"] + .as_u64() + .unwrap_or(self.input_tokens); + self.output_tokens = usage["candidatesTokenCount"] + .as_u64() + .unwrap_or(self.output_tokens); + self.thinking_tokens = usage["thoughtsTokenCount"] + .as_u64() + .unwrap_or(self.thinking_tokens); } if let Some(model) = response["modelVersion"].as_str() { self.model = Some(model.to_string()); @@ -170,8 +176,10 @@ impl StreamingAccumulator { "message_start" => { if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) { self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0); - self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0); - self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0); + self.cache_creation_input_tokens = + usage["cache_creation_input_tokens"].as_u64().unwrap_or(0); + self.cache_read_input_tokens = + usage["cache_read_input_tokens"].as_u64().unwrap_or(0); } if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) { self.model = Some(model.to_string()); @@ -181,7 +189,9 @@ impl StreamingAccumulator { } "message_delta" => { if let Some(usage) = event.get("usage") { - self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens); + self.output_tokens = usage["output_tokens"] + .as_u64() + .unwrap_or(self.output_tokens); } if let Some(reason) = event["delta"]["stop_reason"].as_str() { self.stop_reason = Some(reason.to_string()); @@ -235,7 +245,10 @@ impl StreamingAccumulator { response_output_tokens: 0, model: self.model, stop_reason: self.stop_reason, - api_provider: self.api_provider.unwrap_or_else(|| "unknown".to_string()).into(), + api_provider: self + .api_provider + .unwrap_or_else(|| "unknown".to_string()) + .into(), grpc_method: None, captured_at: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) diff --git a/src/mitm/modify.rs b/src/mitm/modify.rs index a159136..996cb19 100644 --- a/src/mitm/modify.rs +++ b/src/mitm/modify.rs @@ -68,14 +68,14 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option only ({original_len} → {} chars, -{stripped})", new_sys.len() )); - json["request"]["systemInstruction"]["parts"][0]["text"] = - Value::String(new_sys); + json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(new_sys); } } else { // No identity tag found — clear the whole thing - changes.push(format!("system instruction: cleared ({original_len} chars)")); - json["request"]["systemInstruction"]["parts"][0]["text"] = - Value::String(String::new()); + changes.push(format!( + "system instruction: cleared ({original_len} chars)" + )); + json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new()); } } @@ -125,7 +125,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option") { + if let Some(cleaned) = strip_between( + &modified, + "# Conversation History\n", + "", + ) { modified = cleaned; } @@ -147,7 +151,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option") { + if let Some(cleaned) = + strip_between(&modified, "Here are the ", "") + { // Only strip if it's about knowledge items if cleaned.len() < modified.len() && modified.contains("knowledge item") { modified = cleaned; @@ -202,7 +208,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option) -> Option) -> Option) -> Option) -> Option) -> Option 0 { - changes.push(format!("strip {stripped_fc} functionCall/Response parts from history")); + changes.push(format!( + "strip {stripped_fc} functionCall/Response parts from history" + )); } } } @@ -336,16 +357,22 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option = ctx.last_calls.iter().map(|fc| { - serde_json::json!({ - "functionCall": { - "name": fc.name, - "args": fc.args, - } + let fc_parts: Vec = ctx + .last_calls + .iter() + .map(|fc| { + serde_json::json!({ + "functionCall": { + "name": fc.name, + "args": fc.args, + } + }) }) - }).collect(); + .collect(); msg["parts"] = Value::Array(fc_parts); changes.push("rewrite model turn with functionCall".to_string()); break; @@ -355,29 +382,36 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option = ctx.pending_results.iter().map(|r| { - serde_json::json!({ - "functionResponse": { - "name": r.name, - "response": r.result, - } + let fn_response_parts: Vec = ctx + .pending_results + .iter() + .map(|r| { + serde_json::json!({ + "functionResponse": { + "name": r.name, + "response": r.result, + } + }) }) - }).collect(); + .collect(); let fn_response_turn = serde_json::json!({ "role": "user", "parts": fn_response_parts, }); // Insert before the last user message - let last_user_idx = contents.iter().rposition(|msg| { - msg["role"].as_str() == Some("user") - }); + let last_user_idx = contents + .iter() + .rposition(|msg| msg["role"].as_str() == Some("user")); if let Some(idx) = last_user_idx { contents.insert(idx, fn_response_turn); } else { contents.push(fn_response_turn); } - changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len())); + changes.push(format!( + "inject {} functionResponse(s)", + ctx.pending_results.len() + )); } } } @@ -420,8 +454,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option) -> Option) -> Option").unwrap(); + let text = + "keep this # Conversation History\nlots of stuff\n\nand this"; + let result = strip_between( + text, + "# Conversation History\n", + "", + ) + .unwrap(); assert_eq!(result, "keep this and this"); } } @@ -977,7 +1031,9 @@ pub fn modify_response_chunk(chunk: &[u8]) -> Option> { // Replace the JSON in the result string result.replace_range(json_start..json_start + json_end, &new_json); changed = true; - info!("MITM: rewrote functionCall in response → text placeholder for LS"); + info!( + "MITM: rewrote functionCall in response → text placeholder for LS" + ); search_from = json_start + new_json.len(); continue; } @@ -1117,7 +1173,10 @@ fn rewrite_function_calls_in_response(json: &mut Value) -> bool { } // Try nested "response.candidates" - if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) { + if let Some(candidates) = json + .pointer_mut("/response/candidates") + .and_then(|v| v.as_array_mut()) + { changed |= rewrite_candidates(candidates); } diff --git a/src/mitm/proto.rs b/src/mitm/proto.rs index 8561a59..99fc3a8 100644 --- a/src/mitm/proto.rs +++ b/src/mitm/proto.rs @@ -251,7 +251,10 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool // (e.g., a long string that happened to have a valid first-field prefix) if fields.len() == 1 && original_len > 100 { // Single-field messages of >100 bytes are suspicious unless the field is bytes/message - matches!(&fields[0].value, ProtoValue::Bytes(_) | ProtoValue::Message(_)) + matches!( + &fields[0].value, + ProtoValue::Bytes(_) | ProtoValue::Message(_) + ) } else { true } @@ -328,7 +331,9 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option { .iter() .filter_map(|f| { if let ProtoValue::Bytes(ref b) = f.value { - std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string())) + std::str::from_utf8(b) + .ok() + .map(|s| (f.number, s.to_string())) } else { None } @@ -361,14 +366,23 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option { // Check if there's a model-like string (field 7 = message_id or field 11 = response_id // can contain model names, or model enum values map to known names) let has_model_string = string_fields.iter().any(|(_, s)| { - s.contains("claude") || s.contains("gemini") || s.contains("gpt") - || s.starts_with("models/") || s.contains("sonnet") || s.contains("opus") - || s.contains("flash") || s.contains("pro") + s.contains("claude") + || s.contains("gemini") + || s.contains("gpt") + || s.starts_with("models/") + || s.contains("sonnet") + || s.contains("opus") + || s.contains("flash") + || s.contains("pro") }); // Check for fields at the known ModelUsageStats field numbers - let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_))); - let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_))); + let has_field_2 = fields + .iter() + .any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_))); + let has_field_3 = fields + .iter() + .any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_))); // Strong signal: has both input and output token fields let is_likely_usage = (has_field_2 && has_field_3) || has_model_string; @@ -392,8 +406,8 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option { // field 1 = model enum (varint, not string!) 2 => usage.input_tokens = v, 3 => usage.output_tokens = v, - 4 => usage.cache_write_tokens = v, // VERIFIED: field 4 - 5 => usage.cache_read_tokens = v, // VERIFIED: field 5 + 4 => usage.cache_write_tokens = v, // VERIFIED: field 4 + 5 => usage.cache_read_tokens = v, // VERIFIED: field 5 // field 6 = api_provider enum (varint) 9 => usage.thinking_output_tokens = v, // VERIFIED: field 9 10 => usage.response_output_tokens = v, // VERIFIED: field 10 @@ -486,11 +500,11 @@ pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option { fn model_enum_name(enum_val: u64) -> &'static str { match enum_val { // Placeholder models (1000 + N) - 1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7 - 1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8 - 1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12 - 1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18 - 1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26 + 1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7 + 1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8 + 1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12 + 1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18 + 1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26 // Claude models (named) 281 => "claude-4-sonnet", @@ -629,13 +643,13 @@ mod tests { data.push(v as u8); } - encode_varint_field(&mut data, 1, 5); // model enum - encode_varint_field(&mut data, 2, 1000); // input_tokens - encode_varint_field(&mut data, 3, 500); // output_tokens - encode_varint_field(&mut data, 4, 100); // cache_write_tokens - encode_varint_field(&mut data, 5, 200); // cache_read_tokens - encode_varint_field(&mut data, 9, 300); // thinking_output_tokens - encode_varint_field(&mut data, 10, 200); // response_output_tokens + encode_varint_field(&mut data, 1, 5); // model enum + encode_varint_field(&mut data, 2, 1000); // input_tokens + encode_varint_field(&mut data, 3, 500); // output_tokens + encode_varint_field(&mut data, 4, 100); // cache_write_tokens + encode_varint_field(&mut data, 5, 200); // cache_read_tokens + encode_varint_field(&mut data, 9, 300); // thinking_output_tokens + encode_varint_field(&mut data, 10, 200); // response_output_tokens let fields = decode_proto(&data); let usage = try_extract_usage(&fields).expect("should extract usage"); diff --git a/src/mitm/proxy.rs b/src/mitm/proxy.rs index f852a42..f3a392c 100644 --- a/src/mitm/proxy.rs +++ b/src/mitm/proxy.rs @@ -11,8 +11,7 @@ use super::ca::MitmCa; use super::intercept::{ - extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, - StreamingAccumulator, + extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator, }; use super::store::MitmStore; use std::sync::Arc; @@ -54,7 +53,6 @@ pub struct MitmConfig { pub modify_requests: bool, } - /// Run the MITM proxy server. /// /// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown. @@ -84,7 +82,8 @@ pub async fn run( let ca = ca.clone(); let store = store.clone(); tokio::spawn(async move { - if let Err(e) = handle_connection(stream, ca, store, modify_requests).await { + if let Err(e) = handle_connection(stream, ca, store, modify_requests).await + { warn!(error = %e, "MITM connection error"); } }); @@ -131,8 +130,7 @@ async fn handle_connection( .await .map_err(|e| format!("Peek ClientHello: {e}"))?; - let domain = extract_sni(&hello_buf[..n]) - .unwrap_or_else(|| "unknown".to_string()); + let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string()); info!(domain, "MITM: transparent redirect (iptables)"); @@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option { let mut pos = 34; // skip version + random // Session ID - if pos >= body.len() { return None; } + if pos >= body.len() { + return None; + } let sid_len = body[pos] as usize; pos += 1 + sid_len; // Cipher suites - if pos + 2 > body.len() { return None; } + if pos + 2 > body.len() { + return None; + } let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; pos += 2 + cs_len; // Compression methods - if pos >= body.len() { return None; } + if pos >= body.len() { + return None; + } let cm_len = body[pos] as usize; pos += 1 + cm_len; // Extensions - if pos + 2 > body.len() { return None; } + if pos + 2 > body.len() { + return None; + } let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; pos += 2; let ext_end = pos + ext_len.min(body.len() - pos); @@ -304,32 +310,32 @@ async fn handle_intercepted( info!(domain, "MITM: intercepting TLS"); // Get or create server TLS config for this domain - let server_config = ca - .server_config_for_domain(domain) - .await?; + let server_config = ca.server_config_for_domain(domain).await?; let acceptor = TlsAcceptor::from(server_config); // Perform TLS handshake with the client (LS) — 10s timeout - let tls_stream = match tokio::time::timeout( - std::time::Duration::from_secs(10), - acceptor.accept(stream), - ) - .await - { - Ok(Ok(s)) => s, - Ok(Err(e)) => { - warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)"); - return Err(format!("TLS handshake with client failed for {domain}: {e}")); - } - Err(_) => { - warn!(domain, "MITM: TLS handshake TIMED OUT after 10s"); - return Err(format!("TLS handshake timed out for {domain}")); - } - }; + let tls_stream = + match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream)) + .await + { + Ok(Ok(s)) => s, + Ok(Err(e)) => { + warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)"); + return Err(format!( + "TLS handshake with client failed for {domain}: {e}" + )); + } + Err(_) => { + warn!(domain, "MITM: TLS handshake TIMED OUT after 10s"); + return Err(format!("TLS handshake timed out for {domain}")); + } + }; // Check negotiated ALPN protocol - let alpn = tls_stream.get_ref().1 + let alpn = tls_stream + .get_ref() + .1 .alpn_protocol() .map(|p| String::from_utf8_lossy(p).to_string()); @@ -339,12 +345,7 @@ async fn handle_intercepted( Some("h2") => { // HTTP/2 — use the hyper-based gRPC handler info!(domain, "MITM: routing to HTTP/2 handler (gRPC)"); - super::h2_handler::handle_h2_connection( - tls_stream, - domain.to_string(), - store, - ) - .await + super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await } _ => { // HTTP/1.1 or no ALPN — use the existing handler @@ -434,7 +435,10 @@ async fn handle_http_over_tls( .await { let out = String::from_utf8_lossy(&output.stdout); - if let Some(ip) = out.lines().find(|l| l.parse::().is_ok()) { + if let Some(ip) = out + .lines() + .find(|l| l.parse::().is_ok()) + { return format!("{ip}:443"); } } @@ -458,7 +462,6 @@ async fn handle_http_over_tls( loop { // ── Read the HTTP request from the client ───────────────────────── let mut request_buf = Vec::with_capacity(1024 * 64); - let mut is_our_request = false; // 60s timeout on initial read (LS may open connection without sending immediately) const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); @@ -513,7 +516,8 @@ async fn handle_http_over_tls( } // Parse the HTTP request to find headers and body - let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf); + let (headers_end, content_length, _is_streaming_request) = + parse_http_request_meta(&request_buf); // Try to extract cascade hint from request body let cascade_hint = if headers_end < request_buf.len() { @@ -545,6 +549,27 @@ async fn handle_http_over_tls( "MITM: forwarding LLM request" ); + // ── Block ALL requests when one is already in-flight ───────── + // The LS opens multiple connections and sends parallel requests. + // When custom tools are active, only the FIRST request should reach + // Google. Block everything else with a fake response. + if store.is_request_in_flight() { + info!("MITM: blocking LS request — another request already in-flight"); + let fake_response = "HTTP/1.1 200 OK\r\n\ + Content-Type: text/event-stream\r\n\ + Transfer-Encoding: chunked\r\n\ + \r\n"; + let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n"; + let chunked_body = super::modify::rechunk(fake_sse.as_bytes()); + let mut response = fake_response.as_bytes().to_vec(); + response.extend_from_slice(&chunked_body); + if let Err(e) = client.write_all(&response).await { + warn!(error = %e, "MITM: failed to write fake response"); + } + let _ = client.flush().await; + continue; + } + // ── Request modification ───────────────────────────────────── // Dechunk body → check if agent request → modify → rechunk if modify_requests && body_len > 0 { @@ -565,7 +590,11 @@ async fn handle_http_over_tls( let generation_params = store.get_generation_params().await; let pending_image = store.take_pending_image().await; - let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() { + let tool_ctx = if tools.is_some() + || !pending_results.is_empty() + || generation_params.is_some() + || pending_image.is_some() + { Some(super::modify::ToolContext { tools, tool_config, @@ -578,7 +607,9 @@ async fn handle_http_over_tls( None }; - if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) { + if let Some(modified_body) = + super::modify::modify_request(&raw_body, tool_ctx.as_ref()) + { // Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body let new_chunked = super::modify::rechunk(&modified_body); @@ -588,39 +619,12 @@ async fn handle_http_over_tls( let mut new_buf = updated_headers.into_bytes(); new_buf.extend_from_slice(&new_chunked); request_buf = new_buf; - - // Mark this as our modified request and set in-flight flag - is_our_request = true; + + // Mark in-flight IMMEDIATELY — blocks all subsequent requests store.mark_request_in_flight(); } } } - - // ── Block ALL LS follow-up requests once first is in-flight ── - // When custom tools are active, we only need ONE request to Google. - // The LS tries to send multiple requests (its own agentic loop + - // internal requests on gemini-2.5-flash-lite). Block them ALL - // immediately — don't wait for response_complete. - let has_tools = store.get_tools().await.is_some(); - if has_tools && store.is_request_in_flight() && !is_our_request { - info!( - "MITM: blocking LS follow-up — request already in-flight" - ); - // Return a fake SSE response that makes the LS stop - let fake_response = "HTTP/1.1 200 OK\r\n\ - Content-Type: text/event-stream\r\n\ - Transfer-Encoding: chunked\r\n\ - \r\n"; - let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n"; - let chunked_body = super::modify::rechunk(fake_sse.as_bytes()); - let mut response = fake_response.as_bytes().to_vec(); - response.extend_from_slice(&chunked_body); - if let Err(e) = client.write_all(&response).await { - warn!(error = %e, "MITM: failed to write fake response"); - } - let _ = client.flush().await; - continue; // Skip the real upstream call - } } else { debug!( domain, @@ -674,7 +678,10 @@ async fn handle_http_over_tls( }; let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await { - Ok(Ok(0)) => { upstream_ok = false; break; } + Ok(Ok(0)) => { + upstream_ok = false; + break; + } Ok(Ok(n)) => n, Ok(Err(e)) => { debug!(domain, error = %e, "MITM: upstream read ended"); @@ -711,7 +718,9 @@ async fn handle_http_over_tls( if header.name.eq_ignore_ascii_case("content-type") { if let Ok(v) = std::str::from_utf8(header.value) { content_type = v.to_string(); - if v.contains("text/event-stream") { is_streaming_response = true; } + if v.contains("text/event-stream") { + is_streaming_response = true; + } } } if header.name.eq_ignore_ascii_case("content-length") { @@ -721,12 +730,16 @@ async fn handle_http_over_tls( } if header.name.eq_ignore_ascii_case("connection") { if let Ok(v) = std::str::from_utf8(header.value) { - if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; } + if v.trim().eq_ignore_ascii_case("close") { + upstream_ok = false; + } } } if header.name.eq_ignore_ascii_case("transfer-encoding") { if let Ok(v) = std::str::from_utf8(header.value) { - if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; } + if v.trim().eq_ignore_ascii_case("chunked") { + is_chunked = true; + } } } } @@ -749,22 +762,31 @@ async fn handle_http_over_tls( warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response"); // Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}} - let (message, error_status) = serde_json::from_str::(&body_str) - .ok() - .and_then(|v| { - let err = v.get("error")?; - let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string()); - let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string()); - Some((msg, status)) - }) - .unwrap_or((None, None)); + let (message, error_status) = + serde_json::from_str::(&body_str) + .ok() + .and_then(|v| { + let err = v.get("error")?; + let msg = err + .get("message") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()); + let status = err + .get("status") + .and_then(|s| s.as_str()) + .map(|s| s.to_string()); + Some((msg, status)) + }) + .unwrap_or((None, None)); - store.set_upstream_error(super::store::UpstreamError { - status: http_status, - body: body_str, - message, - error_status, - }).await; + store + .set_upstream_error(super::store::UpstreamError { + status: http_status, + body: body_str, + message, + error_status, + }) + .await; } // Save body for usage parsing @@ -779,10 +801,15 @@ async fn handle_http_over_tls( if !streaming_acc.function_calls.is_empty() { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); for fc in &calls { - store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; + store + .record_function_call(cascade_hint.as_deref(), fc.clone()) + .await; } store.set_last_function_calls(calls.clone()).await; - info!("MITM: stored {} function call(s) from initial body", calls.len()); + info!( + "MITM: stored {} function call(s) from initial body", + calls.len() + ); } // Capture response + thinking text + grounding into MitmStore @@ -816,7 +843,9 @@ async fn handle_http_over_tls( } if let Some(cl) = response_content_length { - if response_body_buf.len() >= cl { break; } + if response_body_buf.len() >= cl { + break; + } } // Check chunked terminator in initial body if is_chunked && has_chunked_terminator(&response_body_buf) { @@ -837,10 +866,15 @@ async fn handle_http_over_tls( if !streaming_acc.function_calls.is_empty() { let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); for fc in &calls { - store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; + store + .record_function_call(cascade_hint.as_deref(), fc.clone()) + .await; } store.set_last_function_calls(calls.clone()).await; - info!("MITM: stored {} function call(s) from body chunk", calls.len()); + info!( + "MITM: stored {} function call(s) from body chunk", + calls.len() + ); } // Capture response + thinking text + grounding into MitmStore @@ -875,7 +909,9 @@ async fn handle_http_over_tls( response_body_buf.extend_from_slice(chunk); if let Some(cl) = response_content_length { - if response_body_buf.len() >= cl { break; } + if response_body_buf.len() >= cl { + break; + } } if is_chunked && has_chunked_terminator(&response_body_buf) { debug!(domain, "MITM: chunked response complete"); @@ -912,11 +948,7 @@ async fn handle_http_over_tls( } /// Handle a passthrough connection: transparent TCP tunnel to upstream. -async fn handle_passthrough( - mut client: TcpStream, - domain: &str, - port: u16, -) -> Result<(), String> { +async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> { trace!(domain, port, "MITM: transparent tunnel"); let mut upstream = TcpStream::connect(format!("{domain}:{port}")) @@ -926,7 +958,12 @@ async fn handle_passthrough( // Bidirectional copy match tokio::io::copy_bidirectional(&mut client, &mut upstream).await { Ok((client_to_server, server_to_client)) => { - trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed"); + trace!( + domain, + client_to_server, + server_to_client, + "MITM: tunnel closed" + ); } Err(e) => { trace!(domain, error = %e, "MITM: tunnel error (likely clean close)"); @@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool { return false; } // Check last 7 bytes to account for possible trailing whitespace - let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body }; + let tail = if body.len() > 7 { + &body[body.len() - 7..] + } else { + body + }; // Look for \r\n0\r\n\r\n anywhere in the tail tail.windows(5).any(|w| w == b"0\r\n\r\n") } diff --git a/src/mitm/store.rs b/src/mitm/store.rs index c30aa9c..c753953 100644 --- a/src/mitm/store.rs +++ b/src/mitm/store.rs @@ -2,11 +2,11 @@ //! //! The MITM proxy writes usage data here; the API handlers read from it. -use std::collections::HashMap; -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::sync::RwLock; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use tokio::sync::RwLock; use tracing::{debug, info}; /// Token usage from an intercepted API response. @@ -342,7 +342,9 @@ impl MitmStore { /// Record a captured function call from Google's response. pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { - let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string()); + let key = cascade_id + .map(|s| s.to_string()) + .unwrap_or_else(|| "_latest".to_string()); info!( cascade = %key, tool = %fc.name, @@ -377,7 +379,6 @@ impl MitmStore { self.awaiting_tool_result.store(false, Ordering::SeqCst); } - /// Take any pending function calls (ignoring cascade ID). pub async fn take_any_function_calls(&self) -> Option> { let mut pending = self.pending_function_calls.write().await; @@ -457,8 +458,6 @@ impl MitmStore { // ── Direct response capture (bypass LS) ────────────────────────────── - - /// Set (replace) the captured response text. pub async fn set_response_text(&self, text: &str) { *self.captured_response_text.write().await = Some(text.to_string()); @@ -484,8 +483,6 @@ impl MitmStore { self.response_complete.load(Ordering::SeqCst) } - - /// Async version of clear_response. pub async fn clear_response_async(&self) { self.response_complete.store(false, Ordering::SeqCst); diff --git a/src/proto.rs b/src/proto.rs index 8801284..1333192 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -293,8 +293,7 @@ mod tests { let cascade_bytes = b"test-cascade-id"; assert!( - msg.windows(cascade_bytes.len()) - .any(|w| w == cascade_bytes), + msg.windows(cascade_bytes.len()).any(|w| w == cascade_bytes), "cascade_id must appear in output" ); diff --git a/src/quota.rs b/src/quota.rs index 686a34a..f1ba439 100644 --- a/src/quota.rs +++ b/src/quota.rs @@ -93,9 +93,8 @@ impl QuotaStore { // Initial poll immediately. self.poll_once(&backend).await; - let mut interval = tokio::time::interval( - std::time::Duration::from_secs(POLL_INTERVAL_SECS), - ); + let mut interval = + tokio::time::interval(std::time::Duration::from_secs(POLL_INTERVAL_SECS)); interval.tick().await; // consume the first immediate tick loop { @@ -125,7 +124,9 @@ impl QuotaStore { // Profile picture fetch fails through iptables — harmless, suppress let data_str = data.to_string(); if data_str.contains("profile picture") { - tracing::debug!("GetUserStatus: profile picture fetch failed (expected with iptables)"); + tracing::debug!( + "GetUserStatus: profile picture fetch failed (expected with iptables)" + ); } else { warn!("GetUserStatus returned {status}: {data_str}"); } @@ -172,9 +173,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot { .as_str() .unwrap_or("") .to_string(); - let frac = m["quotaInfo"]["remainingFraction"] - .as_f64() - .unwrap_or(0.0); + let frac = m["quotaInfo"]["remainingFraction"].as_f64().unwrap_or(0.0); let reset_str = m["quotaInfo"]["resetTime"] .as_str() .unwrap_or("") @@ -224,9 +223,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot { flow_available: flow_avail, flow_total, flow_used_pct, - flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"] - .as_i64() - .unwrap_or(0), + flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"].as_i64().unwrap_or(0), can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false), }, models, diff --git a/src/session.rs b/src/session.rs index 2daeba0..679b9cc 100644 --- a/src/session.rs +++ b/src/session.rs @@ -66,9 +66,7 @@ impl SessionManager { msg_count: 0, }, ); - return Ok(SessionResult { - cascade_id, - }); + return Ok(SessionResult { cascade_id }); } let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string(); @@ -111,9 +109,7 @@ impl SessionManager { }, ); } - Ok(SessionResult { - cascade_id, - }) + Ok(SessionResult { cascade_id }) } /// List all active sessions. @@ -146,7 +142,5 @@ impl SessionManager { fn cleanup_expired(sessions: &mut HashMap) { let now = Instant::now(); - sessions.retain(|_, s| { - now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS - }); + sessions.retain(|_, s| now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS); } diff --git a/src/snapshot.rs b/src/snapshot.rs index 80f8cb8..2c4647f 100644 --- a/src/snapshot.rs +++ b/src/snapshot.rs @@ -10,16 +10,44 @@ use std::io::{self, Read}; // ── Domain metadata ────────────────────────────────────────────────────────── const DOMAIN_INFO: &[(&str, &str, &str)] = &[ - ("antigravity-unleash.goog", "Feature Flags", "Unleash SDK — controls A/B tests and feature rollouts"), - ("daily-cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Primary Gemini/Claude API endpoint"), - ("cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Production Gemini/Claude API endpoint"), - ("api.anthropic.com", "Claude API", "Direct Anthropic API calls"), - ("lh3.googleusercontent.com", "Profile Picture", "User avatar"), + ( + "antigravity-unleash.goog", + "Feature Flags", + "Unleash SDK — controls A/B tests and feature rollouts", + ), + ( + "daily-cloudcode-pa.googleapis.com", + "LLM API (gRPC)", + "Primary Gemini/Claude API endpoint", + ), + ( + "cloudcode-pa.googleapis.com", + "LLM API (gRPC)", + "Production Gemini/Claude API endpoint", + ), + ( + "api.anthropic.com", + "Claude API", + "Direct Anthropic API calls", + ), + ( + "lh3.googleusercontent.com", + "Profile Picture", + "User avatar", + ), ("play.googleapis.com", "Telemetry", "Google Play telemetry"), - ("firebaseinstallations.googleapis.com", "Firebase", "Installation tracking"), + ( + "firebaseinstallations.googleapis.com", + "Firebase", + "Installation tracking", + ), ("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"), ("speech.googleapis.com", "Speech", "Voice input processing"), - ("modelarmor.googleapis.com", "Safety", "Content safety/filtering"), + ( + "modelarmor.googleapis.com", + "Safety", + "Content safety/filtering", + ), ]; fn domain_label(domain: &str) -> (&str, &str) { @@ -57,8 +85,8 @@ struct HttpExchange { #[derive(Debug, Clone, Copy, PartialEq)] enum Direction { - Outgoing, // LS → upstream - Incoming, // external → LS (our curl calls) + Outgoing, // LS → upstream + Incoming, // external → LS (our curl calls) } #[derive(Default)] @@ -101,10 +129,12 @@ impl Snapshot { // LS process logs if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E')) - && line.len() > 4 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) { - snap.ls_logs.push(line.to_string()); - continue; - } + && line.len() > 4 + && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) + { + snap.ls_logs.push(line.to_string()); + continue; + } if line.contains("maxprocs:") { snap.ls_logs.push(line.to_string()); continue; @@ -128,8 +158,15 @@ impl Snapshot { if let Some((key, val)) = extract_header(line, "Transport encoding header") { if key == ":method" { // Finalize previous exchange - if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { - snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone()); + if current_pseudo.contains_key(":path") + || current_pseudo.contains_key(":method") + { + snap.finalize_exchange( + ¤t_pseudo, + ¤t_headers, + current_direction, + current_stream.clone(), + ); } current_headers.clear(); current_pseudo.clear(); @@ -147,8 +184,15 @@ impl Snapshot { // Incoming / server-received headers if let Some((key, val)) = extract_header(line, "decoded hpack field header field") { if key == ":authority" && !line.contains("server read frame") { - if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { - snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone()); + if current_pseudo.contains_key(":path") + || current_pseudo.contains_key(":method") + { + snap.finalize_exchange( + ¤t_pseudo, + ¤t_headers, + current_direction, + current_stream.clone(), + ); } current_headers.clear(); current_pseudo.clear(); @@ -167,8 +211,15 @@ impl Snapshot { if line.contains("wrote HEADERS") { if let Some(stream) = extract_stream_id(line) { current_stream = Some(stream.clone()); - if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { - let ex = snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, Some(stream)); + if current_pseudo.contains_key(":path") + || current_pseudo.contains_key(":method") + { + let ex = snap.finalize_exchange( + ¤t_pseudo, + ¤t_headers, + current_direction, + Some(stream), + ); if ex.is_some() { current_headers.clear(); current_pseudo.clear(); @@ -179,10 +230,13 @@ impl Snapshot { } // DATA frames - if (line.contains("wrote DATA") || line.contains("read DATA") || line.contains("server read frame DATA")) + if (line.contains("wrote DATA") + || line.contains("read DATA") + || line.contains("server read frame DATA")) && line.contains("data=\"") { - let is_outgoing = line.contains("wrote DATA") || line.contains("server read frame DATA"); + let is_outgoing = + line.contains("wrote DATA") || line.contains("server read frame DATA"); if let Some(stream) = extract_stream_id(line) { if let Some(data_str) = extract_data(line) { let raw = decode_go_escaped(&data_str); @@ -203,7 +257,12 @@ impl Snapshot { // Finalize remaining if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { - snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream); + snap.finalize_exchange( + ¤t_pseudo, + ¤t_headers, + current_direction, + current_stream, + ); } snap @@ -226,7 +285,11 @@ impl Snapshot { self.exchanges.push(HttpExchange { authority, - method: if method.is_empty() { "GET".into() } else { method }, + method: if method.is_empty() { + "GET".into() + } else { + method + }, path, headers: headers.to_vec(), body: Vec::new(), @@ -245,7 +308,9 @@ impl Snapshot { let sep = "═".repeat(70); let sep_thin = "─".repeat(60); out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n")); - out.push_str(&format!("{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n")); + out.push_str(&format!( + "{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n" + )); out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n")); // LS Logs @@ -265,7 +330,9 @@ impl Snapshot { for target in &self.connections { let domain = target.split(':').next().unwrap_or(target); let (label, desc) = domain_label(domain); - out.push_str(&format!(" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n")); + out.push_str(&format!( + " {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n" + )); if !desc.is_empty() { out.push_str(&format!(" {DIM}{desc}{NC}\n")); } @@ -276,7 +343,10 @@ impl Snapshot { // Group by domain let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new(); for ex in &self.exchanges { - if let Some(entry) = by_domain.iter_mut().find(|(d, _)| *d == ex.authority.as_str()) { + if let Some(entry) = by_domain + .iter_mut() + .find(|(d, _)| *d == ex.authority.as_str()) + { entry.1.push(ex); } else { by_domain.push((&ex.authority, vec![ex])); @@ -293,12 +363,17 @@ impl Snapshot { let color = if label.contains("API") { YELLOW } else { CYAN }; out.push_str(&format!("\n{BOLD}{sep}{NC}\n")); - out.push_str(&format!("{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n")); + out.push_str(&format!( + "{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n" + )); out.push_str(&format!("{BOLD}{sep}{NC}\n")); for ex in exchanges { let method_color = if ex.method == "GET" { GREEN } else { YELLOW }; - out.push_str(&format!("\n {BOLD}→ {method_color}{}{NC} {}\n", ex.method, ex.path)); + out.push_str(&format!( + "\n {BOLD}→ {method_color}{}{NC} {}\n", + ex.method, ex.path + )); // Interesting headers for (key, val) in &ex.headers { @@ -342,7 +417,10 @@ fn render_body(data: &[u8], total_len: usize) -> String { out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n")); for (i, line) in pretty.lines().enumerate() { if i >= 40 { - out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 40)); + out.push_str(&format!( + " {DIM}... ({} more lines){NC}\n", + pretty.lines().count() - 40 + )); break; } out.push_str(&format!(" {GREEN}{line}{NC}\n")); @@ -357,10 +435,16 @@ fn render_body(data: &[u8], total_len: usize) -> String { if let Ok(text) = std::str::from_utf8(&decompressed) { if let Ok(val) = serde_json::from_str::(text) { let pretty = serde_json::to_string_pretty(&val).unwrap_or_default(); - out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", decompressed.len())); + out.push_str(&format!( + " {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", + decompressed.len() + )); for (i, line) in pretty.lines().enumerate() { if i >= 50 { - out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 50)); + out.push_str(&format!( + " {DIM}... ({} more lines){NC}\n", + pretty.lines().count() - 50 + )); break; } out.push_str(&format!(" {GREEN}{line}{NC}\n")); @@ -368,14 +452,20 @@ fn render_body(data: &[u8], total_len: usize) -> String { return out; } // Plain text - out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", decompressed.len())); + out.push_str(&format!( + " {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", + decompressed.len() + )); for line in text.lines().take(20) { out.push_str(&format!(" {line}\n")); } return out; } // Binary gzip - out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", decompressed.len())); + out.push_str(&format!( + " {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", + decompressed.len() + )); let strings = extract_strings(&decompressed); for s in strings.iter().take(15) { out.push_str(&format!(" {MAGENTA}{s}{NC}\n")); @@ -393,7 +483,11 @@ fn render_body(data: &[u8], total_len: usize) -> String { // Protobuf / binary with string extraction let strings = extract_strings(data); if !strings.is_empty() { - let kind = if !data.is_empty() && matches!(data[0], 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22) { + let kind = if !data.is_empty() + && matches!( + data[0], + 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22 + ) { "protobuf" } else { "binary" @@ -448,7 +542,9 @@ fn extract_header(line: &str, pattern: &str) -> Option<(String, String)> { fn extract_stream_id(line: &str) -> Option { let pos = line.find("stream=")?; let rest = &line[pos + 7..]; - let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len()); + let end = rest + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(rest.len()); Some(rest[..end].to_string()) } @@ -470,7 +566,9 @@ fn extract_data(line: &str) -> Option { fn extract_data_len(line: &str) -> Option { let pos = line.find("len=")?; let rest = &line[pos + 4..]; - let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len()); + let end = rest + .find(|c: char| !c.is_ascii_digit()) + .unwrap_or(rest.len()); rest[..end].parse().ok() } @@ -482,17 +580,40 @@ fn decode_go_escaped(s: &str) -> Vec { if bytes[i] == b'\\' && i + 1 < bytes.len() { match bytes[i + 1] { b'x' if i + 3 < bytes.len() => { - if let Ok(b) = u8::from_str_radix(std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), 16) { + if let Ok(b) = u8::from_str_radix( + std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), + 16, + ) { result.push(b); i += 4; continue; } } - b'n' => { result.push(b'\n'); i += 2; continue; } - b'r' => { result.push(b'\r'); i += 2; continue; } - b't' => { result.push(b'\t'); i += 2; continue; } - b'\\' => { result.push(b'\\'); i += 2; continue; } - b'"' => { result.push(b'"'); i += 2; continue; } + b'n' => { + result.push(b'\n'); + i += 2; + continue; + } + b'r' => { + result.push(b'\r'); + i += 2; + continue; + } + b't' => { + result.push(b'\t'); + i += 2; + continue; + } + b'\\' => { + result.push(b'\\'); + i += 2; + continue; + } + b'"' => { + result.push(b'"'); + i += 2; + continue; + } _ => {} } } @@ -562,7 +683,10 @@ pub fn run_cli() { }) } else { let mut buf = String::new(); - io::stdin().lock().read_to_string(&mut buf).expect("Failed to read stdin"); + io::stdin() + .lock() + .read_to_string(&mut buf) + .expect("Failed to read stdin"); buf }; diff --git a/src/standalone.rs b/src/standalone.rs index 776a3bb..6806b32 100644 --- a/src/standalone.rs +++ b/src/standalone.rs @@ -108,7 +108,10 @@ pub struct MainLSConfig { /// and CSRF is a random UUID. pub fn generate_standalone_config() -> MainLSConfig { let csrf = Uuid::new_v4().to_string(); - info!(csrf_len = csrf.len(), "Generated standalone config (headless)"); + info!( + csrf_len = csrf.len(), + "Generated standalone config (headless)" + ); MainLSConfig { extension_server_port: "0".to_string(), // disables extension server csrf, @@ -159,7 +162,13 @@ impl StandaloneLS { let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone"); let annotations_dir = format!("{app_data_dir}/annotations"); let brain_dir = format!("{app_data_dir}/brain"); - for dir in [DATA_DIR, &gemini_dir, &app_data_dir, &annotations_dir, &brain_dir] { + for dir in [ + DATA_DIR, + &gemini_dir, + &app_data_dir, + &annotations_dir, + &brain_dir, + ] { let _ = std::fs::create_dir_all(dir); #[cfg(unix)] { @@ -194,7 +203,10 @@ impl StandaloneLS { #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; - let _ = std::fs::set_permissions(&settings_path, std::fs::Permissions::from_mode(0o0666)); + let _ = std::fs::set_permissions( + &settings_path, + std::fs::Permissions::from_mode(0o0666), + ); } tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)"); } @@ -203,10 +215,7 @@ impl StandaloneLS { // The LS connects to this port and calls LanguageServerStarted — without it, // the LS never fully initializes and won't accept connections on its server_port. let _stub_listener = if headless { - let stub_port: u16 = main_config - .extension_server_port - .parse() - .unwrap_or(0); + let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0); if stub_port == 0 { // Create a real listener so the LS can connect let listener = TcpListener::bind("127.0.0.1:0") @@ -215,7 +224,10 @@ impl StandaloneLS { .local_addr() .map_err(|e| format!("Failed to get stub port: {e}"))? .port(); - info!(port = actual_port, "Stub extension server listening (headless)"); + info!( + port = actual_port, + "Stub extension server listening (headless)" + ); // Read OAuth state from Antigravity's state.vscdb if available. // The DB stores the exact Topic proto (access_token + refresh_token + expiry) // which lets the LS auto-refresh tokens via its built-in Google OAuth2 client. @@ -306,10 +318,7 @@ impl StandaloneLS { // 3. MITM proxy intercepts the transparent TLS connection via SNI if let Some(mitm) = mitm_config { // Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742") - let mitm_port = mitm.proxy_addr - .rsplit(':') - .next() - .unwrap_or("8742"); + let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742"); format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}") } else { "https://daily-cloudcode-pa.googleapis.com".to_string() @@ -324,9 +333,8 @@ impl StandaloneLS { debug!(?args, "LS args"); // Build env vars for the LS process - let mut env_vars: Vec<(String, String)> = vec![ - ("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into()), - ]; + let mut env_vars: Vec<(String, String)> = + vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())]; // If MITM is enabled, add SSL + proxy env vars if let Some(mitm) = mitm_config { @@ -335,8 +343,8 @@ impl StandaloneLS { // Write to /tmp — accessible by antigravity-ls user // (user's ~/.config/ is not traversable by other UIDs) let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string(); - let system_ca = std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt") - .unwrap_or_default(); + let system_ca = + std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default(); let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path) .map_err(|e| format!("Failed to read MITM CA cert: {e}"))?; std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}")) @@ -441,7 +449,11 @@ impl StandaloneLS { }; if let Some(pid) = ls_pid { - info!(ls_pid = pid, sudo = use_sudo, "Discovered actual LS process"); + info!( + ls_pid = pid, + sudo = use_sudo, + "Discovered actual LS process" + ); } Ok(StandaloneLS { @@ -617,8 +629,7 @@ fn find_main_ls_pid() -> Result { return Err("No /proc filesystem".to_string()); } - let entries = std::fs::read_dir(proc) - .map_err(|e| format!("Cannot read /proc: {e}"))?; + let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?; for entry in entries.flatten() { let name = entry.file_name(); @@ -704,12 +715,10 @@ fn cleanup_orphaned_ls() { .output(); let pids: Vec = match output { - Ok(out) => { - String::from_utf8_lossy(&out.stdout) - .lines() - .filter_map(|l| l.trim().parse().ok()) - .collect() - } + Ok(out) => String::from_utf8_lossy(&out.stdout) + .lines() + .filter_map(|l| l.trim().parse().ok()) + .collect(), Err(_) => return, }; @@ -717,7 +726,11 @@ fn cleanup_orphaned_ls() { return; } - info!(count = pids.len(), ?pids, "Cleaning up orphaned standalone LS processes"); + info!( + count = pids.len(), + ?pids, + "Cleaning up orphaned standalone LS processes" + ); // Kill each PID by running `kill` AS the antigravity-ls user. // This works because same-UID processes can signal each other, @@ -870,7 +883,8 @@ fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option { // Simple approach: convert to string and find base64 pattern let as_str = String::from_utf8_lossy(topic_bytes); // The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited) - for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') { + for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') + { if segment.len() > 50 { if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) { // Try to extract field 1 (access_token) from the OAuthTokenInfo proto @@ -951,7 +965,11 @@ fn decode_varint_at(buf: &[u8], offset: usize) -> Option<(u64, usize)> { /// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream. /// If we immediately close it, the LS reconnects in a tight loop and never /// proceeds to fetch OAuth tokens. We keep subscription connections OPEN. -fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_topic_bytes: &Option>) { +fn stub_handle_connection( + conn: std::net::TcpStream, + oauth_token: &str, + oauth_topic_bytes: &Option>, +) { use std::io::{BufRead, BufReader, Read, Write}; let mut reader = BufReader::new(match conn.try_clone() { @@ -1028,7 +1046,7 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to i += 1; if i + len <= proto_body.len() { if field_num == 1 { - topic_name = String::from_utf8_lossy(&proto_body[i..i+len]).to_string(); + topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string(); } i += len; } else { @@ -1084,7 +1102,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to // This includes access_token + refresh_token + expiry, so the // LS can auto-refresh tokens via its built-in Google OAuth2 client. initial_state_bytes = topic_bytes.clone(); - eprintln!("[stub-ext] using state.vscdb topic ({} bytes)", topic_bytes.len()); + eprintln!( + "[stub-ext] using state.vscdb topic ({} bytes)", + topic_bytes.len() + ); } else if !oauth_token.is_empty() { // Manual token fallback — construct OAuthTokenInfo with far-future expiry // (no refresh_token, so the LS can't auto-refresh) @@ -1155,7 +1176,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to if !send_chunk(&mut writer, &initial_env) { return; } - eprintln!("[stub-ext] STREAM → sent initial_state ({} bytes)", initial_state_bytes.len()); + eprintln!( + "[stub-ext] STREAM → sent initial_state ({} bytes)", + initial_state_bytes.len() + ); // (applied_update removed — data is in initial_state) @@ -1197,7 +1221,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to if !oauth_token.is_empty() { // Build protobuf: GetSecretValueResponse { string value = 1 } let proto = encode_proto_string(1, oauth_token.as_bytes()); - eprintln!("[stub-ext] → serving token ({} bytes) for key={key:?}", oauth_token.len()); + eprintln!( + "[stub-ext] → serving token ({} bytes) for key={key:?}", + oauth_token.len() + ); // Data envelope: flag=0x00, length, data envelope.push(0x00u8); diff --git a/src/warmup.rs b/src/warmup.rs index a447b95..9e44947 100644 --- a/src/warmup.rs +++ b/src/warmup.rs @@ -34,7 +34,9 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) { ) .await { - Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"), + Ok(Ok((status, _))) => { + info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}") + } Ok(Err(e)) => warn!("SetUserSettings failed: {e}"), Err(_) => warn!("SetUserSettings timed out"), } @@ -59,12 +61,7 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) { for (method, body) in calls { // Timeout per call — in headless mode, the LS can't reach Google's API // so these would hang forever without a timeout. Warmup is best-effort. - match tokio::time::timeout( - Duration::from_secs(5), - backend.call_json(method, body), - ) - .await - { + match tokio::time::timeout(Duration::from_secs(5), backend.call_json(method, body)).await { Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"), Ok(Err(e)) => warn!("Warmup {method} failed: {e}"), Err(_) => warn!("Warmup {method} timed out"), @@ -87,10 +84,7 @@ pub fn start_heartbeat(backend: Arc) -> JoinHandle<()> { let interval_ms = rand::thread_rng().gen_range(29_500..30_500); tokio::time::sleep(Duration::from_millis(interval_ms)).await; - match backend - .call_json("Heartbeat", &serde_json::json!({})) - .await - { + match backend.call_json("Heartbeat", &serde_json::json!({})).await { Ok((status, _)) => debug!("Heartbeat: {status}"), Err(e) => warn!("Heartbeat failed: {e}"), }