fix: block ALL LS follow-up requests across connections

Move the in-flight blocking check to the top of the LLM request flow,
BEFORE request modification. This catches follow-ups on ALL connections
(the LS opens multiple parallel TLS connections). Only the very first
modified request reaches Google — all others get fake STOP responses.

Previously, each new connection independently allowed one request
through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
Nikketryhard
2026-02-16 00:57:33 -06:00
parent a8f3c8915f
commit 3fdd0368a0
23 changed files with 992 additions and 568 deletions

View File

@@ -10,9 +10,11 @@ use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; use super::polling::{
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
};
use super::types::*; use super::types::*;
use super::util::{err_response, upstream_err_response, now_unix}; use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState; use super::AppState;
/// Extract a conversation/session ID from a flexible JSON value. /// Extract a conversation/session ID from a flexible JSON value.
@@ -33,7 +35,8 @@ fn system_fingerprint() -> String {
/// Build a streaming chunk JSON with all required OpenAI fields. /// Build a streaming chunk JSON with all required OpenAI fields.
/// Includes system_fingerprint, service_tier, and logprobs:null in choices. /// Includes system_fingerprint, service_tier, and logprobs:null in choices.
fn chunk_json( fn chunk_json(
id: &str, model: &str, id: &str,
model: &str,
choices: serde_json::Value, choices: serde_json::Value,
usage: Option<serde_json::Value>, usage: Option<serde_json::Value>,
) -> String { ) -> String {
@@ -53,7 +56,11 @@ fn chunk_json(
} }
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs). /// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str>) -> serde_json::Value { fn chunk_choice(
index: u32,
delta: serde_json::Value,
finish_reason: Option<&str>,
) -> serde_json::Value {
serde_json::json!({ serde_json::json!({
"index": index, "index": index,
"delta": delta, "delta": delta,
@@ -70,7 +77,9 @@ fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str { fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
match stop_reason { match stop_reason {
Some("MAX_TOKENS") => "length", Some("MAX_TOKENS") => "length",
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => "content_filter", Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
"content_filter"
}
_ => "stop", _ => "stop",
} }
} }
@@ -84,7 +93,9 @@ fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
/// sends the entire messages array to the model. /// sends the entire messages array to the model.
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) { fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
// Extract image from last user message content array // Extract image from last user message content array
let image = messages.iter().rev() let image = messages
.iter()
.rev()
.find(|m| m.role == "user") .find(|m| m.role == "user")
.and_then(|m| super::util::extract_first_image(&m.content)); .and_then(|m| super::util::extract_first_image(&m.content));
// Always build the full conversation // Always build the full conversation
@@ -141,10 +152,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
if let Some(func) = tc.get("function") { if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown"); let name = func["name"].as_str().unwrap_or("unknown");
let args = func["arguments"].as_str().unwrap_or("{}"); let args = func["arguments"].as_str().unwrap_or("{}");
parts.push(format!( parts.push(format!("[Tool call: {}({})]", name, args));
"[Tool call: {}({})]",
name, args
));
} }
} }
} }
@@ -153,10 +161,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
let text = extract_message_text(&msg.content); let text = extract_message_text(&msg.content);
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown"); let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
if !text.is_empty() { if !text.is_empty() {
parts.push(format!( parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
"[Tool result ({})]:\n{}",
tool_id, text
));
} }
} }
_ => {} _ => {}
@@ -202,7 +207,10 @@ pub(crate) async fn handle_completions(
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice); let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
state.mitm_store.set_tool_config(gemini_config).await; state.mitm_store.set_tool_config(gemini_config).await;
} }
info!(count = tools.len(), "Completions: stored client tools for MITM injection"); info!(
count = tools.len(),
"Completions: stored client tools for MITM injection"
);
} else { } else {
state.mitm_store.clear_tools().await; state.mitm_store.clear_tools().await;
} }
@@ -239,10 +247,15 @@ pub(crate) async fn handle_completions(
google_search: body.web_search, google_search: body.web_search,
}; };
// Only store if at least one param is set // Only store if at least one param is set
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() if gp.temperature.is_some()
|| gp.frequency_penalty.is_some() || gp.presence_penalty.is_some() || gp.top_p.is_some()
|| gp.reasoning_effort.is_some() || gp.stop_sequences.is_some() || gp.max_output_tokens.is_some()
|| gp.response_mime_type.is_some() || gp.response_schema.is_some() || gp.frequency_penalty.is_some()
|| gp.presence_penalty.is_some()
|| gp.reasoning_effort.is_some()
|| gp.stop_sequences.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search || gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; state.mitm_store.set_generation_params(gp).await;
@@ -306,12 +319,13 @@ pub(crate) async fn handle_completions(
// Store image for MITM injection (LS doesn't forward images to Google API) // Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image { if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state.mitm_store.set_pending_image( state
crate::mitm::store::PendingImage { .mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
} })
).await; .await;
} }
match state match state
.backend .backend
@@ -346,7 +360,10 @@ pub(crate) async fn handle_completions(
uuid::Uuid::new_v4().to_string().replace('-', "") uuid::Uuid::new_v4().to_string().replace('-', "")
); );
let include_usage = body.stream_options.as_ref().map_or(false, |o| o.include_usage); let include_usage = body
.stream_options
.as_ref()
.map_or(false, |o| o.include_usage);
if body.stream { if body.stream {
chat_completions_stream( chat_completions_stream(
@@ -374,11 +391,17 @@ pub(crate) async fn handle_completions(
match state.backend.create_cascade().await { match state.backend.create_cascade().await {
Ok(cid) => { Ok(cid) => {
// Send the same message on each extra cascade // Send the same message on each extra cascade
match state.backend.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()).await { match state
.backend
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => { Ok((200, _)) => {
let bg = Arc::clone(&state.backend); let bg = Arc::clone(&state.backend);
let cid2 = cid.clone(); let cid2 = cid.clone();
tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; }); tokio::spawn(async move {
let _ = bg.update_annotations(&cid2).await;
});
extra_cascade_ids.push(cid); extra_cascade_ids.push(cid);
} }
_ => {} // Skip failed cascades _ => {} // Skip failed cascades
@@ -420,7 +443,12 @@ pub(crate) async fn handle_completions(
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()), mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
); );
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm { let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
(mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens) (
mu.input_tokens,
mu.output_tokens,
mu.cache_read_input_tokens,
mu.thinking_output_tokens,
)
} else if let Some(u) = &result.usage { } else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0) (u.input_tokens, u.output_tokens, 0, 0)
} else { } else {
@@ -874,15 +902,22 @@ async fn chat_completions_sync(
None => state.mitm_store.take_usage("_latest").await, None => state.mitm_store.take_usage("_latest").await,
}; };
let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref())); let finish_reason =
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm { let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens) if let Some(ref mitm_usage) = mitm {
} else if let Some(u) = &result.usage { (
(u.input_tokens, u.output_tokens, 0, 0) mitm_usage.input_tokens,
} else { mitm_usage.output_tokens,
(0, 0, 0, 0) mitm_usage.cache_read_input_tokens,
}; mitm_usage.thinking_output_tokens,
)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
(0, 0, 0, 0)
};
// Build message object, including reasoning_content if thinking is present // Build message object, including reasoning_content if thinking is present
let mut message = serde_json::json!({ let mut message = serde_json::json!({

View File

@@ -15,7 +15,9 @@ use std::sync::Arc;
use tracing::{info, warn}; use tracing::{info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response}; use super::polling::{
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
};
use super::util::{err_response, upstream_err_response}; use super::util::{err_response, upstream_err_response};
use super::AppState; use super::AppState;
use crate::mitm::store::PendingToolResult; use crate::mitm::store::PendingToolResult;
@@ -84,7 +86,9 @@ async fn build_usage_metadata(
store: &crate::mitm::store::MitmStore, store: &crate::mitm::store::MitmStore,
cascade_id: &str, cascade_id: &str,
) -> serde_json::Value { ) -> serde_json::Value {
let usage = store.take_usage(cascade_id).await let usage = store
.take_usage(cascade_id)
.await
.or(store.take_usage("_latest").await); .or(store.take_usage("_latest").await);
if let Some(usage) = usage { if let Some(usage) = usage {
serde_json::json!({ serde_json::json!({
@@ -152,13 +156,12 @@ pub(crate) async fn handle_gemini(
// Gemini-native inlineData format // Gemini-native inlineData format
if image.is_none() { if image.is_none() {
if let Some(inline) = obj.get("inlineData") { if let Some(inline) = obj.get("inlineData") {
if let (Some(mime), Some(b64)) = ( if let (Some(mime), Some(b64)) =
inline["mimeType"].as_str(), (inline["mimeType"].as_str(), inline["data"].as_str())
inline["data"].as_str(), {
) { if let Some(img) = super::util::parse_data_uri(&format!(
if let Some(img) = super::util::parse_data_uri( "data:{mime};base64,{b64}"
&format!("data:{mime};base64,{b64}") )) {
) {
image = Some(img); image = Some(img);
} }
} }
@@ -194,7 +197,10 @@ pub(crate) async fn handle_gemini(
if let Some(ref tools) = body.tools { if let Some(ref tools) = body.tools {
if !tools.is_empty() { if !tools.is_empty() {
state.mitm_store.set_tools(tools.clone()).await; state.mitm_store.set_tools(tools.clone()).await;
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection"); info!(
count = tools.len(),
"Stored Gemini-native tools for MITM injection"
);
} }
} }
if let Some(ref config) = body.tool_config { if let Some(ref config) = body.tool_config {
@@ -207,13 +213,19 @@ pub(crate) async fn handle_gemini(
if let Some(fr) = r.get("functionResponse") { if let Some(fr) = r.get("functionResponse") {
let name = fr["name"].as_str().unwrap_or("unknown").to_string(); let name = fr["name"].as_str().unwrap_or("unknown").to_string();
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({})); let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
state.mitm_store.add_tool_result(PendingToolResult { state
name, .mitm_store
result: response, .add_tool_result(PendingToolResult {
}).await; name,
result: response,
})
.await;
} }
} }
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection"); info!(
count = results.len(),
"Stored Gemini-native tool results for MITM injection"
);
} }
// Store generation parameters for MITM injection // Store generation parameters for MITM injection
@@ -232,9 +244,13 @@ pub(crate) async fn handle_gemini(
response_schema: None, response_schema: None,
google_search: body.google_search, google_search: body.google_search,
}; };
if gp.temperature.is_some() || gp.top_p.is_some() || gp.top_k.is_some() if gp.temperature.is_some()
|| gp.max_output_tokens.is_some() || gp.stop_sequences.is_some() || gp.top_p.is_some()
|| gp.reasoning_effort.is_some() || gp.google_search || gp.top_k.is_some()
|| gp.max_output_tokens.is_some()
|| gp.stop_sequences.is_some()
|| gp.reasoning_effort.is_some()
|| gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; state.mitm_store.set_generation_params(gp).await;
} else { } else {
@@ -277,12 +293,13 @@ pub(crate) async fn handle_gemini(
// Store image for MITM injection (LS doesn't forward images to Google API) // Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image { if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state.mitm_store.set_pending_image( state
crate::mitm::store::PendingImage { .mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
} })
).await; .await;
} }
match state match state
.backend .backend
@@ -372,7 +389,11 @@ async fn gemini_sync(
// Check for completed text response // Check for completed text response
if state.mitm_store.is_response_complete() { if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default(); let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await; let thinking = state.mitm_store.take_thinking_text().await;
// Guard against stale response_complete with no data // Guard against stale response_complete with no data

View File

@@ -44,7 +44,6 @@ pub fn router(state: Arc<AppState>) -> Router {
post(completions::handle_completions), post(completions::handle_completions),
) )
.route("/v1/gemini", post(gemini::handle_gemini)) .route("/v1/gemini", post(gemini::handle_gemini))
.route("/v1/models", get(handle_models)) .route("/v1/models", get(handle_models))
.route("/v1/sessions", get(handle_list_sessions)) .route("/v1/sessions", get(handle_list_sessions))
.route("/v1/sessions/{id}", delete(handle_delete_session)) .route("/v1/sessions/{id}", delete(handle_delete_session))
@@ -106,9 +105,7 @@ async fn handle_models() -> Json<serde_json::Value> {
Json(serde_json::json!({"object": "list", "data": models})) Json(serde_json::json!({"object": "list", "data": models}))
} }
async fn handle_list_sessions( async fn handle_list_sessions(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let sessions = state.sessions.list_sessions().await; let sessions = state.sessions.list_sessions().await;
Json(serde_json::json!({"sessions": sessions})) Json(serde_json::json!({"sessions": sessions}))
} }
@@ -155,9 +152,7 @@ async fn handle_set_token(
) )
} }
async fn handle_usage( async fn handle_usage(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let stats = state.mitm_store.stats().await; let stats = state.mitm_store.stats().await;
Json(serde_json::json!({ Json(serde_json::json!({
"mitm": { "mitm": {
@@ -174,9 +169,7 @@ async fn handle_usage(
})) }))
} }
async fn handle_quota( async fn handle_quota(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
let snap = state.quota_store.snapshot().await; let snap = state.quota_store.snapshot().await;
Json(serde_json::to_value(snap).unwrap_or_default()) Json(serde_json::to_value(snap).unwrap_or_default())
} }

View File

@@ -84,14 +84,8 @@ pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUs
return Some(ModelUsage { return Some(ModelUsage {
input_tokens: input, input_tokens: input,
output_tokens: output, output_tokens: output,
api_provider: usage["apiProvider"] api_provider: usage["apiProvider"].as_str().unwrap_or("").to_string(),
.as_str() model: usage["model"].as_str().unwrap_or("").to_string(),
.unwrap_or("")
.to_string(),
model: usage["model"]
.as_str()
.unwrap_or("")
.to_string(),
}); });
} }
} }
@@ -263,23 +257,36 @@ pub(crate) async fn poll_for_response(
} else { } else {
info!( info!(
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}", "Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
elapsed, text.len(), elapsed,
thinking.as_ref().map_or(String::new(), |t| format!(", thinking: {} chars", t.len())), text.len(),
if thinking_signature.is_some() { ", has sig" } else { "" } thinking.as_ref().map_or(String::new(), |t| format!(
", thinking: {} chars",
t.len()
)),
if thinking_signature.is_some() {
", has sig"
} else {
""
}
); );
} }
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; return PollResult {
text,
usage,
thinking_signature,
thinking,
thinking_duration,
upstream_error: None,
};
} }
} }
// Fallback: check trajectory IDLE status (catches edge cases) // Fallback: check trajectory IDLE status (catches edge cases)
// Only check every 5th poll to reduce network calls // Only check every 5th poll to reduce network calls
if step_count > 4 && step_count % 5 == 0 { if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await {
{
if ts == 200 { if ts == 200 {
let run_status = let run_status = td["status"].as_str().unwrap_or("");
td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") { if run_status.contains("IDLE") {
let text = extract_response_text(steps); let text = extract_response_text(steps);
if !text.is_empty() { if !text.is_empty() {
@@ -293,7 +300,14 @@ pub(crate) async fn poll_for_response(
elapsed, elapsed,
text.len() text.len()
); );
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None }; return PollResult {
text,
usage,
thinking_signature,
thinking,
thinking_duration,
upstream_error: None,
};
} }
} }
} }

View File

@@ -14,12 +14,15 @@ use std::sync::Arc;
use tracing::{debug, info}; use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content}; use super::polling::{
extract_model_usage, extract_response_text, extract_thinking_content,
extract_thinking_signature, is_response_done, poll_for_response,
};
use super::types::*; use super::types::*;
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event}; use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
use super::AppState; use super::AppState;
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
use crate::mitm::store::PendingToolResult; use crate::mitm::store::PendingToolResult;
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
// ─── Input extraction ──────────────────────────────────────────────────────── // ─── Input extraction ────────────────────────────────────────────────────────
@@ -35,7 +38,11 @@ struct ToolResultInput {
fn extract_responses_input( fn extract_responses_input(
input: &serde_json::Value, input: &serde_json::Value,
instructions: Option<&str>, instructions: Option<&str>,
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) { ) -> (
String,
Vec<ToolResultInput>,
Option<crate::proto::ImageData>,
) {
let mut tool_results: Vec<ToolResultInput> = Vec::new(); let mut tool_results: Vec<ToolResultInput> = Vec::new();
let mut image: Option<crate::proto::ImageData> = None; let mut image: Option<crate::proto::ImageData> = None;
@@ -45,10 +52,9 @@ fn extract_responses_input(
// Check for function_call_output items // Check for function_call_output items
for item in items { for item in items {
if item["type"].as_str() == Some("function_call_output") { if item["type"].as_str() == Some("function_call_output") {
if let (Some(call_id), Some(output)) = ( if let (Some(call_id), Some(output)) =
item["call_id"].as_str(), (item["call_id"].as_str(), item["output"].as_str())
item["output"].as_str(), {
) {
tool_results.push(ToolResultInput { tool_results.push(ToolResultInput {
call_id: call_id.to_string(), call_id: call_id.to_string(),
output: output.to_string(), output: output.to_string(),
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
); );
} }
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref()); let (user_text, tool_results, image) =
extract_responses_input(&body.input, body.instructions.as_deref());
// Handle tool result submission (function_call_output in input) // Handle tool result submission (function_call_output in input)
let is_tool_result_turn = !tool_results.is_empty(); let is_tool_result_turn = !tool_results.is_empty();
if is_tool_result_turn { if is_tool_result_turn {
for tr in &tool_results { for tr in &tool_results {
// Look up function name from call_id // Look up function name from call_id
let name = state.mitm_store.lookup_call_id(&tr.call_id).await let name = state
.mitm_store
.lookup_call_id(&tr.call_id)
.await
.unwrap_or_else(|| "unknown_function".to_string()); .unwrap_or_else(|| "unknown_function".to_string());
// Parse the output as JSON, fall back to string wrapper // Parse the output as JSON, fall back to string wrapper
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output) let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
.unwrap_or_else(|_| serde_json::json!({"result": tr.output})); .unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
state.mitm_store.add_tool_result(PendingToolResult { state
name, .mitm_store
result: result_value, .add_tool_result(PendingToolResult {
}).await; name,
result: result_value,
})
.await;
} }
info!( info!(
count = tool_results.len(), count = tool_results.len(),
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
let gemini_tools = openai_tools_to_gemini(tools); let gemini_tools = openai_tools_to_gemini(tools);
if !gemini_tools.is_empty() { if !gemini_tools.is_empty() {
state.mitm_store.set_tools(gemini_tools).await; state.mitm_store.set_tools(gemini_tools).await;
info!(count = tools.len(), "Stored client tools for MITM injection"); info!(
count = tools.len(),
"Stored client tools for MITM injection"
);
} }
} }
if let Some(ref choice) = body.tool_choice { if let Some(ref choice) = body.tool_choice {
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text"); let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
if fmt_type == "json_schema" { if fmt_type == "json_schema" {
let name = text_val["format"]["name"].as_str().map(|s| s.to_string()); let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone())); let schema = text_val["format"]["schema"]
.as_object()
.map(|o| serde_json::Value::Object(o.clone()));
let strict = text_val["format"]["strict"].as_bool(); let strict = text_val["format"]["strict"].as_bool();
let tf = TextFormat { let tf = TextFormat {
format: TextFormatInner { format: TextFormatInner {
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
response_schema, response_schema,
google_search: has_web_search, google_search: has_web_search,
}; };
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() if gp.temperature.is_some()
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some() || gp.top_p.is_some()
|| gp.response_schema.is_some() || gp.google_search || gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; state.mitm_store.set_generation_params(gp).await;
} else { } else {
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
} }
} }
let response_id = format!( let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
"resp_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
// Session/conversation management // Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation); let session_id_str = extract_conversation_id(&body.conversation);
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
// Store image for MITM injection (LS doesn't forward images to Google API) // Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image { if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state.mitm_store.set_pending_image( state
crate::mitm::store::PendingImage { .mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
} })
).await; .await;
} }
match state match state
.backend .backend
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})), metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
max_tool_calls: body.max_tool_calls, max_tool_calls: body.max_tool_calls,
reasoning_effort: body.reasoning_effort.clone(), reasoning_effort: body.reasoning_effort.clone(),
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")), tool_choice: body
.tool_choice
.clone()
.unwrap_or(serde_json::json!("auto")),
tools: body.tools.clone().unwrap_or_default(), tools: body.tools.clone().unwrap_or_default(),
text_format, text_format,
}; };
if body.stream { if body.stream {
handle_responses_stream( handle_responses_stream(
state, response_id, model_name.to_string(), cascade_id, state,
body.timeout, req_params, response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
) )
.await .await
} else { } else {
handle_responses_sync( handle_responses_sync(
state, response_id, model_name.to_string(), cascade_id, state,
body.timeout, req_params, response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
) )
.await .await
} }
@@ -485,7 +516,9 @@ async fn usage_from_poll(
if let Some(u) = mitm_store.peek_usage(key).await { if let Some(u) = mitm_store.peek_usage(key).await {
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() { if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
// Call 2 hasn't arrived yet — wait briefly for the merge // Call 2 hasn't arrived yet — wait briefly for the merge
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge..."); tracing::debug!(
"MITM: thinking tokens found but no text, waiting for summary merge..."
);
for _ in 0..10 { for _ in 0..10 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if let Some(u2) = mitm_store.peek_usage(key).await { if let Some(u2) = mitm_store.peek_usage(key).await {
@@ -526,13 +559,18 @@ async fn usage_from_poll(
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps) // Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
if let Some(u) = model_usage { if let Some(u) = model_usage {
return (Usage { return (
input_tokens: u.input_tokens, Usage {
input_tokens_details: InputTokensDetails { cached_tokens: 0 }, input_tokens: u.input_tokens,
output_tokens: u.output_tokens, input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 }, output_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens, output_tokens_details: OutputTokensDetails {
}, None); reasoning_tokens: 0,
},
total_tokens: u.input_tokens + u.output_tokens,
},
None,
);
} }
// Priority 3: Estimate from text lengths // Priority 3: Estimate from text lengths
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
"call_{}", "call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
); );
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments)); output_items
.push(build_function_call_output(&call_id, &fc.name, &arguments));
} }
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &state.mitm_store,
&params.user_text, "", &cascade_id,
).await; &None,
&params.user_text,
"",
)
.await;
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
// Check for completed text response // Check for completed text response
if state.mitm_store.is_response_complete() { if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default(); let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await; let thinking = state.mitm_store.take_thinking_text().await;
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &state.mitm_store,
&params.user_text, &text, &cascade_id,
).await; &None,
&params.user_text,
&text,
)
.await;
let mut output_items: Vec<serde_json::Value> = Vec::new(); let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = thinking { if let Some(ref t) = thinking {
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
return upstream_err_response(err); return upstream_err_response(err);
} }
let completed_at = now_unix(); let completed_at = now_unix();
let msg_id = format!( let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
"msg_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
// Check for captured function calls from MITM (clears the active flag) // Check for captured function calls from MITM (clears the active flag)
let captured_tool_calls = state.mitm_store.take_any_function_calls().await; let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
); );
// Register call_id → name mapping for tool result routing // Register call_id → name mapping for tool result routing
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
// Stringify args (OpenAI sends arguments as JSON string) // Stringify args (OpenAI sends arguments as JSON string)
let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
} }
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &poll_result.usage, &state.mitm_store,
&params.user_text, &poll_result.text, &cascade_id,
).await; &poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
} }
// Normal text response (no tool calls) // Normal text response (no tool calls)
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await; let (usage, mitm_thinking) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps) // Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
let thinking_text = mitm_thinking.or(poll_result.thinking); let thinking_text = mitm_thinking.or(poll_result.thinking);
@@ -1560,4 +1625,3 @@ fn completion_events(
events events
} }

View File

@@ -126,7 +126,9 @@ pub(crate) struct CompletionRequest {
pub web_search: bool, pub web_search: bool,
} }
fn default_n() -> u32 { 1 } fn default_n() -> u32 {
1
}
/// Stop sequence can be a single string or array of strings (OpenAI accepts both). /// Stop sequence can be a single string or array of strings (OpenAI accepts both).
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone)]
@@ -254,8 +256,7 @@ pub(crate) struct OutputTokensDetails {
pub reasoning_tokens: u64, pub reasoning_tokens: u64,
} }
#[derive(Serialize, Clone)] #[derive(Serialize, Clone, Default)]
#[derive(Default)]
pub(crate) struct Reasoning { pub(crate) struct Reasoning {
pub effort: Option<String>, pub effort: Option<String>,
pub summary: Option<String>, pub summary: Option<String>,
@@ -313,7 +314,6 @@ impl Default for Usage {
} }
} }
impl Default for TextFormat { impl Default for TextFormat {
fn default() -> Self { fn default() -> Self {
Self { Self {

View File

@@ -27,7 +27,9 @@ pub(crate) fn err_response(
/// Convert a MITM-captured upstream error from Google into an HTTP response. /// Convert a MITM-captured upstream error from Google into an HTTP response.
/// Maps Google's HTTP status codes and preserves the error message. /// Maps Google's HTTP status codes and preserves the error message.
pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response { pub(crate) fn upstream_err_response(
err: &crate::mitm::store::UpstreamError,
) -> axum::response::Response {
// Map Google's status code to HTTP status // Map Google's status code to HTTP status
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY); let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY);
@@ -41,7 +43,9 @@ pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) ->
_ => "upstream_error", _ => "upstream_error",
}; };
let message = err.message.clone() let message = err
.message
.clone()
.unwrap_or_else(|| format!("Google API returned HTTP {}", err.status)); .unwrap_or_else(|| format!("Google API returned HTTP {}", err.status));
err_response(status, message, error_type) err_response(status, message, error_type)
@@ -99,7 +103,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
} }
// OpenAI Responses API format // OpenAI Responses API format
"input_image" => { "input_image" => {
let url = item["image_url"].as_str() let url = item["image_url"]
.as_str()
.or_else(|| item["url"].as_str())?; .or_else(|| item["url"].as_str())?;
parse_data_uri(url) parse_data_uri(url)
} }
@@ -109,5 +114,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
/// Extract the first image from a content array (Value::Array of content parts). /// Extract the first image from a content array (Value::Array of content parts).
pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> { pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> {
content.as_array()?.iter().find_map(extract_image_from_content) content
.as_array()?
.iter()
.find_map(extract_image_from_content)
} }

View File

@@ -48,10 +48,7 @@ static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
*CHROME_MAJOR, *CHROME_MAJOR,
)), )),
); );
h.insert( h.insert(HeaderName::from_static("sec-ch-ua-mobile"), hv("?0"));
HeaderName::from_static("sec-ch-ua-mobile"),
hv("?0"),
);
h.insert( h.insert(
HeaderName::from_static("sec-ch-ua-platform"), HeaderName::from_static("sec-ch-ua-platform"),
hv("\"Linux\""), hv("\"Linux\""),
@@ -72,7 +69,7 @@ impl Backend {
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint // wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
let client = wreq::Client::builder() let client = wreq::Client::builder()
.emulation(wreq_util::Emulation::Chrome142) .emulation(wreq_util::Emulation::Chrome142)
.cert_verification(false) // LS uses self-signed cert .cert_verification(false) // LS uses self-signed cert
.verify_hostname(false) .verify_hostname(false)
.build() .build()
.map_err(|e| format!("wreq client build failed: {e}"))?; .map_err(|e| format!("wreq client build failed: {e}"))?;
@@ -86,11 +83,7 @@ impl Backend {
/// Create a Backend with known connection details (for standalone LS). /// Create a Backend with known connection details (for standalone LS).
/// ///
/// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token. /// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token.
pub fn new_with_config( pub fn new_with_config(port: u16, csrf: String, oauth_token: String) -> Result<Self, String> {
port: u16,
csrf: String,
oauth_token: String,
) -> Result<Self, String> {
let inner = BackendInner { let inner = BackendInner {
pid: "standalone".to_string(), pid: "standalone".to_string(),
csrf, csrf,
@@ -212,10 +205,7 @@ impl Backend {
fn common_headers(csrf: &str) -> HeaderMap { fn common_headers(csrf: &str) -> HeaderMap {
let mut h = STATIC_HEADERS.clone(); let mut h = STATIC_HEADERS.clone();
if let Ok(val) = HeaderValue::from_str(csrf) { if let Ok(val) = HeaderValue::from_str(csrf) {
h.insert( h.insert(HeaderName::from_static("x-codeium-csrf-token"), val);
HeaderName::from_static("x-codeium-csrf-token"),
val,
);
} else { } else {
warn!("CSRF token contains invalid header characters, omitting"); warn!("CSRF token contains invalid header characters, omitting");
} }
@@ -239,8 +229,8 @@ impl Backend {
let mut headers = Self::common_headers(&csrf); let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/json")); headers.insert("Content-Type", HeaderValue::from_static("application/json"));
let body_bytes = serde_json::to_vec(body) let body_bytes =
.map_err(|e| format!("JSON serialize error: {e}"))?; serde_json::to_vec(body).map_err(|e| format!("JSON serialize error: {e}"))?;
let resp = self let resp = self
.client .client
@@ -258,7 +248,9 @@ impl Backend {
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
let raw = resp.bytes().await let raw = resp
.bytes()
.await
.map_err(|e| format!("Read body error: {e}"))?; .map_err(|e| format!("Read body error: {e}"))?;
let resp_bytes = decompress(method, &raw, &encoding); let resp_bytes = decompress(method, &raw, &encoding);
// High-frequency polling methods → trace; everything else → debug // High-frequency polling methods → trace; everything else → debug
@@ -288,11 +280,7 @@ impl Backend {
} }
/// Call a binary protobuf RPC method. /// Call a binary protobuf RPC method.
pub async fn call_proto( pub async fn call_proto(&self, method: &str, body: Vec<u8>) -> Result<(u16, Vec<u8>), String> {
&self,
method: &str,
body: Vec<u8>,
) -> Result<(u16, Vec<u8>), String> {
let (base, csrf) = { let (base, csrf) = {
let guard = self.inner.read().await; let guard = self.inner.read().await;
( (
@@ -302,7 +290,10 @@ impl Backend {
}; };
let url = format!("{base}/{LS_SERVICE}/{method}"); let url = format!("{base}/{LS_SERVICE}/{method}");
let mut headers = Self::common_headers(&csrf); let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/proto")); headers.insert(
"Content-Type",
HeaderValue::from_static("application/proto"),
);
let resp = self let resp = self
.client .client
@@ -350,7 +341,8 @@ impl Backend {
text: &str, text: &str,
model_enum: u32, model_enum: u32,
) -> Result<(u16, Vec<u8>), String> { ) -> Result<(u16, Vec<u8>), String> {
self.send_message_with_image(cascade_id, text, model_enum, None).await self.send_message_with_image(cascade_id, text, model_enum, None)
.await
} }
/// SendUserCascadeMessage with optional image attachment. /// SendUserCascadeMessage with optional image attachment.
@@ -365,7 +357,8 @@ impl Backend {
if token.is_empty() { if token.is_empty() {
return Err("No OAuth token available".to_string()); return Err("No OAuth token available".to_string());
} }
let proto = crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image); let proto =
crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
if image.is_some() { if image.is_some() {
tracing::info!( tracing::info!(
proto_size = proto.len(), proto_size = proto.len(),
@@ -376,10 +369,7 @@ impl Backend {
} }
/// GetCascadeTrajectorySteps → JSON with steps array. /// GetCascadeTrajectorySteps → JSON with steps array.
pub async fn get_steps( pub async fn get_steps(&self, cascade_id: &str) -> Result<(u16, serde_json::Value), String> {
&self,
cascade_id: &str,
) -> Result<(u16, serde_json::Value), String> {
let body = serde_json::json!({"cascadeId": cascade_id}); let body = serde_json::json!({"cascadeId": cascade_id});
self.call_json("GetCascadeTrajectorySteps", &body).await self.call_json("GetCascadeTrajectorySteps", &body).await
} }
@@ -415,7 +405,10 @@ impl Backend {
}); });
let mut headers = Self::common_headers(&csrf); let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/connect+json")); headers.insert(
"Content-Type",
HeaderValue::from_static("application/connect+json"),
);
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1")); headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
// Connect protocol envelope: [flags:1][length:4][payload] // Connect protocol envelope: [flags:1][length:4][payload]
@@ -441,7 +434,8 @@ impl Backend {
return Err(format!("{rpc_method} failed: {status}{err_text}")); return Err(format!("{rpc_method} failed: {status}{err_text}"));
} }
let resp_ct = resp.headers() let resp_ct = resp
.headers()
.get("content-type") .get("content-type")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or("unknown") .unwrap_or("unknown")
@@ -495,7 +489,8 @@ impl Backend {
&self, &self,
cascade_id: &str, cascade_id: &str,
) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> { ) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> {
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id).await self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id)
.await
} }
} }
@@ -506,7 +501,10 @@ fn discover() -> Result<BackendInner, String> {
// the wrapper is a shell script named language_server_linux_x64, while // the wrapper is a shell script named language_server_linux_x64, while
// the real binary is language_server_linux_x64.real) // the real binary is language_server_linux_x64.real)
let pid_output = Command::new("sh") let pid_output = Command::new("sh")
.args(["-c", "pgrep -f 'language_server_linux_x64\\.real' | head -1"]) .args([
"-c",
"pgrep -f 'language_server_linux_x64\\.real' | head -1",
])
.output() .output()
.map_err(|e| format!("pgrep failed: {e}"))?; .map_err(|e| format!("pgrep failed: {e}"))?;
@@ -564,9 +562,8 @@ fn discover() -> Result<BackendInner, String> {
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap()); LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
for d in &dirs { for d in &dirs {
let log_path = format!( let log_path =
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log" format!("{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log");
);
if let Ok(contents) = fs::read_to_string(&log_path) { if let Ok(contents) = fs::read_to_string(&log_path) {
for line in contents.lines() { for line in contents.lines() {
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") { if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
@@ -584,10 +581,7 @@ fn discover() -> Result<BackendInner, String> {
if https_port.is_empty() { if https_port.is_empty() {
// Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up) // Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up)
if let Ok(output) = std::process::Command::new("ss") if let Ok(output) = std::process::Command::new("ss").args(["-tlnp"]).output() {
.args(["-tlnp"])
.output()
{
let ss_out = String::from_utf8_lossy(&output.stdout); let ss_out = String::from_utf8_lossy(&output.stdout);
// Find listening ports for this PID — typically the first is HTTPS // Find listening ports for this PID — typically the first is HTTPS
for line in ss_out.lines() { for line in ss_out.lines() {
@@ -653,7 +647,11 @@ fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
Err(e) => { Err(e) => {
if !encoding.is_empty() { if !encoding.is_empty() {
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]); let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview); warn!(
"{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}",
data.len(),
preview
);
} }
data.to_vec() data.to_vec()
} }

View File

@@ -115,9 +115,7 @@ fn detect_versions() -> DetectedVersions {
const FALLBACK_CLIENT: &str = "1.16.5"; const FALLBACK_CLIENT: &str = "1.16.5";
let Some(install_dir) = find_install_dir() else { let Some(install_dir) = find_install_dir() else {
tracing::warn!( tracing::warn!("Could not find Antigravity install — using fallback versions");
"Could not find Antigravity install — using fallback versions"
);
return DetectedVersions { return DetectedVersions {
antigravity: FALLBACK_ANTIGRAVITY.to_string(), antigravity: FALLBACK_ANTIGRAVITY.to_string(),
chrome: FALLBACK_CHROME.to_string(), chrome: FALLBACK_CHROME.to_string(),

View File

@@ -24,7 +24,10 @@ use tracing::{info, warn};
use mitm::store::MitmStore; use mitm::store::MitmStore;
#[derive(Parser)] #[derive(Parser)]
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")] #[command(
name = "antigravity-proxy",
about = "Antigravity OpenAI Proxy (stealth)"
)]
struct Cli { struct Cli {
/// Port to listen on /// Port to listen on
#[arg(long, default_value_t = 8741)] #[arg(long, default_value_t = 8741)]
@@ -93,15 +96,12 @@ async fn main() {
}; };
let filter = if log_level.is_empty() { let filter = if log_level.is_empty() {
tracing_subscriber::EnvFilter::try_from_default_env() tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into())
.unwrap_or_else(|_| "warn".into())
} else { } else {
tracing_subscriber::EnvFilter::new(log_level) tracing_subscriber::EnvFilter::new(log_level)
}; };
tracing_subscriber::fmt() tracing_subscriber::fmt().with_env_filter(filter).init();
.with_env_filter(filter)
.init();
// ── Step 1: Bind main port (auto-kill stale process if needed) ───────────── // ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
let addr = format!("127.0.0.1:{}", cli.port); let addr = format!("127.0.0.1:{}", cli.port);
@@ -111,7 +111,10 @@ async fn main() {
// Port in use — try to kill whatever's holding it // Port in use — try to kill whatever's holding it
eprintln!(" Port {} in use, killing stale process...", cli.port); eprintln!(" Port {} in use, killing stale process...", cli.port);
let _ = std::process::Command::new("sh") let _ = std::process::Command::new("sh")
.args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)]) .args([
"-c",
&format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port),
])
.status(); .status();
// Also kill any leftover standalone LS processes // Also kill any leftover standalone LS processes
let _ = std::process::Command::new("pkill") let _ = std::process::Command::new("pkill")
@@ -180,7 +183,9 @@ async fn main() {
Ok(c) => c, Ok(c) => c,
Err(e) => { Err(e) => {
eprintln!("Fatal: {e}"); eprintln!("Fatal: {e}");
eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode"); eprintln!(
"Hint: start Antigravity first, or remove --classic to use headless mode"
);
std::process::exit(1); std::process::exit(1);
} }
} }
@@ -199,13 +204,14 @@ async fn main() {
None None
}; };
let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) { let mut ls =
Ok(ls) => ls, match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
Err(e) => { Ok(ls) => ls,
eprintln!("Fatal: failed to spawn standalone LS: {e}"); Err(e) => {
std::process::exit(1); eprintln!("Fatal: failed to spawn standalone LS: {e}");
} std::process::exit(1);
}; }
};
// Wait for it to be ready // Wait for it to be ready
let rt_ls_port = ls.port; let rt_ls_port = ls.port;
let rt_ls_csrf = ls.csrf.clone(); let rt_ls_csrf = ls.csrf.clone();
@@ -294,7 +300,15 @@ async fn main() {
// ── Step 5: Start serving ───────────────────────────────────────────────── // ── Step 5: Start serving ─────────────────────────────────────────────────
let app = api::router(state.clone()); let app = api::router(state.clone());
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual, is_standalone); print_banner(
cli.port,
&pid,
&https_port,
&csrf,
&token,
&mitm_port_actual,
is_standalone,
);
info!("Listening on http://{addr}"); info!("Listening on http://{addr}");
axum::serve(listener, app) axum::serve(listener, app)
@@ -349,7 +363,15 @@ async fn shutdown_signal() {
} }
} }
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>, is_standalone: bool) { fn print_banner(
port: u16,
pid: &str,
https_port: &str,
csrf: &str,
token: &str,
mitm: &Option<(u16, String)>,
is_standalone: bool,
) {
let chrome_major = &*constants::CHROME_MAJOR; let chrome_major = &*constants::CHROME_MAJOR;
let ver = crate::constants::antigravity_version(); let ver = crate::constants::antigravity_version();
@@ -401,7 +423,11 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
println!(); println!();
// Status line // Status line
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" }; let mitm_tag = if mitm.is_some() {
"\x1b[32mmitm\x1b[0m"
} else {
"\x1b[31mmitm\x1b[0m"
};
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}"); println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
println!(); println!();
@@ -421,7 +447,9 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
if token == "NOT SET" { if token == "NOT SET" {
println!(" \x1b[1;33m[!]\x1b[0m no oauth token"); println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx"); println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"); println!(
" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"
);
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token"); println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
println!(); println!();
} }
@@ -476,5 +504,7 @@ fn find_ls_binary_path() -> Option<String> {
/// Get the data directory for storing MITM CA cert/key. /// Get the data directory for storing MITM CA cert/key.
fn dirs_data_dir() -> std::path::PathBuf { fn dirs_data_dir() -> std::path::PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string()); let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy") std::path::PathBuf::from(home)
.join(".config")
.join("antigravity-proxy")
} }

View File

@@ -4,8 +4,8 @@
//! Dynamically generates per-domain leaf certificates signed by this CA. //! Dynamically generates per-domain leaf certificates signed by this CA.
use rcgen::{ use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, IsCa,
IsCa, KeyPair, KeyUsagePurpose, SanType, KeyPair, KeyUsagePurpose, SanType,
}; };
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::collections::HashMap; use std::collections::HashMap;
@@ -45,15 +45,16 @@ impl MitmCa {
let key_pem = std::fs::read_to_string(&key_path) let key_pem = std::fs::read_to_string(&key_path)
.map_err(|e| format!("Failed to read CA key: {e}"))?; .map_err(|e| format!("Failed to read CA key: {e}"))?;
let ca_key = KeyPair::from_pem(&key_pem) let ca_key =
.map_err(|e| format!("Failed to parse CA key: {e}"))?; KeyPair::from_pem(&key_pem).map_err(|e| format!("Failed to parse CA key: {e}"))?;
// Re-create params and self-sign to get the rcgen Certificate object // Re-create params and self-sign to get the rcgen Certificate object
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem). // (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
// The re-signed cert will have a different serial/notBefore, but that's fine // The re-signed cert will have a different serial/notBefore, but that's fine
// because we only use it for the rcgen signing API, NOT for the on-disk PEM. // because we only use it for the rcgen signing API, NOT for the on-disk PEM.
let params = Self::ca_params(); let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key) let ca_signed = params
.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?; .map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts // Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
@@ -76,11 +77,12 @@ impl MitmCa {
std::fs::create_dir_all(data_dir) std::fs::create_dir_all(data_dir)
.map_err(|e| format!("Failed to create data dir: {e}"))?; .map_err(|e| format!("Failed to create data dir: {e}"))?;
let ca_key = KeyPair::generate() let ca_key =
.map_err(|e| format!("Failed to generate CA key: {e}"))?; KeyPair::generate().map_err(|e| format!("Failed to generate CA key: {e}"))?;
let params = Self::ca_params(); let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key) let ca_signed = params
.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?; .map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Write cert and key to disk // Write cert and key to disk
@@ -117,10 +119,7 @@ impl MitmCa {
params.distinguished_name = dn; params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained); params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![ params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
];
// Valid for 10 years // Valid for 10 years
let now = time::OffsetDateTime::now_utc(); let now = time::OffsetDateTime::now_utc();
@@ -151,12 +150,17 @@ impl MitmCa {
return None; return None;
} }
use base64::Engine; use base64::Engine;
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?; let der = base64::engine::general_purpose::STANDARD
.decode(&b64)
.ok()?;
Some(CertificateDer::from(der)) Some(CertificateDer::from(der))
} }
/// Get or create a TLS ServerConfig for the given domain. /// Get or create a TLS ServerConfig for the given domain.
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> { pub async fn server_config_for_domain(
&self,
domain: &str,
) -> Result<Arc<rustls::ServerConfig>, String> {
// Check cache first // Check cache first
{ {
let cache = self.domain_cache.read().await; let cache = self.domain_cache.read().await;
@@ -172,7 +176,11 @@ impl MitmCa {
dn.push(DnType::CommonName, domain); dn.push(DnType::CommonName, domain);
params.distinguished_name = dn; params.distinguished_name = dn;
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)]; params.subject_alt_names = vec![SanType::DnsName(
domain
.try_into()
.map_err(|e| format!("Invalid domain: {e}"))?,
)];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth]; params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![ params.key_usages = vec![
KeyUsagePurpose::DigitalSignature, KeyUsagePurpose::DigitalSignature,
@@ -184,10 +192,11 @@ impl MitmCa {
params.not_before = now; params.not_before = now;
params.not_after = now + time::Duration::days(365); params.not_after = now + time::Duration::days(365);
let leaf_key = KeyPair::generate() let leaf_key =
.map_err(|e| format!("Failed to generate leaf key: {e}"))?; KeyPair::generate().map_err(|e| format!("Failed to generate leaf key: {e}"))?;
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key) let leaf_cert = params
.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?; .map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
// Build rustls ServerConfig // Build rustls ServerConfig
@@ -196,10 +205,7 @@ impl MitmCa {
let mut config = rustls::ServerConfig::builder() let mut config = rustls::ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert( .with_single_cert(vec![leaf_cert_der, self.ca_cert_der.clone()], leaf_key_der)
vec![leaf_cert_der, self.ca_cert_der.clone()],
leaf_key_der,
)
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?; .map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2 // Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2

View File

@@ -92,11 +92,10 @@ impl UpstreamPool {
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?; .map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
let upstream_io = TokioIo::new(upstream_tls); let upstream_io = TokioIo::new(upstream_tls);
let (sender, conn) = let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
hyper::client::conn::http2::Builder::new(TokioExecutor::new()) .handshake(upstream_io)
.handshake(upstream_io) .await
.await .map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let domain = self.domain.clone(); let domain = self.domain.clone();
tokio::spawn(async move { tokio::spawn(async move {
@@ -215,12 +214,10 @@ async fn handle_h2_request(
.unwrap_or(false); .unwrap_or(false);
// Check if this method carries usage data // Check if this method carries usage data
let is_usage_method = is_grpc let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m));
&& USAGE_METHODS.iter().any(|m| path.contains(m));
// Check if this is a streaming method // Check if this is a streaming method
let is_streaming = is_grpc let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream"));
&& (path.contains("Stream") || path.contains("stream"));
debug!( debug!(
domain, domain,
@@ -249,9 +246,9 @@ async fn handle_h2_request(
warn!(error = %e, domain, "MITM H2: upstream connect failed"); warn!(error = %e, domain, "MITM H2: upstream connect failed");
let resp = Response::builder() let resp = Response::builder()
.status(502) .status(502)
.body(http_body_util::Either::Left(Full::new( .body(http_body_util::Either::Left(Full::new(Bytes::from(
Bytes::from(format!("upstream connect failed: {e}")), format!("upstream connect failed: {e}"),
))) ))))
.unwrap(); .unwrap();
return Ok(resp); return Ok(resp);
} }
@@ -261,17 +258,11 @@ async fn handle_h2_request(
let upstream_uri = http::Uri::builder() let upstream_uri = http::Uri::builder()
.scheme("https") .scheme("https")
.authority(domain) .authority(domain)
.path_and_query( .path_and_query(uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"))
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/"),
)
.build() .build()
.unwrap_or(uri); .unwrap_or(uri);
let mut upstream_req = Request::builder() let mut upstream_req = Request::builder().method(parts.method).uri(upstream_uri);
.method(parts.method)
.uri(upstream_uri);
// Copy headers, skip hop-by-hop // Copy headers, skip hop-by-hop
for (name, value) in &parts.headers { for (name, value) in &parts.headers {
@@ -287,9 +278,9 @@ async fn handle_h2_request(
Err(e) => { Err(e) => {
let resp = Response::builder() let resp = Response::builder()
.status(502) .status(502)
.body(http_body_util::Either::Left(Full::new( .body(http_body_util::Either::Left(Full::new(Bytes::from(
Bytes::from(format!("build request failed: {e}")), format!("build request failed: {e}"),
))) ))))
.unwrap(); .unwrap();
return Ok(resp); return Ok(resp);
} }
@@ -302,9 +293,9 @@ async fn handle_h2_request(
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed"); warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
let resp = Response::builder() let resp = Response::builder()
.status(502) .status(502)
.body(http_body_util::Either::Left(Full::new( .body(http_body_util::Either::Left(Full::new(Bytes::from(
Bytes::from(format!("upstream request failed: {e}")), format!("upstream request failed: {e}"),
))) ))))
.unwrap(); .unwrap();
return Ok(resp); return Ok(resp);
} }
@@ -326,13 +317,18 @@ async fn handle_h2_request(
// Spawn a task to forward body chunks and tee for usage extraction // Spawn a task to forward body chunks and tee for usage extraction
tokio::spawn(async move { tokio::spawn(async move {
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None }; let mut tee_buffer = if should_track_usage {
Some(Vec::new())
} else {
None
};
let mut body = resp_body; let mut body = resp_body;
loop { loop {
match body.frame().await { match body.frame().await {
Some(Ok(frame)) => { Some(Ok(frame)) => {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) { if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref())
{
buf.extend_from_slice(data); buf.extend_from_slice(data);
} }
if tx.send(Ok(frame)).await.is_err() { if tx.send(Ok(frame)).await.is_err() {
@@ -354,7 +350,9 @@ async fn handle_h2_request(
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) { if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
let usage = grpc_usage.into_api_usage(path_clone.clone()); let usage = grpc_usage.into_api_usage(path_clone.clone());
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone); let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
store_clone.record_usage(cascade_hint.as_deref(), usage).await; store_clone
.record_usage(cascade_hint.as_deref(), usage)
.await;
} }
} }
} }

View File

@@ -78,15 +78,21 @@ impl StreamingAccumulator {
Self::default() Self::default()
} }
/// Process a single SSE event. /// Process a single SSE event.
pub fn process_event(&mut self, event: &Value) { pub fn process_event(&mut self, event: &Value) {
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ── // ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
if let Some(response) = event.get("response") { if let Some(response) = event.get("response") {
// Extract usage metadata (each event has cumulative counts) // Extract usage metadata (each event has cumulative counts)
if let Some(usage) = response.get("usageMetadata") { if let Some(usage) = response.get("usageMetadata") {
self.input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(self.input_tokens); self.input_tokens = usage["promptTokenCount"]
self.output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(self.output_tokens); .as_u64()
self.thinking_tokens = usage["thoughtsTokenCount"].as_u64().unwrap_or(self.thinking_tokens); .unwrap_or(self.input_tokens);
self.output_tokens = usage["candidatesTokenCount"]
.as_u64()
.unwrap_or(self.output_tokens);
self.thinking_tokens = usage["thoughtsTokenCount"]
.as_u64()
.unwrap_or(self.thinking_tokens);
} }
if let Some(model) = response["modelVersion"].as_str() { if let Some(model) = response["modelVersion"].as_str() {
self.model = Some(model.to_string()); self.model = Some(model.to_string());
@@ -170,8 +176,10 @@ impl StreamingAccumulator {
"message_start" => { "message_start" => {
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) { if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0); self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0); self.cache_creation_input_tokens =
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0); usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
self.cache_read_input_tokens =
usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
} }
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) { if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
self.model = Some(model.to_string()); self.model = Some(model.to_string());
@@ -181,7 +189,9 @@ impl StreamingAccumulator {
} }
"message_delta" => { "message_delta" => {
if let Some(usage) = event.get("usage") { if let Some(usage) = event.get("usage") {
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens); self.output_tokens = usage["output_tokens"]
.as_u64()
.unwrap_or(self.output_tokens);
} }
if let Some(reason) = event["delta"]["stop_reason"].as_str() { if let Some(reason) = event["delta"]["stop_reason"].as_str() {
self.stop_reason = Some(reason.to_string()); self.stop_reason = Some(reason.to_string());
@@ -235,7 +245,10 @@ impl StreamingAccumulator {
response_output_tokens: 0, response_output_tokens: 0,
model: self.model, model: self.model,
stop_reason: self.stop_reason, stop_reason: self.stop_reason,
api_provider: self.api_provider.unwrap_or_else(|| "unknown".to_string()).into(), api_provider: self
.api_provider
.unwrap_or_else(|| "unknown".to_string())
.into(),
grpc_method: None, grpc_method: None,
captured_at: std::time::SystemTime::now() captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)

View File

@@ -68,14 +68,14 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
"system instruction: keep <identity> only ({original_len}{} chars, -{stripped})", "system instruction: keep <identity> only ({original_len}{} chars, -{stripped})",
new_sys.len() new_sys.len()
)); ));
json["request"]["systemInstruction"]["parts"][0]["text"] = json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(new_sys);
Value::String(new_sys);
} }
} else { } else {
// No identity tag found — clear the whole thing // No identity tag found — clear the whole thing
changes.push(format!("system instruction: cleared ({original_len} chars)")); changes.push(format!(
json["request"]["systemInstruction"]["parts"][0]["text"] = "system instruction: cleared ({original_len} chars)"
Value::String(String::new()); ));
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
} }
} }
@@ -125,7 +125,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
let mut modified = text.clone(); let mut modified = text.clone();
// Strip conversation summaries block // Strip conversation summaries block
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") { if let Some(cleaned) = strip_between(
&modified,
"# Conversation History\n",
"</conversation_summaries>",
) {
modified = cleaned; modified = cleaned;
} }
@@ -147,7 +151,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
} }
// Strip knowledge item blocks // Strip knowledge item blocks
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") { if let Some(cleaned) =
strip_between(&modified, "Here are the ", "</knowledge_item>")
{
// Only strip if it's about knowledge items // Only strip if it's about knowledge items
if cleaned.len() < modified.len() && modified.contains("knowledge item") { if cleaned.len() < modified.len() && modified.contains("knowledge item") {
modified = cleaned; modified = cleaned;
@@ -202,7 +208,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// Inject client-provided tools from ToolContext // Inject client-provided tools from ToolContext
if let Some(ref ctx) = tool_ctx { if let Some(ref ctx) = tool_ctx {
if let Some(ref custom_tools) = ctx.tools { if let Some(ref custom_tools) = ctx.tools {
let total_decls: usize = custom_tools.iter() let total_decls: usize = custom_tools
.iter()
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array())) .filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
.map(|a| a.len()) .map(|a| a.len())
.sum(); .sum();
@@ -210,7 +217,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
tools.push(tool.clone()); tools.push(tool.clone());
} }
has_custom_tools = true; has_custom_tools = true;
changes.push(format!("inject {} custom tool group(s)", custom_tools.len())); changes.push(format!(
"inject {} custom tool group(s)",
custom_tools.len()
));
// Override LS's VALIDATED toolConfig → AUTO for custom tools. // Override LS's VALIDATED toolConfig → AUTO for custom tools.
// VALIDATED mode forces Google to validate function calls against a // VALIDATED mode forces Google to validate function calls against a
@@ -218,16 +228,20 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// that list, so they'd be rejected. AUTO lets the model freely choose // that list, so they'd be rejected. AUTO lets the model freely choose
// between text and function calls. // between text and function calls.
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) { if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
let has_validated = req.get("toolConfig") let has_validated = req
.get("toolConfig")
.and_then(|tc| tc.pointer("/functionCallingConfig/mode")) .and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
.and_then(|m| m.as_str()) .and_then(|m| m.as_str())
.map_or(false, |m| m == "VALIDATED"); .map_or(false, |m| m == "VALIDATED");
if has_validated { if has_validated {
req.insert("toolConfig".to_string(), serde_json::json!({ req.insert(
"functionCallingConfig": { "toolConfig".to_string(),
"mode": "AUTO" serde_json::json!({
} "functionCallingConfig": {
})); "mode": "AUTO"
}
}),
);
changes.push("override toolConfig VALIDATED → AUTO".to_string()); changes.push("override toolConfig VALIDATED → AUTO".to_string());
} }
} }
@@ -243,7 +257,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
if STRIP_ALL_TOOLS && !has_custom_tools { if STRIP_ALL_TOOLS && !has_custom_tools {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) { if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
// Remove the empty tools array entirely // Remove the empty tools array entirely
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) { if req
.get("tools")
.and_then(|v| v.as_array())
.map_or(false, |a| a.is_empty())
{
req.remove("tools"); req.remove("tools");
changes.push("remove empty tools array".to_string()); changes.push("remove empty tools array".to_string());
} }
@@ -266,7 +284,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
.as_ref() .as_ref()
.and_then(|ctx| ctx.tools.as_ref()) .and_then(|ctx| ctx.tools.as_ref())
.map(|tools| { .map(|tools| {
tools.iter() tools
.iter()
.filter_map(|t| t["functionDeclarations"].as_array()) .filter_map(|t| t["functionDeclarations"].as_array())
.flatten() .flatten()
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string())) .filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
@@ -309,7 +328,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
.map_or(true, |parts| !parts.is_empty()) .map_or(true, |parts| !parts.is_empty())
}); });
if stripped_fc > 0 { if stripped_fc > 0 {
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history")); changes.push(format!(
"strip {stripped_fc} functionCall/Response parts from history"
));
} }
} }
} }
@@ -336,16 +357,22 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
for msg in contents.iter_mut() { for msg in contents.iter_mut() {
if msg["role"].as_str() == Some("model") { if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() { if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") { if text.contains("Tool call completed")
|| text.contains("Awaiting external tool result")
{
// Replace with functionCall parts // Replace with functionCall parts
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| { let fc_parts: Vec<Value> = ctx
serde_json::json!({ .last_calls
"functionCall": { .iter()
"name": fc.name, .map(|fc| {
"args": fc.args, serde_json::json!({
} "functionCall": {
"name": fc.name,
"args": fc.args,
}
})
}) })
}).collect(); .collect();
msg["parts"] = Value::Array(fc_parts); msg["parts"] = Value::Array(fc_parts);
changes.push("rewrite model turn with functionCall".to_string()); changes.push("rewrite model turn with functionCall".to_string());
break; break;
@@ -355,29 +382,36 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
} }
// Add functionResponse as a user turn before the last user message // Add functionResponse as a user turn before the last user message
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| { let fn_response_parts: Vec<Value> = ctx
serde_json::json!({ .pending_results
"functionResponse": { .iter()
"name": r.name, .map(|r| {
"response": r.result, serde_json::json!({
} "functionResponse": {
"name": r.name,
"response": r.result,
}
})
}) })
}).collect(); .collect();
let fn_response_turn = serde_json::json!({ let fn_response_turn = serde_json::json!({
"role": "user", "role": "user",
"parts": fn_response_parts, "parts": fn_response_parts,
}); });
// Insert before the last user message // Insert before the last user message
let last_user_idx = contents.iter().rposition(|msg| { let last_user_idx = contents
msg["role"].as_str() == Some("user") .iter()
}); .rposition(|msg| msg["role"].as_str() == Some("user"));
if let Some(idx) = last_user_idx { if let Some(idx) = last_user_idx {
contents.insert(idx, fn_response_turn); contents.insert(idx, fn_response_turn);
} else { } else {
contents.push(fn_response_turn); contents.push(fn_response_turn);
} }
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len())); changes.push(format!(
"inject {} functionResponse(s)",
ctx.pending_results.len()
));
} }
} }
} }
@@ -420,8 +454,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
} else { } else {
// Not wrapped in request — try top-level (public API format) // Not wrapped in request — try top-level (public API format)
let gen_config = json.as_object_mut().and_then(|o| { let gen_config = json.as_object_mut().and_then(|o| {
Some(o.entry("generationConfig") Some(
.or_insert_with(|| serde_json::json!({}))) o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})),
)
}); });
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) { if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
let thinking_config = gc let thinking_config = gc
@@ -449,8 +485,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
if let Some(ref gp) = ctx.generation_params { if let Some(ref gp) = ctx.generation_params {
// Find or create generationConfig (same path as above) // Find or create generationConfig (same path as above)
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) { let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
Some(req.entry("generationConfig") Some(
.or_insert_with(|| serde_json::json!({}))) req.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})),
)
} else { } else {
json.as_object_mut().map(|o| { json.as_object_mut().map(|o| {
o.entry("generationConfig") o.entry("generationConfig")
@@ -564,8 +602,6 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
changes.join(", ") changes.join(", ")
); );
Some(modified_bytes) Some(modified_bytes)
} }
@@ -832,8 +868,10 @@ mod tests {
let result: Value = serde_json::from_slice(&modified).unwrap(); let result: Value = serde_json::from_slice(&modified).unwrap();
// With no ToolContext, tools should be removed entirely // With no ToolContext, tools should be removed entirely
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(), assert!(
"tools should be removed when no custom tools provided"); result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
"tools should be removed when no custom tools provided"
);
} }
#[test] #[test]
@@ -892,13 +930,23 @@ mod tests {
let contents = result["request"]["contents"].as_array().unwrap(); let contents = result["request"]["contents"].as_array().unwrap();
// Should have removed user_information, user_rules, workflows (3 messages) // Should have removed user_information, user_rules, workflows (3 messages)
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response // Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
assert_eq!(contents.len(), 2, "should keep only user request + model response"); assert_eq!(
contents.len(),
2,
"should keep only user request + model response"
);
// Check USER_REQUEST message had metadata stripped // Check USER_REQUEST message had metadata stripped
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap(); let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
assert!(user_msg.contains("Say hello"), "should keep user request"); assert!(user_msg.contains("Say hello"), "should keep user request");
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata"); assert!(
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info"); !user_msg.contains("ADDITIONAL_METADATA"),
"should strip metadata"
);
assert!(
!user_msg.contains("cursor stuff"),
"should strip cursor info"
);
assert!(!user_msg.starts_with("Step Id:"), "should strip step id"); assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
// Model response kept intact // Model response kept intact
@@ -921,8 +969,14 @@ mod tests {
#[test] #[test]
fn test_strip_between() { fn test_strip_between() {
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this"; let text =
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap(); "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
let result = strip_between(
text,
"# Conversation History\n",
"</conversation_summaries>",
)
.unwrap();
assert_eq!(result, "keep this and this"); assert_eq!(result, "keep this and this");
} }
} }
@@ -977,7 +1031,9 @@ pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
// Replace the JSON in the result string // Replace the JSON in the result string
result.replace_range(json_start..json_start + json_end, &new_json); result.replace_range(json_start..json_start + json_end, &new_json);
changed = true; changed = true;
info!("MITM: rewrote functionCall in response → text placeholder for LS"); info!(
"MITM: rewrote functionCall in response → text placeholder for LS"
);
search_from = json_start + new_json.len(); search_from = json_start + new_json.len();
continue; continue;
} }
@@ -1117,7 +1173,10 @@ fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
} }
// Try nested "response.candidates" // Try nested "response.candidates"
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) { if let Some(candidates) = json
.pointer_mut("/response/candidates")
.and_then(|v| v.as_array_mut())
{
changed |= rewrite_candidates(candidates); changed |= rewrite_candidates(candidates);
} }

View File

@@ -251,7 +251,10 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
// (e.g., a long string that happened to have a valid first-field prefix) // (e.g., a long string that happened to have a valid first-field prefix)
if fields.len() == 1 && original_len > 100 { if fields.len() == 1 && original_len > 100 {
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message // Single-field messages of >100 bytes are suspicious unless the field is bytes/message
matches!(&fields[0].value, ProtoValue::Bytes(_) | ProtoValue::Message(_)) matches!(
&fields[0].value,
ProtoValue::Bytes(_) | ProtoValue::Message(_)
)
} else { } else {
true true
} }
@@ -328,7 +331,9 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
.iter() .iter()
.filter_map(|f| { .filter_map(|f| {
if let ProtoValue::Bytes(ref b) = f.value { if let ProtoValue::Bytes(ref b) = f.value {
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string())) std::str::from_utf8(b)
.ok()
.map(|s| (f.number, s.to_string()))
} else { } else {
None None
} }
@@ -361,14 +366,23 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id // Check if there's a model-like string (field 7 = message_id or field 11 = response_id
// can contain model names, or model enum values map to known names) // can contain model names, or model enum values map to known names)
let has_model_string = string_fields.iter().any(|(_, s)| { let has_model_string = string_fields.iter().any(|(_, s)| {
s.contains("claude") || s.contains("gemini") || s.contains("gpt") s.contains("claude")
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus") || s.contains("gemini")
|| s.contains("flash") || s.contains("pro") || s.contains("gpt")
|| s.starts_with("models/")
|| s.contains("sonnet")
|| s.contains("opus")
|| s.contains("flash")
|| s.contains("pro")
}); });
// Check for fields at the known ModelUsageStats field numbers // Check for fields at the known ModelUsageStats field numbers
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_))); let has_field_2 = fields
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_))); .iter()
.any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields
.iter()
.any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
// Strong signal: has both input and output token fields // Strong signal: has both input and output token fields
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string; let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
@@ -392,8 +406,8 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// field 1 = model enum (varint, not string!) // field 1 = model enum (varint, not string!)
2 => usage.input_tokens = v, 2 => usage.input_tokens = v,
3 => usage.output_tokens = v, 3 => usage.output_tokens = v,
4 => usage.cache_write_tokens = v, // VERIFIED: field 4 4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5 5 => usage.cache_read_tokens = v, // VERIFIED: field 5
// field 6 = api_provider enum (varint) // field 6 = api_provider enum (varint)
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9 9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
10 => usage.response_output_tokens = v, // VERIFIED: field 10 10 => usage.response_output_tokens = v, // VERIFIED: field 10
@@ -486,11 +500,11 @@ pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
fn model_enum_name(enum_val: u64) -> &'static str { fn model_enum_name(enum_val: u64) -> &'static str {
match enum_val { match enum_val {
// Placeholder models (1000 + N) // Placeholder models (1000 + N)
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7 1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8 1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12 1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18 1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26 1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
// Claude models (named) // Claude models (named)
281 => "claude-4-sonnet", 281 => "claude-4-sonnet",
@@ -629,13 +643,13 @@ mod tests {
data.push(v as u8); data.push(v as u8);
} }
encode_varint_field(&mut data, 1, 5); // model enum encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens encode_varint_field(&mut data, 10, 200); // response_output_tokens
let fields = decode_proto(&data); let fields = decode_proto(&data);
let usage = try_extract_usage(&fields).expect("should extract usage"); let usage = try_extract_usage(&fields).expect("should extract usage");

View File

@@ -11,8 +11,7 @@
use super::ca::MitmCa; use super::ca::MitmCa;
use super::intercept::{ use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
StreamingAccumulator,
}; };
use super::store::MitmStore; use super::store::MitmStore;
use std::sync::Arc; use std::sync::Arc;
@@ -54,7 +53,6 @@ pub struct MitmConfig {
pub modify_requests: bool, pub modify_requests: bool,
} }
/// Run the MITM proxy server. /// Run the MITM proxy server.
/// ///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown. /// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
@@ -84,7 +82,8 @@ pub async fn run(
let ca = ca.clone(); let ca = ca.clone();
let store = store.clone(); let store = store.clone();
tokio::spawn(async move { tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await { if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
{
warn!(error = %e, "MITM connection error"); warn!(error = %e, "MITM connection error");
} }
}); });
@@ -131,8 +130,7 @@ async fn handle_connection(
.await .await
.map_err(|e| format!("Peek ClientHello: {e}"))?; .map_err(|e| format!("Peek ClientHello: {e}"))?;
let domain = extract_sni(&hello_buf[..n]) let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
.unwrap_or_else(|| "unknown".to_string());
info!(domain, "MITM: transparent redirect (iptables)"); info!(domain, "MITM: transparent redirect (iptables)");
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
let mut pos = 34; // skip version + random let mut pos = 34; // skip version + random
// Session ID // Session ID
if pos >= body.len() { return None; } if pos >= body.len() {
return None;
}
let sid_len = body[pos] as usize; let sid_len = body[pos] as usize;
pos += 1 + sid_len; pos += 1 + sid_len;
// Cipher suites // Cipher suites
if pos + 2 > body.len() { return None; } if pos + 2 > body.len() {
return None;
}
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2 + cs_len; pos += 2 + cs_len;
// Compression methods // Compression methods
if pos >= body.len() { return None; } if pos >= body.len() {
return None;
}
let cm_len = body[pos] as usize; let cm_len = body[pos] as usize;
pos += 1 + cm_len; pos += 1 + cm_len;
// Extensions // Extensions
if pos + 2 > body.len() { return None; } if pos + 2 > body.len() {
return None;
}
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize; let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2; pos += 2;
let ext_end = pos + ext_len.min(body.len() - pos); let ext_end = pos + ext_len.min(body.len() - pos);
@@ -304,32 +310,32 @@ async fn handle_intercepted(
info!(domain, "MITM: intercepting TLS"); info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain // Get or create server TLS config for this domain
let server_config = ca let server_config = ca.server_config_for_domain(domain).await?;
.server_config_for_domain(domain)
.await?;
let acceptor = TlsAcceptor::from(server_config); let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS) — 10s timeout // Perform TLS handshake with the client (LS) — 10s timeout
let tls_stream = match tokio::time::timeout( let tls_stream =
std::time::Duration::from_secs(10), match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
acceptor.accept(stream), .await
) {
.await Ok(Ok(s)) => s,
{ Ok(Err(e)) => {
Ok(Ok(s)) => s, warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
Ok(Err(e)) => { return Err(format!(
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)"); "TLS handshake with client failed for {domain}: {e}"
return Err(format!("TLS handshake with client failed for {domain}: {e}")); ));
} }
Err(_) => { Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s"); warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}")); return Err(format!("TLS handshake timed out for {domain}"));
} }
}; };
// Check negotiated ALPN protocol // Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1 let alpn = tls_stream
.get_ref()
.1
.alpn_protocol() .alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string()); .map(|p| String::from_utf8_lossy(p).to_string());
@@ -339,12 +345,7 @@ async fn handle_intercepted(
Some("h2") => { Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler // HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)"); info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection( super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
tls_stream,
domain.to_string(),
store,
)
.await
} }
_ => { _ => {
// HTTP/1.1 or no ALPN — use the existing handler // HTTP/1.1 or no ALPN — use the existing handler
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
.await .await
{ {
let out = String::from_utf8_lossy(&output.stdout); let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) { if let Some(ip) = out
.lines()
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
{
return format!("{ip}:443"); return format!("{ip}:443");
} }
} }
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
loop { loop {
// ── Read the HTTP request from the client ───────────────────────── // ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64); let mut request_buf = Vec::with_capacity(1024 * 64);
let mut is_our_request = false;
// 60s timeout on initial read (LS may open connection without sending immediately) // 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
} }
// Parse the HTTP request to find headers and body // Parse the HTTP request to find headers and body
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf); let (headers_end, content_length, _is_streaming_request) =
parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body // Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() { let cascade_hint = if headers_end < request_buf.len() {
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
"MITM: forwarding LLM request" "MITM: forwarding LLM request"
); );
// ── Block ALL requests when one is already in-flight ─────────
// The LS opens multiple connections and sends parallel requests.
// When custom tools are active, only the FIRST request should reach
// Google. Block everything else with a fake response.
if store.is_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue;
}
// ── Request modification ───────────────────────────────────── // ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk // Dechunk body → check if agent request → modify → rechunk
if modify_requests && body_len > 0 { if modify_requests && body_len > 0 {
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
let generation_params = store.get_generation_params().await; let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await; let pending_image = store.take_pending_image().await;
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() { let tool_ctx = if tools.is_some()
|| !pending_results.is_empty()
|| generation_params.is_some()
|| pending_image.is_some()
{
Some(super::modify::ToolContext { Some(super::modify::ToolContext {
tools, tools,
tool_config, tool_config,
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
None None
}; };
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) { if let Some(modified_body) =
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
{
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body // Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
let new_chunked = super::modify::rechunk(&modified_body); let new_chunked = super::modify::rechunk(&modified_body);
@@ -588,39 +619,12 @@ async fn handle_http_over_tls(
let mut new_buf = updated_headers.into_bytes(); let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked); new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf; request_buf = new_buf;
// Mark this as our modified request and set in-flight flag // Mark in-flight IMMEDIATELY — blocks all subsequent requests
is_our_request = true;
store.mark_request_in_flight(); store.mark_request_in_flight();
} }
} }
} }
// ── Block ALL LS follow-up requests once first is in-flight ──
// When custom tools are active, we only need ONE request to Google.
// The LS tries to send multiple requests (its own agentic loop +
// internal requests on gemini-2.5-flash-lite). Block them ALL
// immediately — don't wait for response_complete.
let has_tools = store.get_tools().await.is_some();
if has_tools && store.is_request_in_flight() && !is_our_request {
info!(
"MITM: blocking LS follow-up — request already in-flight"
);
// Return a fake SSE response that makes the LS stop
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue; // Skip the real upstream call
}
} else { } else {
debug!( debug!(
domain, domain,
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
}; };
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await { let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
Ok(Ok(0)) => { upstream_ok = false; break; } Ok(Ok(0)) => {
upstream_ok = false;
break;
}
Ok(Ok(n)) => n, Ok(Ok(n)) => n,
Ok(Err(e)) => { Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read ended"); debug!(domain, error = %e, "MITM: upstream read ended");
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
if header.name.eq_ignore_ascii_case("content-type") { if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(v) = std::str::from_utf8(header.value) { if let Ok(v) = std::str::from_utf8(header.value) {
content_type = v.to_string(); content_type = v.to_string();
if v.contains("text/event-stream") { is_streaming_response = true; } if v.contains("text/event-stream") {
is_streaming_response = true;
}
} }
} }
if header.name.eq_ignore_ascii_case("content-length") { if header.name.eq_ignore_ascii_case("content-length") {
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
} }
if header.name.eq_ignore_ascii_case("connection") { if header.name.eq_ignore_ascii_case("connection") {
if let Ok(v) = std::str::from_utf8(header.value) { if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; } if v.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
} }
} }
if header.name.eq_ignore_ascii_case("transfer-encoding") { if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(v) = std::str::from_utf8(header.value) { if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; } if v.trim().eq_ignore_ascii_case("chunked") {
is_chunked = true;
}
} }
} }
} }
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response"); warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}} // Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str) let (message, error_status) =
.ok() serde_json::from_str::<serde_json::Value>(&body_str)
.and_then(|v| { .ok()
let err = v.get("error")?; .and_then(|v| {
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string()); let err = v.get("error")?;
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string()); let msg = err
Some((msg, status)) .get("message")
}) .and_then(|m| m.as_str())
.unwrap_or((None, None)); .map(|s| s.to_string());
let status = err
.get("status")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
store.set_upstream_error(super::store::UpstreamError { store
status: http_status, .set_upstream_error(super::store::UpstreamError {
body: body_str, status: http_status,
message, body: body_str,
error_status, message,
}).await; error_status,
})
.await;
} }
// Save body for usage parsing // Save body for usage parsing
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() { if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls { for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
} }
store.set_last_function_calls(calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len()); info!(
"MITM: stored {} function call(s) from initial body",
calls.len()
);
} }
// Capture response + thinking text + grounding into MitmStore // Capture response + thinking text + grounding into MitmStore
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
} }
if let Some(cl) = response_content_length { if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; } if response_body_buf.len() >= cl {
break;
}
} }
// Check chunked terminator in initial body // Check chunked terminator in initial body
if is_chunked && has_chunked_terminator(&response_body_buf) { if is_chunked && has_chunked_terminator(&response_body_buf) {
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() { if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls { for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await; store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
} }
store.set_last_function_calls(calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len()); info!(
"MITM: stored {} function call(s) from body chunk",
calls.len()
);
} }
// Capture response + thinking text + grounding into MitmStore // Capture response + thinking text + grounding into MitmStore
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
response_body_buf.extend_from_slice(chunk); response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length { if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; } if response_body_buf.len() >= cl {
break;
}
} }
if is_chunked && has_chunked_terminator(&response_body_buf) { if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete"); debug!(domain, "MITM: chunked response complete");
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
} }
/// Handle a passthrough connection: transparent TCP tunnel to upstream. /// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough( async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel"); trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}")) let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
@@ -926,7 +958,12 @@ async fn handle_passthrough(
// Bidirectional copy // Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await { match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => { Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed"); trace!(
domain,
client_to_server,
server_to_client,
"MITM: tunnel closed"
);
} }
Err(e) => { Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)"); trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
return false; return false;
} }
// Check last 7 bytes to account for possible trailing whitespace // Check last 7 bytes to account for possible trailing whitespace
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body }; let tail = if body.len() > 7 {
&body[body.len() - 7..]
} else {
body
};
// Look for \r\n0\r\n\r\n anywhere in the tail // Look for \r\n0\r\n\r\n anywhere in the tail
tail.windows(5).any(|w| w == b"0\r\n\r\n") tail.windows(5).any(|w| w == b"0\r\n\r\n")
} }

View File

@@ -2,11 +2,11 @@
//! //!
//! The MITM proxy writes usage data here; the API handlers read from it. //! The MITM proxy writes usage data here; the API handlers read from it.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info}; use tracing::{debug, info};
/// Token usage from an intercepted API response. /// Token usage from an intercepted API response.
@@ -342,7 +342,9 @@ impl MitmStore {
/// Record a captured function call from Google's response. /// Record a captured function call from Google's response.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string()); let key = cascade_id
.map(|s| s.to_string())
.unwrap_or_else(|| "_latest".to_string());
info!( info!(
cascade = %key, cascade = %key,
tool = %fc.name, tool = %fc.name,
@@ -377,7 +379,6 @@ impl MitmStore {
self.awaiting_tool_result.store(false, Ordering::SeqCst); self.awaiting_tool_result.store(false, Ordering::SeqCst);
} }
/// Take any pending function calls (ignoring cascade ID). /// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> { pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
@@ -457,8 +458,6 @@ impl MitmStore {
// ── Direct response capture (bypass LS) ────────────────────────────── // ── Direct response capture (bypass LS) ──────────────────────────────
/// Set (replace) the captured response text. /// Set (replace) the captured response text.
pub async fn set_response_text(&self, text: &str) { pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string()); *self.captured_response_text.write().await = Some(text.to_string());
@@ -484,8 +483,6 @@ impl MitmStore {
self.response_complete.load(Ordering::SeqCst) self.response_complete.load(Ordering::SeqCst)
} }
/// Async version of clear_response. /// Async version of clear_response.
pub async fn clear_response_async(&self) { pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst); self.response_complete.store(false, Ordering::SeqCst);

View File

@@ -293,8 +293,7 @@ mod tests {
let cascade_bytes = b"test-cascade-id"; let cascade_bytes = b"test-cascade-id";
assert!( assert!(
msg.windows(cascade_bytes.len()) msg.windows(cascade_bytes.len()).any(|w| w == cascade_bytes),
.any(|w| w == cascade_bytes),
"cascade_id must appear in output" "cascade_id must appear in output"
); );

View File

@@ -93,9 +93,8 @@ impl QuotaStore {
// Initial poll immediately. // Initial poll immediately.
self.poll_once(&backend).await; self.poll_once(&backend).await;
let mut interval = tokio::time::interval( let mut interval =
std::time::Duration::from_secs(POLL_INTERVAL_SECS), tokio::time::interval(std::time::Duration::from_secs(POLL_INTERVAL_SECS));
);
interval.tick().await; // consume the first immediate tick interval.tick().await; // consume the first immediate tick
loop { loop {
@@ -125,7 +124,9 @@ impl QuotaStore {
// Profile picture fetch fails through iptables — harmless, suppress // Profile picture fetch fails through iptables — harmless, suppress
let data_str = data.to_string(); let data_str = data.to_string();
if data_str.contains("profile picture") { if data_str.contains("profile picture") {
tracing::debug!("GetUserStatus: profile picture fetch failed (expected with iptables)"); tracing::debug!(
"GetUserStatus: profile picture fetch failed (expected with iptables)"
);
} else { } else {
warn!("GetUserStatus returned {status}: {data_str}"); warn!("GetUserStatus returned {status}: {data_str}");
} }
@@ -172,9 +173,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
.as_str() .as_str()
.unwrap_or("") .unwrap_or("")
.to_string(); .to_string();
let frac = m["quotaInfo"]["remainingFraction"] let frac = m["quotaInfo"]["remainingFraction"].as_f64().unwrap_or(0.0);
.as_f64()
.unwrap_or(0.0);
let reset_str = m["quotaInfo"]["resetTime"] let reset_str = m["quotaInfo"]["resetTime"]
.as_str() .as_str()
.unwrap_or("") .unwrap_or("")
@@ -224,9 +223,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
flow_available: flow_avail, flow_available: flow_avail,
flow_total, flow_total,
flow_used_pct, flow_used_pct,
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"] flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"].as_i64().unwrap_or(0),
.as_i64()
.unwrap_or(0),
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false), can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
}, },
models, models,

View File

@@ -66,9 +66,7 @@ impl SessionManager {
msg_count: 0, msg_count: 0,
}, },
); );
return Ok(SessionResult { return Ok(SessionResult { cascade_id });
cascade_id,
});
} }
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string(); let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
@@ -111,9 +109,7 @@ impl SessionManager {
}, },
); );
} }
Ok(SessionResult { Ok(SessionResult { cascade_id })
cascade_id,
})
} }
/// List all active sessions. /// List all active sessions.
@@ -146,7 +142,5 @@ impl SessionManager {
fn cleanup_expired(sessions: &mut HashMap<String, Session>) { fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
let now = Instant::now(); let now = Instant::now();
sessions.retain(|_, s| { sessions.retain(|_, s| now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS);
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
});
} }

View File

@@ -10,16 +10,44 @@ use std::io::{self, Read};
// ── Domain metadata ────────────────────────────────────────────────────────── // ── Domain metadata ──────────────────────────────────────────────────────────
const DOMAIN_INFO: &[(&str, &str, &str)] = &[ const DOMAIN_INFO: &[(&str, &str, &str)] = &[
("antigravity-unleash.goog", "Feature Flags", "Unleash SDK — controls A/B tests and feature rollouts"), (
("daily-cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Primary Gemini/Claude API endpoint"), "antigravity-unleash.goog",
("cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Production Gemini/Claude API endpoint"), "Feature Flags",
("api.anthropic.com", "Claude API", "Direct Anthropic API calls"), "Unleash SDK — controls A/B tests and feature rollouts",
("lh3.googleusercontent.com", "Profile Picture", "User avatar"), ),
(
"daily-cloudcode-pa.googleapis.com",
"LLM API (gRPC)",
"Primary Gemini/Claude API endpoint",
),
(
"cloudcode-pa.googleapis.com",
"LLM API (gRPC)",
"Production Gemini/Claude API endpoint",
),
(
"api.anthropic.com",
"Claude API",
"Direct Anthropic API calls",
),
(
"lh3.googleusercontent.com",
"Profile Picture",
"User avatar",
),
("play.googleapis.com", "Telemetry", "Google Play telemetry"), ("play.googleapis.com", "Telemetry", "Google Play telemetry"),
("firebaseinstallations.googleapis.com", "Firebase", "Installation tracking"), (
"firebaseinstallations.googleapis.com",
"Firebase",
"Installation tracking",
),
("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"), ("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"),
("speech.googleapis.com", "Speech", "Voice input processing"), ("speech.googleapis.com", "Speech", "Voice input processing"),
("modelarmor.googleapis.com", "Safety", "Content safety/filtering"), (
"modelarmor.googleapis.com",
"Safety",
"Content safety/filtering",
),
]; ];
fn domain_label(domain: &str) -> (&str, &str) { fn domain_label(domain: &str) -> (&str, &str) {
@@ -57,8 +85,8 @@ struct HttpExchange {
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone, Copy, PartialEq)]
enum Direction { enum Direction {
Outgoing, // LS → upstream Outgoing, // LS → upstream
Incoming, // external → LS (our curl calls) Incoming, // external → LS (our curl calls)
} }
#[derive(Default)] #[derive(Default)]
@@ -101,10 +129,12 @@ impl Snapshot {
// LS process logs // LS process logs
if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E')) if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E'))
&& line.len() > 4 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) { && line.len() > 4
snap.ls_logs.push(line.to_string()); && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
continue; {
} snap.ls_logs.push(line.to_string());
continue;
}
if line.contains("maxprocs:") { if line.contains("maxprocs:") {
snap.ls_logs.push(line.to_string()); snap.ls_logs.push(line.to_string());
continue; continue;
@@ -128,8 +158,15 @@ impl Snapshot {
if let Some((key, val)) = extract_header(line, "Transport encoding header") { if let Some((key, val)) = extract_header(line, "Transport encoding header") {
if key == ":method" { if key == ":method" {
// Finalize previous exchange // Finalize previous exchange
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { if current_pseudo.contains_key(":path")
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream.clone()); || current_pseudo.contains_key(":method")
{
snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream.clone(),
);
} }
current_headers.clear(); current_headers.clear();
current_pseudo.clear(); current_pseudo.clear();
@@ -147,8 +184,15 @@ impl Snapshot {
// Incoming / server-received headers // Incoming / server-received headers
if let Some((key, val)) = extract_header(line, "decoded hpack field header field") { if let Some((key, val)) = extract_header(line, "decoded hpack field header field") {
if key == ":authority" && !line.contains("server read frame") { if key == ":authority" && !line.contains("server read frame") {
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { if current_pseudo.contains_key(":path")
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream.clone()); || current_pseudo.contains_key(":method")
{
snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream.clone(),
);
} }
current_headers.clear(); current_headers.clear();
current_pseudo.clear(); current_pseudo.clear();
@@ -167,8 +211,15 @@ impl Snapshot {
if line.contains("wrote HEADERS") { if line.contains("wrote HEADERS") {
if let Some(stream) = extract_stream_id(line) { if let Some(stream) = extract_stream_id(line) {
current_stream = Some(stream.clone()); current_stream = Some(stream.clone());
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { if current_pseudo.contains_key(":path")
let ex = snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, Some(stream)); || current_pseudo.contains_key(":method")
{
let ex = snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
Some(stream),
);
if ex.is_some() { if ex.is_some() {
current_headers.clear(); current_headers.clear();
current_pseudo.clear(); current_pseudo.clear();
@@ -179,10 +230,13 @@ impl Snapshot {
} }
// DATA frames // DATA frames
if (line.contains("wrote DATA") || line.contains("read DATA") || line.contains("server read frame DATA")) if (line.contains("wrote DATA")
|| line.contains("read DATA")
|| line.contains("server read frame DATA"))
&& line.contains("data=\"") && line.contains("data=\"")
{ {
let is_outgoing = line.contains("wrote DATA") || line.contains("server read frame DATA"); let is_outgoing =
line.contains("wrote DATA") || line.contains("server read frame DATA");
if let Some(stream) = extract_stream_id(line) { if let Some(stream) = extract_stream_id(line) {
if let Some(data_str) = extract_data(line) { if let Some(data_str) = extract_data(line) {
let raw = decode_go_escaped(&data_str); let raw = decode_go_escaped(&data_str);
@@ -203,7 +257,12 @@ impl Snapshot {
// Finalize remaining // Finalize remaining
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") { if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream); snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream,
);
} }
snap snap
@@ -226,7 +285,11 @@ impl Snapshot {
self.exchanges.push(HttpExchange { self.exchanges.push(HttpExchange {
authority, authority,
method: if method.is_empty() { "GET".into() } else { method }, method: if method.is_empty() {
"GET".into()
} else {
method
},
path, path,
headers: headers.to_vec(), headers: headers.to_vec(),
body: Vec::new(), body: Vec::new(),
@@ -245,7 +308,9 @@ impl Snapshot {
let sep = "".repeat(70); let sep = "".repeat(70);
let sep_thin = "".repeat(60); let sep_thin = "".repeat(60);
out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n")); out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n"));
out.push_str(&format!("{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n")); out.push_str(&format!(
"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"
));
out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n")); out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n"));
// LS Logs // LS Logs
@@ -265,7 +330,9 @@ impl Snapshot {
for target in &self.connections { for target in &self.connections {
let domain = target.split(':').next().unwrap_or(target); let domain = target.split(':').next().unwrap_or(target);
let (label, desc) = domain_label(domain); let (label, desc) = domain_label(domain);
out.push_str(&format!(" {GREEN}{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n")); out.push_str(&format!(
" {GREEN}{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"
));
if !desc.is_empty() { if !desc.is_empty() {
out.push_str(&format!(" {DIM}{desc}{NC}\n")); out.push_str(&format!(" {DIM}{desc}{NC}\n"));
} }
@@ -276,7 +343,10 @@ impl Snapshot {
// Group by domain // Group by domain
let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new(); let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new();
for ex in &self.exchanges { for ex in &self.exchanges {
if let Some(entry) = by_domain.iter_mut().find(|(d, _)| *d == ex.authority.as_str()) { if let Some(entry) = by_domain
.iter_mut()
.find(|(d, _)| *d == ex.authority.as_str())
{
entry.1.push(ex); entry.1.push(ex);
} else { } else {
by_domain.push((&ex.authority, vec![ex])); by_domain.push((&ex.authority, vec![ex]));
@@ -293,12 +363,17 @@ impl Snapshot {
let color = if label.contains("API") { YELLOW } else { CYAN }; let color = if label.contains("API") { YELLOW } else { CYAN };
out.push_str(&format!("\n{BOLD}{sep}{NC}\n")); out.push_str(&format!("\n{BOLD}{sep}{NC}\n"));
out.push_str(&format!("{BOLD}{color} {domain}{NC} {DIM}{label}{NC}\n")); out.push_str(&format!(
"{BOLD}{color} {domain}{NC} {DIM}{label}{NC}\n"
));
out.push_str(&format!("{BOLD}{sep}{NC}\n")); out.push_str(&format!("{BOLD}{sep}{NC}\n"));
for ex in exchanges { for ex in exchanges {
let method_color = if ex.method == "GET" { GREEN } else { YELLOW }; let method_color = if ex.method == "GET" { GREEN } else { YELLOW };
out.push_str(&format!("\n {BOLD}{method_color}{}{NC} {}\n", ex.method, ex.path)); out.push_str(&format!(
"\n {BOLD}{method_color}{}{NC} {}\n",
ex.method, ex.path
));
// Interesting headers // Interesting headers
for (key, val) in &ex.headers { for (key, val) in &ex.headers {
@@ -342,7 +417,10 @@ fn render_body(data: &[u8], total_len: usize) -> String {
out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n")); out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n"));
for (i, line) in pretty.lines().enumerate() { for (i, line) in pretty.lines().enumerate() {
if i >= 40 { if i >= 40 {
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 40)); out.push_str(&format!(
" {DIM}... ({} more lines){NC}\n",
pretty.lines().count() - 40
));
break; break;
} }
out.push_str(&format!(" {GREEN}{line}{NC}\n")); out.push_str(&format!(" {GREEN}{line}{NC}\n"));
@@ -357,10 +435,16 @@ fn render_body(data: &[u8], total_len: usize) -> String {
if let Ok(text) = std::str::from_utf8(&decompressed) { if let Ok(text) = std::str::from_utf8(&decompressed) {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) { if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
let pretty = serde_json::to_string_pretty(&val).unwrap_or_default(); let pretty = serde_json::to_string_pretty(&val).unwrap_or_default();
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", decompressed.len())); out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n",
decompressed.len()
));
for (i, line) in pretty.lines().enumerate() { for (i, line) in pretty.lines().enumerate() {
if i >= 50 { if i >= 50 {
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 50)); out.push_str(&format!(
" {DIM}... ({} more lines){NC}\n",
pretty.lines().count() - 50
));
break; break;
} }
out.push_str(&format!(" {GREEN}{line}{NC}\n")); out.push_str(&format!(" {GREEN}{line}{NC}\n"));
@@ -368,14 +452,20 @@ fn render_body(data: &[u8], total_len: usize) -> String {
return out; return out;
} }
// Plain text // Plain text
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", decompressed.len())); out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n",
decompressed.len()
));
for line in text.lines().take(20) { for line in text.lines().take(20) {
out.push_str(&format!(" {line}\n")); out.push_str(&format!(" {line}\n"));
} }
return out; return out;
} }
// Binary gzip // Binary gzip
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", decompressed.len())); out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n",
decompressed.len()
));
let strings = extract_strings(&decompressed); let strings = extract_strings(&decompressed);
for s in strings.iter().take(15) { for s in strings.iter().take(15) {
out.push_str(&format!(" {MAGENTA}{s}{NC}\n")); out.push_str(&format!(" {MAGENTA}{s}{NC}\n"));
@@ -393,7 +483,11 @@ fn render_body(data: &[u8], total_len: usize) -> String {
// Protobuf / binary with string extraction // Protobuf / binary with string extraction
let strings = extract_strings(data); let strings = extract_strings(data);
if !strings.is_empty() { if !strings.is_empty() {
let kind = if !data.is_empty() && matches!(data[0], 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22) { let kind = if !data.is_empty()
&& matches!(
data[0],
0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22
) {
"protobuf" "protobuf"
} else { } else {
"binary" "binary"
@@ -448,7 +542,9 @@ fn extract_header(line: &str, pattern: &str) -> Option<(String, String)> {
fn extract_stream_id(line: &str) -> Option<String> { fn extract_stream_id(line: &str) -> Option<String> {
let pos = line.find("stream=")?; let pos = line.find("stream=")?;
let rest = &line[pos + 7..]; let rest = &line[pos + 7..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len()); let end = rest
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(rest.len());
Some(rest[..end].to_string()) Some(rest[..end].to_string())
} }
@@ -470,7 +566,9 @@ fn extract_data(line: &str) -> Option<String> {
fn extract_data_len(line: &str) -> Option<usize> { fn extract_data_len(line: &str) -> Option<usize> {
let pos = line.find("len=")?; let pos = line.find("len=")?;
let rest = &line[pos + 4..]; let rest = &line[pos + 4..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len()); let end = rest
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(rest.len());
rest[..end].parse().ok() rest[..end].parse().ok()
} }
@@ -482,17 +580,40 @@ fn decode_go_escaped(s: &str) -> Vec<u8> {
if bytes[i] == b'\\' && i + 1 < bytes.len() { if bytes[i] == b'\\' && i + 1 < bytes.len() {
match bytes[i + 1] { match bytes[i + 1] {
b'x' if i + 3 < bytes.len() => { b'x' if i + 3 < bytes.len() => {
if let Ok(b) = u8::from_str_radix(std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), 16) { if let Ok(b) = u8::from_str_radix(
std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""),
16,
) {
result.push(b); result.push(b);
i += 4; i += 4;
continue; continue;
} }
} }
b'n' => { result.push(b'\n'); i += 2; continue; } b'n' => {
b'r' => { result.push(b'\r'); i += 2; continue; } result.push(b'\n');
b't' => { result.push(b'\t'); i += 2; continue; } i += 2;
b'\\' => { result.push(b'\\'); i += 2; continue; } continue;
b'"' => { result.push(b'"'); i += 2; continue; } }
b'r' => {
result.push(b'\r');
i += 2;
continue;
}
b't' => {
result.push(b'\t');
i += 2;
continue;
}
b'\\' => {
result.push(b'\\');
i += 2;
continue;
}
b'"' => {
result.push(b'"');
i += 2;
continue;
}
_ => {} _ => {}
} }
} }
@@ -562,7 +683,10 @@ pub fn run_cli() {
}) })
} else { } else {
let mut buf = String::new(); let mut buf = String::new();
io::stdin().lock().read_to_string(&mut buf).expect("Failed to read stdin"); io::stdin()
.lock()
.read_to_string(&mut buf)
.expect("Failed to read stdin");
buf buf
}; };

View File

@@ -108,7 +108,10 @@ pub struct MainLSConfig {
/// and CSRF is a random UUID. /// and CSRF is a random UUID.
pub fn generate_standalone_config() -> MainLSConfig { pub fn generate_standalone_config() -> MainLSConfig {
let csrf = Uuid::new_v4().to_string(); let csrf = Uuid::new_v4().to_string();
info!(csrf_len = csrf.len(), "Generated standalone config (headless)"); info!(
csrf_len = csrf.len(),
"Generated standalone config (headless)"
);
MainLSConfig { MainLSConfig {
extension_server_port: "0".to_string(), // disables extension server extension_server_port: "0".to_string(), // disables extension server
csrf, csrf,
@@ -159,7 +162,13 @@ impl StandaloneLS {
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone"); let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
let annotations_dir = format!("{app_data_dir}/annotations"); let annotations_dir = format!("{app_data_dir}/annotations");
let brain_dir = format!("{app_data_dir}/brain"); let brain_dir = format!("{app_data_dir}/brain");
for dir in [DATA_DIR, &gemini_dir, &app_data_dir, &annotations_dir, &brain_dir] { for dir in [
DATA_DIR,
&gemini_dir,
&app_data_dir,
&annotations_dir,
&brain_dir,
] {
let _ = std::fs::create_dir_all(dir); let _ = std::fs::create_dir_all(dir);
#[cfg(unix)] #[cfg(unix)]
{ {
@@ -194,7 +203,10 @@ impl StandaloneLS {
#[cfg(unix)] #[cfg(unix)]
{ {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&settings_path, std::fs::Permissions::from_mode(0o0666)); let _ = std::fs::set_permissions(
&settings_path,
std::fs::Permissions::from_mode(0o0666),
);
} }
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)"); tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
} }
@@ -203,10 +215,7 @@ impl StandaloneLS {
// The LS connects to this port and calls LanguageServerStarted — without it, // The LS connects to this port and calls LanguageServerStarted — without it,
// the LS never fully initializes and won't accept connections on its server_port. // the LS never fully initializes and won't accept connections on its server_port.
let _stub_listener = if headless { let _stub_listener = if headless {
let stub_port: u16 = main_config let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
.extension_server_port
.parse()
.unwrap_or(0);
if stub_port == 0 { if stub_port == 0 {
// Create a real listener so the LS can connect // Create a real listener so the LS can connect
let listener = TcpListener::bind("127.0.0.1:0") let listener = TcpListener::bind("127.0.0.1:0")
@@ -215,7 +224,10 @@ impl StandaloneLS {
.local_addr() .local_addr()
.map_err(|e| format!("Failed to get stub port: {e}"))? .map_err(|e| format!("Failed to get stub port: {e}"))?
.port(); .port();
info!(port = actual_port, "Stub extension server listening (headless)"); info!(
port = actual_port,
"Stub extension server listening (headless)"
);
// Read OAuth state from Antigravity's state.vscdb if available. // Read OAuth state from Antigravity's state.vscdb if available.
// The DB stores the exact Topic proto (access_token + refresh_token + expiry) // The DB stores the exact Topic proto (access_token + refresh_token + expiry)
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client. // which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
@@ -306,10 +318,7 @@ impl StandaloneLS {
// 3. MITM proxy intercepts the transparent TLS connection via SNI // 3. MITM proxy intercepts the transparent TLS connection via SNI
if let Some(mitm) = mitm_config { if let Some(mitm) = mitm_config {
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742") // Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
let mitm_port = mitm.proxy_addr let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
.rsplit(':')
.next()
.unwrap_or("8742");
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}") format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
} else { } else {
"https://daily-cloudcode-pa.googleapis.com".to_string() "https://daily-cloudcode-pa.googleapis.com".to_string()
@@ -324,9 +333,8 @@ impl StandaloneLS {
debug!(?args, "LS args"); debug!(?args, "LS args");
// Build env vars for the LS process // Build env vars for the LS process
let mut env_vars: Vec<(String, String)> = vec![ let mut env_vars: Vec<(String, String)> =
("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into()), vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
];
// If MITM is enabled, add SSL + proxy env vars // If MITM is enabled, add SSL + proxy env vars
if let Some(mitm) = mitm_config { if let Some(mitm) = mitm_config {
@@ -335,8 +343,8 @@ impl StandaloneLS {
// Write to /tmp — accessible by antigravity-ls user // Write to /tmp — accessible by antigravity-ls user
// (user's ~/.config/ is not traversable by other UIDs) // (user's ~/.config/ is not traversable by other UIDs)
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string(); let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
let system_ca = std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt") let system_ca =
.unwrap_or_default(); std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path) let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?; .map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}")) std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
@@ -441,7 +449,11 @@ impl StandaloneLS {
}; };
if let Some(pid) = ls_pid { if let Some(pid) = ls_pid {
info!(ls_pid = pid, sudo = use_sudo, "Discovered actual LS process"); info!(
ls_pid = pid,
sudo = use_sudo,
"Discovered actual LS process"
);
} }
Ok(StandaloneLS { Ok(StandaloneLS {
@@ -617,8 +629,7 @@ fn find_main_ls_pid() -> Result<String, String> {
return Err("No /proc filesystem".to_string()); return Err("No /proc filesystem".to_string());
} }
let entries = std::fs::read_dir(proc) let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
.map_err(|e| format!("Cannot read /proc: {e}"))?;
for entry in entries.flatten() { for entry in entries.flatten() {
let name = entry.file_name(); let name = entry.file_name();
@@ -704,12 +715,10 @@ fn cleanup_orphaned_ls() {
.output(); .output();
let pids: Vec<u32> = match output { let pids: Vec<u32> = match output {
Ok(out) => { Ok(out) => String::from_utf8_lossy(&out.stdout)
String::from_utf8_lossy(&out.stdout) .lines()
.lines() .filter_map(|l| l.trim().parse().ok())
.filter_map(|l| l.trim().parse().ok()) .collect(),
.collect()
}
Err(_) => return, Err(_) => return,
}; };
@@ -717,7 +726,11 @@ fn cleanup_orphaned_ls() {
return; return;
} }
info!(count = pids.len(), ?pids, "Cleaning up orphaned standalone LS processes"); info!(
count = pids.len(),
?pids,
"Cleaning up orphaned standalone LS processes"
);
// Kill each PID by running `kill` AS the antigravity-ls user. // Kill each PID by running `kill` AS the antigravity-ls user.
// This works because same-UID processes can signal each other, // This works because same-UID processes can signal each other,
@@ -870,7 +883,8 @@ fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
// Simple approach: convert to string and find base64 pattern // Simple approach: convert to string and find base64 pattern
let as_str = String::from_utf8_lossy(topic_bytes); let as_str = String::from_utf8_lossy(topic_bytes);
// The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited) // The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited)
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') { for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
{
if segment.len() > 50 { if segment.len() > 50 {
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) { if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
// Try to extract field 1 (access_token) from the OAuthTokenInfo proto // Try to extract field 1 (access_token) from the OAuthTokenInfo proto
@@ -951,7 +965,11 @@ fn decode_varint_at(buf: &[u8], offset: usize) -> Option<(u64, usize)> {
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream. /// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
/// If we immediately close it, the LS reconnects in a tight loop and never /// If we immediately close it, the LS reconnects in a tight loop and never
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN. /// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_topic_bytes: &Option<Vec<u8>>) { fn stub_handle_connection(
conn: std::net::TcpStream,
oauth_token: &str,
oauth_topic_bytes: &Option<Vec<u8>>,
) {
use std::io::{BufRead, BufReader, Read, Write}; use std::io::{BufRead, BufReader, Read, Write};
let mut reader = BufReader::new(match conn.try_clone() { let mut reader = BufReader::new(match conn.try_clone() {
@@ -1028,7 +1046,7 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
i += 1; i += 1;
if i + len <= proto_body.len() { if i + len <= proto_body.len() {
if field_num == 1 { if field_num == 1 {
topic_name = String::from_utf8_lossy(&proto_body[i..i+len]).to_string(); topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
} }
i += len; i += len;
} else { } else {
@@ -1084,7 +1102,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
// This includes access_token + refresh_token + expiry, so the // This includes access_token + refresh_token + expiry, so the
// LS can auto-refresh tokens via its built-in Google OAuth2 client. // LS can auto-refresh tokens via its built-in Google OAuth2 client.
initial_state_bytes = topic_bytes.clone(); initial_state_bytes = topic_bytes.clone();
eprintln!("[stub-ext] using state.vscdb topic ({} bytes)", topic_bytes.len()); eprintln!(
"[stub-ext] using state.vscdb topic ({} bytes)",
topic_bytes.len()
);
} else if !oauth_token.is_empty() { } else if !oauth_token.is_empty() {
// Manual token fallback — construct OAuthTokenInfo with far-future expiry // Manual token fallback — construct OAuthTokenInfo with far-future expiry
// (no refresh_token, so the LS can't auto-refresh) // (no refresh_token, so the LS can't auto-refresh)
@@ -1155,7 +1176,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
if !send_chunk(&mut writer, &initial_env) { if !send_chunk(&mut writer, &initial_env) {
return; return;
} }
eprintln!("[stub-ext] STREAM → sent initial_state ({} bytes)", initial_state_bytes.len()); eprintln!(
"[stub-ext] STREAM → sent initial_state ({} bytes)",
initial_state_bytes.len()
);
// (applied_update removed — data is in initial_state) // (applied_update removed — data is in initial_state)
@@ -1197,7 +1221,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
if !oauth_token.is_empty() { if !oauth_token.is_empty() {
// Build protobuf: GetSecretValueResponse { string value = 1 } // Build protobuf: GetSecretValueResponse { string value = 1 }
let proto = encode_proto_string(1, oauth_token.as_bytes()); let proto = encode_proto_string(1, oauth_token.as_bytes());
eprintln!("[stub-ext] → serving token ({} bytes) for key={key:?}", oauth_token.len()); eprintln!(
"[stub-ext] → serving token ({} bytes) for key={key:?}",
oauth_token.len()
);
// Data envelope: flag=0x00, length, data // Data envelope: flag=0x00, length, data
envelope.push(0x00u8); envelope.push(0x00u8);

View File

@@ -34,7 +34,9 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
) )
.await .await
{ {
Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"), Ok(Ok((status, _))) => {
info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}")
}
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"), Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
Err(_) => warn!("SetUserSettings timed out"), Err(_) => warn!("SetUserSettings timed out"),
} }
@@ -59,12 +61,7 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
for (method, body) in calls { for (method, body) in calls {
// Timeout per call — in headless mode, the LS can't reach Google's API // Timeout per call — in headless mode, the LS can't reach Google's API
// so these would hang forever without a timeout. Warmup is best-effort. // so these would hang forever without a timeout. Warmup is best-effort.
match tokio::time::timeout( match tokio::time::timeout(Duration::from_secs(5), backend.call_json(method, body)).await {
Duration::from_secs(5),
backend.call_json(method, body),
)
.await
{
Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"), Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"),
Ok(Err(e)) => warn!("Warmup {method} failed: {e}"), Ok(Err(e)) => warn!("Warmup {method} failed: {e}"),
Err(_) => warn!("Warmup {method} timed out"), Err(_) => warn!("Warmup {method} timed out"),
@@ -87,10 +84,7 @@ pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
let interval_ms = rand::thread_rng().gen_range(29_500..30_500); let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
tokio::time::sleep(Duration::from_millis(interval_ms)).await; tokio::time::sleep(Duration::from_millis(interval_ms)).await;
match backend match backend.call_json("Heartbeat", &serde_json::json!({})).await {
.call_json("Heartbeat", &serde_json::json!({}))
.await
{
Ok((status, _)) => debug!("Heartbeat: {status}"), Ok((status, _)) => debug!("Heartbeat: {status}"),
Err(e) => warn!("Heartbeat failed: {e}"), Err(e) => warn!("Heartbeat failed: {e}"),
} }