fix: block ALL LS follow-up requests across connections

Move the in-flight blocking check to the top of the LLM request flow,
BEFORE request modification. This catches follow-ups on ALL connections
(the LS opens multiple parallel TLS connections). Only the very first
modified request reaches Google — all others get fake STOP responses.

Previously, each new connection independently allowed one request
through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
Nikketryhard
2026-02-16 00:57:33 -06:00
parent a8f3c8915f
commit 3fdd0368a0
23 changed files with 992 additions and 568 deletions

View File

@@ -10,9 +10,11 @@ use std::sync::Arc;
use tracing::{debug, info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
use super::polling::{
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
};
use super::types::*;
use super::util::{err_response, upstream_err_response, now_unix};
use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState;
/// Extract a conversation/session ID from a flexible JSON value.
@@ -33,7 +35,8 @@ fn system_fingerprint() -> String {
/// Build a streaming chunk JSON with all required OpenAI fields.
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
fn chunk_json(
id: &str, model: &str,
id: &str,
model: &str,
choices: serde_json::Value,
usage: Option<serde_json::Value>,
) -> String {
@@ -53,7 +56,11 @@ fn chunk_json(
}
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str>) -> serde_json::Value {
fn chunk_choice(
index: u32,
delta: serde_json::Value,
finish_reason: Option<&str>,
) -> serde_json::Value {
serde_json::json!({
"index": index,
"delta": delta,
@@ -70,7 +77,9 @@ fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
match stop_reason {
Some("MAX_TOKENS") => "length",
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => "content_filter",
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
"content_filter"
}
_ => "stop",
}
}
@@ -84,7 +93,9 @@ fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
/// sends the entire messages array to the model.
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
// Extract image from last user message content array
let image = messages.iter().rev()
let image = messages
.iter()
.rev()
.find(|m| m.role == "user")
.and_then(|m| super::util::extract_first_image(&m.content));
// Always build the full conversation
@@ -141,10 +152,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
if let Some(func) = tc.get("function") {
let name = func["name"].as_str().unwrap_or("unknown");
let args = func["arguments"].as_str().unwrap_or("{}");
parts.push(format!(
"[Tool call: {}({})]",
name, args
));
parts.push(format!("[Tool call: {}({})]", name, args));
}
}
}
@@ -153,10 +161,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
let text = extract_message_text(&msg.content);
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
if !text.is_empty() {
parts.push(format!(
"[Tool result ({})]:\n{}",
tool_id, text
));
parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
}
}
_ => {}
@@ -202,7 +207,10 @@ pub(crate) async fn handle_completions(
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
state.mitm_store.set_tool_config(gemini_config).await;
}
info!(count = tools.len(), "Completions: stored client tools for MITM injection");
info!(
count = tools.len(),
"Completions: stored client tools for MITM injection"
);
} else {
state.mitm_store.clear_tools().await;
}
@@ -239,10 +247,15 @@ pub(crate) async fn handle_completions(
google_search: body.web_search,
};
// Only store if at least one param is set
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|| gp.frequency_penalty.is_some() || gp.presence_penalty.is_some()
|| gp.reasoning_effort.is_some() || gp.stop_sequences.is_some()
|| gp.response_mime_type.is_some() || gp.response_schema.is_some()
if gp.temperature.is_some()
|| gp.top_p.is_some()
|| gp.max_output_tokens.is_some()
|| gp.frequency_penalty.is_some()
|| gp.presence_penalty.is_some()
|| gp.reasoning_effort.is_some()
|| gp.stop_sequences.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search
{
state.mitm_store.set_generation_params(gp).await;
@@ -306,12 +319,13 @@ pub(crate) async fn handle_completions(
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
state.mitm_store.set_pending_image(
crate::mitm::store::PendingImage {
state
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(),
}
).await;
})
.await;
}
match state
.backend
@@ -346,7 +360,10 @@ pub(crate) async fn handle_completions(
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let include_usage = body.stream_options.as_ref().map_or(false, |o| o.include_usage);
let include_usage = body
.stream_options
.as_ref()
.map_or(false, |o| o.include_usage);
if body.stream {
chat_completions_stream(
@@ -374,11 +391,17 @@ pub(crate) async fn handle_completions(
match state.backend.create_cascade().await {
Ok(cid) => {
// Send the same message on each extra cascade
match state.backend.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()).await {
match state
.backend
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
let bg = Arc::clone(&state.backend);
let cid2 = cid.clone();
tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; });
tokio::spawn(async move {
let _ = bg.update_annotations(&cid2).await;
});
extra_cascade_ids.push(cid);
}
_ => {} // Skip failed cascades
@@ -420,7 +443,12 @@ pub(crate) async fn handle_completions(
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
);
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
(mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens)
(
mu.input_tokens,
mu.output_tokens,
mu.cache_read_input_tokens,
mu.thinking_output_tokens,
)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
@@ -874,15 +902,22 @@ async fn chat_completions_sync(
None => state.mitm_store.take_usage("_latest").await,
};
let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
let finish_reason =
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm {
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
(0, 0, 0, 0)
};
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
if let Some(ref mitm_usage) = mitm {
(
mitm_usage.input_tokens,
mitm_usage.output_tokens,
mitm_usage.cache_read_input_tokens,
mitm_usage.thinking_output_tokens,
)
} else if let Some(u) = &result.usage {
(u.input_tokens, u.output_tokens, 0, 0)
} else {
(0, 0, 0, 0)
};
// Build message object, including reasoning_content if thinking is present
let mut message = serde_json::json!({

View File

@@ -15,7 +15,9 @@ use std::sync::Arc;
use tracing::{info, warn};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
use super::polling::{
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
};
use super::util::{err_response, upstream_err_response};
use super::AppState;
use crate::mitm::store::PendingToolResult;
@@ -84,7 +86,9 @@ async fn build_usage_metadata(
store: &crate::mitm::store::MitmStore,
cascade_id: &str,
) -> serde_json::Value {
let usage = store.take_usage(cascade_id).await
let usage = store
.take_usage(cascade_id)
.await
.or(store.take_usage("_latest").await);
if let Some(usage) = usage {
serde_json::json!({
@@ -152,13 +156,12 @@ pub(crate) async fn handle_gemini(
// Gemini-native inlineData format
if image.is_none() {
if let Some(inline) = obj.get("inlineData") {
if let (Some(mime), Some(b64)) = (
inline["mimeType"].as_str(),
inline["data"].as_str(),
) {
if let Some(img) = super::util::parse_data_uri(
&format!("data:{mime};base64,{b64}")
) {
if let (Some(mime), Some(b64)) =
(inline["mimeType"].as_str(), inline["data"].as_str())
{
if let Some(img) = super::util::parse_data_uri(&format!(
"data:{mime};base64,{b64}"
)) {
image = Some(img);
}
}
@@ -194,7 +197,10 @@ pub(crate) async fn handle_gemini(
if let Some(ref tools) = body.tools {
if !tools.is_empty() {
state.mitm_store.set_tools(tools.clone()).await;
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
info!(
count = tools.len(),
"Stored Gemini-native tools for MITM injection"
);
}
}
if let Some(ref config) = body.tool_config {
@@ -207,13 +213,19 @@ pub(crate) async fn handle_gemini(
if let Some(fr) = r.get("functionResponse") {
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
state.mitm_store.add_tool_result(PendingToolResult {
name,
result: response,
}).await;
state
.mitm_store
.add_tool_result(PendingToolResult {
name,
result: response,
})
.await;
}
}
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
info!(
count = results.len(),
"Stored Gemini-native tool results for MITM injection"
);
}
// Store generation parameters for MITM injection
@@ -232,9 +244,13 @@ pub(crate) async fn handle_gemini(
response_schema: None,
google_search: body.google_search,
};
if gp.temperature.is_some() || gp.top_p.is_some() || gp.top_k.is_some()
|| gp.max_output_tokens.is_some() || gp.stop_sequences.is_some()
|| gp.reasoning_effort.is_some() || gp.google_search
if gp.temperature.is_some()
|| gp.top_p.is_some()
|| gp.top_k.is_some()
|| gp.max_output_tokens.is_some()
|| gp.stop_sequences.is_some()
|| gp.reasoning_effort.is_some()
|| gp.google_search
{
state.mitm_store.set_generation_params(gp).await;
} else {
@@ -277,12 +293,13 @@ pub(crate) async fn handle_gemini(
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
state.mitm_store.set_pending_image(
crate::mitm::store::PendingImage {
state
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(),
}
).await;
})
.await;
}
match state
.backend
@@ -372,7 +389,11 @@ async fn gemini_sync(
// Check for completed text response
if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await;
// Guard against stale response_complete with no data

View File

@@ -44,7 +44,6 @@ pub fn router(state: Arc<AppState>) -> Router {
post(completions::handle_completions),
)
.route("/v1/gemini", post(gemini::handle_gemini))
.route("/v1/models", get(handle_models))
.route("/v1/sessions", get(handle_list_sessions))
.route("/v1/sessions/{id}", delete(handle_delete_session))
@@ -106,9 +105,7 @@ async fn handle_models() -> Json<serde_json::Value> {
Json(serde_json::json!({"object": "list", "data": models}))
}
async fn handle_list_sessions(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
async fn handle_list_sessions(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
let sessions = state.sessions.list_sessions().await;
Json(serde_json::json!({"sessions": sessions}))
}
@@ -155,9 +152,7 @@ async fn handle_set_token(
)
}
async fn handle_usage(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
async fn handle_usage(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
let stats = state.mitm_store.stats().await;
Json(serde_json::json!({
"mitm": {
@@ -174,9 +169,7 @@ async fn handle_usage(
}))
}
async fn handle_quota(
State(state): State<Arc<AppState>>,
) -> Json<serde_json::Value> {
async fn handle_quota(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
let snap = state.quota_store.snapshot().await;
Json(serde_json::to_value(snap).unwrap_or_default())
}

View File

@@ -84,14 +84,8 @@ pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUs
return Some(ModelUsage {
input_tokens: input,
output_tokens: output,
api_provider: usage["apiProvider"]
.as_str()
.unwrap_or("")
.to_string(),
model: usage["model"]
.as_str()
.unwrap_or("")
.to_string(),
api_provider: usage["apiProvider"].as_str().unwrap_or("").to_string(),
model: usage["model"].as_str().unwrap_or("").to_string(),
});
}
}
@@ -263,23 +257,36 @@ pub(crate) async fn poll_for_response(
} else {
info!(
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
elapsed, text.len(),
thinking.as_ref().map_or(String::new(), |t| format!(", thinking: {} chars", t.len())),
if thinking_signature.is_some() { ", has sig" } else { "" }
elapsed,
text.len(),
thinking.as_ref().map_or(String::new(), |t| format!(
", thinking: {} chars",
t.len()
)),
if thinking_signature.is_some() {
", has sig"
} else {
""
}
);
}
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
return PollResult {
text,
usage,
thinking_signature,
thinking,
thinking_duration,
upstream_error: None,
};
}
}
// Fallback: check trajectory IDLE status (catches edge cases)
// Only check every 5th poll to reduce network calls
if step_count > 4 && step_count % 5 == 0 {
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
{
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await {
if ts == 200 {
let run_status =
td["status"].as_str().unwrap_or("");
let run_status = td["status"].as_str().unwrap_or("");
if run_status.contains("IDLE") {
let text = extract_response_text(steps);
if !text.is_empty() {
@@ -293,7 +300,14 @@ pub(crate) async fn poll_for_response(
elapsed,
text.len()
);
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
return PollResult {
text,
usage,
thinking_signature,
thinking,
thinking_duration,
upstream_error: None,
};
}
}
}

View File

@@ -14,12 +14,15 @@ use std::sync::Arc;
use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content};
use super::polling::{
extract_model_usage, extract_response_text, extract_thinking_content,
extract_thinking_signature, is_response_done, poll_for_response,
};
use super::types::*;
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event};
use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
use super::AppState;
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
use crate::mitm::store::PendingToolResult;
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
// ─── Input extraction ────────────────────────────────────────────────────────
@@ -35,7 +38,11 @@ struct ToolResultInput {
fn extract_responses_input(
input: &serde_json::Value,
instructions: Option<&str>,
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) {
) -> (
String,
Vec<ToolResultInput>,
Option<crate::proto::ImageData>,
) {
let mut tool_results: Vec<ToolResultInput> = Vec::new();
let mut image: Option<crate::proto::ImageData> = None;
@@ -45,10 +52,9 @@ fn extract_responses_input(
// Check for function_call_output items
for item in items {
if item["type"].as_str() == Some("function_call_output") {
if let (Some(call_id), Some(output)) = (
item["call_id"].as_str(),
item["output"].as_str(),
) {
if let (Some(call_id), Some(output)) =
(item["call_id"].as_str(), item["output"].as_str())
{
tool_results.push(ToolResultInput {
call_id: call_id.to_string(),
output: output.to_string(),
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
);
}
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref());
let (user_text, tool_results, image) =
extract_responses_input(&body.input, body.instructions.as_deref());
// Handle tool result submission (function_call_output in input)
let is_tool_result_turn = !tool_results.is_empty();
if is_tool_result_turn {
for tr in &tool_results {
// Look up function name from call_id
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
let name = state
.mitm_store
.lookup_call_id(&tr.call_id)
.await
.unwrap_or_else(|| "unknown_function".to_string());
// Parse the output as JSON, fall back to string wrapper
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
state.mitm_store.add_tool_result(PendingToolResult {
name,
result: result_value,
}).await;
state
.mitm_store
.add_tool_result(PendingToolResult {
name,
result: result_value,
})
.await;
}
info!(
count = tool_results.len(),
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
let gemini_tools = openai_tools_to_gemini(tools);
if !gemini_tools.is_empty() {
state.mitm_store.set_tools(gemini_tools).await;
info!(count = tools.len(), "Stored client tools for MITM injection");
info!(
count = tools.len(),
"Stored client tools for MITM injection"
);
}
}
if let Some(ref choice) = body.tool_choice {
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
if fmt_type == "json_schema" {
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
let schema = text_val["format"]["schema"]
.as_object()
.map(|o| serde_json::Value::Object(o.clone()));
let strict = text_val["format"]["strict"].as_bool();
let tf = TextFormat {
format: TextFormatInner {
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
response_schema,
google_search: has_web_search,
};
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|| gp.response_schema.is_some() || gp.google_search
if gp.temperature.is_some()
|| gp.top_p.is_some()
|| gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some()
|| gp.response_mime_type.is_some()
|| gp.response_schema.is_some()
|| gp.google_search
{
state.mitm_store.set_generation_params(gp).await;
} else {
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
}
}
let response_id = format!(
"resp_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Session/conversation management
let session_id_str = extract_conversation_id(&body.conversation);
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
state.mitm_store.set_pending_image(
crate::mitm::store::PendingImage {
state
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(),
}
).await;
})
.await;
}
match state
.backend
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
max_tool_calls: body.max_tool_calls,
reasoning_effort: body.reasoning_effort.clone(),
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
tool_choice: body
.tool_choice
.clone()
.unwrap_or(serde_json::json!("auto")),
tools: body.tools.clone().unwrap_or_default(),
text_format,
};
if body.stream {
handle_responses_stream(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
state,
response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
)
.await
} else {
handle_responses_sync(
state, response_id, model_name.to_string(), cascade_id,
body.timeout, req_params,
state,
response_id,
model_name.to_string(),
cascade_id,
body.timeout,
req_params,
)
.await
}
@@ -485,7 +516,9 @@ async fn usage_from_poll(
if let Some(u) = mitm_store.peek_usage(key).await {
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
// Call 2 hasn't arrived yet — wait briefly for the merge
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge...");
tracing::debug!(
"MITM: thinking tokens found but no text, waiting for summary merge..."
);
for _ in 0..10 {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if let Some(u2) = mitm_store.peek_usage(key).await {
@@ -526,13 +559,18 @@ async fn usage_from_poll(
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
if let Some(u) = model_usage {
return (Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
total_tokens: u.input_tokens + u.output_tokens,
}, None);
return (
Usage {
input_tokens: u.input_tokens,
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
output_tokens: u.output_tokens,
output_tokens_details: OutputTokensDetails {
reasoning_tokens: 0,
},
total_tokens: u.input_tokens + u.output_tokens,
},
None,
);
}
// Priority 3: Estimate from text lengths
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
output_items
.push(build_function_call_output(&call_id, &fc.name, &arguments));
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, "",
).await;
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
"",
)
.await;
let resp = build_response_object(
ResponseData {
id: response_id,
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
// Check for completed text response
if state.mitm_store.is_response_complete() {
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await;
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, &text,
).await;
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
&text,
)
.await;
let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = thinking {
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
return upstream_err_response(err);
}
let completed_at = now_unix();
let msg_id = format!(
"msg_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")
);
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Check for captured function calls from MITM (clears the active flag)
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
// Register call_id → name mapping for tool result routing
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
// Stringify args (OpenAI sends arguments as JSON string)
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &poll_result.usage,
&params.user_text, &poll_result.text,
).await;
&state.mitm_store,
&cascade_id,
&poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
let resp = build_response_object(
ResponseData {
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
}
// Normal text response (no tool calls)
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, &params.user_text, &poll_result.text).await;
let (usage, mitm_thinking) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&poll_result.usage,
&params.user_text,
&poll_result.text,
)
.await;
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
let thinking_text = mitm_thinking.or(poll_result.thinking);
@@ -1560,4 +1625,3 @@ fn completion_events(
events
}

View File

@@ -126,7 +126,9 @@ pub(crate) struct CompletionRequest {
pub web_search: bool,
}
fn default_n() -> u32 { 1 }
fn default_n() -> u32 {
1
}
/// Stop sequence can be a single string or array of strings (OpenAI accepts both).
#[derive(Deserialize, Clone)]
@@ -254,8 +256,7 @@ pub(crate) struct OutputTokensDetails {
pub reasoning_tokens: u64,
}
#[derive(Serialize, Clone)]
#[derive(Default)]
#[derive(Serialize, Clone, Default)]
pub(crate) struct Reasoning {
pub effort: Option<String>,
pub summary: Option<String>,
@@ -313,7 +314,6 @@ impl Default for Usage {
}
}
impl Default for TextFormat {
fn default() -> Self {
Self {

View File

@@ -27,7 +27,9 @@ pub(crate) fn err_response(
/// Convert a MITM-captured upstream error from Google into an HTTP response.
/// Maps Google's HTTP status codes and preserves the error message.
pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response {
pub(crate) fn upstream_err_response(
err: &crate::mitm::store::UpstreamError,
) -> axum::response::Response {
// Map Google's status code to HTTP status
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY);
@@ -41,7 +43,9 @@ pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) ->
_ => "upstream_error",
};
let message = err.message.clone()
let message = err
.message
.clone()
.unwrap_or_else(|| format!("Google API returned HTTP {}", err.status));
err_response(status, message, error_type)
@@ -99,7 +103,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
}
// OpenAI Responses API format
"input_image" => {
let url = item["image_url"].as_str()
let url = item["image_url"]
.as_str()
.or_else(|| item["url"].as_str())?;
parse_data_uri(url)
}
@@ -109,5 +114,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
/// Extract the first image from a content array (Value::Array of content parts).
pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> {
content.as_array()?.iter().find_map(extract_image_from_content)
content
.as_array()?
.iter()
.find_map(extract_image_from_content)
}

View File

@@ -48,10 +48,7 @@ static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
*CHROME_MAJOR,
)),
);
h.insert(
HeaderName::from_static("sec-ch-ua-mobile"),
hv("?0"),
);
h.insert(HeaderName::from_static("sec-ch-ua-mobile"), hv("?0"));
h.insert(
HeaderName::from_static("sec-ch-ua-platform"),
hv("\"Linux\""),
@@ -72,7 +69,7 @@ impl Backend {
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
let client = wreq::Client::builder()
.emulation(wreq_util::Emulation::Chrome142)
.cert_verification(false) // LS uses self-signed cert
.cert_verification(false) // LS uses self-signed cert
.verify_hostname(false)
.build()
.map_err(|e| format!("wreq client build failed: {e}"))?;
@@ -86,11 +83,7 @@ impl Backend {
/// Create a Backend with known connection details (for standalone LS).
///
/// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token.
pub fn new_with_config(
port: u16,
csrf: String,
oauth_token: String,
) -> Result<Self, String> {
pub fn new_with_config(port: u16, csrf: String, oauth_token: String) -> Result<Self, String> {
let inner = BackendInner {
pid: "standalone".to_string(),
csrf,
@@ -212,10 +205,7 @@ impl Backend {
fn common_headers(csrf: &str) -> HeaderMap {
let mut h = STATIC_HEADERS.clone();
if let Ok(val) = HeaderValue::from_str(csrf) {
h.insert(
HeaderName::from_static("x-codeium-csrf-token"),
val,
);
h.insert(HeaderName::from_static("x-codeium-csrf-token"), val);
} else {
warn!("CSRF token contains invalid header characters, omitting");
}
@@ -239,8 +229,8 @@ impl Backend {
let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
let body_bytes = serde_json::to_vec(body)
.map_err(|e| format!("JSON serialize error: {e}"))?;
let body_bytes =
serde_json::to_vec(body).map_err(|e| format!("JSON serialize error: {e}"))?;
let resp = self
.client
@@ -258,7 +248,9 @@ impl Backend {
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_string();
let raw = resp.bytes().await
let raw = resp
.bytes()
.await
.map_err(|e| format!("Read body error: {e}"))?;
let resp_bytes = decompress(method, &raw, &encoding);
// High-frequency polling methods → trace; everything else → debug
@@ -288,11 +280,7 @@ impl Backend {
}
/// Call a binary protobuf RPC method.
pub async fn call_proto(
&self,
method: &str,
body: Vec<u8>,
) -> Result<(u16, Vec<u8>), String> {
pub async fn call_proto(&self, method: &str, body: Vec<u8>) -> Result<(u16, Vec<u8>), String> {
let (base, csrf) = {
let guard = self.inner.read().await;
(
@@ -302,7 +290,10 @@ impl Backend {
};
let url = format!("{base}/{LS_SERVICE}/{method}");
let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/proto"));
headers.insert(
"Content-Type",
HeaderValue::from_static("application/proto"),
);
let resp = self
.client
@@ -350,7 +341,8 @@ impl Backend {
text: &str,
model_enum: u32,
) -> Result<(u16, Vec<u8>), String> {
self.send_message_with_image(cascade_id, text, model_enum, None).await
self.send_message_with_image(cascade_id, text, model_enum, None)
.await
}
/// SendUserCascadeMessage with optional image attachment.
@@ -365,7 +357,8 @@ impl Backend {
if token.is_empty() {
return Err("No OAuth token available".to_string());
}
let proto = crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
let proto =
crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
if image.is_some() {
tracing::info!(
proto_size = proto.len(),
@@ -376,10 +369,7 @@ impl Backend {
}
/// GetCascadeTrajectorySteps → JSON with steps array.
pub async fn get_steps(
&self,
cascade_id: &str,
) -> Result<(u16, serde_json::Value), String> {
pub async fn get_steps(&self, cascade_id: &str) -> Result<(u16, serde_json::Value), String> {
let body = serde_json::json!({"cascadeId": cascade_id});
self.call_json("GetCascadeTrajectorySteps", &body).await
}
@@ -415,7 +405,10 @@ impl Backend {
});
let mut headers = Self::common_headers(&csrf);
headers.insert("Content-Type", HeaderValue::from_static("application/connect+json"));
headers.insert(
"Content-Type",
HeaderValue::from_static("application/connect+json"),
);
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
// Connect protocol envelope: [flags:1][length:4][payload]
@@ -441,7 +434,8 @@ impl Backend {
return Err(format!("{rpc_method} failed: {status}{err_text}"));
}
let resp_ct = resp.headers()
let resp_ct = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("unknown")
@@ -495,7 +489,8 @@ impl Backend {
&self,
cascade_id: &str,
) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> {
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id).await
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id)
.await
}
}
@@ -506,7 +501,10 @@ fn discover() -> Result<BackendInner, String> {
// the wrapper is a shell script named language_server_linux_x64, while
// the real binary is language_server_linux_x64.real)
let pid_output = Command::new("sh")
.args(["-c", "pgrep -f 'language_server_linux_x64\\.real' | head -1"])
.args([
"-c",
"pgrep -f 'language_server_linux_x64\\.real' | head -1",
])
.output()
.map_err(|e| format!("pgrep failed: {e}"))?;
@@ -564,9 +562,8 @@ fn discover() -> Result<BackendInner, String> {
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
for d in &dirs {
let log_path = format!(
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"
);
let log_path =
format!("{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log");
if let Ok(contents) = fs::read_to_string(&log_path) {
for line in contents.lines() {
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
@@ -584,10 +581,7 @@ fn discover() -> Result<BackendInner, String> {
if https_port.is_empty() {
// Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up)
if let Ok(output) = std::process::Command::new("ss")
.args(["-tlnp"])
.output()
{
if let Ok(output) = std::process::Command::new("ss").args(["-tlnp"]).output() {
let ss_out = String::from_utf8_lossy(&output.stdout);
// Find listening ports for this PID — typically the first is HTTPS
for line in ss_out.lines() {
@@ -653,7 +647,11 @@ fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
Err(e) => {
if !encoding.is_empty() {
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview);
warn!(
"{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}",
data.len(),
preview
);
}
data.to_vec()
}

View File

@@ -115,9 +115,7 @@ fn detect_versions() -> DetectedVersions {
const FALLBACK_CLIENT: &str = "1.16.5";
let Some(install_dir) = find_install_dir() else {
tracing::warn!(
"Could not find Antigravity install — using fallback versions"
);
tracing::warn!("Could not find Antigravity install — using fallback versions");
return DetectedVersions {
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
chrome: FALLBACK_CHROME.to_string(),

View File

@@ -24,7 +24,10 @@ use tracing::{info, warn};
use mitm::store::MitmStore;
#[derive(Parser)]
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")]
#[command(
name = "antigravity-proxy",
about = "Antigravity OpenAI Proxy (stealth)"
)]
struct Cli {
/// Port to listen on
#[arg(long, default_value_t = 8741)]
@@ -93,15 +96,12 @@ async fn main() {
};
let filter = if log_level.is_empty() {
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "warn".into())
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into())
} else {
tracing_subscriber::EnvFilter::new(log_level)
};
tracing_subscriber::fmt()
.with_env_filter(filter)
.init();
tracing_subscriber::fmt().with_env_filter(filter).init();
// ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
let addr = format!("127.0.0.1:{}", cli.port);
@@ -111,7 +111,10 @@ async fn main() {
// Port in use — try to kill whatever's holding it
eprintln!(" Port {} in use, killing stale process...", cli.port);
let _ = std::process::Command::new("sh")
.args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)])
.args([
"-c",
&format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port),
])
.status();
// Also kill any leftover standalone LS processes
let _ = std::process::Command::new("pkill")
@@ -180,7 +183,9 @@ async fn main() {
Ok(c) => c,
Err(e) => {
eprintln!("Fatal: {e}");
eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode");
eprintln!(
"Hint: start Antigravity first, or remove --classic to use headless mode"
);
std::process::exit(1);
}
}
@@ -199,13 +204,14 @@ async fn main() {
None
};
let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
Ok(ls) => ls,
Err(e) => {
eprintln!("Fatal: failed to spawn standalone LS: {e}");
std::process::exit(1);
}
};
let mut ls =
match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
Ok(ls) => ls,
Err(e) => {
eprintln!("Fatal: failed to spawn standalone LS: {e}");
std::process::exit(1);
}
};
// Wait for it to be ready
let rt_ls_port = ls.port;
let rt_ls_csrf = ls.csrf.clone();
@@ -294,7 +300,15 @@ async fn main() {
// ── Step 5: Start serving ─────────────────────────────────────────────────
let app = api::router(state.clone());
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual, is_standalone);
print_banner(
cli.port,
&pid,
&https_port,
&csrf,
&token,
&mitm_port_actual,
is_standalone,
);
info!("Listening on http://{addr}");
axum::serve(listener, app)
@@ -349,7 +363,15 @@ async fn shutdown_signal() {
}
}
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>, is_standalone: bool) {
fn print_banner(
port: u16,
pid: &str,
https_port: &str,
csrf: &str,
token: &str,
mitm: &Option<(u16, String)>,
is_standalone: bool,
) {
let chrome_major = &*constants::CHROME_MAJOR;
let ver = crate::constants::antigravity_version();
@@ -401,7 +423,11 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
println!();
// Status line
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" };
let mitm_tag = if mitm.is_some() {
"\x1b[32mmitm\x1b[0m"
} else {
"\x1b[31mmitm\x1b[0m"
};
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
println!();
@@ -421,7 +447,9 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
if token == "NOT SET" {
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'");
println!(
" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"
);
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
println!();
}
@@ -476,5 +504,7 @@ fn find_ls_binary_path() -> Option<String> {
/// Get the data directory for storing MITM CA cert/key.
fn dirs_data_dir() -> std::path::PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy")
std::path::PathBuf::from(home)
.join(".config")
.join("antigravity-proxy")
}

View File

@@ -4,8 +4,8 @@
//! Dynamically generates per-domain leaf certificates signed by this CA.
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
IsCa, KeyPair, KeyUsagePurpose, SanType,
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, IsCa,
KeyPair, KeyUsagePurpose, SanType,
};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
use std::collections::HashMap;
@@ -45,15 +45,16 @@ impl MitmCa {
let key_pem = std::fs::read_to_string(&key_path)
.map_err(|e| format!("Failed to read CA key: {e}"))?;
let ca_key = KeyPair::from_pem(&key_pem)
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
let ca_key =
KeyPair::from_pem(&key_pem).map_err(|e| format!("Failed to parse CA key: {e}"))?;
// Re-create params and self-sign to get the rcgen Certificate object
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
// The re-signed cert will have a different serial/notBefore, but that's fine
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
let ca_signed = params
.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
@@ -76,11 +77,12 @@ impl MitmCa {
std::fs::create_dir_all(data_dir)
.map_err(|e| format!("Failed to create data dir: {e}"))?;
let ca_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
let ca_key =
KeyPair::generate().map_err(|e| format!("Failed to generate CA key: {e}"))?;
let params = Self::ca_params();
let ca_signed = params.self_signed(&ca_key)
let ca_signed = params
.self_signed(&ca_key)
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
// Write cert and key to disk
@@ -117,10 +119,7 @@ impl MitmCa {
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
];
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
// Valid for 10 years
let now = time::OffsetDateTime::now_utc();
@@ -151,12 +150,17 @@ impl MitmCa {
return None;
}
use base64::Engine;
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
let der = base64::engine::general_purpose::STANDARD
.decode(&b64)
.ok()?;
Some(CertificateDer::from(der))
}
/// Get or create a TLS ServerConfig for the given domain.
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
pub async fn server_config_for_domain(
&self,
domain: &str,
) -> Result<Arc<rustls::ServerConfig>, String> {
// Check cache first
{
let cache = self.domain_cache.read().await;
@@ -172,7 +176,11 @@ impl MitmCa {
dn.push(DnType::CommonName, domain);
params.distinguished_name = dn;
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
params.subject_alt_names = vec![SanType::DnsName(
domain
.try_into()
.map_err(|e| format!("Invalid domain: {e}"))?,
)];
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
params.key_usages = vec![
KeyUsagePurpose::DigitalSignature,
@@ -184,10 +192,11 @@ impl MitmCa {
params.not_before = now;
params.not_after = now + time::Duration::days(365);
let leaf_key = KeyPair::generate()
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
let leaf_key =
KeyPair::generate().map_err(|e| format!("Failed to generate leaf key: {e}"))?;
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
let leaf_cert = params
.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
// Build rustls ServerConfig
@@ -196,10 +205,7 @@ impl MitmCa {
let mut config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(
vec![leaf_cert_der, self.ca_cert_der.clone()],
leaf_key_der,
)
.with_single_cert(vec![leaf_cert_der, self.ca_cert_der.clone()], leaf_key_der)
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2

View File

@@ -92,11 +92,10 @@ impl UpstreamPool {
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
let upstream_io = TokioIo::new(upstream_tls);
let (sender, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(upstream_io)
.await
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(upstream_io)
.await
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
let domain = self.domain.clone();
tokio::spawn(async move {
@@ -215,12 +214,10 @@ async fn handle_h2_request(
.unwrap_or(false);
// Check if this method carries usage data
let is_usage_method = is_grpc
&& USAGE_METHODS.iter().any(|m| path.contains(m));
let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m));
// Check if this is a streaming method
let is_streaming = is_grpc
&& (path.contains("Stream") || path.contains("stream"));
let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream"));
debug!(
domain,
@@ -249,9 +246,9 @@ async fn handle_h2_request(
warn!(error = %e, domain, "MITM H2: upstream connect failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream connect failed: {e}")),
)))
.body(http_body_util::Either::Left(Full::new(Bytes::from(
format!("upstream connect failed: {e}"),
))))
.unwrap();
return Ok(resp);
}
@@ -261,17 +258,11 @@ async fn handle_h2_request(
let upstream_uri = http::Uri::builder()
.scheme("https")
.authority(domain)
.path_and_query(
uri.path_and_query()
.map(|pq| pq.as_str())
.unwrap_or("/"),
)
.path_and_query(uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"))
.build()
.unwrap_or(uri);
let mut upstream_req = Request::builder()
.method(parts.method)
.uri(upstream_uri);
let mut upstream_req = Request::builder().method(parts.method).uri(upstream_uri);
// Copy headers, skip hop-by-hop
for (name, value) in &parts.headers {
@@ -287,9 +278,9 @@ async fn handle_h2_request(
Err(e) => {
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("build request failed: {e}")),
)))
.body(http_body_util::Either::Left(Full::new(Bytes::from(
format!("build request failed: {e}"),
))))
.unwrap();
return Ok(resp);
}
@@ -302,9 +293,9 @@ async fn handle_h2_request(
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
let resp = Response::builder()
.status(502)
.body(http_body_util::Either::Left(Full::new(
Bytes::from(format!("upstream request failed: {e}")),
)))
.body(http_body_util::Either::Left(Full::new(Bytes::from(
format!("upstream request failed: {e}"),
))))
.unwrap();
return Ok(resp);
}
@@ -326,13 +317,18 @@ async fn handle_h2_request(
// Spawn a task to forward body chunks and tee for usage extraction
tokio::spawn(async move {
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
let mut tee_buffer = if should_track_usage {
Some(Vec::new())
} else {
None
};
let mut body = resp_body;
loop {
match body.frame().await {
Some(Ok(frame)) => {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref())
{
buf.extend_from_slice(data);
}
if tx.send(Ok(frame)).await.is_err() {
@@ -354,7 +350,9 @@ async fn handle_h2_request(
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
let usage = grpc_usage.into_api_usage(path_clone.clone());
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
store_clone
.record_usage(cascade_hint.as_deref(), usage)
.await;
}
}
}

View File

@@ -78,15 +78,21 @@ impl StreamingAccumulator {
Self::default()
}
/// Process a single SSE event.
/// Process a single SSE event.
pub fn process_event(&mut self, event: &Value) {
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
if let Some(response) = event.get("response") {
// Extract usage metadata (each event has cumulative counts)
if let Some(usage) = response.get("usageMetadata") {
self.input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(self.input_tokens);
self.output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(self.output_tokens);
self.thinking_tokens = usage["thoughtsTokenCount"].as_u64().unwrap_or(self.thinking_tokens);
self.input_tokens = usage["promptTokenCount"]
.as_u64()
.unwrap_or(self.input_tokens);
self.output_tokens = usage["candidatesTokenCount"]
.as_u64()
.unwrap_or(self.output_tokens);
self.thinking_tokens = usage["thoughtsTokenCount"]
.as_u64()
.unwrap_or(self.thinking_tokens);
}
if let Some(model) = response["modelVersion"].as_str() {
self.model = Some(model.to_string());
@@ -170,8 +176,10 @@ impl StreamingAccumulator {
"message_start" => {
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
self.cache_creation_input_tokens =
usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
self.cache_read_input_tokens =
usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
}
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
self.model = Some(model.to_string());
@@ -181,7 +189,9 @@ impl StreamingAccumulator {
}
"message_delta" => {
if let Some(usage) = event.get("usage") {
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
self.output_tokens = usage["output_tokens"]
.as_u64()
.unwrap_or(self.output_tokens);
}
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
self.stop_reason = Some(reason.to_string());
@@ -235,7 +245,10 @@ impl StreamingAccumulator {
response_output_tokens: 0,
model: self.model,
stop_reason: self.stop_reason,
api_provider: self.api_provider.unwrap_or_else(|| "unknown".to_string()).into(),
api_provider: self
.api_provider
.unwrap_or_else(|| "unknown".to_string())
.into(),
grpc_method: None,
captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)

View File

@@ -68,14 +68,14 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
"system instruction: keep <identity> only ({original_len}{} chars, -{stripped})",
new_sys.len()
));
json["request"]["systemInstruction"]["parts"][0]["text"] =
Value::String(new_sys);
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(new_sys);
}
} else {
// No identity tag found — clear the whole thing
changes.push(format!("system instruction: cleared ({original_len} chars)"));
json["request"]["systemInstruction"]["parts"][0]["text"] =
Value::String(String::new());
changes.push(format!(
"system instruction: cleared ({original_len} chars)"
));
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
}
}
@@ -125,7 +125,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
let mut modified = text.clone();
// Strip conversation summaries block
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") {
if let Some(cleaned) = strip_between(
&modified,
"# Conversation History\n",
"</conversation_summaries>",
) {
modified = cleaned;
}
@@ -147,7 +151,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
}
// Strip knowledge item blocks
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") {
if let Some(cleaned) =
strip_between(&modified, "Here are the ", "</knowledge_item>")
{
// Only strip if it's about knowledge items
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
modified = cleaned;
@@ -202,7 +208,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// Inject client-provided tools from ToolContext
if let Some(ref ctx) = tool_ctx {
if let Some(ref custom_tools) = ctx.tools {
let total_decls: usize = custom_tools.iter()
let total_decls: usize = custom_tools
.iter()
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
.map(|a| a.len())
.sum();
@@ -210,7 +217,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
tools.push(tool.clone());
}
has_custom_tools = true;
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
changes.push(format!(
"inject {} custom tool group(s)",
custom_tools.len()
));
// Override LS's VALIDATED toolConfig → AUTO for custom tools.
// VALIDATED mode forces Google to validate function calls against a
@@ -218,16 +228,20 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
// that list, so they'd be rejected. AUTO lets the model freely choose
// between text and function calls.
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
let has_validated = req.get("toolConfig")
let has_validated = req
.get("toolConfig")
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
.and_then(|m| m.as_str())
.map_or(false, |m| m == "VALIDATED");
if has_validated {
req.insert("toolConfig".to_string(), serde_json::json!({
"functionCallingConfig": {
"mode": "AUTO"
}
}));
req.insert(
"toolConfig".to_string(),
serde_json::json!({
"functionCallingConfig": {
"mode": "AUTO"
}
}),
);
changes.push("override toolConfig VALIDATED → AUTO".to_string());
}
}
@@ -243,7 +257,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
if STRIP_ALL_TOOLS && !has_custom_tools {
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
// Remove the empty tools array entirely
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) {
if req
.get("tools")
.and_then(|v| v.as_array())
.map_or(false, |a| a.is_empty())
{
req.remove("tools");
changes.push("remove empty tools array".to_string());
}
@@ -266,7 +284,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
.as_ref()
.and_then(|ctx| ctx.tools.as_ref())
.map(|tools| {
tools.iter()
tools
.iter()
.filter_map(|t| t["functionDeclarations"].as_array())
.flatten()
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
@@ -309,7 +328,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
.map_or(true, |parts| !parts.is_empty())
});
if stripped_fc > 0 {
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history"));
changes.push(format!(
"strip {stripped_fc} functionCall/Response parts from history"
));
}
}
}
@@ -336,16 +357,22 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
for msg in contents.iter_mut() {
if msg["role"].as_str() == Some("model") {
if let Some(text) = msg["parts"][0]["text"].as_str() {
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
if text.contains("Tool call completed")
|| text.contains("Awaiting external tool result")
{
// Replace with functionCall parts
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
let fc_parts: Vec<Value> = ctx
.last_calls
.iter()
.map(|fc| {
serde_json::json!({
"functionCall": {
"name": fc.name,
"args": fc.args,
}
})
})
}).collect();
.collect();
msg["parts"] = Value::Array(fc_parts);
changes.push("rewrite model turn with functionCall".to_string());
break;
@@ -355,29 +382,36 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
}
// Add functionResponse as a user turn before the last user message
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
serde_json::json!({
"functionResponse": {
"name": r.name,
"response": r.result,
}
let fn_response_parts: Vec<Value> = ctx
.pending_results
.iter()
.map(|r| {
serde_json::json!({
"functionResponse": {
"name": r.name,
"response": r.result,
}
})
})
}).collect();
.collect();
let fn_response_turn = serde_json::json!({
"role": "user",
"parts": fn_response_parts,
});
// Insert before the last user message
let last_user_idx = contents.iter().rposition(|msg| {
msg["role"].as_str() == Some("user")
});
let last_user_idx = contents
.iter()
.rposition(|msg| msg["role"].as_str() == Some("user"));
if let Some(idx) = last_user_idx {
contents.insert(idx, fn_response_turn);
} else {
contents.push(fn_response_turn);
}
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
changes.push(format!(
"inject {} functionResponse(s)",
ctx.pending_results.len()
));
}
}
}
@@ -420,8 +454,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
} else {
// Not wrapped in request — try top-level (public API format)
let gen_config = json.as_object_mut().and_then(|o| {
Some(o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})))
Some(
o.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})),
)
});
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
let thinking_config = gc
@@ -449,8 +485,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
if let Some(ref gp) = ctx.generation_params {
// Find or create generationConfig (same path as above)
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
Some(req.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})))
Some(
req.entry("generationConfig")
.or_insert_with(|| serde_json::json!({})),
)
} else {
json.as_object_mut().map(|o| {
o.entry("generationConfig")
@@ -564,8 +602,6 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
changes.join(", ")
);
Some(modified_bytes)
}
@@ -832,8 +868,10 @@ mod tests {
let result: Value = serde_json::from_slice(&modified).unwrap();
// With no ToolContext, tools should be removed entirely
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
"tools should be removed when no custom tools provided");
assert!(
result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
"tools should be removed when no custom tools provided"
);
}
#[test]
@@ -892,13 +930,23 @@ mod tests {
let contents = result["request"]["contents"].as_array().unwrap();
// Should have removed user_information, user_rules, workflows (3 messages)
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
assert_eq!(contents.len(), 2, "should keep only user request + model response");
assert_eq!(
contents.len(),
2,
"should keep only user request + model response"
);
// Check USER_REQUEST message had metadata stripped
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
assert!(user_msg.contains("Say hello"), "should keep user request");
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata");
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info");
assert!(
!user_msg.contains("ADDITIONAL_METADATA"),
"should strip metadata"
);
assert!(
!user_msg.contains("cursor stuff"),
"should strip cursor info"
);
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
// Model response kept intact
@@ -921,8 +969,14 @@ mod tests {
#[test]
fn test_strip_between() {
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap();
let text =
"keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
let result = strip_between(
text,
"# Conversation History\n",
"</conversation_summaries>",
)
.unwrap();
assert_eq!(result, "keep this and this");
}
}
@@ -977,7 +1031,9 @@ pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
// Replace the JSON in the result string
result.replace_range(json_start..json_start + json_end, &new_json);
changed = true;
info!("MITM: rewrote functionCall in response → text placeholder for LS");
info!(
"MITM: rewrote functionCall in response → text placeholder for LS"
);
search_from = json_start + new_json.len();
continue;
}
@@ -1117,7 +1173,10 @@ fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
}
// Try nested "response.candidates"
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) {
if let Some(candidates) = json
.pointer_mut("/response/candidates")
.and_then(|v| v.as_array_mut())
{
changed |= rewrite_candidates(candidates);
}

View File

@@ -251,7 +251,10 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
// (e.g., a long string that happened to have a valid first-field prefix)
if fields.len() == 1 && original_len > 100 {
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
matches!(&fields[0].value, ProtoValue::Bytes(_) | ProtoValue::Message(_))
matches!(
&fields[0].value,
ProtoValue::Bytes(_) | ProtoValue::Message(_)
)
} else {
true
}
@@ -328,7 +331,9 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
.iter()
.filter_map(|f| {
if let ProtoValue::Bytes(ref b) = f.value {
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
std::str::from_utf8(b)
.ok()
.map(|s| (f.number, s.to_string()))
} else {
None
}
@@ -361,14 +366,23 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
// can contain model names, or model enum values map to known names)
let has_model_string = string_fields.iter().any(|(_, s)| {
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|| s.contains("flash") || s.contains("pro")
s.contains("claude")
|| s.contains("gemini")
|| s.contains("gpt")
|| s.starts_with("models/")
|| s.contains("sonnet")
|| s.contains("opus")
|| s.contains("flash")
|| s.contains("pro")
});
// Check for fields at the known ModelUsageStats field numbers
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_2 = fields
.iter()
.any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
let has_field_3 = fields
.iter()
.any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
// Strong signal: has both input and output token fields
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
@@ -392,8 +406,8 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
// field 1 = model enum (varint, not string!)
2 => usage.input_tokens = v,
3 => usage.output_tokens = v,
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
// field 6 = api_provider enum (varint)
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
10 => usage.response_output_tokens = v, // VERIFIED: field 10
@@ -486,11 +500,11 @@ pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
fn model_enum_name(enum_val: u64) -> &'static str {
match enum_val {
// Placeholder models (1000 + N)
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
// Claude models (named)
281 => "claude-4-sonnet",
@@ -629,13 +643,13 @@ mod tests {
data.push(v as u8);
}
encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens
encode_varint_field(&mut data, 1, 5); // model enum
encode_varint_field(&mut data, 2, 1000); // input_tokens
encode_varint_field(&mut data, 3, 500); // output_tokens
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
encode_varint_field(&mut data, 10, 200); // response_output_tokens
let fields = decode_proto(&data);
let usage = try_extract_usage(&fields).expect("should extract usage");

View File

@@ -11,8 +11,7 @@
use super::ca::MitmCa;
use super::intercept::{
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
StreamingAccumulator,
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
};
use super::store::MitmStore;
use std::sync::Arc;
@@ -54,7 +53,6 @@ pub struct MitmConfig {
pub modify_requests: bool,
}
/// Run the MITM proxy server.
///
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
@@ -84,7 +82,8 @@ pub async fn run(
let ca = ca.clone();
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
{
warn!(error = %e, "MITM connection error");
}
});
@@ -131,8 +130,7 @@ async fn handle_connection(
.await
.map_err(|e| format!("Peek ClientHello: {e}"))?;
let domain = extract_sni(&hello_buf[..n])
.unwrap_or_else(|| "unknown".to_string());
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
info!(domain, "MITM: transparent redirect (iptables)");
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
let mut pos = 34; // skip version + random
// Session ID
if pos >= body.len() { return None; }
if pos >= body.len() {
return None;
}
let sid_len = body[pos] as usize;
pos += 1 + sid_len;
// Cipher suites
if pos + 2 > body.len() { return None; }
if pos + 2 > body.len() {
return None;
}
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2 + cs_len;
// Compression methods
if pos >= body.len() { return None; }
if pos >= body.len() {
return None;
}
let cm_len = body[pos] as usize;
pos += 1 + cm_len;
// Extensions
if pos + 2 > body.len() { return None; }
if pos + 2 > body.len() {
return None;
}
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
pos += 2;
let ext_end = pos + ext_len.min(body.len() - pos);
@@ -304,32 +310,32 @@ async fn handle_intercepted(
info!(domain, "MITM: intercepting TLS");
// Get or create server TLS config for this domain
let server_config = ca
.server_config_for_domain(domain)
.await?;
let server_config = ca.server_config_for_domain(domain).await?;
let acceptor = TlsAcceptor::from(server_config);
// Perform TLS handshake with the client (LS) — 10s timeout
let tls_stream = match tokio::time::timeout(
std::time::Duration::from_secs(10),
acceptor.accept(stream),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
let tls_stream =
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => {
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
return Err(format!(
"TLS handshake with client failed for {domain}: {e}"
));
}
Err(_) => {
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
return Err(format!("TLS handshake timed out for {domain}"));
}
};
// Check negotiated ALPN protocol
let alpn = tls_stream.get_ref().1
let alpn = tls_stream
.get_ref()
.1
.alpn_protocol()
.map(|p| String::from_utf8_lossy(p).to_string());
@@ -339,12 +345,7 @@ async fn handle_intercepted(
Some("h2") => {
// HTTP/2 — use the hyper-based gRPC handler
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
super::h2_handler::handle_h2_connection(
tls_stream,
domain.to_string(),
store,
)
.await
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
}
_ => {
// HTTP/1.1 or no ALPN — use the existing handler
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
if let Some(ip) = out
.lines()
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
{
return format!("{ip}:443");
}
}
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
loop {
// ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64);
let mut is_our_request = false;
// 60s timeout on initial read (LS may open connection without sending immediately)
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
}
// Parse the HTTP request to find headers and body
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf);
let (headers_end, content_length, _is_streaming_request) =
parse_http_request_meta(&request_buf);
// Try to extract cascade hint from request body
let cascade_hint = if headers_end < request_buf.len() {
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
"MITM: forwarding LLM request"
);
// ── Block ALL requests when one is already in-flight ─────────
// The LS opens multiple connections and sends parallel requests.
// When custom tools are active, only the FIRST request should reach
// Google. Block everything else with a fake response.
if store.is_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue;
}
// ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk
if modify_requests && body_len > 0 {
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await;
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() {
let tool_ctx = if tools.is_some()
|| !pending_results.is_empty()
|| generation_params.is_some()
|| pending_image.is_some()
{
Some(super::modify::ToolContext {
tools,
tool_config,
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
None
};
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
if let Some(modified_body) =
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
{
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
let new_chunked = super::modify::rechunk(&modified_body);
@@ -588,39 +619,12 @@ async fn handle_http_over_tls(
let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf;
// Mark this as our modified request and set in-flight flag
is_our_request = true;
// Mark in-flight IMMEDIATELY — blocks all subsequent requests
store.mark_request_in_flight();
}
}
}
// ── Block ALL LS follow-up requests once first is in-flight ──
// When custom tools are active, we only need ONE request to Google.
// The LS tries to send multiple requests (its own agentic loop +
// internal requests on gemini-2.5-flash-lite). Block them ALL
// immediately — don't wait for response_complete.
let has_tools = store.get_tools().await.is_some();
if has_tools && store.is_request_in_flight() && !is_our_request {
info!(
"MITM: blocking LS follow-up — request already in-flight"
);
// Return a fake SSE response that makes the LS stop
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue; // Skip the real upstream call
}
} else {
debug!(
domain,
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
};
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
Ok(Ok(0)) => { upstream_ok = false; break; }
Ok(Ok(0)) => {
upstream_ok = false;
break;
}
Ok(Ok(n)) => n,
Ok(Err(e)) => {
debug!(domain, error = %e, "MITM: upstream read ended");
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
if header.name.eq_ignore_ascii_case("content-type") {
if let Ok(v) = std::str::from_utf8(header.value) {
content_type = v.to_string();
if v.contains("text/event-stream") { is_streaming_response = true; }
if v.contains("text/event-stream") {
is_streaming_response = true;
}
}
}
if header.name.eq_ignore_ascii_case("content-length") {
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
}
if header.name.eq_ignore_ascii_case("connection") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
if v.trim().eq_ignore_ascii_case("close") {
upstream_ok = false;
}
}
}
if header.name.eq_ignore_ascii_case("transfer-encoding") {
if let Ok(v) = std::str::from_utf8(header.value) {
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
if v.trim().eq_ignore_ascii_case("chunked") {
is_chunked = true;
}
}
}
}
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str)
.ok()
.and_then(|v| {
let err = v.get("error")?;
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string());
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
let (message, error_status) =
serde_json::from_str::<serde_json::Value>(&body_str)
.ok()
.and_then(|v| {
let err = v.get("error")?;
let msg = err
.get("message")
.and_then(|m| m.as_str())
.map(|s| s.to_string());
let status = err
.get("status")
.and_then(|s| s.as_str())
.map(|s| s.to_string());
Some((msg, status))
})
.unwrap_or((None, None));
store.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
}).await;
store
.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
})
.await;
}
// Save body for usage parsing
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len());
info!(
"MITM: stored {} function call(s) from initial body",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
}
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
if response_body_buf.len() >= cl {
break;
}
}
// Check chunked terminator in initial body
if is_chunked && has_chunked_terminator(&response_body_buf) {
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len());
info!(
"MITM: stored {} function call(s) from body chunk",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { break; }
if response_body_buf.len() >= cl {
break;
}
}
if is_chunked && has_chunked_terminator(&response_body_buf) {
debug!(domain, "MITM: chunked response complete");
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(
mut client: TcpStream,
domain: &str,
port: u16,
) -> Result<(), String> {
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel");
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
@@ -926,7 +958,12 @@ async fn handle_passthrough(
// Bidirectional copy
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
Ok((client_to_server, server_to_client)) => {
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
trace!(
domain,
client_to_server,
server_to_client,
"MITM: tunnel closed"
);
}
Err(e) => {
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
return false;
}
// Check last 7 bytes to account for possible trailing whitespace
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
let tail = if body.len() > 7 {
&body[body.len() - 7..]
} else {
body
};
// Look for \r\n0\r\n\r\n anywhere in the tail
tail.windows(5).any(|w| w == b"0\r\n\r\n")
}

View File

@@ -2,11 +2,11 @@
//!
//! The MITM proxy writes usage data here; the API handlers read from it.
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
/// Token usage from an intercepted API response.
@@ -342,7 +342,9 @@ impl MitmStore {
/// Record a captured function call from Google's response.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
let key = cascade_id
.map(|s| s.to_string())
.unwrap_or_else(|| "_latest".to_string());
info!(
cascade = %key,
tool = %fc.name,
@@ -377,7 +379,6 @@ impl MitmStore {
self.awaiting_tool_result.store(false, Ordering::SeqCst);
}
/// Take any pending function calls (ignoring cascade ID).
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await;
@@ -457,8 +458,6 @@ impl MitmStore {
// ── Direct response capture (bypass LS) ──────────────────────────────
/// Set (replace) the captured response text.
pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string());
@@ -484,8 +483,6 @@ impl MitmStore {
self.response_complete.load(Ordering::SeqCst)
}
/// Async version of clear_response.
pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst);

View File

@@ -293,8 +293,7 @@ mod tests {
let cascade_bytes = b"test-cascade-id";
assert!(
msg.windows(cascade_bytes.len())
.any(|w| w == cascade_bytes),
msg.windows(cascade_bytes.len()).any(|w| w == cascade_bytes),
"cascade_id must appear in output"
);

View File

@@ -93,9 +93,8 @@ impl QuotaStore {
// Initial poll immediately.
self.poll_once(&backend).await;
let mut interval = tokio::time::interval(
std::time::Duration::from_secs(POLL_INTERVAL_SECS),
);
let mut interval =
tokio::time::interval(std::time::Duration::from_secs(POLL_INTERVAL_SECS));
interval.tick().await; // consume the first immediate tick
loop {
@@ -125,7 +124,9 @@ impl QuotaStore {
// Profile picture fetch fails through iptables — harmless, suppress
let data_str = data.to_string();
if data_str.contains("profile picture") {
tracing::debug!("GetUserStatus: profile picture fetch failed (expected with iptables)");
tracing::debug!(
"GetUserStatus: profile picture fetch failed (expected with iptables)"
);
} else {
warn!("GetUserStatus returned {status}: {data_str}");
}
@@ -172,9 +173,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
.as_str()
.unwrap_or("")
.to_string();
let frac = m["quotaInfo"]["remainingFraction"]
.as_f64()
.unwrap_or(0.0);
let frac = m["quotaInfo"]["remainingFraction"].as_f64().unwrap_or(0.0);
let reset_str = m["quotaInfo"]["resetTime"]
.as_str()
.unwrap_or("")
@@ -224,9 +223,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
flow_available: flow_avail,
flow_total,
flow_used_pct,
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"]
.as_i64()
.unwrap_or(0),
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"].as_i64().unwrap_or(0),
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
},
models,

View File

@@ -66,9 +66,7 @@ impl SessionManager {
msg_count: 0,
},
);
return Ok(SessionResult {
cascade_id,
});
return Ok(SessionResult { cascade_id });
}
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
@@ -111,9 +109,7 @@ impl SessionManager {
},
);
}
Ok(SessionResult {
cascade_id,
})
Ok(SessionResult { cascade_id })
}
/// List all active sessions.
@@ -146,7 +142,5 @@ impl SessionManager {
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
let now = Instant::now();
sessions.retain(|_, s| {
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
});
sessions.retain(|_, s| now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS);
}

View File

@@ -10,16 +10,44 @@ use std::io::{self, Read};
// ── Domain metadata ──────────────────────────────────────────────────────────
const DOMAIN_INFO: &[(&str, &str, &str)] = &[
("antigravity-unleash.goog", "Feature Flags", "Unleash SDK — controls A/B tests and feature rollouts"),
("daily-cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Primary Gemini/Claude API endpoint"),
("cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Production Gemini/Claude API endpoint"),
("api.anthropic.com", "Claude API", "Direct Anthropic API calls"),
("lh3.googleusercontent.com", "Profile Picture", "User avatar"),
(
"antigravity-unleash.goog",
"Feature Flags",
"Unleash SDK — controls A/B tests and feature rollouts",
),
(
"daily-cloudcode-pa.googleapis.com",
"LLM API (gRPC)",
"Primary Gemini/Claude API endpoint",
),
(
"cloudcode-pa.googleapis.com",
"LLM API (gRPC)",
"Production Gemini/Claude API endpoint",
),
(
"api.anthropic.com",
"Claude API",
"Direct Anthropic API calls",
),
(
"lh3.googleusercontent.com",
"Profile Picture",
"User avatar",
),
("play.googleapis.com", "Telemetry", "Google Play telemetry"),
("firebaseinstallations.googleapis.com", "Firebase", "Installation tracking"),
(
"firebaseinstallations.googleapis.com",
"Firebase",
"Installation tracking",
),
("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"),
("speech.googleapis.com", "Speech", "Voice input processing"),
("modelarmor.googleapis.com", "Safety", "Content safety/filtering"),
(
"modelarmor.googleapis.com",
"Safety",
"Content safety/filtering",
),
];
fn domain_label(domain: &str) -> (&str, &str) {
@@ -57,8 +85,8 @@ struct HttpExchange {
#[derive(Debug, Clone, Copy, PartialEq)]
enum Direction {
Outgoing, // LS → upstream
Incoming, // external → LS (our curl calls)
Outgoing, // LS → upstream
Incoming, // external → LS (our curl calls)
}
#[derive(Default)]
@@ -101,10 +129,12 @@ impl Snapshot {
// LS process logs
if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E'))
&& line.len() > 4 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) {
snap.ls_logs.push(line.to_string());
continue;
}
&& line.len() > 4
&& line.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
{
snap.ls_logs.push(line.to_string());
continue;
}
if line.contains("maxprocs:") {
snap.ls_logs.push(line.to_string());
continue;
@@ -128,8 +158,15 @@ impl Snapshot {
if let Some((key, val)) = extract_header(line, "Transport encoding header") {
if key == ":method" {
// Finalize previous exchange
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream.clone());
if current_pseudo.contains_key(":path")
|| current_pseudo.contains_key(":method")
{
snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream.clone(),
);
}
current_headers.clear();
current_pseudo.clear();
@@ -147,8 +184,15 @@ impl Snapshot {
// Incoming / server-received headers
if let Some((key, val)) = extract_header(line, "decoded hpack field header field") {
if key == ":authority" && !line.contains("server read frame") {
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream.clone());
if current_pseudo.contains_key(":path")
|| current_pseudo.contains_key(":method")
{
snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream.clone(),
);
}
current_headers.clear();
current_pseudo.clear();
@@ -167,8 +211,15 @@ impl Snapshot {
if line.contains("wrote HEADERS") {
if let Some(stream) = extract_stream_id(line) {
current_stream = Some(stream.clone());
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
let ex = snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, Some(stream));
if current_pseudo.contains_key(":path")
|| current_pseudo.contains_key(":method")
{
let ex = snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
Some(stream),
);
if ex.is_some() {
current_headers.clear();
current_pseudo.clear();
@@ -179,10 +230,13 @@ impl Snapshot {
}
// DATA frames
if (line.contains("wrote DATA") || line.contains("read DATA") || line.contains("server read frame DATA"))
if (line.contains("wrote DATA")
|| line.contains("read DATA")
|| line.contains("server read frame DATA"))
&& line.contains("data=\"")
{
let is_outgoing = line.contains("wrote DATA") || line.contains("server read frame DATA");
let is_outgoing =
line.contains("wrote DATA") || line.contains("server read frame DATA");
if let Some(stream) = extract_stream_id(line) {
if let Some(data_str) = extract_data(line) {
let raw = decode_go_escaped(&data_str);
@@ -203,7 +257,12 @@ impl Snapshot {
// Finalize remaining
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
snap.finalize_exchange(&current_pseudo, &current_headers, current_direction, current_stream);
snap.finalize_exchange(
&current_pseudo,
&current_headers,
current_direction,
current_stream,
);
}
snap
@@ -226,7 +285,11 @@ impl Snapshot {
self.exchanges.push(HttpExchange {
authority,
method: if method.is_empty() { "GET".into() } else { method },
method: if method.is_empty() {
"GET".into()
} else {
method
},
path,
headers: headers.to_vec(),
body: Vec::new(),
@@ -245,7 +308,9 @@ impl Snapshot {
let sep = "".repeat(70);
let sep_thin = "".repeat(60);
out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n"));
out.push_str(&format!("{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"));
out.push_str(&format!(
"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"
));
out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n"));
// LS Logs
@@ -265,7 +330,9 @@ impl Snapshot {
for target in &self.connections {
let domain = target.split(':').next().unwrap_or(target);
let (label, desc) = domain_label(domain);
out.push_str(&format!(" {GREEN}{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"));
out.push_str(&format!(
" {GREEN}{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"
));
if !desc.is_empty() {
out.push_str(&format!(" {DIM}{desc}{NC}\n"));
}
@@ -276,7 +343,10 @@ impl Snapshot {
// Group by domain
let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new();
for ex in &self.exchanges {
if let Some(entry) = by_domain.iter_mut().find(|(d, _)| *d == ex.authority.as_str()) {
if let Some(entry) = by_domain
.iter_mut()
.find(|(d, _)| *d == ex.authority.as_str())
{
entry.1.push(ex);
} else {
by_domain.push((&ex.authority, vec![ex]));
@@ -293,12 +363,17 @@ impl Snapshot {
let color = if label.contains("API") { YELLOW } else { CYAN };
out.push_str(&format!("\n{BOLD}{sep}{NC}\n"));
out.push_str(&format!("{BOLD}{color} {domain}{NC} {DIM}{label}{NC}\n"));
out.push_str(&format!(
"{BOLD}{color} {domain}{NC} {DIM}{label}{NC}\n"
));
out.push_str(&format!("{BOLD}{sep}{NC}\n"));
for ex in exchanges {
let method_color = if ex.method == "GET" { GREEN } else { YELLOW };
out.push_str(&format!("\n {BOLD}{method_color}{}{NC} {}\n", ex.method, ex.path));
out.push_str(&format!(
"\n {BOLD}{method_color}{}{NC} {}\n",
ex.method, ex.path
));
// Interesting headers
for (key, val) in &ex.headers {
@@ -342,7 +417,10 @@ fn render_body(data: &[u8], total_len: usize) -> String {
out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n"));
for (i, line) in pretty.lines().enumerate() {
if i >= 40 {
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 40));
out.push_str(&format!(
" {DIM}... ({} more lines){NC}\n",
pretty.lines().count() - 40
));
break;
}
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
@@ -357,10 +435,16 @@ fn render_body(data: &[u8], total_len: usize) -> String {
if let Ok(text) = std::str::from_utf8(&decompressed) {
if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
let pretty = serde_json::to_string_pretty(&val).unwrap_or_default();
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", decompressed.len()));
out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n",
decompressed.len()
));
for (i, line) in pretty.lines().enumerate() {
if i >= 50 {
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 50));
out.push_str(&format!(
" {DIM}... ({} more lines){NC}\n",
pretty.lines().count() - 50
));
break;
}
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
@@ -368,14 +452,20 @@ fn render_body(data: &[u8], total_len: usize) -> String {
return out;
}
// Plain text
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", decompressed.len()));
out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n",
decompressed.len()
));
for line in text.lines().take(20) {
out.push_str(&format!(" {line}\n"));
}
return out;
}
// Binary gzip
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", decompressed.len()));
out.push_str(&format!(
" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n",
decompressed.len()
));
let strings = extract_strings(&decompressed);
for s in strings.iter().take(15) {
out.push_str(&format!(" {MAGENTA}{s}{NC}\n"));
@@ -393,7 +483,11 @@ fn render_body(data: &[u8], total_len: usize) -> String {
// Protobuf / binary with string extraction
let strings = extract_strings(data);
if !strings.is_empty() {
let kind = if !data.is_empty() && matches!(data[0], 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22) {
let kind = if !data.is_empty()
&& matches!(
data[0],
0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22
) {
"protobuf"
} else {
"binary"
@@ -448,7 +542,9 @@ fn extract_header(line: &str, pattern: &str) -> Option<(String, String)> {
fn extract_stream_id(line: &str) -> Option<String> {
let pos = line.find("stream=")?;
let rest = &line[pos + 7..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
let end = rest
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(rest.len());
Some(rest[..end].to_string())
}
@@ -470,7 +566,9 @@ fn extract_data(line: &str) -> Option<String> {
fn extract_data_len(line: &str) -> Option<usize> {
let pos = line.find("len=")?;
let rest = &line[pos + 4..];
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
let end = rest
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(rest.len());
rest[..end].parse().ok()
}
@@ -482,17 +580,40 @@ fn decode_go_escaped(s: &str) -> Vec<u8> {
if bytes[i] == b'\\' && i + 1 < bytes.len() {
match bytes[i + 1] {
b'x' if i + 3 < bytes.len() => {
if let Ok(b) = u8::from_str_radix(std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), 16) {
if let Ok(b) = u8::from_str_radix(
std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""),
16,
) {
result.push(b);
i += 4;
continue;
}
}
b'n' => { result.push(b'\n'); i += 2; continue; }
b'r' => { result.push(b'\r'); i += 2; continue; }
b't' => { result.push(b'\t'); i += 2; continue; }
b'\\' => { result.push(b'\\'); i += 2; continue; }
b'"' => { result.push(b'"'); i += 2; continue; }
b'n' => {
result.push(b'\n');
i += 2;
continue;
}
b'r' => {
result.push(b'\r');
i += 2;
continue;
}
b't' => {
result.push(b'\t');
i += 2;
continue;
}
b'\\' => {
result.push(b'\\');
i += 2;
continue;
}
b'"' => {
result.push(b'"');
i += 2;
continue;
}
_ => {}
}
}
@@ -562,7 +683,10 @@ pub fn run_cli() {
})
} else {
let mut buf = String::new();
io::stdin().lock().read_to_string(&mut buf).expect("Failed to read stdin");
io::stdin()
.lock()
.read_to_string(&mut buf)
.expect("Failed to read stdin");
buf
};

View File

@@ -108,7 +108,10 @@ pub struct MainLSConfig {
/// and CSRF is a random UUID.
pub fn generate_standalone_config() -> MainLSConfig {
let csrf = Uuid::new_v4().to_string();
info!(csrf_len = csrf.len(), "Generated standalone config (headless)");
info!(
csrf_len = csrf.len(),
"Generated standalone config (headless)"
);
MainLSConfig {
extension_server_port: "0".to_string(), // disables extension server
csrf,
@@ -159,7 +162,13 @@ impl StandaloneLS {
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
let annotations_dir = format!("{app_data_dir}/annotations");
let brain_dir = format!("{app_data_dir}/brain");
for dir in [DATA_DIR, &gemini_dir, &app_data_dir, &annotations_dir, &brain_dir] {
for dir in [
DATA_DIR,
&gemini_dir,
&app_data_dir,
&annotations_dir,
&brain_dir,
] {
let _ = std::fs::create_dir_all(dir);
#[cfg(unix)]
{
@@ -194,7 +203,10 @@ impl StandaloneLS {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(&settings_path, std::fs::Permissions::from_mode(0o0666));
let _ = std::fs::set_permissions(
&settings_path,
std::fs::Permissions::from_mode(0o0666),
);
}
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
}
@@ -203,10 +215,7 @@ impl StandaloneLS {
// The LS connects to this port and calls LanguageServerStarted — without it,
// the LS never fully initializes and won't accept connections on its server_port.
let _stub_listener = if headless {
let stub_port: u16 = main_config
.extension_server_port
.parse()
.unwrap_or(0);
let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
if stub_port == 0 {
// Create a real listener so the LS can connect
let listener = TcpListener::bind("127.0.0.1:0")
@@ -215,7 +224,10 @@ impl StandaloneLS {
.local_addr()
.map_err(|e| format!("Failed to get stub port: {e}"))?
.port();
info!(port = actual_port, "Stub extension server listening (headless)");
info!(
port = actual_port,
"Stub extension server listening (headless)"
);
// Read OAuth state from Antigravity's state.vscdb if available.
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
@@ -306,10 +318,7 @@ impl StandaloneLS {
// 3. MITM proxy intercepts the transparent TLS connection via SNI
if let Some(mitm) = mitm_config {
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
let mitm_port = mitm.proxy_addr
.rsplit(':')
.next()
.unwrap_or("8742");
let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
} else {
"https://daily-cloudcode-pa.googleapis.com".to_string()
@@ -324,9 +333,8 @@ impl StandaloneLS {
debug!(?args, "LS args");
// Build env vars for the LS process
let mut env_vars: Vec<(String, String)> = vec![
("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into()),
];
let mut env_vars: Vec<(String, String)> =
vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
// If MITM is enabled, add SSL + proxy env vars
if let Some(mitm) = mitm_config {
@@ -335,8 +343,8 @@ impl StandaloneLS {
// Write to /tmp — accessible by antigravity-ls user
// (user's ~/.config/ is not traversable by other UIDs)
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
let system_ca = std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt")
.unwrap_or_default();
let system_ca =
std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
@@ -441,7 +449,11 @@ impl StandaloneLS {
};
if let Some(pid) = ls_pid {
info!(ls_pid = pid, sudo = use_sudo, "Discovered actual LS process");
info!(
ls_pid = pid,
sudo = use_sudo,
"Discovered actual LS process"
);
}
Ok(StandaloneLS {
@@ -617,8 +629,7 @@ fn find_main_ls_pid() -> Result<String, String> {
return Err("No /proc filesystem".to_string());
}
let entries = std::fs::read_dir(proc)
.map_err(|e| format!("Cannot read /proc: {e}"))?;
let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
for entry in entries.flatten() {
let name = entry.file_name();
@@ -704,12 +715,10 @@ fn cleanup_orphaned_ls() {
.output();
let pids: Vec<u32> = match output {
Ok(out) => {
String::from_utf8_lossy(&out.stdout)
.lines()
.filter_map(|l| l.trim().parse().ok())
.collect()
}
Ok(out) => String::from_utf8_lossy(&out.stdout)
.lines()
.filter_map(|l| l.trim().parse().ok())
.collect(),
Err(_) => return,
};
@@ -717,7 +726,11 @@ fn cleanup_orphaned_ls() {
return;
}
info!(count = pids.len(), ?pids, "Cleaning up orphaned standalone LS processes");
info!(
count = pids.len(),
?pids,
"Cleaning up orphaned standalone LS processes"
);
// Kill each PID by running `kill` AS the antigravity-ls user.
// This works because same-UID processes can signal each other,
@@ -870,7 +883,8 @@ fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
// Simple approach: convert to string and find base64 pattern
let as_str = String::from_utf8_lossy(topic_bytes);
// The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited)
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') {
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
{
if segment.len() > 50 {
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
// Try to extract field 1 (access_token) from the OAuthTokenInfo proto
@@ -951,7 +965,11 @@ fn decode_varint_at(buf: &[u8], offset: usize) -> Option<(u64, usize)> {
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
/// If we immediately close it, the LS reconnects in a tight loop and never
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_topic_bytes: &Option<Vec<u8>>) {
fn stub_handle_connection(
conn: std::net::TcpStream,
oauth_token: &str,
oauth_topic_bytes: &Option<Vec<u8>>,
) {
use std::io::{BufRead, BufReader, Read, Write};
let mut reader = BufReader::new(match conn.try_clone() {
@@ -1028,7 +1046,7 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
i += 1;
if i + len <= proto_body.len() {
if field_num == 1 {
topic_name = String::from_utf8_lossy(&proto_body[i..i+len]).to_string();
topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
}
i += len;
} else {
@@ -1084,7 +1102,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
// This includes access_token + refresh_token + expiry, so the
// LS can auto-refresh tokens via its built-in Google OAuth2 client.
initial_state_bytes = topic_bytes.clone();
eprintln!("[stub-ext] using state.vscdb topic ({} bytes)", topic_bytes.len());
eprintln!(
"[stub-ext] using state.vscdb topic ({} bytes)",
topic_bytes.len()
);
} else if !oauth_token.is_empty() {
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
// (no refresh_token, so the LS can't auto-refresh)
@@ -1155,7 +1176,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
if !send_chunk(&mut writer, &initial_env) {
return;
}
eprintln!("[stub-ext] STREAM → sent initial_state ({} bytes)", initial_state_bytes.len());
eprintln!(
"[stub-ext] STREAM → sent initial_state ({} bytes)",
initial_state_bytes.len()
);
// (applied_update removed — data is in initial_state)
@@ -1197,7 +1221,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
if !oauth_token.is_empty() {
// Build protobuf: GetSecretValueResponse { string value = 1 }
let proto = encode_proto_string(1, oauth_token.as_bytes());
eprintln!("[stub-ext] → serving token ({} bytes) for key={key:?}", oauth_token.len());
eprintln!(
"[stub-ext] → serving token ({} bytes) for key={key:?}",
oauth_token.len()
);
// Data envelope: flag=0x00, length, data
envelope.push(0x00u8);

View File

@@ -34,7 +34,9 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
)
.await
{
Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"),
Ok(Ok((status, _))) => {
info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}")
}
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
Err(_) => warn!("SetUserSettings timed out"),
}
@@ -59,12 +61,7 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
for (method, body) in calls {
// Timeout per call — in headless mode, the LS can't reach Google's API
// so these would hang forever without a timeout. Warmup is best-effort.
match tokio::time::timeout(
Duration::from_secs(5),
backend.call_json(method, body),
)
.await
{
match tokio::time::timeout(Duration::from_secs(5), backend.call_json(method, body)).await {
Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"),
Ok(Err(e)) => warn!("Warmup {method} failed: {e}"),
Err(_) => warn!("Warmup {method} timed out"),
@@ -87,10 +84,7 @@ pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
match backend
.call_json("Heartbeat", &serde_json::json!({}))
.await
{
match backend.call_json("Heartbeat", &serde_json::json!({})).await {
Ok((status, _)) => debug!("Heartbeat: {status}"),
Err(e) => warn!("Heartbeat failed: {e}"),
}