refactor: decompose large functions and remove dead code

- Decompose modify_request() into 7 single-responsibility helpers
- Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events
- Promote connect_upstream/resolve_upstream to module-level functions
- Split standalone.rs (1238 lines) into 4 submodules:
  standalone/mod.rs, spawn.rs, discovery.rs, stub.rs
- Extract proto wire primitives into proto/wire.rs
- Remove 6 dead MitmStore methods
- Remove dead SessionResult, DEFAULT_SESSION, get_or_create
- Remove dead decode_varint_at, extract_conversation_id
- Clean all unused imports across 10 files
- Suppress structural dead_code warnings on deserialization fields

Warnings: 20 -> 0. All 43 tests pass.
This commit is contained in:
Nikketryhard
2026-02-17 22:27:26 -06:00
parent 637fbc0e54
commit 48674f65da
21 changed files with 3099 additions and 3346 deletions

View File

@@ -18,15 +18,9 @@ use super::util::{err_response, now_unix, upstream_err_response};
use super::AppState; use super::AppState;
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound}; use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
/// Extract a conversation/session ID from a flexible JSON value.
/// Accepts a plain string or an object with an "id" field.
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
/// System fingerprint for completions responses (derived from crate version at compile time). /// System fingerprint for completions responses (derived from crate version at compile time).
fn system_fingerprint() -> String { fn system_fingerprint() -> String {
@@ -187,10 +181,7 @@ pub(crate) async fn handle_completions(
model_name, body.stream model_name, body.stream
); );
// Diagnostic: dump OpenCode's raw request
if let Ok(pretty) = serde_json::to_string_pretty(&body) {
let _ = std::fs::write("/tmp/opencode-request.json", &pretty);
}
let model = match lookup_model(model_name) { let model = match lookup_model(model_name) {
Some(m) => m, Some(m) => m,
@@ -204,35 +195,28 @@ pub(crate) async fn handle_completions(
} }
}; };
// Store client tools from this request (or clear stale ones from other endpoints) // ── Build per-request state locally ──────────────────────────────────
if let Some(ref tools) = body.tools {
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(tools); // Convert OpenAI tools to Gemini format
if !gemini_tools.is_empty() { let tools = body.tools.as_ref().and_then(|t| {
state.mitm_store.set_tools(gemini_tools).await; let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(t);
if let Some(ref choice) = body.tool_choice { if gemini_tools.is_empty() { None } else {
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice); info!(count = t.len(), "Completions: client tools for MITM injection");
state.mitm_store.set_tool_config(gemini_config).await; Some(gemini_tools)
}
info!(
count = tools.len(),
"Completions: stored client tools for MITM injection"
);
} else {
state.mitm_store.clear_tools().await;
}
} else {
state.mitm_store.clear_tools().await;
} }
});
let tool_config = body.tools.as_ref().and_then(|_| {
body.tool_choice.as_ref().map(|choice| {
crate::mitm::modify::openai_tool_choice_to_gemini(choice)
})
});
// ── Extract tool results from messages for MITM injection ────────── // ── Extract tool results from messages for MITM injection ──────────
// When OpenCode sends back tool results, the messages array contains: // Build ToolRounds from message history: each round pairs assistant tool_calls
// 1. assistant message with tool_calls (the model's previous function calls) // with subsequent tool result messages. Local call_id_to_name mapping.
// 2. tool messages with results (the executed tool outputs) let mut tool_rounds: Vec<ToolRound> = Vec::new();
// We build ToolRounds: each round pairs one assistant's tool_calls with let mut call_id_to_name: std::collections::HashMap<String, String> = std::collections::HashMap::new();
// the subsequent tool result messages. This enables correct per-turn
// history rewriting for multi-step tool use.
{ {
let mut rounds: Vec<ToolRound> = Vec::new();
let mut current_round: Option<ToolRound> = None; let mut current_round: Option<ToolRound> = None;
for msg in &body.messages { for msg in &body.messages {
@@ -241,7 +225,7 @@ pub(crate) async fn handle_completions(
// Finalize any open round // Finalize any open round
if let Some(round) = current_round.take() { if let Some(round) = current_round.take() {
if !round.calls.is_empty() { if !round.calls.is_empty() {
rounds.push(round); tool_rounds.push(round);
} }
} }
// Start new round if this assistant has tool_calls // Start new round if this assistant has tool_calls
@@ -255,14 +239,15 @@ pub(crate) async fn handle_completions(
.unwrap_or(serde_json::json!({})); .unwrap_or(serde_json::json!({}));
let call_id = tc["id"].as_str().unwrap_or("").to_string(); let call_id = tc["id"].as_str().unwrap_or("").to_string();
// Register call_id → name for lookup // Register call_id → name locally
if !call_id.is_empty() { if !call_id.is_empty() {
state.mitm_store.register_call_id(call_id, name.clone()).await; call_id_to_name.insert(call_id, name.clone());
} }
calls.push(CapturedFunctionCall { calls.push(CapturedFunctionCall {
name, name,
args, args,
thought_signature: None,
captured_at: std::time::SystemTime::now() captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
@@ -281,16 +266,13 @@ pub(crate) async fn handle_completions(
"tool" => { "tool" => {
let text = extract_message_text(&msg.content); let text = extract_message_text(&msg.content);
if let Some(ref call_id) = msg.tool_call_id { if let Some(ref call_id) = msg.tool_call_id {
// Look up function name from call_id, fall back to
// positional index within the current round's calls
let result_index = current_round let result_index = current_round
.as_ref() .as_ref()
.map(|r| r.results.len()) .map(|r| r.results.len())
.unwrap_or(0); .unwrap_or(0);
let name = state let name = call_id_to_name
.mitm_store .get(call_id.as_str())
.lookup_call_id(call_id) .cloned()
.await
.unwrap_or_else(|| { .unwrap_or_else(|| {
current_round current_round
.as_ref() .as_ref()
@@ -314,7 +296,7 @@ pub(crate) async fn handle_completions(
// Any other role (user, system) finalizes the current round // Any other role (user, system) finalizes the current round
if let Some(round) = current_round.take() { if let Some(round) = current_round.take() {
if !round.calls.is_empty() { if !round.calls.is_empty() {
rounds.push(round); tool_rounds.push(round);
} }
} }
} }
@@ -323,27 +305,46 @@ pub(crate) async fn handle_completions(
// Finalize last round // Finalize last round
if let Some(round) = current_round.take() { if let Some(round) = current_round.take() {
if !round.calls.is_empty() { if !round.calls.is_empty() {
rounds.push(round); tool_rounds.push(round);
} }
} }
if !rounds.is_empty() { if !tool_rounds.is_empty() {
info!( info!(
round_count = rounds.len(), round_count = tool_rounds.len(),
calls = ?rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(), calls = ?tool_rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
"Completions: stored {} tool round(s) for MITM history rewrite", "Completions: {} tool round(s) for MITM history rewrite",
rounds.len(), tool_rounds.len(),
); );
// Also set last_function_calls from the latest round for proxy.rs recording compat
if let Some(last_round) = rounds.last() { // Merge thought_signatures from MITM-captured function calls.
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await; // OpenAI format doesn't carry thought signatures, but Google requires
// them when injecting functionCall parts back into history.
let sigs = state.mitm_store.peek_thought_signatures().await;
if !sigs.is_empty() {
let mut merged = 0usize;
for round in &mut tool_rounds {
for fc in &mut round.calls {
if fc.thought_signature.is_none() {
if let Some(sig) = sigs.get(&fc.name) {
fc.thought_signature = Some(sig.clone());
merged += 1;
}
}
}
}
if merged > 0 {
info!(
merged_count = merged,
"Completions: merged {} thought_signature(s) from MITM capture",
merged,
);
}
} }
state.mitm_store.set_tool_rounds(rounds).await;
} }
} }
// Store generation parameters for MITM injection // Build generation parameters locally
{
use crate::mitm::store::GenerationParams; use crate::mitm::store::GenerationParams;
let (response_mime_type, response_schema) = match body.response_format.as_ref() { let (response_mime_type, response_schema) = match body.response_format.as_ref() {
Some(rf) => match rf.format_type.as_str() { Some(rf) => match rf.format_type.as_str() {
@@ -359,7 +360,7 @@ pub(crate) async fn handle_completions(
let gp = GenerationParams { let gp = GenerationParams {
temperature: body.temperature, temperature: body.temperature,
top_p: body.top_p, top_p: body.top_p,
top_k: None, // OpenAI doesn't have top_k top_k: None,
max_output_tokens: body.max_tokens.or(body.max_completion_tokens), max_output_tokens: body.max_tokens.or(body.max_completion_tokens),
stop_sequences: body.stop.clone().map(|s| s.into_vec()), stop_sequences: body.stop.clone().map(|s| s.into_vec()),
frequency_penalty: body.frequency_penalty, frequency_penalty: body.frequency_penalty,
@@ -369,8 +370,7 @@ pub(crate) async fn handle_completions(
response_schema, response_schema,
google_search: body.web_search, google_search: body.web_search,
}; };
// Only store if at least one param is set let generation_params = if gp.temperature.is_some()
if gp.temperature.is_some()
|| gp.top_p.is_some() || gp.top_p.is_some()
|| gp.max_output_tokens.is_some() || gp.max_output_tokens.is_some()
|| gp.frequency_penalty.is_some() || gp.frequency_penalty.is_some()
@@ -381,11 +381,10 @@ pub(crate) async fn handle_completions(
|| gp.response_schema.is_some() || gp.response_schema.is_some()
|| gp.google_search || gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; Some(gp)
} else { } else {
state.mitm_store.clear_generation_params().await; None
} };
}
let token = state.backend.oauth_token().await; let token = state.backend.oauth_token().await;
if token.is_empty() { if token.is_empty() {
@@ -410,23 +409,8 @@ pub(crate) async fn handle_completions(
warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n"); warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n");
} }
// Session/conversation: reuse cascade if conversation ID provided // Always create a new cascade for every request
let session_id_str = extract_conversation_id(&body.conversation); let cascade_id = match state.backend.create_cascade().await {
// Helper to create a cascade (reuses session or creates fresh)
let create_cascade = |state: Arc<AppState>, session_id: Option<String>| async move {
if let Some(ref sid) = session_id {
state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
.map(|sr| sr.cascade_id)
} else {
state.backend.create_cascade().await
}
};
let cascade_id = match create_cascade(Arc::clone(&state), session_id_str.clone()).await {
Ok(cid) => cid, Ok(cid) => cid,
Err(e) => { Err(e) => {
return err_response( return err_response(
@@ -437,40 +421,54 @@ pub(crate) async fn handle_completions(
} }
}; };
// Send message on primary cascade // Image for MITM injection
state.mitm_store.set_active_cascade(&cascade_id).await; let pending_image = image.as_ref().map(|img| {
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state crate::mitm::store::PendingImage {
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
})
.await;
} }
});
// Pre-flight: install channel BEFORE send_message so the MITM proxy // Get last calls from the latest tool round (if any) for proxy recording compat
// can grab it when the LS fires its API call. let last_function_calls = tool_rounds.last()
// Only for streaming — sync paths use poll_for_response (legacy store). .map(|r| r.calls.clone())
let has_custom_tools = state.mitm_store.get_tools().await.is_some(); .unwrap_or_default();
let mitm_rx = if has_custom_tools && body.stream {
state.mitm_store.clear_response_async().await; // Build event channel for streaming
state.mitm_store.clear_upstream_error().await; let has_custom_tools = tools.is_some();
let _ = state.mitm_store.take_any_function_calls().await; let (mitm_rx, event_tx) = if has_custom_tools && body.stream {
let (tx, rx) = tokio::sync::mpsc::channel(64); let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await; (Some(rx), Some(tx))
Some(rx)
} else { } else {
None (None, None)
}; };
// Build pending tool results from latest round
let pending_tool_results = tool_rounds.last()
.map(|r| r.results.clone())
.unwrap_or_default();
// Register all per-request state atomically
state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(),
event_channel: event_tx,
generation_params,
pending_image,
tools,
tool_config,
pending_tool_results,
tool_rounds,
last_function_calls,
call_id_to_name,
created_at: std::time::Instant::now(),
}).await;
// Send REAL user text to LS
match state match state
.backend .backend
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref()) .send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -481,7 +479,7 @@ pub(crate) async fn handle_completions(
}); });
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"), format!("Backend returned {status}"),
@@ -489,7 +487,7 @@ pub(crate) async fn handle_completions(
); );
} }
Err(e) => { Err(e) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"), format!("Send failed: {e}"),
@@ -537,7 +535,7 @@ pub(crate) async fn handle_completions(
// Send the same message on each extra cascade // Send the same message on each extra cascade
match state match state
.backend .backend
.send_message_with_image(&cid, ".", model.model_enum, image.as_ref()) .send_message_with_image(&cid, &format!(".<cid:{}>", cid), model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -775,7 +773,7 @@ async fn chat_completions_stream(
))); )));
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::ResponseComplete => { MitmEvent::ResponseComplete => {
@@ -803,15 +801,15 @@ async fn chat_completions_stream(
))); )));
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} else if !acc_thinking.is_empty() && !did_unblock_ls { } else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls. // Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate. // Create a new channel and unblock the gate.
did_unblock_ls = true; did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await; state.mitm_store.set_channel(&cascade_id, new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await; let _ = state.mitm_store.take_any_function_calls().await;
*rx = new_rx; *rx = new_rx;
debug!( debug!(
@@ -845,7 +843,7 @@ async fn chat_completions_stream(
))); )));
} }
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
// Don't break — wait for more channel events // Don't break — wait for more channel events
@@ -861,7 +859,7 @@ async fn chat_completions_stream(
None, None,
))); )));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
continue 'channel_loop; continue 'channel_loop;
@@ -878,7 +876,7 @@ async fn chat_completions_stream(
} }
})).unwrap())); })).unwrap()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::Usage(u) => { MitmEvent::Usage(u) => {
@@ -891,7 +889,7 @@ async fn chat_completions_stream(
} }
// Channel closed or timeout — clean up // Channel closed or timeout — clean up
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
// If we got here from timeout with content, emit what we have // If we got here from timeout with content, emit what we have
if !last_text.is_empty() || last_thinking_len > 0 { if !last_text.is_empty() || last_thinking_len > 0 {
@@ -1026,7 +1024,7 @@ async fn chat_completions_stream(
} }
})).unwrap())); })).unwrap()));
// Always clear in-flight flag when stream ends // Always clear in-flight flag when stream ends
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
}; };

View File

@@ -16,7 +16,7 @@ use axum::{
}; };
use rand::Rng; use rand::Rng;
use std::sync::Arc; use std::sync::Arc;
use tracing::{debug, info, warn}; use tracing::{debug, info};
use super::models::{lookup_model, DEFAULT_MODEL, MODELS}; use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
use super::polling::{ use super::polling::{
@@ -40,6 +40,7 @@ pub(crate) struct GeminiRequest {
pub tool_config: Option<serde_json::Value>, pub tool_config: Option<serde_json::Value>,
/// Session/conversation ID. /// Session/conversation ID.
#[serde(default)] #[serde(default)]
#[allow(dead_code)]
pub conversation: Option<serde_json::Value>, pub conversation: Option<serde_json::Value>,
#[serde(default = "default_timeout")] #[serde(default = "default_timeout")]
pub timeout: u64, pub timeout: u64,
@@ -81,17 +82,8 @@ pub(crate) struct GeminiRequest {
pub response_schema: Option<serde_json::Value>, pub response_schema: Option<serde_json::Value>,
} }
fn default_timeout() -> u64 { use super::util::default_timeout;
120
}
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
/// Build Gemini-format usageMetadata from MITM store. /// Build Gemini-format usageMetadata from MITM store.
async fn build_usage_metadata( async fn build_usage_metadata(
@@ -247,64 +239,37 @@ async fn handle_gemini_inner(
} }
}; };
// Store tools directly in Gemini format (no conversion needed!) // ── Build per-request state locally ──────────────────────────────────
if let Some(ref tools) = body.tools {
if !tools.is_empty() {
state.mitm_store.set_tools(tools.clone()).await;
info!(
count = tools.len(),
"Stored Gemini-native tools for MITM injection"
);
} else {
state.mitm_store.clear_tools().await;
}
} else {
state.mitm_store.clear_tools().await;
}
if let Some(ref config) = body.tool_config {
state.mitm_store.set_tool_config(config.clone()).await;
}
// Handle tool results (Gemini format: functionResponse) // Tools (already in Gemini format)
let tools = body.tools.as_ref().and_then(|t| {
if t.is_empty() { None } else {
info!(count = t.len(), "Gemini-native tools for MITM injection");
Some(t.clone())
}
});
let tool_config = body.tool_config.clone();
// Tool results → collect (ToolRound built after cascade_id is known)
let mut pending_tool_results: Vec<PendingToolResult> = Vec::new();
if let Some(ref results) = body.tool_results { if let Some(ref results) = body.tool_results {
let mut pending: Vec<PendingToolResult> = Vec::new();
for r in results { for r in results {
if let Some(fr) = r.get("functionResponse") { if let Some(fr) = r.get("functionResponse") {
let name = fr["name"].as_str().unwrap_or("unknown").to_string(); let name = fr["name"].as_str().unwrap_or("unknown").to_string();
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({})); let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
// Legacy compat pending_tool_results.push(PendingToolResult {
state
.mitm_store
.add_tool_result(PendingToolResult {
name: name.clone(),
result: response.clone(),
})
.await;
pending.push(PendingToolResult {
name, name,
result: response, result: response,
}); });
} }
} }
if !pending.is_empty() {
// Build a ToolRound from captured function calls + client results.
// Accumulate with existing rounds for multi-round history rewriting.
let last_calls = state.mitm_store.get_last_function_calls().await;
let mut rounds = state.mitm_store.take_tool_rounds().await;
rounds.push(crate::mitm::store::ToolRound {
calls: last_calls,
results: pending,
});
state.mitm_store.set_tool_rounds(rounds).await;
}
info!( info!(
count = results.len(), count = results.len(),
"Stored Gemini-native tool results for MITM injection (built tool round)" "Gemini-native tool results (will build tool round after cascade_id)"
); );
} }
// Store generation parameters for MITM injection // Generation parameters
{
use crate::mitm::store::GenerationParams; use crate::mitm::store::GenerationParams;
let gp = GenerationParams { let gp = GenerationParams {
temperature: body.temperature, temperature: body.temperature,
@@ -319,7 +284,7 @@ async fn handle_gemini_inner(
response_schema: body.response_schema.clone(), response_schema: body.response_schema.clone(),
google_search: body.google_search, google_search: body.google_search,
}; };
if gp.temperature.is_some() let generation_params = if gp.temperature.is_some()
|| gp.top_p.is_some() || gp.top_p.is_some()
|| gp.top_k.is_some() || gp.top_k.is_some()
|| gp.max_output_tokens.is_some() || gp.max_output_tokens.is_some()
@@ -329,31 +294,13 @@ async fn handle_gemini_inner(
|| gp.response_schema.is_some() || gp.response_schema.is_some()
|| gp.google_search || gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; Some(gp)
} else { } else {
state.mitm_store.clear_generation_params().await; None
} };
}
// Session/conversation management // Always create a new cascade for every request
let session_id_str = extract_conversation_id(&body.conversation); let cascade_id = match state.backend.create_cascade().await {
let cascade_id = if let Some(ref sid) = session_id_str {
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(cid) => cid, Ok(cid) => cid,
Err(e) => { Err(e) => {
return err_response( return err_response(
@@ -362,42 +309,57 @@ async fn handle_gemini_inner(
"server_error", "server_error",
); );
} }
}
}; };
// Send message // Image for MITM injection
state.mitm_store.set_active_cascade(&cascade_id).await; let pending_image = image.as_ref().map(|img| {
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state crate::mitm::store::PendingImage {
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
})
.await;
} }
});
// Pre-flight: install channel BEFORE send_message so the MITM proxy // Build event channel for streaming
// can grab it when the LS fires its API call. let has_custom_tools = tools.is_some();
let has_custom_tools = state.mitm_store.get_tools().await.is_some(); let (mitm_rx, event_tx) = if has_custom_tools {
let mitm_rx = if has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64); let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await; (Some(rx), Some(tx))
Some(rx)
} else { } else {
None (None, None)
}; };
// Build tool rounds now that cascade_id is known
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
if !pending_tool_results.is_empty() {
let last_calls = state.mitm_store.take_function_calls(&cascade_id).await
.unwrap_or_default();
tool_rounds.push(crate::mitm::store::ToolRound {
calls: last_calls,
results: pending_tool_results.clone(),
});
}
// Register all per-request state atomically
state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(),
event_channel: event_tx,
generation_params,
pending_image,
tools,
tool_config,
pending_tool_results,
tool_rounds,
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
}).await;
// Send REAL user text to LS (no more dummy ".")
match state match state
.backend .backend
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref()) .send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -408,7 +370,7 @@ async fn handle_gemini_inner(
}); });
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"), format!("Antigravity returned {status}"),
@@ -416,7 +378,7 @@ async fn handle_gemini_inner(
); );
} }
Err(e) => { Err(e) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"), format!("Send message failed: {e}"),
@@ -478,7 +440,7 @@ async fn gemini_sync(
}) })
}) })
.collect(); .collect();
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return Json(serde_json::json!({ return Json(serde_json::json!({
"candidates": [{ "candidates": [{
"content": { "content": {
@@ -500,8 +462,8 @@ async fn gemini_sync(
// Thinking-only — LS needs to make a follow-up request. // Thinking-only — LS needs to make a follow-up request.
// Reinstall channel and unblock gate. // Reinstall channel and unblock gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await; state.mitm_store.set_channel(&cascade_id, new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await; let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx; rx = new_rx;
debug!( debug!(
@@ -515,7 +477,7 @@ async fn gemini_sync(
parts.push(serde_json::json!({"text": t, "thought": true})); parts.push(serde_json::json!({"text": t, "thought": true}));
} }
parts.push(serde_json::json!({"text": acc_text})); parts.push(serde_json::json!({"text": acc_text}));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return Json(serde_json::json!({ return Json(serde_json::json!({
"candidates": [{ "candidates": [{
"content": { "content": {
@@ -530,14 +492,14 @@ async fn gemini_sync(
.into_response(); .into_response();
} }
MitmEvent::UpstreamError(err) => { MitmEvent::UpstreamError(err) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return upstream_err_response(&err); return upstream_err_response(&err);
} }
} }
} }
// Timeout // Timeout
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return ( return (
axum::http::StatusCode::GATEWAY_TIMEOUT, axum::http::StatusCode::GATEWAY_TIMEOUT,
Json(serde_json::json!({ Json(serde_json::json!({
@@ -703,7 +665,7 @@ async fn gemini_stream(
"modelVersion": model_name, "modelVersion": model_name,
})).unwrap_or_default())); })).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::ResponseComplete => { MitmEvent::ResponseComplete => {
@@ -722,15 +684,15 @@ async fn gemini_stream(
"modelVersion": model_name, "modelVersion": model_name,
})).unwrap_or_default())); })).unwrap_or_default()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} else if !last_thinking.is_empty() && !did_unblock_ls { } else if !last_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls. // Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate. // Create a new channel and unblock the gate.
did_unblock_ls = true; did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await; state.mitm_store.set_channel(&cascade_id, new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await; let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx; rx = new_rx;
debug!( debug!(
@@ -752,7 +714,7 @@ async fn gemini_stream(
} }
})).unwrap())); })).unwrap()));
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {} MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {}
@@ -760,7 +722,7 @@ async fn gemini_stream(
} }
// Timeout or channel closed // Timeout or channel closed
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({ yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": { "error": {
"message": format!("Timeout: no response from Google API after {timeout}s"), "message": format!("Timeout: no response from Google API after {timeout}s"),

View File

@@ -5,6 +5,7 @@ mod gemini;
mod models; mod models;
mod polling; mod polling;
mod responses; mod responses;
mod search;
mod types; mod types;
mod util; mod util;
@@ -48,6 +49,8 @@ pub fn router(state: Arc<AppState>) -> Router {
post(gemini::handle_gemini_v1beta), post(gemini::handle_gemini_v1beta),
) )
.route("/v1/models", get(handle_models)) .route("/v1/models", get(handle_models))
.route("/v1/search", get(search::handle_search_get))
.route("/v1/search", post(search::handle_search_post))
.route("/v1/sessions", get(handle_list_sessions)) .route("/v1/sessions", get(handle_list_sessions))
.route("/v1/sessions/{id}", delete(handle_delete_session)) .route("/v1/sessions/{id}", delete(handle_delete_session))
.route("/v1/token", post(handle_set_token)) .route("/v1/token", post(handle_set_token))

View File

@@ -142,14 +142,9 @@ fn extract_responses_input(
(final_text, tool_results, image) (final_text, tool_results, image)
} }
/// Extract conversation/session ID from Responses API `conversation` field.
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
match conv {
Some(serde_json::Value::String(s)) => Some(s.clone()),
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
None => None,
}
}
/// Response-specific data for building a Response object. /// Response-specific data for building a Response object.
struct ResponseData { struct ResponseData {
@@ -241,47 +236,26 @@ pub(crate) async fn handle_responses(
// Handle tool result submission (function_call_output in input) // Handle tool result submission (function_call_output in input)
let is_tool_result_turn = !tool_results.is_empty(); let is_tool_result_turn = !tool_results.is_empty();
if is_tool_result_turn { let mut pending_tool_results: Vec<PendingToolResult> = Vec::new();
let mut pending: Vec<PendingToolResult> = Vec::new();
for tr in &tool_results { if is_tool_result_turn {
// Look up function name from call_id for tr in &tool_results {
let name = state // For tool result turns, we use the call_id as the name directly.
.mitm_store // The proxy captured function calls (with real names) are paired in
.lookup_call_id(&tr.call_id) // the ToolRound when we know the cascade_id later.
.await let name = tr.call_id.clone();
.unwrap_or_else(|| "unknown_function".to_string());
// Parse the output as JSON, fall back to string wrapper
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output) let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
.unwrap_or_else(|_| serde_json::json!({"result": tr.output})); .unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
// Also store as pending (legacy compat) pending_tool_results.push(PendingToolResult {
state
.mitm_store
.add_tool_result(PendingToolResult {
name: name.clone(),
result: result_value.clone(),
})
.await;
pending.push(PendingToolResult {
name, name,
result: result_value, result: result_value,
}); });
} }
// Build a ToolRound from the MITM-captured function calls + client results.
// get_last_function_calls() has the calls from Google's previous response.
// We take existing accumulated rounds and append this new round.
let last_calls = state.mitm_store.get_last_function_calls().await;
let mut rounds = state.mitm_store.take_tool_rounds().await;
rounds.push(crate::mitm::store::ToolRound {
calls: last_calls,
results: pending,
});
state.mitm_store.set_tool_rounds(rounds).await;
info!( info!(
count = tool_results.len(), count = tool_results.len(),
"Stored tool results for MITM injection (built tool round)" "Tool results for MITM injection (will build tool round after cascade_id)"
); );
} }
@@ -293,7 +267,8 @@ pub(crate) async fn handle_responses(
); );
} }
// Store client tools in MitmStore for MITM injection // ── Build per-request state locally ──────────────────────────────────
// Detect web_search_preview tool (OpenAI spec) → enable Google Search grounding // Detect web_search_preview tool (OpenAI spec) → enable Google Search grounding
let has_web_search = body.tools.as_ref().map_or(false, |tools| { let has_web_search = body.tools.as_ref().map_or(false, |tools| {
tools.iter().any(|t| { tools.iter().any(|t| {
@@ -301,27 +276,20 @@ pub(crate) async fn handle_responses(
t_type == "web_search_preview" || t_type == "web_search" t_type == "web_search_preview" || t_type == "web_search"
}) })
}); });
if let Some(ref tools) = body.tools {
let gemini_tools = openai_tools_to_gemini(tools);
if !gemini_tools.is_empty() {
state.mitm_store.set_tools(gemini_tools).await;
info!(
count = tools.len(),
"Stored client tools for MITM injection"
);
} else {
state.mitm_store.clear_tools().await;
}
} else {
state.mitm_store.clear_tools().await;
}
if let Some(ref choice) = body.tool_choice {
let gemini_config = openai_tool_choice_to_gemini(choice);
state.mitm_store.set_tool_config(gemini_config).await;
}
// Store generation parameters for MITM injection // Convert OpenAI tools to Gemini format
// Extract text.format for structured output (json_schema) let tools = body.tools.as_ref().and_then(|t| {
let gemini_tools = openai_tools_to_gemini(t);
if gemini_tools.is_empty() { None } else {
info!(count = t.len(), "Client tools for MITM injection");
Some(gemini_tools)
}
});
let tool_config = body.tool_choice.as_ref().map(|choice| {
openai_tool_choice_to_gemini(choice)
});
// Build generation params locally
let (response_mime_type, response_schema, text_format) = if let Some(ref text_val) = body.text { let (response_mime_type, response_schema, text_format) = if let Some(ref text_val) = body.text {
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text"); let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
if fmt_type == "json_schema" { if fmt_type == "json_schema" {
@@ -345,7 +313,7 @@ pub(crate) async fn handle_responses(
} else { } else {
(None, None, TextFormat::default()) (None, None, TextFormat::default())
}; };
{
use crate::mitm::store::GenerationParams; use crate::mitm::store::GenerationParams;
let gp = GenerationParams { let gp = GenerationParams {
temperature: body.temperature, temperature: body.temperature,
@@ -360,7 +328,7 @@ pub(crate) async fn handle_responses(
response_schema, response_schema,
google_search: has_web_search, google_search: has_web_search,
}; };
if gp.temperature.is_some() let generation_params = if gp.temperature.is_some()
|| gp.top_p.is_some() || gp.top_p.is_some()
|| gp.max_output_tokens.is_some() || gp.max_output_tokens.is_some()
|| gp.reasoning_effort.is_some() || gp.reasoning_effort.is_some()
@@ -368,33 +336,15 @@ pub(crate) async fn handle_responses(
|| gp.response_schema.is_some() || gp.response_schema.is_some()
|| gp.google_search || gp.google_search
{ {
state.mitm_store.set_generation_params(gp).await; Some(gp)
} else { } else {
state.mitm_store.clear_generation_params().await; None
} };
}
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
// Session/conversation management // Always create a new cascade for every request
let session_id_str = extract_conversation_id(&body.conversation); let cascade_id = match state.backend.create_cascade().await {
let cascade_id = if let Some(ref sid) = session_id_str {
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::BAD_GATEWAY,
format!("StartCascade failed: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(cid) => cid, Ok(cid) => cid,
Err(e) => { Err(e) => {
return err_response( return err_response(
@@ -403,42 +353,58 @@ pub(crate) async fn handle_responses(
"server_error", "server_error",
); );
} }
}
}; };
// Send message // Image for MITM injection
state.mitm_store.set_active_cascade(&cascade_id).await; let pending_image = image.as_ref().map(|img| {
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
state crate::mitm::store::PendingImage {
.mitm_store
.set_pending_image(crate::mitm::store::PendingImage {
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data), base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
mime_type: img.mime_type.clone(), mime_type: img.mime_type.clone(),
})
.await;
} }
});
// Pre-flight: install channel BEFORE send_message so the MITM proxy // Build event channel
// can grab it when the LS fires its API call. let has_custom_tools = tools.is_some();
let has_custom_tools = state.mitm_store.get_tools().await.is_some(); let (mitm_rx, event_tx) = if has_custom_tools {
let mitm_rx = if has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64); let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await; (Some(rx), Some(tx))
Some(rx)
} else { } else {
None (None, None)
}; };
// Build tool rounds now that cascade_id is known
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
if is_tool_result_turn && !pending_tool_results.is_empty() {
// Get last captured function calls from the previous request context
let last_calls = state.mitm_store.take_function_calls(&cascade_id).await
.unwrap_or_default();
tool_rounds.push(crate::mitm::store::ToolRound {
calls: last_calls,
results: pending_tool_results.clone(),
});
}
// Register all per-request state atomically
state.mitm_store.register_request(crate::mitm::store::RequestContext {
cascade_id: cascade_id.clone(),
pending_user_text: user_text.clone(),
event_channel: event_tx,
generation_params,
pending_image,
tools,
tool_config,
pending_tool_results,
tool_rounds,
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
}).await;
// Send REAL user text to LS
match state match state
.backend .backend
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref()) .send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -449,7 +415,7 @@ pub(crate) async fn handle_responses(
}); });
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"), format!("Antigravity returned {status}"),
@@ -457,7 +423,7 @@ pub(crate) async fn handle_responses(
); );
} }
Err(e) => { Err(e) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"), format!("Send message failed: {e}"),
@@ -644,7 +610,7 @@ async fn handle_responses_sync(
let mut acc_text = String::new(); let mut acc_text = String::new();
let mut acc_thinking: Option<String> = None; let mut acc_thinking: Option<String> = None;
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None; let mut _last_usage: Option<crate::mitm::store::ApiUsage> = None;
while let Some(event) = tokio::time::timeout( while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())), std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
@@ -654,7 +620,7 @@ async fn handle_responses_sync(
match event { match event {
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); } MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
MitmEvent::TextDelta(t) => { acc_text = t; } MitmEvent::TextDelta(t) => { acc_text = t; }
MitmEvent::Usage(u) => { last_usage = Some(u); } MitmEvent::Usage(u) => { _last_usage = Some(u); }
MitmEvent::Grounding(_) => {} // stored by proxy directly MitmEvent::Grounding(_) => {} // stored by proxy directly
MitmEvent::FunctionCall(raw_calls) => { MitmEvent::FunctionCall(raw_calls) => {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls { let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
@@ -668,14 +634,14 @@ async fn handle_responses_sync(
"call_{}", "call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
); );
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; state.mitm_store.register_call_id(&cascade_id, call_id.clone(), fc.name.clone()).await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments)); output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
} }
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &params.user_text, "", &state.mitm_store, &cascade_id, &None, &params.user_text, "",
).await; ).await;
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -700,8 +666,8 @@ async fn handle_responses_sync(
// Thinking-only — LS needs to make a follow-up request. // Thinking-only — LS needs to make a follow-up request.
// Reinstall channel and unblock gate. // Reinstall channel and unblock gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await; state.mitm_store.set_channel(&cascade_id, new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await; let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx; rx = new_rx;
debug!( debug!(
@@ -713,7 +679,7 @@ async fn handle_responses_sync(
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &params.user_text, &acc_text, &state.mitm_store, &cascade_id, &None, &params.user_text, &acc_text,
).await; ).await;
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
let mut output_items: Vec<serde_json::Value> = Vec::new(); let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = acc_thinking { if let Some(ref t) = acc_thinking {
@@ -738,14 +704,14 @@ async fn handle_responses_sync(
return Json(resp).into_response(); return Json(resp).into_response();
} }
MitmEvent::UpstreamError(err) => { MitmEvent::UpstreamError(err) => {
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return upstream_err_response(&err); return upstream_err_response(&err);
} }
} }
} }
// Timeout // Timeout
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return err_response( return err_response(
StatusCode::GATEWAY_TIMEOUT, StatusCode::GATEWAY_TIMEOUT,
format!("Timeout: no response from Google API after {timeout}s"), format!("Timeout: no response from Google API after {timeout}s"),
@@ -789,7 +755,7 @@ async fn handle_responses_sync(
// Register call_id → name mapping for tool result routing // Register call_id → name mapping for tool result routing
state state
.mitm_store .mitm_store
.register_call_id(call_id.clone(), fc.name.clone()) .register_call_id(&cascade_id, call_id.clone(), fc.name.clone())
.await; .await;
// Stringify args (OpenAI sends arguments as JSON string) // Stringify args (OpenAI sends arguments as JSON string)
@@ -1092,7 +1058,7 @@ async fn handle_responses_stream(
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
); );
let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await; state.mitm_store.register_call_id(&cascade_id, call_id.clone(), fc.name.clone()).await;
let fc_item_id = format!("fc_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let fc_item_id = format!("fc_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
yield Ok(responses_sse_event( yield Ok(responses_sse_event(
@@ -1166,7 +1132,7 @@ async fn handle_responses_stream(
"response": response_to_json(&final_resp), "response": response_to_json(&final_resp),
}), }),
)); ));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::ResponseComplete => { MitmEvent::ResponseComplete => {
@@ -1184,14 +1150,14 @@ async fn handle_responses_stream(
) { ) {
yield Ok(evt); yield Ok(evt);
} }
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} else if !last_thinking.is_empty() { } else if !last_thinking.is_empty() {
// Thinking-only response — LS needs follow-up API calls. // Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate. // Create a new channel and unblock the gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64); let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await; state.mitm_store.set_channel(&cascade_id, new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await; let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx; rx = new_rx;
debug!( debug!(
@@ -1220,7 +1186,7 @@ async fn handle_responses_stream(
}, },
}), }),
)); ));
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
return; return;
} }
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => { MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {
@@ -1230,7 +1196,7 @@ async fn handle_responses_stream(
} }
// Timeout in channel mode // Timeout in channel mode
state.mitm_store.drop_channel().await; state.mitm_store.remove_request(&cascade_id).await;
yield Ok(responses_sse_event( yield Ok(responses_sse_event(
"response.failed", "response.failed",
serde_json::json!({ serde_json::json!({

View File

@@ -33,6 +33,7 @@ pub(crate) struct SearchRequest {
pub timeout: u64, pub timeout: u64,
/// Conversation/session ID for context reuse. /// Conversation/session ID for context reuse.
#[serde(default)] #[serde(default)]
#[allow(dead_code)]
pub conversation: Option<String>, pub conversation: Option<String>,
/// Max output tokens — keep low since we only want grounding metadata. /// Max output tokens — keep low since we only want grounding metadata.
#[serde(default = "default_search_max_tokens")] #[serde(default = "default_search_max_tokens")]
@@ -111,19 +112,13 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
); );
} }
// Enable Google Search grounding via GenerationParams // Build generation params with Google Search grounding enabled
{
use crate::mitm::store::GenerationParams; use crate::mitm::store::GenerationParams;
let gp = GenerationParams { let gp = GenerationParams {
max_output_tokens: Some(body.max_output_tokens), max_output_tokens: Some(body.max_output_tokens),
google_search: true, google_search: true,
..Default::default() ..Default::default()
}; };
state.mitm_store.set_generation_params(gp).await;
}
// Clear any stale tools — we only want googleSearch
state.mitm_store.clear_tools().await;
// Create a prompt that encourages the model to ground its response // Create a prompt that encourages the model to ground its response
let search_prompt = format!( let search_prompt = format!(
@@ -131,26 +126,9 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
body.query body.query
); );
// Session management // Always create a new cascade for every request
let session_id_str = body.conversation.clone(); let cascade_id = match state.backend.create_cascade().await {
let cascade_id = if let Some(ref sid) = session_id_str { Ok(cid) => cid,
match state
.sessions
.get_or_create(Some(sid), || state.backend.create_cascade())
.await
{
Ok(sr) => sr.cascade_id,
Err(e) => {
return err_response(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to create session: {e}"),
"server_error",
);
}
}
} else {
match state.backend.create_cascade().await {
Ok(id) => id,
Err(e) => { Err(e) => {
return err_response( return err_response(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
@@ -158,22 +136,31 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
"server_error", "server_error",
); );
} }
}
}; };
// Set active cascade for MITM correlation // Register per-request state — no tools, just generation params for search grounding
state.mitm_store.set_active_cascade(&cascade_id).await; state.mitm_store.register_request(crate::mitm::store::RequestContext {
// Store real search prompt for MITM injection — LS gets a dummy prompt cascade_id: cascade_id.clone(),
state.mitm_store.set_pending_user_text(search_prompt.clone()).await; pending_user_text: search_prompt.clone(),
event_channel: None,
generation_params: Some(gp),
pending_image: None,
tools: None,
tool_config: None,
pending_tool_results: Vec::new(),
tool_rounds: Vec::new(),
last_function_calls: Vec::new(),
call_id_to_name: std::collections::HashMap::new(),
created_at: std::time::Instant::now(),
}).await;
// Send the search message // Send dot to LS — real search prompt injected by MITM proxy
if let Err(e) = state if let Err(e) = state
.backend .backend
.send_message(&cascade_id, ".", model.model_enum) .send_message(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum)
.await .await
{ {
state.mitm_store.clear_active_cascade().await; state.mitm_store.remove_request(&cascade_id).await;
state.mitm_store.clear_generation_params().await;
return err_response( return err_response(
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
format!("Failed to send search message: {e}"), format!("Failed to send search message: {e}"),
@@ -199,8 +186,7 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
}; };
// Clean up // Clean up
state.mitm_store.clear_active_cascade().await; state.mitm_store.remove_request(&cascade_id).await;
state.mitm_store.clear_generation_params().await;
state.mitm_store.clear_response_async().await; state.mitm_store.clear_response_async().await;
// Build the search response // Build the search response

View File

@@ -17,6 +17,7 @@ pub(crate) struct ResponsesRequest {
pub stream: bool, pub stream: bool,
#[serde(default = "default_timeout")] #[serde(default = "default_timeout")]
pub timeout: u64, pub timeout: u64,
#[allow(dead_code)]
pub conversation: Option<serde_json::Value>, pub conversation: Option<serde_json::Value>,
#[serde(default = "default_true")] #[serde(default = "default_true")]
pub store: bool, pub store: bool,
@@ -189,9 +190,7 @@ pub(crate) struct CompletionMessage {
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
} }
fn default_timeout() -> u64 { use super::util::default_timeout;
120
}
fn default_true() -> bool { fn default_true() -> bool {
true true

View File

@@ -122,10 +122,17 @@ pub(crate) fn now_unix() -> u64 {
.as_secs() .as_secs()
} }
/// Default request timeout in seconds (used by serde defaults).
pub(crate) fn default_timeout() -> u64 {
120
}
pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event { pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event {
Event::default() Event::default()
.event(event_type) .event(event_type)
.data(serde_json::to_string(&data).unwrap()) .data(serde_json::to_string(&data).unwrap_or_default())
} }
// ─── Image extraction ──────────────────────────────────────────────────────── // ─── Image extraction ────────────────────────────────────────────────────────

View File

@@ -412,7 +412,8 @@ impl Backend {
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1")); headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
// Connect protocol envelope: [flags:1][length:4][payload] // Connect protocol envelope: [flags:1][length:4][payload]
let json_bytes = serde_json::to_vec(&body).unwrap(); let json_bytes = serde_json::to_vec(&body)
.map_err(|e| format!("{rpc_method} JSON serialize error: {e}"))?;
let mut envelope = Vec::with_capacity(5 + json_bytes.len()); let mut envelope = Vec::with_capacity(5 + json_bytes.len());
envelope.push(0x00); envelope.push(0x00);
envelope.extend_from_slice(&(json_bytes.len() as u32).to_be_bytes()); envelope.extend_from_slice(&(json_bytes.len() as u32).to_be_bytes());

View File

@@ -129,14 +129,21 @@ impl StreamingAccumulator {
else if let Some(fc) = part.get("functionCall") { else if let Some(fc) = part.get("functionCall") {
let name = fc["name"].as_str().unwrap_or("unknown").to_string(); let name = fc["name"].as_str().unwrap_or("unknown").to_string();
let args = fc["args"].clone(); let args = fc["args"].clone();
// thoughtSignature is a SIBLING of functionCall in the part,
// not nested inside functionCall
let thought_signature = part.get("thoughtSignature")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
info!( info!(
tool_name = %name, tool_name = %name,
tool_args = %args, tool_args = %args,
has_thought_sig = thought_signature.is_some(),
"MITM: Google returned functionCall!" "MITM: Google returned functionCall!"
); );
self.function_calls.push(CapturedFunctionCall { self.function_calls.push(CapturedFunctionCall {
name, name,
args, args,
thought_signature,
captured_at: std::time::SystemTime::now() captured_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default() .unwrap_or_default()
@@ -300,39 +307,48 @@ fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
/// Try to identify a cascade ID from the request body. /// Try to identify a cascade ID from the request body.
/// ///
/// The LS includes cascade-related metadata in its API requests (as part of /// Priority:
/// the system prompt or metadata field). We try to find it. /// 1. `<cid:UUID>` marker embedded by our proxy in the user message content
/// 2. `requestId` field: `agent/{timestamp}/{cascade_uuid}/{sequence}` format
/// 3. `metadata.user_id` fallback
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> { pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
// Fast path: look for <cid:UUID> marker in raw bytes (avoid JSON parse)
let body_str = std::str::from_utf8(request_body).ok()?;
if let Some(start) = body_str.find("<cid:") {
let rest = &body_str[start + 5..];
if let Some(end) = rest.find('>') {
let candidate = &rest[..end];
// Validate UUID format
if candidate.len() == 36
&& candidate.chars().filter(|c| *c == '-').count() == 4
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
{
return Some(candidate.to_string());
}
}
}
let json: Value = serde_json::from_slice(request_body).ok()?; let json: Value = serde_json::from_slice(request_body).ok()?;
// Check for metadata field (some API configurations include it) // Secondary: extract cascade UUID from requestId field
if let Some(metadata) = json.get("metadata") { // Format: "agent/{timestamp}/{cascade_uuid}/{sequence}"
if let Some(user_id) = metadata["user_id"].as_str() { if let Some(request_id) = json.get("requestId").and_then(|v| v.as_str()) {
// The LS often sets user_id to the cascadeId let parts: Vec<&str> = request_id.split('/').collect();
return Some(user_id.to_string()); if parts.len() >= 3 {
let candidate = parts[2];
if candidate.len() == 36
&& candidate.chars().filter(|c| *c == '-').count() == 4
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
{
return Some(candidate.to_string());
}
} }
} }
// Check system prompt for cascade/workspace markers // Fallback: check metadata.user_id
if let Some(system) = json.get("system") { if let Some(metadata) = json.get("metadata") {
let system_str = match system { if let Some(user_id) = metadata["user_id"].as_str() {
Value::String(s) => s.clone(), return Some(user_id.to_string());
Value::Array(arr) => {
// Array of content blocks
arr.iter()
.filter_map(|b| b["text"].as_str())
.collect::<Vec<_>>()
.join(" ")
}
_ => return None,
};
// Look for workspace_id or cascade_id patterns
if let Some(pos) = system_str.find("workspace_id") {
let rest = &system_str[pos..];
// Extract the value after workspace_id
if let Some(val) = rest.split_whitespace().nth(1) {
return Some(val.to_string());
}
} }
} }

File diff suppressed because it is too large Load Diff

View File

@@ -37,6 +37,9 @@ use flate2::read::GzDecoder;
use std::io::Read; use std::io::Read;
use tracing::{debug, trace, warn}; use tracing::{debug, trace, warn};
// Re-import the shared varint decoder under the name used throughout this module
use crate::proto::wire::decode_varint as read_varint;
/// A decoded protobuf field. /// A decoded protobuf field.
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ProtoValue { pub enum ProtoValue {
@@ -260,26 +263,7 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
} }
} }
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0;
for (i, &byte) in data.iter().enumerate() {
if i >= 10 {
return None; // Too many bytes for a varint
}
result |= ((byte & 0x7F) as u64) << shift;
shift += 7;
if byte & 0x80 == 0 {
return Some((result, i + 1));
}
}
None
}
/// Search a decoded protobuf message tree for usage-like structures. /// Search a decoded protobuf message tree for usage-like structures.
/// ///

View File

@@ -383,138 +383,14 @@ async fn handle_http_over_tls(
// Reusable upstream connection — created lazily, reconnected if stale // Reusable upstream connection — created lazily, reconnected if stale
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None; let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
/// Connect (or reconnect) to the real upstream via TLS.
///
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
/// then falls back to cached IPs file, then to normal system resolution.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
// Try to resolve the real IP, bypassing /etc/hosts
let addr = resolve_upstream(domain).await;
info!(domain, addr = %addr, "MITM: connecting upstream");
let tcp = match tokio::time::timeout(
std::time::Duration::from_secs(15),
TcpStream::connect(&addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
};
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
match tokio::time::timeout(
std::time::Duration::from_secs(15),
connector.connect(server_name, tcp),
)
.await
{
Ok(Ok(s)) => {
info!(domain, "MITM: upstream TLS connected ✓");
Ok(s)
}
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
}
}
/// Resolve upstream IP bypassing /etc/hosts.
async fn resolve_upstream(domain: &str) -> String {
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
if let Ok(output) = tokio::process::Command::new("dig")
.args(["+short", "@8.8.8.8", domain])
.output()
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out
.lines()
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
{
return format!("{ip}:443");
}
}
// 2. Try cached IPs file (written by dns-redirect.sh install)
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
for line in contents.lines() {
if let Some((d, ip)) = line.split_once('=') {
if d == domain {
return format!("{ip}:443");
}
}
}
}
// 3. Fallback to normal resolution (may hit /etc/hosts)
format!("{domain}:443")
}
// Keep-alive loop: handle multiple requests on this connection // Keep-alive loop: handle multiple requests on this connection
loop { loop {
// ── Read the HTTP request from the client ───────────────────────── // ── Read the HTTP request from the client ─────────────────────────
let mut request_buf = Vec::with_capacity(1024 * 64); let mut request_buf = match read_full_request(&mut client, &mut tmp, domain).await {
Some(buf) if !buf.is_empty() => buf,
// 60s timeout on initial read (LS may open connection without sending immediately) _ => return Ok(()),
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
loop {
let read_result = if request_buf.is_empty() {
// First read — apply idle timeout
match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await {
Ok(r) => r,
Err(_) => {
// Idle timeout — connection pool warmup, no data sent
debug!(domain, "MITM: client idle timeout (60s), closing");
return Ok(());
}
}
} else {
// Subsequent reads — wait up to 30s for rest of request
match tokio::time::timeout(
std::time::Duration::from_secs(30),
client.read(&mut tmp),
)
.await
{
Ok(r) => r,
Err(_) => {
warn!(domain, "MITM: partial request read timed out");
return Err("Partial request read timed out".into());
}
}
}; };
let n = match read_result {
Ok(0) => return Ok(()), // Client closed connection cleanly
Ok(n) => n,
Err(e) => {
// Connection reset / broken pipe is normal for keep-alive end
debug!(domain, error = %e, "MITM: client read finished");
return Ok(());
}
};
request_buf.extend_from_slice(&tmp[..n]);
// Check if we have the full request (headers + body)
if has_complete_http_request(&request_buf) {
break;
}
}
if request_buf.is_empty() {
return Ok(());
}
// Parse the HTTP request to find headers and body // Parse the HTTP request to find headers and body
let (headers_end, content_length, _is_streaming_request) = let (headers_end, content_length, _is_streaming_request) =
parse_http_request_meta(&request_buf); parse_http_request_meta(&request_buf);
@@ -554,33 +430,10 @@ async fn handle_http_over_tls(
"MITM: forwarding LLM request" "MITM: forwarding LLM request"
); );
// ── Atomic in-flight gate ───────────────────────────────── // ── Per-request context lookup ────────────────────────────
// The LS opens multiple connections and sends parallel requests. // Deferred until we know this is an agent request containing our
// When custom tools are active, only the FIRST request wins the // dummy dot. This prevents LS internal requests (title generation,
// atomic compare_exchange. All others get fake STOP responses. // checkpoints) from stealing the RequestContext.
let has_tools = store.get_tools().await.is_some();
if has_tools {
if !store.try_mark_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/event-stream\r\n\
Transfer-Encoding: chunked\r\n\
\r\n";
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
let mut response = fake_response.as_bytes().to_vec();
response.extend_from_slice(&chunked_body);
if let Err(e) = client.write_all(&response).await {
warn!(error = %e, "MITM: failed to write fake response");
}
let _ = client.flush().await;
continue;
}
// Grab the channel sender — the API handler installed it before
// sending the LS message. If it's gone, we still proceed but
// fall back to legacy store writes.
event_tx = store.take_channel().await;
}
// ── Request modification ───────────────────────────────────── // ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk // Dechunk body → check if agent request → modify → rechunk
@@ -594,33 +447,128 @@ async fn handle_http_over_tls(
|| body_str.contains("\"requestType\": \"agent\""); || body_str.contains("\"requestType\": \"agent\"");
if is_agent { if is_agent {
// Build ToolContext from store // Re-extract cascade_hint from the dechunked (JSON-parseable) body.
let tools = store.get_tools().await; // The chunked transfer encoding body at `request_buf[headers_end..]`
let tool_config = store.get_tool_config().await; // can't be JSON-parsed, but `raw_body` (dechunked) can.
let pending_results = store.take_tool_results().await; let precise_cascade = extract_cascade_hint(&raw_body);
let last_calls = store.get_last_function_calls().await; debug!(
let generation_params = store.get_generation_params().await; cascade = ?precise_cascade,
let pending_image = store.take_pending_image().await; "MITM: cascade from dechunked requestId"
let tool_rounds = store.get_tool_rounds().await; );
let pending_user_text = store.take_pending_user_text().await;
let tool_ctx = if tools.is_some() // Check if ANY user message contains our dummy dot prompt
|| !pending_results.is_empty() // within a <USER_REQUEST> wrapper.
|| !tool_rounds.is_empty() // Only then should we consume the pending RequestContext.
|| generation_params.is_some() // This prevents LS internal requests (title gen, etc.) from
|| pending_image.is_some() // consuming the context meant for the user's actual request.
|| pending_user_text.is_some() // NOTE: We check ALL user messages because the LS appends context
{ // messages AFTER the dot prompt (conversation summaries, etc.).
Some(super::modify::ToolContext { // We look for <USER_REQUEST> + dot specifically to avoid matching
tools, // old <cid:> markers in history (which are in model messages).
tool_config, let contains_our_dot = serde_json::from_slice::<serde_json::Value>(&raw_body)
pending_results, .ok()
last_calls, .and_then(|json| {
generation_params, let contents = json.pointer("/request/contents")?.as_array()?;
pending_image, for msg in contents.iter() {
tool_rounds, let is_user = msg.get("role")
pending_user_text, .and_then(|r| r.as_str())
.map_or(true, |r| r == "user");
if !is_user { continue; }
if let Some(text) = msg.pointer("/parts/0/text").and_then(|v| v.as_str()) {
// Check for dot in <USER_REQUEST> wrapper
if text.contains("<USER_REQUEST>") {
if let (Some(s), Some(e)) = (text.find("<USER_REQUEST>"), text.find("</USER_REQUEST>")) {
let inner = &text[s + 14..e]; // 14 = len("<USER_REQUEST>")
let it = inner.trim();
if it == "." || it.starts_with(".<cid:") {
return Some(true);
}
}
}
// Check for bare dot (no wrapper)
let t = text.trim();
if t == "." || t == ".<cid:" || (t.starts_with(".<cid:") && t.ends_with(">")) {
return Some(true);
}
}
}
Some(false)
}) })
.unwrap_or(false);
// Only take the RequestContext if this request has our dot
let effective_cascade = precise_cascade.or(cascade_hint.clone());
let mut request_ctx: Option<super::store::RequestContext> = if contains_our_dot {
let ctx = if let Some(ref cid) = effective_cascade {
store.take_request(cid).await
} else {
None
};
if ctx.is_some() {
ctx
} else if let Some(ref cid) = effective_cascade {
// Check if this is a subsequent turn (turn 1+) of an
// already-processed cascade. If so, DON'T fall through
// to take_latest_request — that would steal an unrelated
// cascade's context.
if store.has_cascade_cache(cid).await {
debug!(cascade = %cid, "MITM: subsequent turn — using cached context");
None
} else {
// Unknown cascade, try latest fallback
store.take_latest_request().await
}
} else {
store.take_latest_request().await
}
} else {
None
};
// Extract event channel from matched context
if let Some(ref mut ctx) = request_ctx {
event_tx = ctx.event_channel.take();
}
// Build ToolContext from RequestContext (turn 0) or cached
// context (turn 1+). On turn 0, we also cache the context
// for subsequent turns.
let tool_ctx = if let Some(ctx) = request_ctx.take() {
// Turn 0: cache context for subsequent turns
if let Some(ref cid) = effective_cascade {
store.cache_cascade(cid, super::store::CascadeCache {
user_text: ctx.pending_user_text.clone(),
tools: ctx.tools.clone(),
tool_config: ctx.tool_config.clone(),
generation_params: ctx.generation_params.clone(),
}).await;
}
Some(super::modify::ToolContext {
pending_user_text: ctx.pending_user_text,
tools: ctx.tools,
tool_config: ctx.tool_config,
pending_results: ctx.pending_tool_results,
last_calls: ctx.last_function_calls,
generation_params: ctx.generation_params,
pending_image: ctx.pending_image,
tool_rounds: ctx.tool_rounds,
})
} else if let Some(ref cid) = effective_cascade {
// Turn 1+: rebuild lite ToolContext from cache
if let Some(cached) = store.get_cascade_cache(cid).await {
Some(super::modify::ToolContext {
pending_user_text: cached.user_text,
tools: cached.tools,
tool_config: cached.tool_config,
pending_results: vec![],
last_calls: vec![],
generation_params: cached.generation_params,
pending_image: None,
tool_rounds: vec![],
})
} else {
None
}
} else { } else {
None None
}; };
@@ -637,8 +585,6 @@ async fn handle_http_over_tls(
let mut new_buf = updated_headers.into_bytes(); let mut new_buf = updated_headers.into_bytes();
new_buf.extend_from_slice(&new_chunked); new_buf.extend_from_slice(&new_chunked);
request_buf = new_buf; request_buf = new_buf;
// In-flight already marked atomically above
} }
} }
} }
@@ -677,6 +623,7 @@ async fn handle_http_over_tls(
// ALWAYS forward data to client immediately (no buffering). // ALWAYS forward data to client immediately (no buffering).
// Buffer body on the side for usage parsing. // Buffer body on the side for usage parsing.
let mut streaming_acc = StreamingAccumulator::new(); let mut streaming_acc = StreamingAccumulator::new();
let mut response_rewriter: Option<super::modify::ResponseRewriter> = None;
let mut is_streaming_response = false; let mut is_streaming_response = false;
let mut headers_parsed = false; let mut headers_parsed = false;
let mut upstream_ok = true; let mut upstream_ok = true;
@@ -737,6 +684,10 @@ async fn handle_http_over_tls(
content_type = v.to_string(); content_type = v.to_string();
if v.contains("text/event-stream") { if v.contains("text/event-stream") {
is_streaming_response = true; is_streaming_response = true;
// Lazily initialize the response rewriter for SSE streams
if modify_requests {
response_rewriter = Some(super::modify::ResponseRewriter::new());
}
} }
} }
} }
@@ -802,11 +753,11 @@ async fn handle_http_over_tls(
message, message,
error_status, error_status,
}; };
// Send through channel if available, otherwise store for legacy consumers // Send through channel if available
if let Some(ref tx) = event_tx { if let Some(ref tx) = event_tx {
let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await; let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await;
} else { } else {
store.set_upstream_error(upstream_err).await; warn!("MITM: upstream error but no channel to forward it");
} }
} }
@@ -817,77 +768,21 @@ async fn handle_http_over_tls(
if is_streaming_response && hdr_end < header_buf.len() { if is_streaming_response && hdr_end < header_buf.len() {
let body = String::from_utf8_lossy(&header_buf[hdr_end..]); let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc); parse_streaming_chunk(&body, &mut streaming_acc);
dispatch_stream_events(&mut streaming_acc, &event_tx, &store, cascade_hint.as_deref()).await;
// Send events through channel if available, otherwise use legacy store
if let Some(ref tx) = event_tx {
// Function calls → channel event
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
store.set_last_function_calls(calls.clone()).await;
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
info!("MITM: sending {} function call(s) via channel (initial body)", calls.len());
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
}
// Thinking delta → channel event
if !streaming_acc.thinking_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
}
// Text delta → channel event
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
}
// Response complete → channel event
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
"MITM: response complete (initial body) — sending via channel"
);
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers (search, etc.)
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
}
} }
// Forward to client — rewrite function calls if custom tools are injected // Forward to client — rewrite function calls if custom tools are injected
let forward_buf = if modify_requests { let forward_buf = if let Some(ref mut rewriter) = response_rewriter {
if let Some(modified) = super::modify::modify_response_chunk(&header_buf) { rewriter.feed(&header_buf)
modified
} else {
header_buf.clone()
}
} else { } else {
header_buf.clone() header_buf.clone()
}; };
if !forward_buf.is_empty() {
if let Err(e) = client.write_all(&forward_buf).await { if let Err(e) = client.write_all(&forward_buf).await {
warn!(error = %e, "MITM: write to client failed"); warn!(error = %e, "MITM: write to client failed");
break; break;
} }
}
if let Some(cl) = response_content_length { if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { if response_body_buf.len() >= cl {
@@ -908,80 +803,24 @@ async fn handle_http_over_tls(
if is_streaming_response { if is_streaming_response {
let s = String::from_utf8_lossy(chunk); let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc); parse_streaming_chunk(&s, &mut streaming_acc);
dispatch_stream_events(&mut streaming_acc, &event_tx, &store, cascade_hint.as_deref()).await;
// Send events through channel if available, otherwise use legacy store
if let Some(ref tx) = event_tx {
// Function calls → channel event
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
store.set_last_function_calls(calls.clone()).await;
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len());
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
}
// Thinking delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.thinking_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
}
// Text delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
}
// Response complete → channel event
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
function_calls = streaming_acc.function_calls.len(),
"MITM: response complete — sending via channel"
);
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
}
} }
// Forward chunk to client (LS) — rewrite function calls if custom tools // Forward chunk to client (LS) — rewrite function calls if custom tools
let forward_chunk = if modify_requests { let forward_chunk = if let Some(ref mut rewriter) = response_rewriter {
if let Some(modified) = super::modify::modify_response_chunk(chunk) { rewriter.feed(chunk)
modified
} else {
chunk.to_vec()
}
} else { } else {
chunk.to_vec() chunk.to_vec()
}; };
if !forward_chunk.is_empty() {
if let Err(e) = client.write_all(&forward_chunk).await { if let Err(e) = client.write_all(&forward_chunk).await {
warn!(error = %e, "MITM: write to client failed"); warn!(error = %e, "MITM: write to client failed");
break; break;
} }
}
response_body_buf.extend_from_slice(chunk); response_body_buf.extend_from_slice(chunk);
if let Some(cl) = response_content_length { if let Some(cl) = response_content_length {
if response_body_buf.len() >= cl { if response_body_buf.len() >= cl {
break; break;
@@ -992,6 +831,13 @@ async fn handle_http_over_tls(
break; break;
} }
} }
// Flush any remaining buffered response data through the rewriter
if let Some(ref mut rewriter) = response_rewriter {
let remaining = rewriter.flush();
if !remaining.is_empty() {
let _ = client.write_all(&remaining).await;
}
}
// Flush client // Flush client
let _ = client.flush().await; let _ = client.flush().await;
@@ -1023,6 +869,176 @@ async fn handle_http_over_tls(
} // end keep-alive loop } // end keep-alive loop
} }
/// Read a complete HTTP request from the client with idle/partial timeouts.
///
/// Returns `Some(buf)` on success, `None` if the client closed cleanly or timed out.
async fn read_full_request(
client: &mut tokio_rustls::server::TlsStream<TcpStream>,
tmp: &mut [u8],
domain: &str,
) -> Option<Vec<u8>> {
let mut buf = Vec::with_capacity(1024 * 64);
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
loop {
let read_result = if buf.is_empty() {
match tokio::time::timeout(IDLE_TIMEOUT, client.read(tmp)).await {
Ok(r) => r,
Err(_) => {
debug!(domain, "MITM: client idle timeout (60s), closing");
return None;
}
}
} else {
match tokio::time::timeout(std::time::Duration::from_secs(30), client.read(tmp)).await {
Ok(r) => r,
Err(_) => {
warn!(domain, "MITM: partial request read timed out");
return None;
}
}
};
let n = match read_result {
Ok(0) => return None,
Ok(n) => n,
Err(e) => {
debug!(domain, error = %e, "MITM: client read finished");
return None;
}
};
buf.extend_from_slice(&tmp[..n]);
if has_complete_http_request(&buf) {
break;
}
}
Some(buf)
}
/// Connect (or reconnect) to the real upstream via TLS.
///
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
/// then falls back to cached IPs file, then to normal system resolution.
async fn connect_upstream(
domain: &str,
config: &Arc<rustls::ClientConfig>,
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
let connector = tokio_rustls::TlsConnector::from(config.clone());
let addr = resolve_upstream(domain).await;
info!(domain, addr = %addr, "MITM: connecting upstream");
let tcp = match tokio::time::timeout(
std::time::Duration::from_secs(15),
TcpStream::connect(&addr),
)
.await
{
Ok(Ok(s)) => s,
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
};
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
.map_err(|e| format!("Invalid server name: {e}"))?;
match tokio::time::timeout(
std::time::Duration::from_secs(15),
connector.connect(server_name, tcp),
)
.await
{
Ok(Ok(s)) => {
info!(domain, "MITM: upstream TLS connected ✓");
Ok(s)
}
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
}
}
/// Resolve upstream IP bypassing /etc/hosts.
async fn resolve_upstream(domain: &str) -> String {
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
if let Ok(output) = tokio::process::Command::new("dig")
.args(["+short", "@8.8.8.8", domain])
.output()
.await
{
let out = String::from_utf8_lossy(&output.stdout);
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
return format!("{ip}:443");
}
}
// 2. Try cached IPs file
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
for line in contents.lines() {
if let Some((d, ip)) = line.split_once('=') {
if d == domain {
return format!("{ip}:443");
}
}
}
}
// 3. Fallback to normal resolution
format!("{domain}:443")
}
/// Dispatch parsed streaming events to the channel or legacy store.
///
/// Deduplicates the event dispatch logic used both for initial body parsing
/// and subsequent body chunk processing.
async fn dispatch_stream_events(
acc: &mut StreamingAccumulator,
event_tx: &Option<tokio::sync::mpsc::Sender<super::store::MitmEvent>>,
store: &MitmStore,
cascade_hint: Option<&str>,
) {
if let Some(ref tx) = event_tx {
if !acc.function_calls.is_empty() {
let calls: Vec<_> = acc.function_calls.drain(..).collect();
store.record_function_call(cascade_hint, calls[0].clone()).await;
info!("MITM: sending {} function call(s) via channel", calls.len());
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
}
if !acc.thinking_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(acc.thinking_text.clone())).await;
}
if !acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(acc.response_text.clone())).await;
}
if let Some(ref gm) = acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
}
if acc.is_complete {
info!(
response_text_len = acc.response_text.len(),
thinking_text_len = acc.thinking_text.len(),
"MITM: response complete — sending via channel"
);
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
acc.is_complete = false;
}
} else {
if !acc.function_calls.is_empty() {
let calls: Vec<_> = acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint, fc.clone()).await;
}
info!("MITM: stored {} function call(s)", calls.len());
}
if !acc.response_text.is_empty() {
store.set_response_text(&acc.response_text).await;
}
if let Some(ref gm) = acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
}
}
/// Handle a passthrough connection: transparent TCP tunnel to upstream. /// Handle a passthrough connection: transparent TCP tunnel to upstream.
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> { async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
trace!(domain, port, "MITM: transparent tunnel"); trace!(domain, port, "MITM: transparent tunnel");

View File

@@ -1,13 +1,13 @@
//! Shared store for intercepted API usage data. //! Shared store for intercepted API usage data.
//! //!
//! The MITM proxy writes usage data here; the API handlers read from it. //! Per-request state is stored in `RequestContext`, keyed by cascade ID.
//! When custom tools are active, the MITM proxy sends real-time events //! The MITM proxy looks up the context when intercepting LS requests,
//! through a channel instead of writing to shared state. //! enabling concurrent request processing without global locks.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant;
use tokio::sync::{mpsc, RwLock}; use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info}; use tracing::{debug, info};
@@ -52,6 +52,10 @@ pub struct ApiUsage {
pub struct CapturedFunctionCall { pub struct CapturedFunctionCall {
pub name: String, pub name: String,
pub args: serde_json::Value, pub args: serde_json::Value,
/// Google's thought signature — required when injecting functionCall back
/// into conversation history. Without it, Google returns INVALID_ARGUMENT.
#[serde(skip_serializing_if = "Option::is_none")]
pub thought_signature: Option<String>,
pub captured_at: u64, pub captured_at: u64,
} }
@@ -128,6 +132,25 @@ pub struct GenerationParams {
pub google_search: bool, pub google_search: bool,
} }
/// Cached context from turn 0 of a cascade.
///
/// On the first turn, the MITM proxy consumes the `RequestContext` and builds
/// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext`
/// is gone. This cache stores the essential fields so we can rebuild a lite
/// `ToolContext` on every turn — ensuring the model always sees the real user
/// text and has access to custom tools.
#[derive(Debug, Clone)]
pub struct CascadeCache {
/// The real user text (used to replace the "." dot prompt).
pub user_text: String,
/// Custom tool definitions (Gemini format).
pub tools: Option<Vec<serde_json::Value>>,
/// Custom tool config.
pub tool_config: Option<serde_json::Value>,
/// Client generation parameters.
pub generation_params: Option<GenerationParams>,
}
// ─── Channel-based event pipeline ──────────────────────────────────────────── // ─── Channel-based event pipeline ────────────────────────────────────────────
/// Events sent from the MITM proxy to API handlers through a per-request channel. /// Events sent from the MITM proxy to API handlers through a per-request channel.
@@ -146,15 +169,53 @@ pub enum MitmEvent {
/// Google API returned an error. /// Google API returned an error.
UpstreamError(UpstreamError), UpstreamError(UpstreamError),
/// Grounding metadata (search results) from the response. /// Grounding metadata (search results) from the response.
#[allow(dead_code)]
Grounding(serde_json::Value), Grounding(serde_json::Value),
/// Token usage data from the response. /// Token usage data from the response.
Usage(ApiUsage), Usage(ApiUsage),
} }
// ─── Per-request context ─────────────────────────────────────────────────────
/// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`.
///
/// API handlers build this before `send_message`, and the MITM proxy consumes
/// it when the LS's outbound request is intercepted.
#[derive(Debug)]
pub struct RequestContext {
/// Cascade ID this context belongs to.
pub cascade_id: String,
/// Real user text for MITM injection (LS receives "." instead).
pub pending_user_text: String,
/// Event channel for real-time streaming from MITM → API handler.
/// Only present when custom tools are active.
pub event_channel: Option<mpsc::Sender<MitmEvent>>,
/// Client-specified generation parameters (temperature, top_p, etc.).
pub generation_params: Option<GenerationParams>,
/// Image to inject into the Google API request.
pub pending_image: Option<PendingImage>,
/// Gemini-format tool declarations for MITM injection.
pub tools: Option<Vec<serde_json::Value>>,
/// Gemini-format toolConfig.
pub tool_config: Option<serde_json::Value>,
/// Pending tool results to inject as functionResponse.
pub pending_tool_results: Vec<PendingToolResult>,
/// Multi-round tool call history for history rewriting.
pub tool_rounds: Vec<ToolRound>,
/// Last captured function calls for history rewriting.
pub last_function_calls: Vec<CapturedFunctionCall>,
/// Mapping call_id → function name for tool result routing.
pub call_id_to_name: HashMap<String, String>,
/// When this context was created (for TTL cleanup).
pub created_at: Instant,
}
// ─── MitmStore ───────────────────────────────────────────────────────────────
/// Thread-safe store for intercepted data. /// Thread-safe store for intercepted data.
/// ///
/// Keyed by a unique request ID that we can correlate with cascade operations. /// Per-request state lives in `pending_requests`, keyed by cascade ID.
/// In practice, we use the cascade ID + a sequence number. /// Global state (usage stats, function call capture) remains shared.
#[derive(Clone)] #[derive(Clone)]
pub struct MitmStore { pub struct MitmStore {
/// Most recent usage per cascade ID. /// Most recent usage per cascade ID.
@@ -163,62 +224,24 @@ pub struct MitmStore {
stats: Arc<RwLock<MitmStats>>, stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses. /// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls. /// Key: cascade hint or "_latest". Value: list of function calls.
/// Used by the non-tool LS path (normal sync responses).
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>, pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Set when the MITM forwards the first LLM request with custom tools.
/// Blocks ALL subsequent LS requests until the API handler clears it.
request_in_flight: Arc<AtomicBool>,
// ── Channel-based event pipeline (replaces old polling) ────────────── // ── Per-request state (keyed by cascade ID) ──────────────────────────
/// Active channel sender for the current tool-path request. /// Active request contexts. API handlers register before send_message,
/// When present, the MITM proxy sends events through this instead of /// MITM proxy consumes when intercepting the LS request.
/// writing to shared state. The channel's existence = request in-flight. pending_requests: Arc<RwLock<HashMap<String, RequestContext>>>,
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
// ── Tool call support ──────────────────────────────────────────────── /// Cached context from turn 0, keyed by cascade ID.
/// Active tool definitions (Gemini format) for MITM injection. /// Used to rebuild ToolContext on subsequent turns of the same cascade.
active_tools: Arc<RwLock<Option<Vec<serde_json::Value>>>>, cascade_cache: Arc<RwLock<HashMap<String, CascadeCache>>>,
/// Active tool config (Gemini toolConfig format).
active_tool_config: Arc<RwLock<Option<serde_json::Value>>>,
/// Pending tool results for MITM to inject as functionResponse.
pending_tool_results: Arc<RwLock<Vec<PendingToolResult>>>,
/// Mapping call_id → function name for tool result routing.
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
/// Last captured function calls (for conversation history rewriting).
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
/// Multi-round tool call history for correct per-turn history rewriting.
/// Set by completions/responses handler, consumed by modify_request.
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
// ── Cascade correlation ──────────────────────────────────────────────
/// Active cascade ID set by the API layer before sending a message.
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
active_cascade_id: Arc<RwLock<Option<String>>>,
// ── Legacy direct response capture (used by search.rs) ─────────────── // ── Legacy direct response capture (used by search.rs) ───────────────
/// Captured response text from MITM. Used as fallback by search endpoint. /// Captured response text from MITM. Used as fallback by search endpoint.
captured_response_text: Arc<RwLock<Option<String>>>, captured_response_text: Arc<RwLock<Option<String>>>,
// ── Generation parameters for MITM injection ─────────────────────────
/// Client-specified sampling parameters to inject into Google API requests.
generation_params: Arc<RwLock<Option<GenerationParams>>>,
// ── Grounding metadata capture ────────────────────────────────────── // ── Grounding metadata capture ──────────────────────────────────────
/// Captured grounding metadata from Google API responses (search results). /// Captured grounding metadata from Google API responses (search results).
captured_grounding: Arc<RwLock<Option<serde_json::Value>>>, captured_grounding: Arc<RwLock<Option<serde_json::Value>>>,
// ── Pending image for MITM injection ─────────────────────────────────
/// Image to inject into the next Google API request via MITM.
pending_image: Arc<RwLock<Option<PendingImage>>>,
// ── Upstream error capture (legacy, used when no channel) ────────────
/// Error from Google's API, captured by MITM for forwarding to client.
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
// ── Standard LS input: real user text for MITM injection ─────────────
/// The real user text to inject into the Google API request.
/// API handlers store this before sending a dummy prompt to the LS.
pending_user_text: Arc<RwLock<Option<String>>>,
} }
/// Aggregate statistics across all intercepted traffic. /// Aggregate statistics across all intercepted traffic.
@@ -251,24 +274,106 @@ impl MitmStore {
latest_usage: Arc::new(RwLock::new(HashMap::new())), latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())), stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
request_in_flight: Arc::new(AtomicBool::new(false)), pending_requests: Arc::new(RwLock::new(HashMap::new())),
active_channel: Arc::new(RwLock::new(None)), cascade_cache: Arc::new(RwLock::new(HashMap::new())),
active_tools: Arc::new(RwLock::new(None)),
active_tool_config: Arc::new(RwLock::new(None)),
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
last_function_calls: Arc::new(RwLock::new(Vec::new())),
tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)), captured_response_text: Arc::new(RwLock::new(None)),
generation_params: Arc::new(RwLock::new(None)),
captured_grounding: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)),
pending_image: Arc::new(RwLock::new(None)),
upstream_error: Arc::new(RwLock::new(None)),
pending_user_text: Arc::new(RwLock::new(None)),
} }
} }
// ── Per-request context management ───────────────────────────────────
/// Register a request context for a cascade. Called by API handlers
/// before `send_message` so the MITM proxy can find it.
pub async fn register_request(&self, ctx: RequestContext) {
let cascade_id = ctx.cascade_id.clone();
info!(cascade = %cascade_id, "Registered request context");
self.pending_requests.write().await.insert(cascade_id, ctx);
}
/// Take (consume) the request context for a cascade.
/// Called by the MITM proxy when intercepting the LS's outbound request.
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
let ctx = self.pending_requests.write().await.remove(cascade_id);
if ctx.is_some() {
debug!(cascade = %cascade_id, "Took request context");
}
ctx
}
/// Take the most recently registered request context (by creation time).
/// Fallback when cascade_id can't be extracted from the Google API request.
pub async fn take_latest_request(&self) -> Option<RequestContext> {
let mut pending = self.pending_requests.write().await;
if pending.is_empty() {
return None;
}
// Find the most recently created request
let latest_key = pending
.iter()
.max_by_key(|(_, ctx)| ctx.created_at)
.map(|(k, _)| k.clone());
if let Some(key) = latest_key {
let ctx = pending.remove(&key);
if ctx.is_some() {
debug!(cascade = %key, "Took latest request context (fallback)");
}
ctx
} else {
None
}
}
/// Update a request context in-place. Returns false if not found.
pub async fn update_request<F>(&self, cascade_id: &str, updater: F) -> bool
where
F: FnOnce(&mut RequestContext),
{
let mut map = self.pending_requests.write().await;
if let Some(ctx) = map.get_mut(cascade_id) {
updater(ctx);
true
} else {
false
}
}
/// Remove a request context (cleanup after response is complete).
pub async fn remove_request(&self, cascade_id: &str) {
if self.pending_requests.write().await.remove(cascade_id).is_some() {
debug!(cascade = %cascade_id, "Removed request context");
}
}
// ── Cascade cache (turn 0 context for re-injection on turn 1+) ──────
/// Cache the essential context from turn 0 so it can be re-used on
/// subsequent turns of the same cascade.
pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) {
debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(),
has_tools = cache.tools.is_some(),
"Cached cascade context for subsequent turns");
self.cascade_cache.write().await.insert(cascade_id.to_string(), cache);
}
/// Get cached context for a cascade (non-consuming — needed on every turn).
pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option<CascadeCache> {
self.cascade_cache.read().await.get(cascade_id).cloned()
}
/// Check if a cascade has been processed (turn 0 complete).
pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool {
self.cascade_cache.read().await.contains_key(cascade_id)
}
// ── Usage recording ──────────────────────────────────────────────────
/// Record a completed API exchange with usage data. /// Record a completed API exchange with usage data.
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) { pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
debug!( debug!(
@@ -314,13 +419,7 @@ impl MitmStore {
// Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary) // Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary)
// //
// When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage. // When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage.
let key = if let Some(cid) = cascade_id { let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
cid.to_string()
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
active.clone()
} else {
"_latest".to_string()
};
let mut latest = self.latest_usage.write().await; let mut latest = self.latest_usage.write().await;
if let Some(existing) = latest.get_mut(&key) { if let Some(existing) = latest.get_mut(&key) {
@@ -346,7 +445,6 @@ impl MitmStore {
// Evict old entries to prevent unbounded memory growth // Evict old entries to prevent unbounded memory growth
const MAX_ENTRIES: usize = 500; const MAX_ENTRIES: usize = 500;
if latest.len() > MAX_ENTRIES { if latest.len() > MAX_ENTRIES {
// Find the oldest entry by captured_at and remove it
let oldest_key = latest let oldest_key = latest
.iter() .iter()
.min_by_key(|(_, v)| v.captured_at) .min_by_key(|(_, v)| v.captured_at)
@@ -357,18 +455,13 @@ impl MitmStore {
} }
} }
/// Get the latest usage for a cascade, consuming it (one-shot read).
///
/// Peek at usage data for a cascade without consuming it. /// Peek at usage data for a cascade without consuming it.
/// Used to check if thinking text has been merged before taking.
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> { pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let latest = self.latest_usage.read().await; let latest = self.latest_usage.read().await;
latest.get(cascade_id).cloned() latest.get(cascade_id).cloned()
} }
/// Only returns exact cascade_id matches — no cross-cascade fallback. /// Only returns exact cascade_id matches — no cross-cascade fallback.
/// The `_latest` key is only consumed when the caller explicitly requests it
/// (i.e., when the MITM couldn't identify the cascade).
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> { pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
let mut latest = self.latest_usage.write().await; let mut latest = self.latest_usage.write().await;
latest.remove(cascade_id) latest.remove(cascade_id)
@@ -379,19 +472,11 @@ impl MitmStore {
self.stats.read().await.clone() self.stats.read().await.clone()
} }
// ── Function call capture ────────────────────────────────────────────
/// Record a captured function call from Google's response. /// Record a captured function call from Google's response.
///
/// Falls back to `active_cascade_id` (set by the API handler) when no
/// cascade hint is available from the request body, matching
/// `record_usage`'s fallback behavior for consistent correlation.
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) { pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
let key = if let Some(cid) = cascade_id { let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
cid.to_string()
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
active.clone()
} else {
"_latest".to_string()
};
info!( info!(
cascade = %key, cascade = %key,
tool = %fc.name, tool = %fc.name,
@@ -404,9 +489,7 @@ impl MitmStore {
/// Take pending function calls for a specific cascade. /// Take pending function calls for a specific cascade.
/// ///
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key. /// Priority: exact cascade_id → `_latest` → any key.
/// This prevents cross-cascade contamination when multiple requests are
/// in-flight simultaneously.
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> { pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
@@ -415,21 +498,12 @@ impl MitmStore {
return Some(result); return Some(result);
} }
// 2. Active cascade (set by API handler) // 2. Fallback to _latest
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
if active != cascade_id {
if let Some(result) = pending.remove(active.as_str()) {
return Some(result);
}
}
}
// 3. Fallback to _latest
if let Some(result) = pending.remove("_latest") { if let Some(result) = pending.remove("_latest") {
return Some(result); return Some(result);
} }
// 4. Last resort: any key // 3. Last resort: any key
if let Some(key) = pending.keys().next().cloned() { if let Some(key) = pending.keys().next().cloned() {
return pending.remove(&key); return pending.remove(&key);
} }
@@ -438,7 +512,6 @@ impl MitmStore {
} }
/// Take any pending function calls (ignoring cascade ID). /// Take any pending function calls (ignoring cascade ID).
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> { pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest"); let result = pending.remove("_latest");
@@ -451,114 +524,24 @@ impl MitmStore {
None None
} }
// ── Channel-based event pipeline ───────────────────────────────────── /// Peek at the thought_signatures of recently captured function calls.
/// Returns a map of function_name → thought_signature (non-destructive).
/// Install a channel sender for the current tool-path request. pub async fn peek_thought_signatures(&self) -> std::collections::HashMap<String, String> {
/// The MITM proxy will send events through this channel. let pending = self.pending_function_calls.read().await;
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) { let mut sigs = std::collections::HashMap::new();
*self.active_channel.write().await = Some(tx); for calls in pending.values() {
// NOTE: Do NOT set request_in_flight here. The MITM proxy's for fc in calls {
// try_mark_request_in_flight() is the sole setter — setting it if let Some(ref sig) = fc.thought_signature {
// here causes compare_exchange(false,true) to always fail, sigs.insert(fc.name.clone(), sig.clone());
// blocking every real LS request.
} }
/// Take the active channel sender (used by MITM proxy to grab it).
/// Returns None if no channel is active.
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
self.active_channel.write().await.take()
} }
/// Drop the active channel and clear in-flight state.
/// Called when the API handler is done with the current request.
pub async fn drop_channel(&self) {
*self.active_channel.write().await = None;
self.request_in_flight.store(false, Ordering::SeqCst);
} }
sigs
// ── Tool context methods ─────────────────────────────────────────────
/// Set active tool definitions (already in Gemini format).
pub async fn set_tools(&self, tools: Vec<serde_json::Value>) {
*self.active_tools.write().await = Some(tools);
}
/// Get active tool definitions.
pub async fn get_tools(&self) -> Option<Vec<serde_json::Value>> {
self.active_tools.read().await.clone()
}
/// Clear active tool definitions.
pub async fn clear_tools(&self) {
*self.active_tools.write().await = None;
*self.active_tool_config.write().await = None;
// Also clear accumulated tool rounds to prevent stale data
self.tool_rounds.write().await.clear();
}
/// Set active tool config (Gemini toolConfig format).
pub async fn set_tool_config(&self, config: serde_json::Value) {
*self.active_tool_config.write().await = Some(config);
}
/// Get active tool config.
pub async fn get_tool_config(&self) -> Option<serde_json::Value> {
self.active_tool_config.read().await.clone()
}
/// Add a pending tool result for MITM injection.
pub async fn add_tool_result(&self, result: PendingToolResult) {
info!(name = %result.name, "Storing pending tool result");
self.pending_tool_results.write().await.push(result);
}
/// Take (consume) all pending tool results.
pub async fn take_tool_results(&self) -> Vec<PendingToolResult> {
std::mem::take(&mut *self.pending_tool_results.write().await)
}
/// Register a call_id → function name mapping.
pub async fn register_call_id(&self, call_id: String, name: String) {
self.call_id_to_name.write().await.insert(call_id, name);
}
/// Look up function name by call_id.
pub async fn lookup_call_id(&self, call_id: &str) -> Option<String> {
self.call_id_to_name.read().await.get(call_id).cloned()
}
/// Save the last captured function calls (for history rewriting).
pub async fn set_last_function_calls(&self, calls: Vec<CapturedFunctionCall>) {
*self.last_function_calls.write().await = calls;
}
/// Get the last captured function calls.
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
self.last_function_calls.read().await.clone()
}
/// Store multi-round tool call history for correct per-turn history rewriting.
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
*self.tool_rounds.write().await = rounds;
}
/// Take (consume) multi-round tool call history.
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
std::mem::take(&mut *self.tool_rounds.write().await)
}
/// Get (non-destructive clone) multi-round tool call history.
/// Used by proxy.rs to read rounds without consuming them, so they
/// persist across multiple LS requests in the same cascade.
pub async fn get_tool_rounds(&self) -> Vec<ToolRound> {
self.tool_rounds.read().await.clone()
} }
// ── Legacy direct response capture (search.rs fallback) ────────────── // ── Legacy direct response capture (search.rs fallback) ──────────────
/// Set (replace) the captured response text. /// Set (replace) the captured response text.
/// Used by MITM proxy for non-channel path (search endpoint fallback).
pub async fn set_response_text(&self, text: &str) { pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string()); *self.captured_response_text.write().await = Some(text.to_string());
} }
@@ -568,71 +551,11 @@ impl MitmStore {
self.captured_response_text.write().await.take() self.captured_response_text.write().await.take()
} }
/// Clear stale state between requests. /// Clear stale legacy response state.
/// Drops any active channel and clears in-flight flags.
pub async fn clear_response_async(&self) { pub async fn clear_response_async(&self) {
self.request_in_flight.store(false, Ordering::SeqCst);
*self.active_channel.write().await = None;
*self.captured_response_text.write().await = None; *self.captured_response_text.write().await = None;
} }
/// Atomically try to mark request as in-flight.
/// Returns true if this caller won the race (was first to set it).
/// Returns false if already in-flight (someone else set it first).
pub fn try_mark_request_in_flight(&self) -> bool {
self.request_in_flight
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
}
/// Check if a request is currently in-flight.
#[allow(dead_code)]
pub fn is_request_in_flight(&self) -> bool {
self.request_in_flight.load(Ordering::SeqCst)
}
/// Clear the in-flight flag so the LS can make follow-up requests.
pub fn clear_request_in_flight(&self) {
self.request_in_flight.store(false, Ordering::SeqCst);
}
// ── Cascade correlation ──────────────────────────────────────────────
/// Set the active cascade ID (called by API handlers before sending a message).
/// The MITM proxy will use this to correlate intercepted traffic.
pub async fn set_active_cascade(&self, cascade_id: &str) {
*self.active_cascade_id.write().await = Some(cascade_id.to_string());
}
/// Get the active cascade ID.
#[allow(dead_code)]
pub async fn get_active_cascade(&self) -> Option<String> {
self.active_cascade_id.read().await.clone()
}
/// Clear the active cascade ID (called after response is complete).
#[allow(dead_code)]
pub async fn clear_active_cascade(&self) {
*self.active_cascade_id.write().await = None;
}
// ── Generation parameters ────────────────────────────────────────────
/// Store client-specified generation parameters for MITM injection.
pub async fn set_generation_params(&self, params: GenerationParams) {
*self.generation_params.write().await = Some(params);
}
/// Read current generation parameters (non-consuming).
pub async fn get_generation_params(&self) -> Option<GenerationParams> {
self.generation_params.read().await.clone()
}
/// Clear generation parameters.
pub async fn clear_generation_params(&self) {
*self.generation_params.write().await = None;
}
// ── Grounding metadata capture ────────────────────────────────────── // ── Grounding metadata capture ──────────────────────────────────────
/// Store captured grounding metadata from API response. /// Store captured grounding metadata from API response.
@@ -652,46 +575,35 @@ impl MitmStore {
self.captured_grounding.read().await.clone() self.captured_grounding.read().await.clone()
} }
// ── Pending image for MITM injection ───────────────────────────────── // ── Compat shims for streaming tool-call loops ──────────────────────
/// Store a pending image for MITM injection. /// Update the event channel on an existing request context.
pub async fn set_pending_image(&self, image: PendingImage) { /// Used by streaming loop handlers when re-registering for a new tool round.
*self.pending_image.write().await = Some(image); pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
self.update_request(cascade_id, |ctx| {
ctx.event_channel = Some(tx);
}).await;
} }
/// Take (consume) pending image for injection. /// No-op. Upstream errors are now delivered through the event channel.
pub async fn take_pending_image(&self) -> Option<PendingImage> { /// Kept for API handler compatibility.
self.pending_image.write().await.take()
}
// ── Upstream error capture ───────────────────────────────────────────
/// Store an upstream error from Google's API.
pub async fn set_upstream_error(&self, error: UpstreamError) {
*self.upstream_error.write().await = Some(error);
}
/// Take (consume) captured upstream error.
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
self.upstream_error.write().await.take()
}
/// Clear any stored upstream error.
pub async fn clear_upstream_error(&self) { pub async fn clear_upstream_error(&self) {
*self.upstream_error.write().await = None; // Intentionally empty — errors flow through MitmEvent::UpstreamError
} }
// ── Pending user text for MITM injection ───────────────────────────── /// Returns None. Upstream errors are now captured and delivered via the
/// per-request event channel rather than stored globally.
/// Store the real user text for MITM injection. pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
/// Called by API handlers before sending a dummy prompt to the LS. None
pub async fn set_pending_user_text(&self, text: String) {
*self.pending_user_text.write().await = Some(text);
} }
/// Take (consume) the pending user text. /// Store a call_id → function_name mapping in the request context.
/// Called by the MITM proxy when building ToolContext. /// Used by streaming tool-call loops when the model returns function calls.
pub async fn take_pending_user_text(&self) -> Option<String> { pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) {
self.pending_user_text.write().await.take() self.update_request(cascade_id, |ctx| {
ctx.call_id_to_name.insert(call_id, name);
}).await;
} }
} }

View File

@@ -9,6 +9,10 @@
//! carries the `detect_and_use_proxy` enum, model selection, and version info. //! carries the `detect_and_use_proxy` enum, model selection, and version info.
//! See `docs/ls-binary-analysis.md` for the full proto schema reverse engineering. //! See `docs/ls-binary-analysis.md` for the full proto schema reverse engineering.
pub mod wire;
use crate::constants::{client_version, CLIENT_NAME}; use crate::constants::{client_version, CLIENT_NAME};
// ─── Wire primitives ──────────────────────────────────────────────────────── // ─── Wire primitives ────────────────────────────────────────────────────────

159
src/proto/wire.rs Normal file
View File

@@ -0,0 +1,159 @@
//! Shared protobuf wire-format primitives — decode + encode.
//!
//! This module is the single source of truth for varint encoding/decoding,
//! proto string encoding/extraction, etc. All other modules should import
//! from here instead of rolling their own.
/// Decode a varint from a byte slice. Returns `(value, bytes_consumed)`.
///
/// This is the canonical decoder — all other modules should use this.
pub fn decode_varint(buf: &[u8]) -> Option<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in buf.iter().enumerate() {
if i >= 10 {
return None; // Too many bytes for a varint
}
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
return Some((result, i + 1));
}
shift += 7;
if shift >= 64 {
return None;
}
}
None
}
/// Encode a varint into an existing buffer.
pub fn encode_varint(buf: &mut Vec<u8>, mut val: u64) {
loop {
let byte = (val & 0x7F) as u8;
val >>= 7;
if val == 0 {
buf.push(byte);
break;
}
buf.push(byte | 0x80);
}
}
/// Encode a string/bytes value as a protobuf length-delimited field.
///
/// Produces: `[tag(field_num, wire_type=2)] [len] [data]`
pub fn encode_proto_string(field_num: u32, data: &[u8]) -> Vec<u8> {
let tag = (field_num << 3) | 2; // wire type 2 = length-delimited
let mut buf = Vec::with_capacity(1 + 5 + data.len());
encode_varint(&mut buf, tag as u64);
encode_varint(&mut buf, data.len() as u64);
buf.extend_from_slice(data);
buf
}
/// Extract a string field from raw protobuf bytes by field number.
///
/// Walks top-level fields, skipping varints, 64-bit, 32-bit, and other
/// length-delimited fields until the target field number is found.
/// Only returns the first occurrence.
pub fn extract_proto_string(buf: &[u8], target_field: u32) -> Option<String> {
let mut i = 0;
while i < buf.len() {
let (tag, consumed) = decode_varint(&buf[i..])?;
i += consumed;
let field_num = (tag >> 3) as u32;
let wire_type = (tag & 0x07) as u8;
match wire_type {
0 => {
// Varint — skip
let (_, c) = decode_varint(&buf[i..])?;
i += c;
}
1 => {
// 64-bit fixed — skip 8 bytes
if i + 8 > buf.len() {
return None;
}
i += 8;
}
2 => {
// Length-delimited
let (len, c) = decode_varint(&buf[i..])?;
i += c;
let len = len as usize;
if i + len > buf.len() {
return None;
}
if field_num == target_field {
return String::from_utf8(buf[i..i + len].to_vec()).ok();
}
i += len;
}
5 => {
// 32-bit fixed — skip 4 bytes
if i + 4 > buf.len() {
return None;
}
i += 4;
}
_ => return None, // Unknown wire type
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_varint_basic() {
assert_eq!(decode_varint(&[0x00]), Some((0, 1)));
assert_eq!(decode_varint(&[0x01]), Some((1, 1)));
assert_eq!(decode_varint(&[0x7F]), Some((127, 1)));
assert_eq!(decode_varint(&[0x80, 0x01]), Some((128, 2)));
assert_eq!(decode_varint(&[0x96, 0x01]), Some((150, 2)));
assert_eq!(decode_varint(&[0xAC, 0x02]), Some((300, 2)));
}
#[test]
fn test_encode_decode_roundtrip() {
for val in [0u64, 1, 127, 128, 300, 1026, u32::MAX as u64, u64::MAX] {
let mut buf = Vec::new();
encode_varint(&mut buf, val);
let (decoded, consumed) = decode_varint(&buf).unwrap();
assert_eq!(decoded, val, "roundtrip failed for {val}");
assert_eq!(consumed, buf.len());
}
}
#[test]
fn test_encode_proto_string() {
let result = encode_proto_string(1, b"hello");
// tag(1,2) = 0x0A, len=5, h,e,l,l,o
assert_eq!(result[0], 0x0A);
assert_eq!(result[1], 0x05);
assert_eq!(&result[2..], b"hello");
}
#[test]
fn test_extract_proto_string() {
// Build: field 1 = "abc", field 2 (varint) = 42, field 3 = "xyz"
let mut buf = Vec::new();
buf.extend_from_slice(&encode_proto_string(1, b"abc"));
// field 2 varint 42: tag = (2<<3)|0 = 0x10, value = 0x2A
buf.push(0x10);
buf.push(0x2A);
buf.extend_from_slice(&encode_proto_string(3, b"xyz"));
assert_eq!(extract_proto_string(&buf, 1), Some("abc".to_string()));
assert_eq!(extract_proto_string(&buf, 3), Some("xyz".to_string()));
assert_eq!(extract_proto_string(&buf, 99), None);
}
}

View File

@@ -8,7 +8,6 @@ use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
use tokio::sync::RwLock; use tokio::sync::RwLock;
const DEFAULT_SESSION: &str = "__default__";
const SESSION_TTL_SECS: u64 = 3600 * 4; // 4 hours const SESSION_TTL_SECS: u64 = 3600 * 4; // 4 hours
#[derive(Clone)] #[derive(Clone)]
@@ -23,10 +22,7 @@ pub struct SessionManager {
sessions: RwLock<HashMap<String, Session>>, sessions: RwLock<HashMap<String, Session>>,
} }
/// Result of session resolution.
pub struct SessionResult {
pub cascade_id: String,
}
impl SessionManager { impl SessionManager {
pub fn new() -> Self { pub fn new() -> Self {
@@ -35,82 +31,7 @@ impl SessionManager {
} }
} }
/// Get existing cascade for session, or create a new one.
///
/// - `session_id = None` → use default session
/// - `session_id = Some("new")` → always create fresh cascade
/// - `session_id = Some("my-task")` → reuse cascade for that task
///
/// Uses double-check locking to avoid TOCTOU races: after creating a cascade,
/// re-acquires the lock and checks if another request raced us.
pub async fn get_or_create<F, Fut>(
&self,
session_id: Option<&str>,
create_fn: F,
) -> Result<SessionResult, String>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<String, String>>,
{
// "new" always creates a fresh cascade
if session_id == Some("new") {
let cascade_id = create_fn().await?;
let new_sid = format!("s-{}", &uuid::Uuid::new_v4().to_string()[..8]);
let mut sessions = self.sessions.write().await;
sessions.insert(
new_sid.clone(),
Session {
cascade_id: cascade_id.clone(),
created: Instant::now(),
last_used: Instant::now(),
msg_count: 0,
},
);
return Ok(SessionResult { cascade_id });
}
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
// Check existing — only need write lock for cleanup + mutation
{
let mut sessions = self.sessions.write().await;
cleanup_expired(&mut sessions);
if let Some(sess) = sessions.get_mut(&sid) {
sess.last_used = Instant::now();
sess.msg_count += 1;
return Ok(SessionResult {
cascade_id: sess.cascade_id.clone(),
});
}
}
// Lock released before async create_fn
// Create new cascade (this may take a while — lock is NOT held)
let cascade_id = create_fn().await?;
// Double-check: another request may have raced us and created the same session
{
let mut sessions = self.sessions.write().await;
if let Some(existing) = sessions.get_mut(&sid) {
// Another request won the race — use their cascade, discard ours
existing.last_used = Instant::now();
existing.msg_count += 1;
return Ok(SessionResult {
cascade_id: existing.cascade_id.clone(),
});
}
sessions.insert(
sid.clone(),
Session {
cascade_id: cascade_id.clone(),
created: Instant::now(),
last_used: Instant::now(),
msg_count: 1,
},
);
}
Ok(SessionResult { cascade_id })
}
/// List all active sessions. /// List all active sessions.
pub async fn list_sessions(&self) -> serde_json::Value { pub async fn list_sessions(&self) -> serde_json::Value {

File diff suppressed because it is too large Load Diff

340
src/standalone/discovery.rs Normal file
View File

@@ -0,0 +1,340 @@
//! LS process discovery — finding, inspecting, and managing LS processes.
use super::{MainLSConfig, LS_USER};
use crate::proto::wire::extract_proto_string;
use std::net::TcpListener;
use std::process::{Command, Stdio};
use tracing::info;
/// Discover only the extension_server_port and csrf_token from the running main LS.
///
/// This does NOT discover the HTTPS port — we don't need to talk to the real LS,
/// only steal its extension server connection info.
pub fn discover_main_ls_config() -> Result<MainLSConfig, String> {
let pid = find_main_ls_pid()?;
let cmdline = std::fs::read(format!("/proc/{pid}/cmdline"))
.map_err(|e| format!("Can't read cmdline for PID {pid}: {e}"))?;
let args: Vec<&[u8]> = cmdline.split(|&b| b == 0).collect();
let mut csrf = String::new();
let mut ext_port = String::new();
for (i, arg) in args.iter().enumerate() {
if let Ok(s) = std::str::from_utf8(arg) {
match s {
"--csrf_token" | "-csrf_token" => {
if let Some(next) = args.get(i + 1) {
if let Ok(val) = std::str::from_utf8(next) {
csrf = val.to_string();
}
}
}
"--extension_server_port" | "-extension_server_port" => {
if let Some(next) = args.get(i + 1) {
if let Ok(val) = std::str::from_utf8(next) {
ext_port = val.to_string();
}
}
}
_ => {}
}
}
}
if csrf.is_empty() {
return Err("Could not find CSRF token from main LS".to_string());
}
if ext_port.is_empty() {
return Err("Could not find extension_server_port from main LS".to_string());
}
info!(
pid,
ext_port,
csrf_len = csrf.len(),
"Discovered main LS config"
);
Ok(MainLSConfig {
extension_server_port: ext_port,
csrf,
})
}
/// Find the PID of the main (real) LS process.
///
/// Checks `/proc/<pid>/exe` to ensure we find the actual LS binary,
/// not bash scripts that happen to mention `language_server_linux` in their args.
pub(super) fn find_main_ls_pid() -> Result<String, String> {
let proc = std::path::Path::new("/proc");
if !proc.exists() {
return Err("No /proc filesystem".to_string());
}
let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
for entry in entries.flatten() {
let name = entry.file_name();
let name_str = name.to_string_lossy();
// Only numeric dirs (PIDs)
if !name_str.chars().all(|c| c.is_ascii_digit()) {
continue;
}
let exe_link = entry.path().join("exe");
if let Ok(target) = std::fs::read_link(&exe_link) {
let target_str = target.to_string_lossy().to_string();
let target_clean = target_str.trim_end_matches(" (deleted)");
// Must be the actual LS binary, not a bash script
if target_clean.contains("language_server_linux")
|| target_clean.contains("antigravity-language-server")
{
return Ok(name_str.to_string());
}
}
}
Err("No main LS process found — Antigravity must be running".to_string())
}
/// Find a free TCP port by binding to port 0.
pub(super) fn find_free_port() -> Result<u16, String> {
let listener =
TcpListener::bind("127.0.0.1:0").map_err(|e| format!("Failed to bind for port: {e}"))?;
listener
.local_addr()
.map(|a| a.port())
.map_err(|e| format!("Failed to get port: {e}"))
}
/// Check if the dedicated LS system user exists.
///
/// When the user exists, the proxy spawns the LS as that UID so iptables
/// can scope the :443 redirect to only the standalone LS process.
pub(super) fn has_ls_user() -> bool {
Command::new("id")
.args(["-u", LS_USER])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false)
}
/// Find the PID of a language_server process owned by a specific user.
///
/// Used to discover the actual LS process after sudo spawns it as a different user.
pub(super) fn find_ls_pid_for_user(user: &str) -> Result<u32, String> {
let output = Command::new("pgrep")
.args(["-u", user, "-f", "language_server_linux"])
.output()
.map_err(|e| format!("pgrep failed: {e}"))?;
let stdout = String::from_utf8_lossy(&output.stdout);
stdout
.lines()
.next()
.and_then(|line| line.trim().parse::<u32>().ok())
.ok_or_else(|| format!("No LS process found for user {user}"))
}
/// Kill any orphaned standalone LS processes from previous runs.
///
/// This handles the case where the proxy crashed or was killed without
/// properly cleaning up the sudo-spawned LS process.
///
/// Key insight: the sudoers rule allows running commands AS antigravity-ls
/// (`ALL=(antigravity-ls) NOPASSWD: ALL`). A process can send signals to
/// other processes with the same UID, so we run `kill` as antigravity-ls
/// rather than as root.
pub(super) fn cleanup_orphaned_ls() {
if !has_ls_user() {
return;
}
// Find all LS processes owned by antigravity-ls user
let output = Command::new("pgrep")
.args(["-u", LS_USER, "-f", "language_server_linux"])
.output();
let pids: Vec<u32> = match output {
Ok(out) => String::from_utf8_lossy(&out.stdout)
.lines()
.filter_map(|l| l.trim().parse().ok())
.collect(),
Err(_) => return,
};
if pids.is_empty() {
return;
}
info!(
count = pids.len(),
?pids,
"Cleaning up orphaned standalone LS processes"
);
// Kill each PID by running `kill` AS the antigravity-ls user.
// This works because same-UID processes can signal each other,
// and the sudoers rule allows ALL commands as antigravity-ls.
for pid in &pids {
let ok = Command::new("sudo")
.args(["-n", "-u", LS_USER, "kill", "-TERM", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false);
if ok {
info!(pid, "Killed orphaned LS process");
} else {
// Fallback: try as root (needs separate sudoers entry)
let _ = Command::new("sudo")
.args(["-n", "kill", "-TERM", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
}
}
// Wait for graceful exit
std::thread::sleep(std::time::Duration::from_millis(500));
// Force-kill any survivors
let still_alive = Command::new("pgrep")
.args(["-u", LS_USER, "-f", "language_server_linux"])
.output()
.map(|o| !o.stdout.is_empty())
.unwrap_or(false);
if still_alive {
info!("Orphaned LS still alive, force killing");
for pid in &pids {
let _ = Command::new("sudo")
.args(["-n", "-u", LS_USER, "kill", "-KILL", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
}
std::thread::sleep(std::time::Duration::from_millis(300));
// Final check
let still_alive = Command::new("pgrep")
.args(["-u", LS_USER, "-f", "language_server_linux"])
.output()
.map(|o| !o.stdout.is_empty())
.unwrap_or(false);
if still_alive {
eprintln!("\n \x1b[1;31m⚠ Cannot kill orphaned LS process\x1b[0m");
eprintln!(" Run: \x1b[1msudo pkill -u {LS_USER} -f language_server_linux\x1b[0m\n");
}
} else {
info!("Orphaned LS processes cleaned up");
}
}
/// Read OAuth token state directly from Antigravity's state.vscdb.
///
/// The DB stores the exact Topic proto bytes under key `antigravityUnifiedStateSync.oauthToken`.
/// This includes access_token + refresh_token + expiry, allowing the LS to auto-refresh.
/// Returns (access_token, topic_proto_bytes) or None if unavailable.
pub(super) fn read_oauth_from_state_db() -> Option<(String, Vec<u8>)> {
use base64::Engine;
let home = std::env::var("HOME").ok()?;
let db_path = format!("{home}/.config/Antigravity/User/globalStorage/state.vscdb");
// Check the DB file exists
if !std::path::Path::new(&db_path).exists() {
return None;
}
// Read the Topic proto (base64-encoded in the DB)
let output = std::process::Command::new("sqlite3")
.args([
&db_path,
"SELECT value FROM ItemTable WHERE key='antigravityUnifiedStateSync.oauthToken'",
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let b64_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
if b64_str.is_empty() {
return None;
}
// Decode the base64 to get the raw Topic proto bytes
let topic_bytes = base64::engine::general_purpose::STANDARD
.decode(&b64_str)
.ok()?;
if topic_bytes.is_empty() {
return None;
}
// Extract the access_token from the OAuthTokenInfo inside the Topic proto.
// The inner value (Row.value) is also base64, containing a serialized OAuthTokenInfo.
// For the access_token (used by GetSecretValue), we can read it from the authStatus.
let access_token = read_access_token_from_auth_status(&db_path)
.or_else(|| extract_access_token_from_topic(&topic_bytes))
.unwrap_or_default();
Some((access_token, topic_bytes))
}
/// Read the current access token from `antigravityAuthStatus` in state.vscdb.
/// This JSON object has an `apiKey` field with the latest access token.
fn read_access_token_from_auth_status(db_path: &str) -> Option<String> {
let output = std::process::Command::new("sqlite3")
.args([
db_path,
"SELECT value FROM ItemTable WHERE key='antigravityAuthStatus'",
])
.output()
.ok()?;
if !output.status.success() {
return None;
}
let json_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
// Simple extraction: find "apiKey":"..." pattern
let marker = "\"apiKey\":\"";
let start = json_str.find(marker)? + marker.len();
let end = json_str[start..].find('"')? + start;
let api_key = &json_str[start..end];
if api_key.starts_with("ya29.") {
Some(api_key.to_string())
} else {
None
}
}
/// Extract access_token from the Topic proto bytes by finding the inner
/// base64-encoded OAuthTokenInfo and decoding its first string field.
fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
use base64::Engine;
let as_str = String::from_utf8_lossy(topic_bytes);
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
{
if segment.len() > 50 {
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
// Use shared proto decoder
if let Some(token) = extract_proto_string(&decoded, 1) {
if token.starts_with("ya29.") {
return Some(token);
}
}
}
}
}
None
}

137
src/standalone/mod.rs Normal file
View File

@@ -0,0 +1,137 @@
//! Standalone Language Server — spawn and lifecycle management.
//!
//! Launches an isolated LS instance as a child process that the proxy fully owns.
//! In **headless** mode, the LS runs completely independently — no running
//! Antigravity app required. Extension server is disabled (`port=0`), CSRF is
//! self-generated, and MITM uses `HTTPS_PROXY` instead of iptables.
mod discovery;
mod spawn;
mod stub;
use std::process::Command;
use tracing::info;
use uuid::Uuid;
// Re-export public API
pub use spawn::StandaloneLS;
/// Default path to the LS binary.
const LS_BINARY_PATH: &str =
"/usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64";
/// App root for ANTIGRAVITY_EDITOR_APP_ROOT env var.
const APP_ROOT: &str = "/usr/share/antigravity/resources/app";
/// Data directory for the standalone LS.
const DATA_DIR: &str = "/tmp/antigravity-standalone";
/// System user for UID-scoped iptables isolation.
const LS_USER: &str = "antigravity-ls";
/// Path for the compiled dns_redirect.so preload library.
const DNS_REDIRECT_SO_PATH: &str = "/tmp/antigravity-dns-redirect.so";
/// Source file for the DNS redirect preload library (relative to binary).
const DNS_REDIRECT_C_SOURCE: &str = include_str!("../mitm/dns_redirect.c");
/// Config needed to bootstrap the standalone LS.
///
/// In normal mode, discovered from the running main LS.
/// In headless mode, generated entirely by the proxy.
pub struct MainLSConfig {
pub extension_server_port: String,
pub csrf: String,
}
/// Optional MITM proxy config for the standalone LS.
pub struct StandaloneMitmConfig {
pub proxy_addr: String, // Full URL with scheme, e.g. "http://127.0.0.1:8742"
pub ca_cert_path: String, // path to MITM CA .pem
}
/// Generate a fully self-contained config for headless mode.
///
/// No running Antigravity instance needed — extension server is disabled
/// and CSRF is a random UUID.
pub fn generate_standalone_config() -> MainLSConfig {
let csrf = Uuid::new_v4().to_string();
info!(
csrf_len = csrf.len(),
"Generated standalone config (headless)"
);
MainLSConfig {
extension_server_port: "0".to_string(), // disables extension server
csrf,
}
}
/// Discover only the extension_server_port and csrf_token from the running main LS.
///
/// This does NOT discover the HTTPS port — we don't need to talk to the real LS,
/// only steal its extension server connection info.
pub fn discover_main_ls_config() -> Result<MainLSConfig, String> {
discovery::discover_main_ls_config()
}
/// Build the dns_redirect.so preload library if it doesn't already exist.
///
/// The library hooks `getaddrinfo()` via LD_PRELOAD to redirect Google API
/// domain lookups to 127.0.0.1. This is needed because the LS binary uses
/// CGO for DNS resolution (libc getaddrinfo) but raw syscalls for connect(),
/// so only DNS can be intercepted via LD_PRELOAD.
///
/// Returns the path to the .so on success, None on failure.
fn build_dns_redirect_so() -> Option<String> {
let so_path = DNS_REDIRECT_SO_PATH;
// Skip rebuild if already exists
if std::path::Path::new(so_path).exists() {
return Some(so_path.to_string());
}
// Write C source to a temp file
let c_path = format!("{so_path}.c");
if let Err(e) = std::fs::write(&c_path, DNS_REDIRECT_C_SOURCE) {
tracing::warn!("Failed to write dns_redirect.c: {e}");
return None;
}
// Compile: gcc -shared -fPIC -o dns_redirect.so dns_redirect.c -ldl
let output = Command::new("gcc")
.args(["-shared", "-fPIC", "-o", so_path, &c_path, "-ldl"])
.output();
match output {
Ok(out) if out.status.success() => {
info!("Built dns_redirect.so at {so_path}");
// Clean up source
let _ = std::fs::remove_file(&c_path);
Some(so_path.to_string())
}
Ok(out) => {
let stderr = String::from_utf8_lossy(&out.stderr);
tracing::warn!("Failed to compile dns_redirect.so: {stderr}");
None
}
Err(e) => {
tracing::warn!("gcc not found, cannot build dns_redirect.so: {e}");
None
}
}
}
#[cfg(test)]
mod tests {
use super::discovery::find_free_port;
use std::net::TcpListener;
#[test]
fn test_find_free_port() {
let port = find_free_port().unwrap();
assert!(port > 0);
// Port should be available — try binding to it
let listener = TcpListener::bind(format!("127.0.0.1:{port}"));
assert!(listener.is_ok(), "Port {port} should be free");
}
}

464
src/standalone/spawn.rs Normal file
View File

@@ -0,0 +1,464 @@
//! StandaloneLS — process lifecycle (spawn, wait, kill).
use super::discovery::{cleanup_orphaned_ls, find_free_port, find_ls_pid_for_user, has_ls_user, read_oauth_from_state_db};
use super::stub::stub_handle_connection;
use super::{build_dns_redirect_so, MainLSConfig, StandaloneMitmConfig, APP_ROOT, DATA_DIR, LS_BINARY_PATH, LS_USER};
use crate::constants;
use crate::proto;
use std::io::Write;
use std::net::TcpListener;
use std::process::{Child, Command, Stdio};
use tokio::time::{sleep, Duration};
use tracing::{debug, info};
/// A running standalone LS process.
pub struct StandaloneLS {
child: Child,
/// The actual LS process PID (may differ from child PID when spawned via sudo).
ls_pid: Option<u32>,
/// Whether the LS was spawned via sudo (needs sudo kill).
use_sudo: bool,
/// Whether kill() has already been called.
killed: bool,
pub port: u16,
pub csrf: String,
}
impl StandaloneLS {
/// Spawn a standalone LS process.
///
/// Discovers the main LS's extension server port and CSRF token,
/// picks a free port, builds init metadata, and launches the binary.
///
/// If `mitm_config` is provided, sets HTTPS_PROXY and SSL_CERT_FILE
/// so the LS routes LLM API calls through the MITM proxy.
pub fn spawn(
main_config: &MainLSConfig,
mitm_config: Option<&StandaloneMitmConfig>,
headless: bool,
) -> Result<Self, String> {
// Kill any orphaned LS processes from previous runs
cleanup_orphaned_ls();
let port = find_free_port()?;
let lsp_port = find_free_port()?;
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
// Build init metadata protobuf
let api_key = format!("standalone-api-key-{ts}");
let session_id = format!("standalone-session-{ts}");
let metadata = proto::build_init_metadata(
&api_key,
constants::antigravity_version(),
constants::client_version(),
&session_id,
1, // DETECT_AND_USE_PROXY_ENABLED
);
// Setup data dir (mode 1777 so both current user and antigravity-ls can write)
let gemini_dir = format!("{DATA_DIR}/.gemini");
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
let annotations_dir = format!("{app_data_dir}/annotations");
let brain_dir = format!("{app_data_dir}/brain");
for dir in [
DATA_DIR,
&gemini_dir,
&app_data_dir,
&annotations_dir,
&brain_dir,
] {
let _ = std::fs::create_dir_all(dir);
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(dir, std::fs::Permissions::from_mode(0o1777));
}
}
// Check if data dir is writable by writing a test file.
// Old runs as `antigravity-ls` user leave dirs owned by that user.
let test_path = format!("{app_data_dir}/.write_test");
if std::fs::write(&test_path, b"ok").is_err() {
eprintln!(
"\n ⚠ Data dir {} is not writable (owned by another user from previous sudo run)\n \
Fix with: sudo chmod -R a+rwX {}\n",
app_data_dir, DATA_DIR
);
} else {
let _ = std::fs::remove_file(&test_path);
}
// Pre-seed user_settings.pb with detect_and_use_proxy = ENABLED.
// The LS reads this at startup when creating its HTTP transport.
// Without it, the LS ignores HTTPS_PROXY and API traffic bypasses MITM.
// UserSettings proto: field 34 (varint) = 1 (DETECT_AND_USE_PROXY_ENABLED)
// Tag: (34 << 3) | 0 = 272 → varint [0x90, 0x02]
// Value: 1 → varint [0x01]
let settings_path = format!("{app_data_dir}/user_settings.pb");
let settings_bytes: &[u8] = &[0x90, 0x02, 0x01];
if let Err(e) = std::fs::write(&settings_path, settings_bytes) {
tracing::warn!("Failed to pre-seed user_settings.pb: {e}");
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(
&settings_path,
std::fs::Permissions::from_mode(0o0666),
);
}
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
}
// In headless mode, spawn a stub TCP listener to serve as the extension server.
// The LS connects to this port and calls LanguageServerStarted — without it,
// the LS never fully initializes and won't accept connections on its server_port.
let _stub_listener = if headless {
let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
if stub_port == 0 {
// Create a real listener so the LS can connect
let listener = TcpListener::bind("127.0.0.1:0")
.map_err(|e| format!("Failed to bind stub extension server: {e}"))?;
let actual_port = listener
.local_addr()
.map_err(|e| format!("Failed to get stub port: {e}"))?
.port();
info!(
port = actual_port,
"Stub extension server listening (headless)"
);
// Read OAuth state from Antigravity's state.vscdb if available.
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
let (oauth_token, oauth_topic_bytes) = read_oauth_from_state_db()
.map(|(token, topic)| {
info!("Loaded OAuth token from Antigravity state.vscdb");
(token, Some(topic))
})
.unwrap_or_else(|| {
// Fall back to env var / token file
let token = std::env::var("ANTIGRAVITY_OAUTH_TOKEN")
.ok()
.filter(|s| !s.is_empty())
.or_else(|| {
let home = std::env::var("HOME").unwrap_or_default();
let path = format!("{home}/.config/antigravity-proxy-token");
std::fs::read_to_string(&path)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
})
.unwrap_or_default();
if !token.is_empty() {
info!("Loaded OAuth token from env/file (no refresh token — manual refresh needed)");
} else {
eprintln!("[headless] ⚠ No OAuth token found. Login to Antigravity first, or set ANTIGRAVITY_OAUTH_TOKEN");
}
(token, None)
});
let oauth_arc = std::sync::Arc::new(oauth_token);
let topic_arc = std::sync::Arc::new(oauth_topic_bytes);
// Spawn a thread to accept connections (just hold them open)
let listener_clone = listener
.try_clone()
.map_err(|e| format!("Failed to clone stub listener: {e}"))?;
std::thread::spawn(move || {
for stream in listener_clone.incoming() {
match stream {
Ok(conn) => {
let token = std::sync::Arc::clone(&oauth_arc);
let topic = std::sync::Arc::clone(&topic_arc);
// Handle each connection in its own thread
std::thread::spawn(move || {
stub_handle_connection(conn, &token, &topic);
});
}
Err(_) => break,
}
}
});
// Update the extension_server_port to the stub's port
// (we need to use this in args below)
Some((listener, actual_port))
} else {
None
}
} else {
None
};
// Determine the actual extension_server_port to use
let ext_port = if let Some((_, stub_port)) = &_stub_listener {
stub_port.to_string()
} else {
main_config.extension_server_port.clone()
};
// LS args — NO -standalone flag (it disables TCP listeners entirely)
// NOTE: do NOT use -random_port — it overrides -server_port and the LS
// would listen on a random port we can't discover.
let args = vec![
"-enable_lsp".to_string(),
format!("-lsp_port={}", lsp_port),
"-extension_server_port".to_string(),
ext_port,
"-csrf_token".to_string(),
main_config.csrf.clone(),
"-server_port".to_string(),
port.to_string(),
"-workspace_id".to_string(),
format!("standalone_{ts}"),
"-cloud_code_endpoint".to_string(),
// When MITM is active, append the MITM port to the endpoint URL.
// The LS's CodeAssistClient ignores HTTPS_PROXY (hardcoded Proxy:nil),
// so we redirect at the DNS+port level instead:
// 1. LD_PRELOAD hooks getaddrinfo() → 127.0.0.1 for API domains
// 2. Custom port in URL → LS connects to 127.0.0.1:MITM_PORT
// 3. MITM proxy intercepts the transparent TLS connection via SNI
if let Some(mitm) = mitm_config {
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
} else {
"https://daily-cloudcode-pa.googleapis.com".to_string()
},
"-app_data_dir".to_string(),
"antigravity-standalone".to_string(),
"-gemini_dir".to_string(),
gemini_dir,
];
info!(port, "Spawning standalone LS");
debug!(?args, "LS args");
// Build env vars for the LS process
let mut env_vars: Vec<(String, String)> =
vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
// If MITM is enabled, add SSL + proxy env vars
if let Some(mitm) = mitm_config {
// Go's SSL_CERT_FILE replaces the entire system cert pool, so we
// need a combined bundle: system CAs + our MITM CA
// Write to /tmp — accessible by antigravity-ls user
// (user's ~/.config/ is not traversable by other UIDs)
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
let system_ca =
std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
.map_err(|e| format!("Failed to write combined CA bundle: {e}"))?;
// Make readable by antigravity-ls user
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = std::fs::set_permissions(
&combined_ca_path,
std::fs::Permissions::from_mode(0o644),
);
}
info!(
proxy = %mitm.proxy_addr,
ca = %combined_ca_path,
"Setting MITM env vars on standalone LS (combined CA bundle)"
);
env_vars.push(("SSL_CERT_FILE".into(), combined_ca_path));
env_vars.push(("SSL_CERT_DIR".into(), "/dev/null".into()));
env_vars.push(("NODE_EXTRA_CA_CERTS".into(), mitm.ca_cert_path.clone()));
// Only set HTTPS_PROXY when iptables UID isolation is NOT available
// OR when running in headless mode (no sudo at all).
// With iptables, all outbound traffic is transparently redirected at the
// kernel level — setting HTTPS_PROXY on top causes double-proxying.
if headless || !has_ls_user() {
// proxy_addr already includes the scheme (e.g. "http://127.0.0.1:8742")
env_vars.push(("HTTPS_PROXY".into(), mitm.proxy_addr.clone()));
env_vars.push(("HTTP_PROXY".into(), mitm.proxy_addr.clone()));
// LD_PRELOAD DNS redirect: hooks getaddrinfo() so Google API domains
// resolve to 127.0.0.1. Combined with the port-modified endpoint URL,
// this makes the LS connect to our MITM proxy for ALL API calls —
// even the CodeAssistClient which has Proxy:nil hardcoded.
let so_path = build_dns_redirect_so();
if let Some(so) = so_path {
info!(path = %so, "Enabling LD_PRELOAD DNS redirect for headless MITM");
env_vars.push(("LD_PRELOAD".into(), so));
env_vars.push((
"DNS_REDIRECT_LOG".into(),
"/tmp/antigravity-dns-redirect.log".into(),
));
}
}
}
// In headless mode, never use sudo — run as current user
// In normal mode, use sudo if 'antigravity-ls' user exists
let use_sudo = !headless && has_ls_user();
let mut cmd = if use_sudo {
info!("Using UID isolation: spawning LS as 'antigravity-ls' user");
let mut c = Command::new("sudo");
c.args(["-n", "-u", LS_USER, "--", "/usr/bin/env"]);
for (k, v) in &env_vars {
c.arg(format!("{k}={v}"));
}
c.arg(LS_BINARY_PATH);
c.args(&args);
c
} else {
debug!("Spawning LS as current user");
let mut c = Command::new(LS_BINARY_PATH);
c.args(&args);
for (k, v) in &env_vars {
c.env(k, v);
}
c
};
// Capture stderr for debugging — logs to /tmp so we can diagnose LS failures
let stderr_file = std::fs::File::create("/tmp/antigravity-ls-debug.log")
.map_err(|e| format!("Failed to create LS debug log: {e}"))?;
cmd.stdin(Stdio::piped())
.stdout(Stdio::null())
.stderr(Stdio::from(stderr_file));
let mut child = cmd
.spawn()
.map_err(|e| format!("Failed to spawn LS binary: {e}"))?;
// Feed init metadata via stdin, then close it
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(&metadata)
.map_err(|e| format!("Failed to write init metadata to stdin: {e}"))?;
// stdin drops here → EOF (LS handles this fine in non-standalone mode)
}
info!(pid = child.id(), port, "Standalone LS spawned");
// When spawned via sudo, the child is the sudo process which exits after
// launching the LS as the target user. We need the actual LS PID for cleanup.
let ls_pid = if use_sudo {
// Give sudo a moment to spawn the real process
std::thread::sleep(std::time::Duration::from_millis(500));
// Find the LS process owned by antigravity-ls user
find_ls_pid_for_user(LS_USER).ok()
} else {
Some(child.id())
};
if let Some(pid) = ls_pid {
info!(
ls_pid = pid,
sudo = use_sudo,
"Discovered actual LS process"
);
}
Ok(StandaloneLS {
child,
ls_pid,
use_sudo,
killed: false,
port,
csrf: main_config.csrf.clone(),
})
}
/// Wait for the standalone LS to be ready (accepting TCP connections).
///
/// Retries up to `max_attempts` times with a 1-second delay between each.
pub async fn wait_ready(&mut self, max_attempts: u32) -> Result<(), String> {
info!(port = self.port, "Waiting for standalone LS to be ready...");
for attempt in 1..=max_attempts {
sleep(Duration::from_secs(1)).await;
// Check if the process is still alive
match self.child.try_wait() {
Ok(Some(status)) => {
return Err(format!(
"Standalone LS exited prematurely with status: {status}"
));
}
Ok(None) => {} // still running
Err(e) => {
return Err(format!("Failed to check LS process status: {e}"));
}
}
// Simple TCP connect check — if the LS is listening, it's ready
match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", self.port)).await {
Ok(_) => {
info!(attempt, "Standalone LS is ready (accepting connections)");
return Ok(());
}
Err(e) => {
debug!(attempt, error = %e, "LS not ready yet");
}
}
}
Err(format!(
"Standalone LS failed to become ready after {max_attempts} attempts on port {}",
self.port
))
}
/// Check if the child process is still running.
#[allow(dead_code)]
pub fn is_alive(&mut self) -> bool {
matches!(self.child.try_wait(), Ok(None))
}
/// Kill the standalone LS process.
pub fn kill(&mut self) {
if self.killed {
return;
}
self.killed = true;
info!("Killing standalone LS");
if self.use_sudo {
// The child is sudo which already exited. Kill the actual LS.
if let Some(pid) = self.ls_pid {
info!(pid, "Killing LS process via sudo -u {}", LS_USER);
// Run kill AS the antigravity-ls user (same UID can signal)
let ok = std::process::Command::new("sudo")
.args(["-n", "-u", LS_USER, "kill", "-TERM", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status()
.map(|s| s.success())
.unwrap_or(false);
if ok {
std::thread::sleep(std::time::Duration::from_millis(500));
// Force kill if still alive
let _ = std::process::Command::new("sudo")
.args(["-n", "-u", LS_USER, "kill", "-KILL", &pid.to_string()])
.stdout(Stdio::null())
.stderr(Stdio::null())
.status();
} else {
// Fallback: try with root sudo, then cleanup
info!("sudo -u kill failed, trying fallback cleanup");
cleanup_orphaned_ls();
}
} else {
// No PID recorded, try blanket cleanup
cleanup_orphaned_ls();
}
} else {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
}
impl Drop for StandaloneLS {
fn drop(&mut self) {
self.kill();
}
}

330
src/standalone/stub.rs Normal file
View File

@@ -0,0 +1,330 @@
//! Stub extension server — handles LS connections in headless mode.
use crate::proto::wire::{encode_proto_string, encode_varint, extract_proto_string};
use std::io::{BufRead, BufReader, Read, Write};
/// Handle a single connection from the LS to the stub extension server.
///
/// The LS uses Connect RPC (HTTP/1.1, ServerStream) to call ExtensionServerService methods.
/// ALL methods are ServerStream — responses use Connect streaming envelope framing:
/// [0x00 | len(4) | protobuf_data]... (0+ data messages)
/// [0x02 | len(4) | json_trailer] (exactly 1 end-of-stream)
///
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
/// If we immediately close it, the LS reconnects in a tight loop and never
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
pub fn stub_handle_connection(
conn: std::net::TcpStream,
oauth_token: &str,
oauth_topic_bytes: &Option<Vec<u8>>,
) {
let mut reader = BufReader::new(match conn.try_clone() {
Ok(c) => c,
Err(_) => return,
});
let mut writer = conn;
// Read the HTTP request line
let mut request_line = String::new();
match reader.read_line(&mut request_line) {
Ok(0) | Err(_) => return,
_ => {}
}
// Extract method path for logging
let path = request_line
.split_whitespace()
.nth(1)
.unwrap_or("/unknown")
.to_string();
// Read headers
let mut content_len: usize = 0;
loop {
let mut line = String::new();
if reader.read_line(&mut line).unwrap_or(0) == 0 {
return;
}
if line.trim().is_empty() {
break;
}
if line.to_lowercase().starts_with("content-length:") {
content_len = line
.split(':')
.nth(1)
.and_then(|v| v.trim().parse().ok())
.unwrap_or(0);
}
}
// Read body
let mut body = Vec::new();
if content_len > 0 {
body.resize(content_len, 0u8);
if Read::read_exact(&mut reader, &mut body).is_err() {
return;
}
}
// ─── Long-lived streams ──────────────────────────────────────────────
// SubscribeToUnifiedStateSyncTopic must stay open — the LS subscribes
// once and expects updates (OAuth, settings) delivered over this stream.
// If we close immediately, the LS reconnects in a tight loop (~30/sec).
if path.contains("SubscribeToUnifiedStateSyncTopic") {
handle_subscribe_stream(&mut writer, &body, &path, oauth_token, oauth_topic_bytes);
return;
}
// ─── Short-lived methods (everything else) ───────────────────────────
handle_short_lived(&mut writer, &body, &path, oauth_token);
}
/// Handle the long-lived SubscribeToUnifiedStateSyncTopic stream.
fn handle_subscribe_stream(
writer: &mut std::net::TcpStream,
body: &[u8],
path: &str,
oauth_token: &str,
oauth_topic_bytes: &Option<Vec<u8>>,
) {
// Parse the request body to extract the topic name.
// Connect envelope: [flag(1)] [len(4)] [proto(N)]
let proto_body = if body.len() > 5 {
&body[5..]
} else {
&body[..]
};
// SubscribeToUnifiedStateSyncTopicRequest { string topic = 1; }
let mut topic_name = String::new();
let mut i = 0;
while i < proto_body.len() {
let tag_byte = proto_body[i];
let field_num = tag_byte >> 3;
let wire_type = tag_byte & 0x07;
i += 1;
if wire_type == 2 && i < proto_body.len() {
let len = proto_body[i] as usize;
i += 1;
if i + len <= proto_body.len() {
if field_num == 1 {
topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
}
i += len;
} else {
break;
}
} else {
break;
}
}
eprintln!("[stub-ext] STREAM → {path} topic={topic_name:?}");
// Build initial_state bytes
let initial_state_bytes = build_initial_state(&topic_name, oauth_token, oauth_topic_bytes);
// Helper: wrap protobuf bytes in a Connect data envelope
let make_envelope = |proto: &[u8]| -> Vec<u8> {
let mut env = Vec::with_capacity(5 + proto.len());
env.push(0x00u8); // data flag
env.extend_from_slice(&(proto.len() as u32).to_be_bytes());
env.extend_from_slice(proto);
env
};
// Helper: write a chunk
let send_chunk = |w: &mut std::net::TcpStream, data: &[u8]| -> bool {
let hdr = format!("{:x}\r\n", data.len());
w.write_all(hdr.as_bytes()).is_ok()
&& w.write_all(data).is_ok()
&& w.write_all(b"\r\n").is_ok()
&& w.flush().is_ok()
};
// Build UnifiedStateSyncUpdate { initial_state = initial_state_bytes }
let mut initial_proto = Vec::new();
initial_proto.push(0x0A); // field 1 (initial_state), LEN
encode_varint(&mut initial_proto, initial_state_bytes.len() as u64);
initial_proto.extend_from_slice(&initial_state_bytes);
let initial_env = make_envelope(&initial_proto);
let header = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: application/connect+proto\r\n\
Transfer-Encoding: chunked\r\n\
\r\n"
);
if writer.write_all(header.as_bytes()).is_err() {
return;
}
if !send_chunk(writer, &initial_env) {
return;
}
eprintln!(
"[stub-ext] STREAM → sent initial_state ({} bytes)",
initial_state_bytes.len()
);
// Keep the stream alive with periodic valid messages.
// The LS has a ~10s read timeout on streams. After the initial_state,
// the LS only accepts AppliedUpdate (field 2 in the oneof).
// We send an empty AppliedUpdate {} every 5s as keepalive.
let keepalive_proto: &[u8] = &[0x12, 0x00]; // field 2 (applied_update), LEN=0
let keepalive_env = make_envelope(keepalive_proto);
loop {
std::thread::sleep(std::time::Duration::from_secs(5));
if !send_chunk(writer, &keepalive_env) {
break;
}
}
}
/// Build the initial_state bytes for a USS topic subscription.
fn build_initial_state(
topic_name: &str,
oauth_token: &str,
oauth_topic_bytes: &Option<Vec<u8>>,
) -> Vec<u8> {
let mut initial_state_bytes = Vec::new();
if topic_name == "uss-oauth" {
if let Some(topic_bytes) = oauth_topic_bytes {
// Use the exact Topic proto from Antigravity's state.vscdb.
initial_state_bytes = topic_bytes.clone();
eprintln!(
"[stub-ext] using state.vscdb topic ({} bytes)",
topic_bytes.len()
);
} else if !oauth_token.is_empty() {
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
let mut oauth_proto = Vec::new();
// field 1 (access_token), LEN
oauth_proto.push(0x0A);
encode_varint(&mut oauth_proto, oauth_token.len() as u64);
oauth_proto.extend_from_slice(oauth_token.as_bytes());
// field 2 (token_type), LEN
let token_type = b"Bearer";
oauth_proto.push(0x12);
encode_varint(&mut oauth_proto, token_type.len() as u64);
oauth_proto.extend_from_slice(token_type);
// field 4 (expiry) = Timestamp { seconds = 4_102_444_800 } (year 2099-12-31)
let mut ts_proto = Vec::new();
ts_proto.push(0x08); // field 1 (seconds), varint
encode_varint(&mut ts_proto, 4_102_444_800u64);
oauth_proto.push(0x22); // field 4 (expiry), LEN
encode_varint(&mut oauth_proto, ts_proto.len() as u64);
oauth_proto.extend_from_slice(&ts_proto);
use base64::Engine;
let b64_value = base64::engine::general_purpose::STANDARD.encode(&oauth_proto);
// Build Row { value = b64_value, e_tag = 1 }
let mut row = Vec::new();
row.push(0x0A); // field 1 (value), LEN
encode_varint(&mut row, b64_value.len() as u64);
row.extend_from_slice(b64_value.as_bytes());
row.push(0x10); // field 2 (e_tag), varint
row.push(0x01);
// Build map entry: { key = "oauthTokenInfoSentinelKey", value = row }
let key_str = b"oauthTokenInfoSentinelKey";
let mut map_entry = Vec::new();
map_entry.push(0x0A); // field 1 (key), LEN
encode_varint(&mut map_entry, key_str.len() as u64);
map_entry.extend_from_slice(key_str);
map_entry.push(0x12); // field 2 (value = Row), LEN
encode_varint(&mut map_entry, row.len() as u64);
map_entry.extend_from_slice(&row);
// Build Topic { data = [map_entry] }
initial_state_bytes.push(0x0A); // field 1 (data map), LEN
encode_varint(&mut initial_state_bytes, map_entry.len() as u64);
initial_state_bytes.extend_from_slice(&map_entry);
}
}
initial_state_bytes
}
/// Handle short-lived extension server methods.
fn handle_short_lived(
writer: &mut std::net::TcpStream,
body: &[u8],
path: &str,
oauth_token: &str,
) {
let is_noisy = path.contains("GetChromeDevtoolsMcpUrl")
|| path.contains("FetchMCPAuthToken")
|| path.contains("PushUnifiedStateSyncUpdate");
if !is_noisy {
eprintln!("[stub-ext] 200 OK → {path}");
}
// Build Connect streaming response body with proper envelope framing.
let mut envelope = Vec::new();
if path.contains("GetSecretValue") {
// Parse request body to extract the key (protobuf: field 1 = key, string)
let key = extract_proto_string(body, 1).unwrap_or_default();
eprintln!("[stub-ext] ← GetSecretValue key={key:?}");
if !oauth_token.is_empty() {
// Build protobuf: GetSecretValueResponse { string value = 1 }
let proto = encode_proto_string(1, oauth_token.as_bytes());
eprintln!(
"[stub-ext] → serving token ({} bytes) for key={key:?}",
oauth_token.len()
);
// Data envelope: flag=0x00, length, data
envelope.push(0x00u8);
envelope.extend_from_slice(&(proto.len() as u32).to_be_bytes());
envelope.extend_from_slice(&proto);
} else {
eprintln!("[stub-ext] ⚠ no OAuth token available for key={key:?}");
}
} else if path.contains("StoreSecretValue") {
// Parse and log what the LS is storing (for debugging)
let key = extract_proto_string(body, 1).unwrap_or_default();
let value = extract_proto_string(body, 2).unwrap_or_default();
let val_preview = if value.len() > 32 {
format!("{}...", &value[..32])
} else {
value
};
eprintln!("[stub-ext] ← StoreSecretValue key={key:?} value={val_preview:?}");
}
if path.contains("PushUnifiedStateSyncUpdate") {
// Unary proto — respond with empty PushUnifiedStateSyncUpdateResponse (0 bytes body)
let header = "HTTP/1.1 200 OK\r\n\
Content-Type: application/proto\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\
\r\n";
let _ = writer.write_all(header.as_bytes());
let _ = writer.flush();
return;
}
// End-of-stream envelope: flag=0x02, length=2, data="{}"
envelope.push(0x02u8);
envelope.extend_from_slice(&2u32.to_be_bytes());
envelope.extend_from_slice(b"{}");
// Respond with 200 OK + Connection: close (one request per connection)
let header = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: application/connect+proto\r\n\
Content-Length: {}\r\n\
Connection: close\r\n\
\r\n",
envelope.len()
);
let _ = writer.write_all(header.as_bytes());
let _ = writer.write_all(&envelope);
let _ = writer.flush();
}