refactor: decompose large functions and remove dead code
- Decompose modify_request() into 7 single-responsibility helpers - Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events - Promote connect_upstream/resolve_upstream to module-level functions - Split standalone.rs (1238 lines) into 4 submodules: standalone/mod.rs, spawn.rs, discovery.rs, stub.rs - Extract proto wire primitives into proto/wire.rs - Remove 6 dead MitmStore methods - Remove dead SessionResult, DEFAULT_SESSION, get_or_create - Remove dead decode_varint_at, extract_conversation_id - Clean all unused imports across 10 files - Suppress structural dead_code warnings on deserialization fields Warnings: 20 -> 0. All 43 tests pass.
This commit is contained in:
@@ -18,15 +18,9 @@ use super::util::{err_response, now_unix, upstream_err_response};
|
||||
use super::AppState;
|
||||
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
||||
|
||||
/// Extract a conversation/session ID from a flexible JSON value.
|
||||
/// Accepts a plain string or an object with an "id" field.
|
||||
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||
match conv {
|
||||
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// System fingerprint for completions responses (derived from crate version at compile time).
|
||||
fn system_fingerprint() -> String {
|
||||
@@ -187,10 +181,7 @@ pub(crate) async fn handle_completions(
|
||||
model_name, body.stream
|
||||
);
|
||||
|
||||
// Diagnostic: dump OpenCode's raw request
|
||||
if let Ok(pretty) = serde_json::to_string_pretty(&body) {
|
||||
let _ = std::fs::write("/tmp/opencode-request.json", &pretty);
|
||||
}
|
||||
|
||||
|
||||
let model = match lookup_model(model_name) {
|
||||
Some(m) => m,
|
||||
@@ -204,35 +195,28 @@ pub(crate) async fn handle_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// Store client tools from this request (or clear stale ones from other endpoints)
|
||||
if let Some(ref tools) = body.tools {
|
||||
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
state.mitm_store.set_tools(gemini_tools).await;
|
||||
if let Some(ref choice) = body.tool_choice {
|
||||
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
|
||||
state.mitm_store.set_tool_config(gemini_config).await;
|
||||
}
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Completions: stored client tools for MITM injection"
|
||||
);
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
// ── Build per-request state locally ──────────────────────────────────
|
||||
|
||||
// Convert OpenAI tools to Gemini format
|
||||
let tools = body.tools.as_ref().and_then(|t| {
|
||||
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(t);
|
||||
if gemini_tools.is_empty() { None } else {
|
||||
info!(count = t.len(), "Completions: client tools for MITM injection");
|
||||
Some(gemini_tools)
|
||||
}
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
});
|
||||
let tool_config = body.tools.as_ref().and_then(|_| {
|
||||
body.tool_choice.as_ref().map(|choice| {
|
||||
crate::mitm::modify::openai_tool_choice_to_gemini(choice)
|
||||
})
|
||||
});
|
||||
|
||||
// ── Extract tool results from messages for MITM injection ──────────
|
||||
// When OpenCode sends back tool results, the messages array contains:
|
||||
// 1. assistant message with tool_calls (the model's previous function calls)
|
||||
// 2. tool messages with results (the executed tool outputs)
|
||||
// We build ToolRounds: each round pairs one assistant's tool_calls with
|
||||
// the subsequent tool result messages. This enables correct per-turn
|
||||
// history rewriting for multi-step tool use.
|
||||
// Build ToolRounds from message history: each round pairs assistant tool_calls
|
||||
// with subsequent tool result messages. Local call_id_to_name mapping.
|
||||
let mut tool_rounds: Vec<ToolRound> = Vec::new();
|
||||
let mut call_id_to_name: std::collections::HashMap<String, String> = std::collections::HashMap::new();
|
||||
{
|
||||
let mut rounds: Vec<ToolRound> = Vec::new();
|
||||
let mut current_round: Option<ToolRound> = None;
|
||||
|
||||
for msg in &body.messages {
|
||||
@@ -241,7 +225,7 @@ pub(crate) async fn handle_completions(
|
||||
// Finalize any open round
|
||||
if let Some(round) = current_round.take() {
|
||||
if !round.calls.is_empty() {
|
||||
rounds.push(round);
|
||||
tool_rounds.push(round);
|
||||
}
|
||||
}
|
||||
// Start new round if this assistant has tool_calls
|
||||
@@ -255,14 +239,15 @@ pub(crate) async fn handle_completions(
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
let call_id = tc["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Register call_id → name for lookup
|
||||
// Register call_id → name locally
|
||||
if !call_id.is_empty() {
|
||||
state.mitm_store.register_call_id(call_id, name.clone()).await;
|
||||
call_id_to_name.insert(call_id, name.clone());
|
||||
}
|
||||
|
||||
calls.push(CapturedFunctionCall {
|
||||
name,
|
||||
args,
|
||||
thought_signature: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
@@ -281,16 +266,13 @@ pub(crate) async fn handle_completions(
|
||||
"tool" => {
|
||||
let text = extract_message_text(&msg.content);
|
||||
if let Some(ref call_id) = msg.tool_call_id {
|
||||
// Look up function name from call_id, fall back to
|
||||
// positional index within the current round's calls
|
||||
let result_index = current_round
|
||||
.as_ref()
|
||||
.map(|r| r.results.len())
|
||||
.unwrap_or(0);
|
||||
let name = state
|
||||
.mitm_store
|
||||
.lookup_call_id(call_id)
|
||||
.await
|
||||
let name = call_id_to_name
|
||||
.get(call_id.as_str())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| {
|
||||
current_round
|
||||
.as_ref()
|
||||
@@ -314,7 +296,7 @@ pub(crate) async fn handle_completions(
|
||||
// Any other role (user, system) finalizes the current round
|
||||
if let Some(round) = current_round.take() {
|
||||
if !round.calls.is_empty() {
|
||||
rounds.push(round);
|
||||
tool_rounds.push(round);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -323,69 +305,86 @@ pub(crate) async fn handle_completions(
|
||||
// Finalize last round
|
||||
if let Some(round) = current_round.take() {
|
||||
if !round.calls.is_empty() {
|
||||
rounds.push(round);
|
||||
tool_rounds.push(round);
|
||||
}
|
||||
}
|
||||
|
||||
if !rounds.is_empty() {
|
||||
if !tool_rounds.is_empty() {
|
||||
info!(
|
||||
round_count = rounds.len(),
|
||||
calls = ?rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
|
||||
"Completions: stored {} tool round(s) for MITM history rewrite",
|
||||
rounds.len(),
|
||||
round_count = tool_rounds.len(),
|
||||
calls = ?tool_rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
|
||||
"Completions: {} tool round(s) for MITM history rewrite",
|
||||
tool_rounds.len(),
|
||||
);
|
||||
// Also set last_function_calls from the latest round for proxy.rs recording compat
|
||||
if let Some(last_round) = rounds.last() {
|
||||
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
|
||||
|
||||
// Merge thought_signatures from MITM-captured function calls.
|
||||
// OpenAI format doesn't carry thought signatures, but Google requires
|
||||
// them when injecting functionCall parts back into history.
|
||||
let sigs = state.mitm_store.peek_thought_signatures().await;
|
||||
if !sigs.is_empty() {
|
||||
let mut merged = 0usize;
|
||||
for round in &mut tool_rounds {
|
||||
for fc in &mut round.calls {
|
||||
if fc.thought_signature.is_none() {
|
||||
if let Some(sig) = sigs.get(&fc.name) {
|
||||
fc.thought_signature = Some(sig.clone());
|
||||
merged += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if merged > 0 {
|
||||
info!(
|
||||
merged_count = merged,
|
||||
"Completions: merged {} thought_signature(s) from MITM capture",
|
||||
merged,
|
||||
);
|
||||
}
|
||||
}
|
||||
state.mitm_store.set_tool_rounds(rounds).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
// Build generation parameters locally
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let (response_mime_type, response_schema) = match body.response_format.as_ref() {
|
||||
Some(rf) => match rf.format_type.as_str() {
|
||||
"json_object" | "json" => (Some("application/json".to_string()), None),
|
||||
"json_schema" => {
|
||||
let schema = rf.json_schema.as_ref().and_then(|js| js.schema.clone());
|
||||
(Some("application/json".to_string()), schema)
|
||||
}
|
||||
_ => (None, None),
|
||||
},
|
||||
None => (None, None),
|
||||
};
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: None,
|
||||
max_output_tokens: body.max_tokens.or(body.max_completion_tokens),
|
||||
stop_sequences: body.stop.clone().map(|s| s.into_vec()),
|
||||
frequency_penalty: body.frequency_penalty,
|
||||
presence_penalty: body.presence_penalty,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
response_mime_type,
|
||||
response_schema,
|
||||
google_search: body.web_search,
|
||||
};
|
||||
let generation_params = if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.frequency_penalty.is_some()
|
||||
|| gp.presence_penalty.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let (response_mime_type, response_schema) = match body.response_format.as_ref() {
|
||||
Some(rf) => match rf.format_type.as_str() {
|
||||
"json_object" | "json" => (Some("application/json".to_string()), None),
|
||||
"json_schema" => {
|
||||
let schema = rf.json_schema.as_ref().and_then(|js| js.schema.clone());
|
||||
(Some("application/json".to_string()), schema)
|
||||
}
|
||||
_ => (None, None),
|
||||
},
|
||||
None => (None, None),
|
||||
};
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: None, // OpenAI doesn't have top_k
|
||||
max_output_tokens: body.max_tokens.or(body.max_completion_tokens),
|
||||
stop_sequences: body.stop.clone().map(|s| s.into_vec()),
|
||||
frequency_penalty: body.frequency_penalty,
|
||||
presence_penalty: body.presence_penalty,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
response_mime_type,
|
||||
response_schema,
|
||||
google_search: body.web_search,
|
||||
};
|
||||
// Only store if at least one param is set
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.frequency_penalty.is_some()
|
||||
|| gp.presence_penalty.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
}
|
||||
}
|
||||
Some(gp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let token = state.backend.oauth_token().await;
|
||||
if token.is_empty() {
|
||||
@@ -410,23 +409,8 @@ pub(crate) async fn handle_completions(
|
||||
warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n");
|
||||
}
|
||||
|
||||
// Session/conversation: reuse cascade if conversation ID provided
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
|
||||
// Helper to create a cascade (reuses session or creates fresh)
|
||||
let create_cascade = |state: Arc<AppState>, session_id: Option<String>| async move {
|
||||
if let Some(ref sid) = session_id {
|
||||
state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
.map(|sr| sr.cascade_id)
|
||||
} else {
|
||||
state.backend.create_cascade().await
|
||||
}
|
||||
};
|
||||
|
||||
let cascade_id = match create_cascade(Arc::clone(&state), session_id_str.clone()).await {
|
||||
// Always create a new cascade for every request
|
||||
let cascade_id = match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
@@ -437,40 +421,54 @@ pub(crate) async fn handle_completions(
|
||||
}
|
||||
};
|
||||
|
||||
// Send message on primary cascade
|
||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||
// Store real user text for MITM injection — LS gets a dummy prompt
|
||||
state.mitm_store.set_pending_user_text(user_text.clone()).await;
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
// Image for MITM injection
|
||||
let pending_image = image.as_ref().map(|img| {
|
||||
use base64::Engine;
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
});
|
||||
|
||||
// Pre-flight: install channel BEFORE send_message so the MITM proxy
|
||||
// can grab it when the LS fires its API call.
|
||||
// Only for streaming — sync paths use poll_for_response (legacy store).
|
||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||
let mitm_rx = if has_custom_tools && body.stream {
|
||||
state.mitm_store.clear_response_async().await;
|
||||
state.mitm_store.clear_upstream_error().await;
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
// Get last calls from the latest tool round (if any) for proxy recording compat
|
||||
let last_function_calls = tool_rounds.last()
|
||||
.map(|r| r.calls.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Build event channel for streaming
|
||||
let has_custom_tools = tools.is_some();
|
||||
let (mitm_rx, event_tx) = if has_custom_tools && body.stream {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(tx).await;
|
||||
Some(rx)
|
||||
(Some(rx), Some(tx))
|
||||
} else {
|
||||
None
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Build pending tool results from latest round
|
||||
let pending_tool_results = tool_rounds.last()
|
||||
.map(|r| r.results.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Register all per-request state atomically
|
||||
state.mitm_store.register_request(crate::mitm::store::RequestContext {
|
||||
cascade_id: cascade_id.clone(),
|
||||
pending_user_text: user_text.clone(),
|
||||
event_channel: event_tx,
|
||||
generation_params,
|
||||
pending_image,
|
||||
tools,
|
||||
tool_config,
|
||||
pending_tool_results,
|
||||
tool_rounds,
|
||||
last_function_calls,
|
||||
call_id_to_name,
|
||||
created_at: std::time::Instant::now(),
|
||||
}).await;
|
||||
|
||||
// Send REAL user text to LS
|
||||
match state
|
||||
.backend
|
||||
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
|
||||
.send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
@@ -481,7 +479,7 @@ pub(crate) async fn handle_completions(
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Backend returned {status}"),
|
||||
@@ -489,7 +487,7 @@ pub(crate) async fn handle_completions(
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send failed: {e}"),
|
||||
@@ -537,7 +535,7 @@ pub(crate) async fn handle_completions(
|
||||
// Send the same message on each extra cascade
|
||||
match state
|
||||
.backend
|
||||
.send_message_with_image(&cid, ".", model.model_enum, image.as_ref())
|
||||
.send_message_with_image(&cid, &format!(".<cid:{}>", cid), model.model_enum, image.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
@@ -775,7 +773,7 @@ async fn chat_completions_stream(
|
||||
)));
|
||||
}
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::ResponseComplete => {
|
||||
@@ -803,15 +801,15 @@ async fn chat_completions_stream(
|
||||
)));
|
||||
}
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
} else if !acc_thinking.is_empty() && !did_unblock_ls {
|
||||
// Thinking-only response — LS needs follow-up API calls.
|
||||
// Create a new channel and unblock the gate.
|
||||
did_unblock_ls = true;
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(new_tx).await;
|
||||
state.mitm_store.clear_request_in_flight();
|
||||
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
||||
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
*rx = new_rx;
|
||||
debug!(
|
||||
@@ -845,7 +843,7 @@ async fn chat_completions_stream(
|
||||
)));
|
||||
}
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
// Don't break — wait for more channel events
|
||||
@@ -861,7 +859,7 @@ async fn chat_completions_stream(
|
||||
None,
|
||||
)));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
continue 'channel_loop;
|
||||
@@ -878,7 +876,7 @@ async fn chat_completions_stream(
|
||||
}
|
||||
})).unwrap()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::Usage(u) => {
|
||||
@@ -891,7 +889,7 @@ async fn chat_completions_stream(
|
||||
}
|
||||
|
||||
// Channel closed or timeout — clean up
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
|
||||
// If we got here from timeout with content, emit what we have
|
||||
if !last_text.is_empty() || last_thinking_len > 0 {
|
||||
@@ -1026,7 +1024,7 @@ async fn chat_completions_stream(
|
||||
}
|
||||
})).unwrap()));
|
||||
// Always clear in-flight flag when stream ends
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
};
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ use axum::{
|
||||
};
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{
|
||||
@@ -40,6 +40,7 @@ pub(crate) struct GeminiRequest {
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Session/conversation ID.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub conversation: Option<serde_json::Value>,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
@@ -81,17 +82,8 @@ pub(crate) struct GeminiRequest {
|
||||
pub response_schema: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
use super::util::default_timeout;
|
||||
|
||||
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||
match conv {
|
||||
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build Gemini-format usageMetadata from MITM store.
|
||||
async fn build_usage_metadata(
|
||||
@@ -247,157 +239,127 @@ async fn handle_gemini_inner(
|
||||
}
|
||||
};
|
||||
|
||||
// Store tools directly in Gemini format (no conversion needed!)
|
||||
if let Some(ref tools) = body.tools {
|
||||
if !tools.is_empty() {
|
||||
state.mitm_store.set_tools(tools.clone()).await;
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Stored Gemini-native tools for MITM injection"
|
||||
);
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
if let Some(ref config) = body.tool_config {
|
||||
state.mitm_store.set_tool_config(config.clone()).await;
|
||||
}
|
||||
// ── Build per-request state locally ──────────────────────────────────
|
||||
|
||||
// Handle tool results (Gemini format: functionResponse)
|
||||
// Tools (already in Gemini format)
|
||||
let tools = body.tools.as_ref().and_then(|t| {
|
||||
if t.is_empty() { None } else {
|
||||
info!(count = t.len(), "Gemini-native tools for MITM injection");
|
||||
Some(t.clone())
|
||||
}
|
||||
});
|
||||
let tool_config = body.tool_config.clone();
|
||||
|
||||
// Tool results → collect (ToolRound built after cascade_id is known)
|
||||
let mut pending_tool_results: Vec<PendingToolResult> = Vec::new();
|
||||
if let Some(ref results) = body.tool_results {
|
||||
let mut pending: Vec<PendingToolResult> = Vec::new();
|
||||
for r in results {
|
||||
if let Some(fr) = r.get("functionResponse") {
|
||||
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
||||
// Legacy compat
|
||||
state
|
||||
.mitm_store
|
||||
.add_tool_result(PendingToolResult {
|
||||
name: name.clone(),
|
||||
result: response.clone(),
|
||||
})
|
||||
.await;
|
||||
pending.push(PendingToolResult {
|
||||
pending_tool_results.push(PendingToolResult {
|
||||
name,
|
||||
result: response,
|
||||
});
|
||||
}
|
||||
}
|
||||
if !pending.is_empty() {
|
||||
// Build a ToolRound from captured function calls + client results.
|
||||
// Accumulate with existing rounds for multi-round history rewriting.
|
||||
let last_calls = state.mitm_store.get_last_function_calls().await;
|
||||
let mut rounds = state.mitm_store.take_tool_rounds().await;
|
||||
rounds.push(crate::mitm::store::ToolRound {
|
||||
calls: last_calls,
|
||||
results: pending,
|
||||
});
|
||||
state.mitm_store.set_tool_rounds(rounds).await;
|
||||
}
|
||||
info!(
|
||||
count = results.len(),
|
||||
"Stored Gemini-native tool results for MITM injection (built tool round)"
|
||||
"Gemini-native tool results (will build tool round after cascade_id)"
|
||||
);
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: body.top_k,
|
||||
max_output_tokens: body.max_output_tokens,
|
||||
stop_sequences: body.stop_sequences.clone(),
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
reasoning_effort: body.thinking_level.clone(),
|
||||
response_mime_type: body.response_mime_type.clone(),
|
||||
response_schema: body.response_schema.clone(),
|
||||
google_search: body.google_search,
|
||||
};
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.top_k.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
}
|
||||
}
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||
match state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
{
|
||||
Ok(sr) => sr.cascade_id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
// Generation parameters
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: body.top_k,
|
||||
max_output_tokens: body.max_output_tokens,
|
||||
stop_sequences: body.stop_sequences.clone(),
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
reasoning_effort: body.thinking_level.clone(),
|
||||
response_mime_type: body.response_mime_type.clone(),
|
||||
response_schema: body.response_schema.clone(),
|
||||
google_search: body.google_search,
|
||||
};
|
||||
|
||||
// Send message
|
||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||
// Store real user text for MITM injection — LS gets a dummy prompt
|
||||
state.mitm_store.set_pending_user_text(user_text.clone()).await;
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// Pre-flight: install channel BEFORE send_message so the MITM proxy
|
||||
// can grab it when the LS fires its API call.
|
||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||
let mitm_rx = if has_custom_tools {
|
||||
state.mitm_store.clear_response_async().await;
|
||||
state.mitm_store.clear_upstream_error().await;
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(tx).await;
|
||||
Some(rx)
|
||||
let generation_params = if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.top_k.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
Some(gp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Always create a new cascade for every request
|
||||
let cascade_id = match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Image for MITM injection
|
||||
let pending_image = image.as_ref().map(|img| {
|
||||
use base64::Engine;
|
||||
crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
});
|
||||
|
||||
// Build event channel for streaming
|
||||
let has_custom_tools = tools.is_some();
|
||||
let (mitm_rx, event_tx) = if has_custom_tools {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
(Some(rx), Some(tx))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Build tool rounds now that cascade_id is known
|
||||
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
|
||||
if !pending_tool_results.is_empty() {
|
||||
let last_calls = state.mitm_store.take_function_calls(&cascade_id).await
|
||||
.unwrap_or_default();
|
||||
tool_rounds.push(crate::mitm::store::ToolRound {
|
||||
calls: last_calls,
|
||||
results: pending_tool_results.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Register all per-request state atomically
|
||||
state.mitm_store.register_request(crate::mitm::store::RequestContext {
|
||||
cascade_id: cascade_id.clone(),
|
||||
pending_user_text: user_text.clone(),
|
||||
event_channel: event_tx,
|
||||
generation_params,
|
||||
pending_image,
|
||||
tools,
|
||||
tool_config,
|
||||
pending_tool_results,
|
||||
tool_rounds,
|
||||
last_function_calls: Vec::new(),
|
||||
call_id_to_name: std::collections::HashMap::new(),
|
||||
created_at: std::time::Instant::now(),
|
||||
}).await;
|
||||
|
||||
// Send REAL user text to LS (no more dummy ".")
|
||||
match state
|
||||
.backend
|
||||
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
|
||||
.send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
@@ -408,7 +370,7 @@ async fn handle_gemini_inner(
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Antigravity returned {status}"),
|
||||
@@ -416,7 +378,7 @@ async fn handle_gemini_inner(
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send message failed: {e}"),
|
||||
@@ -478,7 +440,7 @@ async fn gemini_sync(
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return Json(serde_json::json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
@@ -500,8 +462,8 @@ async fn gemini_sync(
|
||||
// Thinking-only — LS needs to make a follow-up request.
|
||||
// Reinstall channel and unblock gate.
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(new_tx).await;
|
||||
state.mitm_store.clear_request_in_flight();
|
||||
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
||||
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
rx = new_rx;
|
||||
debug!(
|
||||
@@ -515,7 +477,7 @@ async fn gemini_sync(
|
||||
parts.push(serde_json::json!({"text": t, "thought": true}));
|
||||
}
|
||||
parts.push(serde_json::json!({"text": acc_text}));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return Json(serde_json::json!({
|
||||
"candidates": [{
|
||||
"content": {
|
||||
@@ -530,14 +492,14 @@ async fn gemini_sync(
|
||||
.into_response();
|
||||
}
|
||||
MitmEvent::UpstreamError(err) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return upstream_err_response(&err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Timeout
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return (
|
||||
axum::http::StatusCode::GATEWAY_TIMEOUT,
|
||||
Json(serde_json::json!({
|
||||
@@ -703,7 +665,7 @@ async fn gemini_stream(
|
||||
"modelVersion": model_name,
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::ResponseComplete => {
|
||||
@@ -722,15 +684,15 @@ async fn gemini_stream(
|
||||
"modelVersion": model_name,
|
||||
})).unwrap_or_default()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
} else if !last_thinking.is_empty() && !did_unblock_ls {
|
||||
// Thinking-only response — LS needs follow-up API calls.
|
||||
// Create a new channel and unblock the gate.
|
||||
did_unblock_ls = true;
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(new_tx).await;
|
||||
state.mitm_store.clear_request_in_flight();
|
||||
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
||||
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
rx = new_rx;
|
||||
debug!(
|
||||
@@ -752,7 +714,7 @@ async fn gemini_stream(
|
||||
}
|
||||
})).unwrap()));
|
||||
yield Ok(Event::default().data("[DONE]"));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {}
|
||||
@@ -760,7 +722,7 @@ async fn gemini_stream(
|
||||
}
|
||||
|
||||
// Timeout or channel closed
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
yield Ok::<_, std::convert::Infallible>(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("Timeout: no response from Google API after {timeout}s"),
|
||||
|
||||
@@ -5,6 +5,7 @@ mod gemini;
|
||||
mod models;
|
||||
mod polling;
|
||||
mod responses;
|
||||
mod search;
|
||||
|
||||
mod types;
|
||||
mod util;
|
||||
@@ -48,6 +49,8 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
post(gemini::handle_gemini_v1beta),
|
||||
)
|
||||
.route("/v1/models", get(handle_models))
|
||||
.route("/v1/search", get(search::handle_search_get))
|
||||
.route("/v1/search", post(search::handle_search_post))
|
||||
.route("/v1/sessions", get(handle_list_sessions))
|
||||
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||
.route("/v1/token", post(handle_set_token))
|
||||
|
||||
@@ -142,14 +142,9 @@ fn extract_responses_input(
|
||||
(final_text, tool_results, image)
|
||||
}
|
||||
|
||||
/// Extract conversation/session ID from Responses API `conversation` field.
|
||||
fn extract_conversation_id(conv: &Option<serde_json::Value>) -> Option<String> {
|
||||
match conv {
|
||||
Some(serde_json::Value::String(s)) => Some(s.clone()),
|
||||
Some(obj) => obj["id"].as_str().map(|s| s.to_string()),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
/// Response-specific data for building a Response object.
|
||||
struct ResponseData {
|
||||
@@ -241,47 +236,26 @@ pub(crate) async fn handle_responses(
|
||||
|
||||
// Handle tool result submission (function_call_output in input)
|
||||
let is_tool_result_turn = !tool_results.is_empty();
|
||||
if is_tool_result_turn {
|
||||
let mut pending: Vec<PendingToolResult> = Vec::new();
|
||||
for tr in &tool_results {
|
||||
// Look up function name from call_id
|
||||
let name = state
|
||||
.mitm_store
|
||||
.lookup_call_id(&tr.call_id)
|
||||
.await
|
||||
.unwrap_or_else(|| "unknown_function".to_string());
|
||||
let mut pending_tool_results: Vec<PendingToolResult> = Vec::new();
|
||||
|
||||
if is_tool_result_turn {
|
||||
for tr in &tool_results {
|
||||
// For tool result turns, we use the call_id as the name directly.
|
||||
// The proxy captured function calls (with real names) are paired in
|
||||
// the ToolRound when we know the cascade_id later.
|
||||
let name = tr.call_id.clone();
|
||||
|
||||
// Parse the output as JSON, fall back to string wrapper
|
||||
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
||||
|
||||
// Also store as pending (legacy compat)
|
||||
state
|
||||
.mitm_store
|
||||
.add_tool_result(PendingToolResult {
|
||||
name: name.clone(),
|
||||
result: result_value.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
pending.push(PendingToolResult {
|
||||
pending_tool_results.push(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
});
|
||||
}
|
||||
// Build a ToolRound from the MITM-captured function calls + client results.
|
||||
// get_last_function_calls() has the calls from Google's previous response.
|
||||
// We take existing accumulated rounds and append this new round.
|
||||
let last_calls = state.mitm_store.get_last_function_calls().await;
|
||||
let mut rounds = state.mitm_store.take_tool_rounds().await;
|
||||
rounds.push(crate::mitm::store::ToolRound {
|
||||
calls: last_calls,
|
||||
results: pending,
|
||||
});
|
||||
state.mitm_store.set_tool_rounds(rounds).await;
|
||||
info!(
|
||||
count = tool_results.len(),
|
||||
"Stored tool results for MITM injection (built tool round)"
|
||||
"Tool results for MITM injection (will build tool round after cascade_id)"
|
||||
);
|
||||
}
|
||||
|
||||
@@ -293,7 +267,8 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
|
||||
// Store client tools in MitmStore for MITM injection
|
||||
// ── Build per-request state locally ──────────────────────────────────
|
||||
|
||||
// Detect web_search_preview tool (OpenAI spec) → enable Google Search grounding
|
||||
let has_web_search = body.tools.as_ref().map_or(false, |tools| {
|
||||
tools.iter().any(|t| {
|
||||
@@ -301,27 +276,20 @@ pub(crate) async fn handle_responses(
|
||||
t_type == "web_search_preview" || t_type == "web_search"
|
||||
})
|
||||
});
|
||||
if let Some(ref tools) = body.tools {
|
||||
let gemini_tools = openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
state.mitm_store.set_tools(gemini_tools).await;
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Stored client tools for MITM injection"
|
||||
);
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
if let Some(ref choice) = body.tool_choice {
|
||||
let gemini_config = openai_tool_choice_to_gemini(choice);
|
||||
state.mitm_store.set_tool_config(gemini_config).await;
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
// Extract text.format for structured output (json_schema)
|
||||
// Convert OpenAI tools to Gemini format
|
||||
let tools = body.tools.as_ref().and_then(|t| {
|
||||
let gemini_tools = openai_tools_to_gemini(t);
|
||||
if gemini_tools.is_empty() { None } else {
|
||||
info!(count = t.len(), "Client tools for MITM injection");
|
||||
Some(gemini_tools)
|
||||
}
|
||||
});
|
||||
let tool_config = body.tool_choice.as_ref().map(|choice| {
|
||||
openai_tool_choice_to_gemini(choice)
|
||||
});
|
||||
|
||||
// Build generation params locally
|
||||
let (response_mime_type, response_schema, text_format) = if let Some(ref text_val) = body.text {
|
||||
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
||||
if fmt_type == "json_schema" {
|
||||
@@ -345,100 +313,98 @@ pub(crate) async fn handle_responses(
|
||||
} else {
|
||||
(None, None, TextFormat::default())
|
||||
};
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: None,
|
||||
max_output_tokens: body.max_output_tokens,
|
||||
stop_sequences: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
response_mime_type,
|
||||
response_schema,
|
||||
google_search: has_web_search,
|
||||
};
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
}
|
||||
}
|
||||
|
||||
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||
match state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
{
|
||||
Ok(sr) => sr.cascade_id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
temperature: body.temperature,
|
||||
top_p: body.top_p,
|
||||
top_k: None,
|
||||
max_output_tokens: body.max_output_tokens,
|
||||
stop_sequences: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
response_mime_type,
|
||||
response_schema,
|
||||
google_search: has_web_search,
|
||||
};
|
||||
|
||||
// Send message
|
||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||
// Store real user text for MITM injection — LS gets a dummy prompt
|
||||
state.mitm_store.set_pending_user_text(user_text.clone()).await;
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// Pre-flight: install channel BEFORE send_message so the MITM proxy
|
||||
// can grab it when the LS fires its API call.
|
||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||
let mitm_rx = if has_custom_tools {
|
||||
state.mitm_store.clear_response_async().await;
|
||||
state.mitm_store.clear_upstream_error().await;
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(tx).await;
|
||||
Some(rx)
|
||||
let generation_params = if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
Some(gp)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Always create a new cascade for every request
|
||||
let cascade_id = match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("StartCascade failed: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Image for MITM injection
|
||||
let pending_image = image.as_ref().map(|img| {
|
||||
use base64::Engine;
|
||||
crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
});
|
||||
|
||||
// Build event channel
|
||||
let has_custom_tools = tools.is_some();
|
||||
let (mitm_rx, event_tx) = if has_custom_tools {
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||
(Some(rx), Some(tx))
|
||||
} else {
|
||||
(None, None)
|
||||
};
|
||||
|
||||
// Build tool rounds now that cascade_id is known
|
||||
let mut tool_rounds: Vec<crate::mitm::store::ToolRound> = Vec::new();
|
||||
if is_tool_result_turn && !pending_tool_results.is_empty() {
|
||||
// Get last captured function calls from the previous request context
|
||||
let last_calls = state.mitm_store.take_function_calls(&cascade_id).await
|
||||
.unwrap_or_default();
|
||||
tool_rounds.push(crate::mitm::store::ToolRound {
|
||||
calls: last_calls,
|
||||
results: pending_tool_results.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
// Register all per-request state atomically
|
||||
state.mitm_store.register_request(crate::mitm::store::RequestContext {
|
||||
cascade_id: cascade_id.clone(),
|
||||
pending_user_text: user_text.clone(),
|
||||
event_channel: event_tx,
|
||||
generation_params,
|
||||
pending_image,
|
||||
tools,
|
||||
tool_config,
|
||||
pending_tool_results,
|
||||
tool_rounds,
|
||||
last_function_calls: Vec::new(),
|
||||
call_id_to_name: std::collections::HashMap::new(),
|
||||
created_at: std::time::Instant::now(),
|
||||
}).await;
|
||||
|
||||
// Send REAL user text to LS
|
||||
match state
|
||||
.backend
|
||||
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
|
||||
.send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
@@ -449,7 +415,7 @@ pub(crate) async fn handle_responses(
|
||||
});
|
||||
}
|
||||
Ok((status, _)) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Antigravity returned {status}"),
|
||||
@@ -457,7 +423,7 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!("Send message failed: {e}"),
|
||||
@@ -644,7 +610,7 @@ async fn handle_responses_sync(
|
||||
|
||||
let mut acc_text = String::new();
|
||||
let mut acc_thinking: Option<String> = None;
|
||||
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
|
||||
let mut _last_usage: Option<crate::mitm::store::ApiUsage> = None;
|
||||
|
||||
while let Some(event) = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
|
||||
@@ -654,7 +620,7 @@ async fn handle_responses_sync(
|
||||
match event {
|
||||
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
|
||||
MitmEvent::TextDelta(t) => { acc_text = t; }
|
||||
MitmEvent::Usage(u) => { last_usage = Some(u); }
|
||||
MitmEvent::Usage(u) => { _last_usage = Some(u); }
|
||||
MitmEvent::Grounding(_) => {} // stored by proxy directly
|
||||
MitmEvent::FunctionCall(raw_calls) => {
|
||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||
@@ -668,14 +634,14 @@ async fn handle_responses_sync(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state.mitm_store.register_call_id(&cascade_id, call_id.clone(), fc.name.clone()).await;
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
}
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None, ¶ms.user_text, "",
|
||||
).await;
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
id: response_id,
|
||||
@@ -700,8 +666,8 @@ async fn handle_responses_sync(
|
||||
// Thinking-only — LS needs to make a follow-up request.
|
||||
// Reinstall channel and unblock gate.
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(new_tx).await;
|
||||
state.mitm_store.clear_request_in_flight();
|
||||
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
||||
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
rx = new_rx;
|
||||
debug!(
|
||||
@@ -713,7 +679,7 @@ async fn handle_responses_sync(
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None, ¶ms.user_text, &acc_text,
|
||||
).await;
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
if let Some(ref t) = acc_thinking {
|
||||
@@ -738,14 +704,14 @@ async fn handle_responses_sync(
|
||||
return Json(resp).into_response();
|
||||
}
|
||||
MitmEvent::UpstreamError(err) => {
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return upstream_err_response(&err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Timeout
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::GATEWAY_TIMEOUT,
|
||||
format!("Timeout: no response from Google API after {timeout}s"),
|
||||
@@ -789,7 +755,7 @@ async fn handle_responses_sync(
|
||||
// Register call_id → name mapping for tool result routing
|
||||
state
|
||||
.mitm_store
|
||||
.register_call_id(call_id.clone(), fc.name.clone())
|
||||
.register_call_id(&cascade_id, call_id.clone(), fc.name.clone())
|
||||
.await;
|
||||
|
||||
// Stringify args (OpenAI sends arguments as JSON string)
|
||||
@@ -1092,7 +1058,7 @@ async fn handle_responses_stream(
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state.mitm_store.register_call_id(&cascade_id, call_id.clone(), fc.name.clone()).await;
|
||||
let fc_item_id = format!("fc_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
yield Ok(responses_sse_event(
|
||||
@@ -1166,7 +1132,7 @@ async fn handle_responses_stream(
|
||||
"response": response_to_json(&final_resp),
|
||||
}),
|
||||
));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::ResponseComplete => {
|
||||
@@ -1184,14 +1150,14 @@ async fn handle_responses_stream(
|
||||
) {
|
||||
yield Ok(evt);
|
||||
}
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
} else if !last_thinking.is_empty() {
|
||||
// Thinking-only response — LS needs follow-up API calls.
|
||||
// Create a new channel and unblock the gate.
|
||||
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||
state.mitm_store.set_channel(new_tx).await;
|
||||
state.mitm_store.clear_request_in_flight();
|
||||
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
||||
|
||||
let _ = state.mitm_store.take_any_function_calls().await;
|
||||
rx = new_rx;
|
||||
debug!(
|
||||
@@ -1220,7 +1186,7 @@ async fn handle_responses_stream(
|
||||
},
|
||||
}),
|
||||
));
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return;
|
||||
}
|
||||
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {
|
||||
@@ -1230,7 +1196,7 @@ async fn handle_responses_stream(
|
||||
}
|
||||
|
||||
// Timeout in channel mode
|
||||
state.mitm_store.drop_channel().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
yield Ok(responses_sse_event(
|
||||
"response.failed",
|
||||
serde_json::json!({
|
||||
|
||||
@@ -33,6 +33,7 @@ pub(crate) struct SearchRequest {
|
||||
pub timeout: u64,
|
||||
/// Conversation/session ID for context reuse.
|
||||
#[serde(default)]
|
||||
#[allow(dead_code)]
|
||||
pub conversation: Option<String>,
|
||||
/// Max output tokens — keep low since we only want grounding metadata.
|
||||
#[serde(default = "default_search_max_tokens")]
|
||||
@@ -111,19 +112,13 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
|
||||
);
|
||||
}
|
||||
|
||||
// Enable Google Search grounding via GenerationParams
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
max_output_tokens: Some(body.max_output_tokens),
|
||||
google_search: true,
|
||||
..Default::default()
|
||||
};
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
}
|
||||
|
||||
// Clear any stale tools — we only want googleSearch
|
||||
state.mitm_store.clear_tools().await;
|
||||
// Build generation params with Google Search grounding enabled
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
max_output_tokens: Some(body.max_output_tokens),
|
||||
google_search: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create a prompt that encourages the model to ground its response
|
||||
let search_prompt = format!(
|
||||
@@ -131,49 +126,41 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
|
||||
body.query
|
||||
);
|
||||
|
||||
// Session management
|
||||
let session_id_str = body.conversation.clone();
|
||||
let cascade_id = if let Some(ref sid) = session_id_str {
|
||||
match state
|
||||
.sessions
|
||||
.get_or_create(Some(sid), || state.backend.create_cascade())
|
||||
.await
|
||||
{
|
||||
Ok(sr) => sr.cascade_id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to create session: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(id) => id,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to create cascade: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
// Always create a new cascade for every request
|
||||
let cascade_id = match state.backend.create_cascade().await {
|
||||
Ok(cid) => cid,
|
||||
Err(e) => {
|
||||
return err_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to create cascade: {e}"),
|
||||
"server_error",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// Set active cascade for MITM correlation
|
||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||
// Store real search prompt for MITM injection — LS gets a dummy prompt
|
||||
state.mitm_store.set_pending_user_text(search_prompt.clone()).await;
|
||||
// Register per-request state — no tools, just generation params for search grounding
|
||||
state.mitm_store.register_request(crate::mitm::store::RequestContext {
|
||||
cascade_id: cascade_id.clone(),
|
||||
pending_user_text: search_prompt.clone(),
|
||||
event_channel: None,
|
||||
generation_params: Some(gp),
|
||||
pending_image: None,
|
||||
tools: None,
|
||||
tool_config: None,
|
||||
pending_tool_results: Vec::new(),
|
||||
tool_rounds: Vec::new(),
|
||||
last_function_calls: Vec::new(),
|
||||
call_id_to_name: std::collections::HashMap::new(),
|
||||
created_at: std::time::Instant::now(),
|
||||
}).await;
|
||||
|
||||
// Send the search message
|
||||
// Send dot to LS — real search prompt injected by MITM proxy
|
||||
if let Err(e) = state
|
||||
.backend
|
||||
.send_message(&cascade_id, ".", model.model_enum)
|
||||
.send_message(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum)
|
||||
.await
|
||||
{
|
||||
state.mitm_store.clear_active_cascade().await;
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
return err_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Failed to send search message: {e}"),
|
||||
@@ -199,8 +186,7 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
|
||||
};
|
||||
|
||||
// Clean up
|
||||
state.mitm_store.clear_active_cascade().await;
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
state.mitm_store.remove_request(&cascade_id).await;
|
||||
state.mitm_store.clear_response_async().await;
|
||||
|
||||
// Build the search response
|
||||
|
||||
@@ -17,6 +17,7 @@ pub(crate) struct ResponsesRequest {
|
||||
pub stream: bool,
|
||||
#[serde(default = "default_timeout")]
|
||||
pub timeout: u64,
|
||||
#[allow(dead_code)]
|
||||
pub conversation: Option<serde_json::Value>,
|
||||
#[serde(default = "default_true")]
|
||||
pub store: bool,
|
||||
@@ -189,9 +190,7 @@ pub(crate) struct CompletionMessage {
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
fn default_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
use super::util::default_timeout;
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
|
||||
@@ -122,10 +122,17 @@ pub(crate) fn now_unix() -> u64 {
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
/// Default request timeout in seconds (used by serde defaults).
|
||||
pub(crate) fn default_timeout() -> u64 {
|
||||
120
|
||||
}
|
||||
|
||||
|
||||
|
||||
pub(crate) fn responses_sse_event(event_type: &str, data: serde_json::Value) -> Event {
|
||||
Event::default()
|
||||
.event(event_type)
|
||||
.data(serde_json::to_string(&data).unwrap())
|
||||
.data(serde_json::to_string(&data).unwrap_or_default())
|
||||
}
|
||||
|
||||
// ─── Image extraction ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -412,7 +412,8 @@ impl Backend {
|
||||
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
|
||||
|
||||
// Connect protocol envelope: [flags:1][length:4][payload]
|
||||
let json_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let json_bytes = serde_json::to_vec(&body)
|
||||
.map_err(|e| format!("{rpc_method} JSON serialize error: {e}"))?;
|
||||
let mut envelope = Vec::with_capacity(5 + json_bytes.len());
|
||||
envelope.push(0x00);
|
||||
envelope.extend_from_slice(&(json_bytes.len() as u32).to_be_bytes());
|
||||
|
||||
@@ -129,14 +129,21 @@ impl StreamingAccumulator {
|
||||
else if let Some(fc) = part.get("functionCall") {
|
||||
let name = fc["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let args = fc["args"].clone();
|
||||
// thoughtSignature is a SIBLING of functionCall in the part,
|
||||
// not nested inside functionCall
|
||||
let thought_signature = part.get("thoughtSignature")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
info!(
|
||||
tool_name = %name,
|
||||
tool_args = %args,
|
||||
has_thought_sig = thought_signature.is_some(),
|
||||
"MITM: Google returned functionCall!"
|
||||
);
|
||||
self.function_calls.push(CapturedFunctionCall {
|
||||
name,
|
||||
args,
|
||||
thought_signature,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
@@ -300,42 +307,51 @@ fn extract_usage_from_message(msg: &Value) -> Option<ApiUsage> {
|
||||
|
||||
/// Try to identify a cascade ID from the request body.
|
||||
///
|
||||
/// The LS includes cascade-related metadata in its API requests (as part of
|
||||
/// the system prompt or metadata field). We try to find it.
|
||||
/// Priority:
|
||||
/// 1. `<cid:UUID>` marker embedded by our proxy in the user message content
|
||||
/// 2. `requestId` field: `agent/{timestamp}/{cascade_uuid}/{sequence}` format
|
||||
/// 3. `metadata.user_id` fallback
|
||||
pub fn extract_cascade_hint(request_body: &[u8]) -> Option<String> {
|
||||
let json: Value = serde_json::from_slice(request_body).ok()?;
|
||||
|
||||
// Check for metadata field (some API configurations include it)
|
||||
if let Some(metadata) = json.get("metadata") {
|
||||
if let Some(user_id) = metadata["user_id"].as_str() {
|
||||
// The LS often sets user_id to the cascadeId
|
||||
return Some(user_id.to_string());
|
||||
// Fast path: look for <cid:UUID> marker in raw bytes (avoid JSON parse)
|
||||
let body_str = std::str::from_utf8(request_body).ok()?;
|
||||
if let Some(start) = body_str.find("<cid:") {
|
||||
let rest = &body_str[start + 5..];
|
||||
if let Some(end) = rest.find('>') {
|
||||
let candidate = &rest[..end];
|
||||
// Validate UUID format
|
||||
if candidate.len() == 36
|
||||
&& candidate.chars().filter(|c| *c == '-').count() == 4
|
||||
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||
{
|
||||
return Some(candidate.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check system prompt for cascade/workspace markers
|
||||
if let Some(system) = json.get("system") {
|
||||
let system_str = match system {
|
||||
Value::String(s) => s.clone(),
|
||||
Value::Array(arr) => {
|
||||
// Array of content blocks
|
||||
arr.iter()
|
||||
.filter_map(|b| b["text"].as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ")
|
||||
}
|
||||
_ => return None,
|
||||
};
|
||||
// Look for workspace_id or cascade_id patterns
|
||||
if let Some(pos) = system_str.find("workspace_id") {
|
||||
let rest = &system_str[pos..];
|
||||
// Extract the value after workspace_id
|
||||
if let Some(val) = rest.split_whitespace().nth(1) {
|
||||
return Some(val.to_string());
|
||||
let json: Value = serde_json::from_slice(request_body).ok()?;
|
||||
|
||||
// Secondary: extract cascade UUID from requestId field
|
||||
// Format: "agent/{timestamp}/{cascade_uuid}/{sequence}"
|
||||
if let Some(request_id) = json.get("requestId").and_then(|v| v.as_str()) {
|
||||
let parts: Vec<&str> = request_id.split('/').collect();
|
||||
if parts.len() >= 3 {
|
||||
let candidate = parts[2];
|
||||
if candidate.len() == 36
|
||||
&& candidate.chars().filter(|c| *c == '-').count() == 4
|
||||
&& candidate.chars().all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||
{
|
||||
return Some(candidate.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: check metadata.user_id
|
||||
if let Some(metadata) = json.get("metadata") {
|
||||
if let Some(user_id) = metadata["user_id"].as_str() {
|
||||
return Some(user_id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
1327
src/mitm/modify.rs
1327
src/mitm/modify.rs
File diff suppressed because it is too large
Load Diff
@@ -37,6 +37,9 @@ use flate2::read::GzDecoder;
|
||||
use std::io::Read;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
// Re-import the shared varint decoder under the name used throughout this module
|
||||
use crate::proto::wire::decode_varint as read_varint;
|
||||
|
||||
/// A decoded protobuf field.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ProtoValue {
|
||||
@@ -260,26 +263,7 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a varint from a byte slice. Returns (value, bytes_consumed).
|
||||
pub fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
|
||||
let mut result: u64 = 0;
|
||||
let mut shift = 0;
|
||||
|
||||
for (i, &byte) in data.iter().enumerate() {
|
||||
if i >= 10 {
|
||||
return None; // Too many bytes for a varint
|
||||
}
|
||||
|
||||
result |= ((byte & 0x7F) as u64) << shift;
|
||||
shift += 7;
|
||||
|
||||
if byte & 0x80 == 0 {
|
||||
return Some((result, i + 1));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Search a decoded protobuf message tree for usage-like structures.
|
||||
///
|
||||
|
||||
@@ -383,137 +383,13 @@ async fn handle_http_over_tls(
|
||||
// Reusable upstream connection — created lazily, reconnected if stale
|
||||
let mut upstream: Option<tokio_rustls::client::TlsStream<TcpStream>> = None;
|
||||
|
||||
/// Connect (or reconnect) to the real upstream via TLS.
|
||||
///
|
||||
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
|
||||
/// then falls back to cached IPs file, then to normal system resolution.
|
||||
async fn connect_upstream(
|
||||
domain: &str,
|
||||
config: &Arc<rustls::ClientConfig>,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
||||
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
||||
|
||||
// Try to resolve the real IP, bypassing /etc/hosts
|
||||
let addr = resolve_upstream(domain).await;
|
||||
info!(domain, addr = %addr, "MITM: connecting upstream");
|
||||
|
||||
let tcp = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(15),
|
||||
TcpStream::connect(&addr),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
|
||||
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
|
||||
};
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
||||
.map_err(|e| format!("Invalid server name: {e}"))?;
|
||||
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(15),
|
||||
connector.connect(server_name, tcp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => {
|
||||
info!(domain, "MITM: upstream TLS connected ✓");
|
||||
Ok(s)
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
|
||||
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve upstream IP bypassing /etc/hosts.
|
||||
async fn resolve_upstream(domain: &str) -> String {
|
||||
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
|
||||
if let Ok(output) = tokio::process::Command::new("dig")
|
||||
.args(["+short", "@8.8.8.8", domain])
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
let out = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(ip) = out
|
||||
.lines()
|
||||
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
|
||||
{
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Try cached IPs file (written by dns-redirect.sh install)
|
||||
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
|
||||
for line in contents.lines() {
|
||||
if let Some((d, ip)) = line.split_once('=') {
|
||||
if d == domain {
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to normal resolution (may hit /etc/hosts)
|
||||
format!("{domain}:443")
|
||||
}
|
||||
|
||||
// Keep-alive loop: handle multiple requests on this connection
|
||||
loop {
|
||||
// ── Read the HTTP request from the client ─────────────────────────
|
||||
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||
|
||||
// 60s timeout on initial read (LS may open connection without sending immediately)
|
||||
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
let read_result = if request_buf.is_empty() {
|
||||
// First read — apply idle timeout
|
||||
match tokio::time::timeout(IDLE_TIMEOUT, client.read(&mut tmp)).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
// Idle timeout — connection pool warmup, no data sent
|
||||
debug!(domain, "MITM: client idle timeout (60s), closing");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Subsequent reads — wait up to 30s for rest of request
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
client.read(&mut tmp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: partial request read timed out");
|
||||
return Err("Partial request read timed out".into());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let n = match read_result {
|
||||
Ok(0) => return Ok(()), // Client closed connection cleanly
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
// Connection reset / broken pipe is normal for keep-alive end
|
||||
debug!(domain, error = %e, "MITM: client read finished");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
request_buf.extend_from_slice(&tmp[..n]);
|
||||
|
||||
// Check if we have the full request (headers + body)
|
||||
if has_complete_http_request(&request_buf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if request_buf.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let mut request_buf = match read_full_request(&mut client, &mut tmp, domain).await {
|
||||
Some(buf) if !buf.is_empty() => buf,
|
||||
_ => return Ok(()),
|
||||
};
|
||||
|
||||
// Parse the HTTP request to find headers and body
|
||||
let (headers_end, content_length, _is_streaming_request) =
|
||||
@@ -554,33 +430,10 @@ async fn handle_http_over_tls(
|
||||
"MITM: forwarding LLM request"
|
||||
);
|
||||
|
||||
// ── Atomic in-flight gate ─────────────────────────────────
|
||||
// The LS opens multiple connections and sends parallel requests.
|
||||
// When custom tools are active, only the FIRST request wins the
|
||||
// atomic compare_exchange. All others get fake STOP responses.
|
||||
let has_tools = store.get_tools().await.is_some();
|
||||
if has_tools {
|
||||
if !store.try_mark_request_in_flight() {
|
||||
info!("MITM: blocking LS request — another request already in-flight");
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue;
|
||||
}
|
||||
// Grab the channel sender — the API handler installed it before
|
||||
// sending the LS message. If it's gone, we still proceed but
|
||||
// fall back to legacy store writes.
|
||||
event_tx = store.take_channel().await;
|
||||
}
|
||||
// ── Per-request context lookup ────────────────────────────
|
||||
// Deferred until we know this is an agent request containing our
|
||||
// dummy dot. This prevents LS internal requests (title generation,
|
||||
// checkpoints) from stealing the RequestContext.
|
||||
|
||||
// ── Request modification ─────────────────────────────────────
|
||||
// Dechunk body → check if agent request → modify → rechunk
|
||||
@@ -594,33 +447,128 @@ async fn handle_http_over_tls(
|
||||
|| body_str.contains("\"requestType\": \"agent\"");
|
||||
|
||||
if is_agent {
|
||||
// Build ToolContext from store
|
||||
let tools = store.get_tools().await;
|
||||
let tool_config = store.get_tool_config().await;
|
||||
let pending_results = store.take_tool_results().await;
|
||||
let last_calls = store.get_last_function_calls().await;
|
||||
let generation_params = store.get_generation_params().await;
|
||||
let pending_image = store.take_pending_image().await;
|
||||
let tool_rounds = store.get_tool_rounds().await;
|
||||
let pending_user_text = store.take_pending_user_text().await;
|
||||
// Re-extract cascade_hint from the dechunked (JSON-parseable) body.
|
||||
// The chunked transfer encoding body at `request_buf[headers_end..]`
|
||||
// can't be JSON-parsed, but `raw_body` (dechunked) can.
|
||||
let precise_cascade = extract_cascade_hint(&raw_body);
|
||||
debug!(
|
||||
cascade = ?precise_cascade,
|
||||
"MITM: cascade from dechunked requestId"
|
||||
);
|
||||
|
||||
let tool_ctx = if tools.is_some()
|
||||
|| !pending_results.is_empty()
|
||||
|| !tool_rounds.is_empty()
|
||||
|| generation_params.is_some()
|
||||
|| pending_image.is_some()
|
||||
|| pending_user_text.is_some()
|
||||
{
|
||||
Some(super::modify::ToolContext {
|
||||
tools,
|
||||
tool_config,
|
||||
pending_results,
|
||||
last_calls,
|
||||
generation_params,
|
||||
pending_image,
|
||||
tool_rounds,
|
||||
pending_user_text,
|
||||
// Check if ANY user message contains our dummy dot prompt
|
||||
// within a <USER_REQUEST> wrapper.
|
||||
// Only then should we consume the pending RequestContext.
|
||||
// This prevents LS internal requests (title gen, etc.) from
|
||||
// consuming the context meant for the user's actual request.
|
||||
// NOTE: We check ALL user messages because the LS appends context
|
||||
// messages AFTER the dot prompt (conversation summaries, etc.).
|
||||
// We look for <USER_REQUEST> + dot specifically to avoid matching
|
||||
// old <cid:> markers in history (which are in model messages).
|
||||
let contains_our_dot = serde_json::from_slice::<serde_json::Value>(&raw_body)
|
||||
.ok()
|
||||
.and_then(|json| {
|
||||
let contents = json.pointer("/request/contents")?.as_array()?;
|
||||
for msg in contents.iter() {
|
||||
let is_user = msg.get("role")
|
||||
.and_then(|r| r.as_str())
|
||||
.map_or(true, |r| r == "user");
|
||||
if !is_user { continue; }
|
||||
if let Some(text) = msg.pointer("/parts/0/text").and_then(|v| v.as_str()) {
|
||||
// Check for dot in <USER_REQUEST> wrapper
|
||||
if text.contains("<USER_REQUEST>") {
|
||||
if let (Some(s), Some(e)) = (text.find("<USER_REQUEST>"), text.find("</USER_REQUEST>")) {
|
||||
let inner = &text[s + 14..e]; // 14 = len("<USER_REQUEST>")
|
||||
let it = inner.trim();
|
||||
if it == "." || it.starts_with(".<cid:") {
|
||||
return Some(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Check for bare dot (no wrapper)
|
||||
let t = text.trim();
|
||||
if t == "." || t == ".<cid:" || (t.starts_with(".<cid:") && t.ends_with(">")) {
|
||||
return Some(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(false)
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// Only take the RequestContext if this request has our dot
|
||||
let effective_cascade = precise_cascade.or(cascade_hint.clone());
|
||||
let mut request_ctx: Option<super::store::RequestContext> = if contains_our_dot {
|
||||
let ctx = if let Some(ref cid) = effective_cascade {
|
||||
store.take_request(cid).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
if ctx.is_some() {
|
||||
ctx
|
||||
} else if let Some(ref cid) = effective_cascade {
|
||||
// Check if this is a subsequent turn (turn 1+) of an
|
||||
// already-processed cascade. If so, DON'T fall through
|
||||
// to take_latest_request — that would steal an unrelated
|
||||
// cascade's context.
|
||||
if store.has_cascade_cache(cid).await {
|
||||
debug!(cascade = %cid, "MITM: subsequent turn — using cached context");
|
||||
None
|
||||
} else {
|
||||
// Unknown cascade, try latest fallback
|
||||
store.take_latest_request().await
|
||||
}
|
||||
} else {
|
||||
store.take_latest_request().await
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Extract event channel from matched context
|
||||
if let Some(ref mut ctx) = request_ctx {
|
||||
event_tx = ctx.event_channel.take();
|
||||
}
|
||||
|
||||
// Build ToolContext from RequestContext (turn 0) or cached
|
||||
// context (turn 1+). On turn 0, we also cache the context
|
||||
// for subsequent turns.
|
||||
let tool_ctx = if let Some(ctx) = request_ctx.take() {
|
||||
// Turn 0: cache context for subsequent turns
|
||||
if let Some(ref cid) = effective_cascade {
|
||||
store.cache_cascade(cid, super::store::CascadeCache {
|
||||
user_text: ctx.pending_user_text.clone(),
|
||||
tools: ctx.tools.clone(),
|
||||
tool_config: ctx.tool_config.clone(),
|
||||
generation_params: ctx.generation_params.clone(),
|
||||
}).await;
|
||||
}
|
||||
Some(super::modify::ToolContext {
|
||||
pending_user_text: ctx.pending_user_text,
|
||||
tools: ctx.tools,
|
||||
tool_config: ctx.tool_config,
|
||||
pending_results: ctx.pending_tool_results,
|
||||
last_calls: ctx.last_function_calls,
|
||||
generation_params: ctx.generation_params,
|
||||
pending_image: ctx.pending_image,
|
||||
tool_rounds: ctx.tool_rounds,
|
||||
})
|
||||
} else if let Some(ref cid) = effective_cascade {
|
||||
// Turn 1+: rebuild lite ToolContext from cache
|
||||
if let Some(cached) = store.get_cascade_cache(cid).await {
|
||||
Some(super::modify::ToolContext {
|
||||
pending_user_text: cached.user_text,
|
||||
tools: cached.tools,
|
||||
tool_config: cached.tool_config,
|
||||
pending_results: vec![],
|
||||
last_calls: vec![],
|
||||
generation_params: cached.generation_params,
|
||||
pending_image: None,
|
||||
tool_rounds: vec![],
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -637,8 +585,6 @@ async fn handle_http_over_tls(
|
||||
let mut new_buf = updated_headers.into_bytes();
|
||||
new_buf.extend_from_slice(&new_chunked);
|
||||
request_buf = new_buf;
|
||||
|
||||
// In-flight already marked atomically above
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -677,6 +623,7 @@ async fn handle_http_over_tls(
|
||||
// ALWAYS forward data to client immediately (no buffering).
|
||||
// Buffer body on the side for usage parsing.
|
||||
let mut streaming_acc = StreamingAccumulator::new();
|
||||
let mut response_rewriter: Option<super::modify::ResponseRewriter> = None;
|
||||
let mut is_streaming_response = false;
|
||||
let mut headers_parsed = false;
|
||||
let mut upstream_ok = true;
|
||||
@@ -737,6 +684,10 @@ async fn handle_http_over_tls(
|
||||
content_type = v.to_string();
|
||||
if v.contains("text/event-stream") {
|
||||
is_streaming_response = true;
|
||||
// Lazily initialize the response rewriter for SSE streams
|
||||
if modify_requests {
|
||||
response_rewriter = Some(super::modify::ResponseRewriter::new());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -802,11 +753,11 @@ async fn handle_http_over_tls(
|
||||
message,
|
||||
error_status,
|
||||
};
|
||||
// Send through channel if available, otherwise store for legacy consumers
|
||||
// Send through channel if available
|
||||
if let Some(ref tx) = event_tx {
|
||||
let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await;
|
||||
} else {
|
||||
store.set_upstream_error(upstream_err).await;
|
||||
warn!("MITM: upstream error but no channel to forward it");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -817,76 +768,20 @@ async fn handle_http_over_tls(
|
||||
if is_streaming_response && hdr_end < header_buf.len() {
|
||||
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
||||
parse_streaming_chunk(&body, &mut streaming_acc);
|
||||
|
||||
// Send events through channel if available, otherwise use legacy store
|
||||
if let Some(ref tx) = event_tx {
|
||||
// Function calls → channel event
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
|
||||
info!("MITM: sending {} function call(s) via channel (initial body)", calls.len());
|
||||
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
|
||||
}
|
||||
// Thinking delta → channel event
|
||||
if !streaming_acc.thinking_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
|
||||
streaming_acc.thinking_text.clone(),
|
||||
)).await;
|
||||
}
|
||||
// Text delta → channel event
|
||||
if !streaming_acc.response_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::TextDelta(
|
||||
streaming_acc.response_text.clone(),
|
||||
)).await;
|
||||
}
|
||||
// Grounding → channel event
|
||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
|
||||
}
|
||||
// Response complete → channel event
|
||||
if streaming_acc.is_complete {
|
||||
info!(
|
||||
response_text_len = streaming_acc.response_text.len(),
|
||||
thinking_text_len = streaming_acc.thinking_text.len(),
|
||||
"MITM: response complete (initial body) — sending via channel"
|
||||
);
|
||||
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
|
||||
streaming_acc.is_complete = false; // prevent duplicate sends
|
||||
}
|
||||
} else {
|
||||
// Legacy path: store writes for non-channel consumers (search, etc.)
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from initial body", calls.len());
|
||||
}
|
||||
if !streaming_acc.response_text.is_empty() {
|
||||
store.set_response_text(&streaming_acc.response_text).await;
|
||||
}
|
||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
}
|
||||
}
|
||||
dispatch_stream_events(&mut streaming_acc, &event_tx, &store, cascade_hint.as_deref()).await;
|
||||
}
|
||||
|
||||
// Forward to client — rewrite function calls if custom tools are injected
|
||||
let forward_buf = if modify_requests {
|
||||
if let Some(modified) = super::modify::modify_response_chunk(&header_buf) {
|
||||
modified
|
||||
} else {
|
||||
header_buf.clone()
|
||||
}
|
||||
let forward_buf = if let Some(ref mut rewriter) = response_rewriter {
|
||||
rewriter.feed(&header_buf)
|
||||
} else {
|
||||
header_buf.clone()
|
||||
};
|
||||
if let Err(e) = client.write_all(&forward_buf).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
break;
|
||||
if !forward_buf.is_empty() {
|
||||
if let Err(e) = client.write_all(&forward_buf).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
@@ -908,80 +803,24 @@ async fn handle_http_over_tls(
|
||||
if is_streaming_response {
|
||||
let s = String::from_utf8_lossy(chunk);
|
||||
parse_streaming_chunk(&s, &mut streaming_acc);
|
||||
|
||||
// Send events through channel if available, otherwise use legacy store
|
||||
if let Some(ref tx) = event_tx {
|
||||
// Function calls → channel event
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
|
||||
info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len());
|
||||
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
|
||||
}
|
||||
// Thinking delta → channel event (send accumulated, handler tracks last len)
|
||||
if !streaming_acc.thinking_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
|
||||
streaming_acc.thinking_text.clone(),
|
||||
)).await;
|
||||
}
|
||||
// Text delta → channel event (send accumulated, handler tracks last len)
|
||||
if !streaming_acc.response_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::TextDelta(
|
||||
streaming_acc.response_text.clone(),
|
||||
)).await;
|
||||
}
|
||||
// Grounding → channel event
|
||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
|
||||
}
|
||||
// Response complete → channel event
|
||||
if streaming_acc.is_complete {
|
||||
info!(
|
||||
response_text_len = streaming_acc.response_text.len(),
|
||||
thinking_text_len = streaming_acc.thinking_text.len(),
|
||||
function_calls = streaming_acc.function_calls.len(),
|
||||
"MITM: response complete — sending via channel"
|
||||
);
|
||||
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
|
||||
streaming_acc.is_complete = false; // prevent duplicate sends
|
||||
}
|
||||
} else {
|
||||
// Legacy path: store writes for non-channel consumers
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from body chunk", calls.len());
|
||||
}
|
||||
if !streaming_acc.response_text.is_empty() {
|
||||
store.set_response_text(&streaming_acc.response_text).await;
|
||||
}
|
||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
}
|
||||
}
|
||||
dispatch_stream_events(&mut streaming_acc, &event_tx, &store, cascade_hint.as_deref()).await;
|
||||
}
|
||||
|
||||
// Forward chunk to client (LS) — rewrite function calls if custom tools
|
||||
let forward_chunk = if modify_requests {
|
||||
if let Some(modified) = super::modify::modify_response_chunk(chunk) {
|
||||
modified
|
||||
} else {
|
||||
chunk.to_vec()
|
||||
}
|
||||
let forward_chunk = if let Some(ref mut rewriter) = response_rewriter {
|
||||
rewriter.feed(chunk)
|
||||
} else {
|
||||
chunk.to_vec()
|
||||
};
|
||||
if let Err(e) = client.write_all(&forward_chunk).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
break;
|
||||
if !forward_chunk.is_empty() {
|
||||
if let Err(e) = client.write_all(&forward_chunk).await {
|
||||
warn!(error = %e, "MITM: write to client failed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
response_body_buf.extend_from_slice(chunk);
|
||||
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
if response_body_buf.len() >= cl {
|
||||
break;
|
||||
@@ -992,6 +831,13 @@ async fn handle_http_over_tls(
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Flush any remaining buffered response data through the rewriter
|
||||
if let Some(ref mut rewriter) = response_rewriter {
|
||||
let remaining = rewriter.flush();
|
||||
if !remaining.is_empty() {
|
||||
let _ = client.write_all(&remaining).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Flush client
|
||||
let _ = client.flush().await;
|
||||
@@ -1023,6 +869,176 @@ async fn handle_http_over_tls(
|
||||
} // end keep-alive loop
|
||||
}
|
||||
|
||||
/// Read a complete HTTP request from the client with idle/partial timeouts.
|
||||
///
|
||||
/// Returns `Some(buf)` on success, `None` if the client closed cleanly or timed out.
|
||||
async fn read_full_request(
|
||||
client: &mut tokio_rustls::server::TlsStream<TcpStream>,
|
||||
tmp: &mut [u8],
|
||||
domain: &str,
|
||||
) -> Option<Vec<u8>> {
|
||||
let mut buf = Vec::with_capacity(1024 * 64);
|
||||
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
|
||||
loop {
|
||||
let read_result = if buf.is_empty() {
|
||||
match tokio::time::timeout(IDLE_TIMEOUT, client.read(tmp)).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
debug!(domain, "MITM: client idle timeout (60s), closing");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(30), client.read(tmp)).await {
|
||||
Ok(r) => r,
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: partial request read timed out");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let n = match read_result {
|
||||
Ok(0) => return None,
|
||||
Ok(n) => n,
|
||||
Err(e) => {
|
||||
debug!(domain, error = %e, "MITM: client read finished");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
buf.extend_from_slice(&tmp[..n]);
|
||||
if has_complete_http_request(&buf) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Some(buf)
|
||||
}
|
||||
|
||||
/// Connect (or reconnect) to the real upstream via TLS.
|
||||
///
|
||||
/// Bypasses /etc/hosts by resolving via direct DNS query (dig @8.8.8.8),
|
||||
/// then falls back to cached IPs file, then to normal system resolution.
|
||||
async fn connect_upstream(
|
||||
domain: &str,
|
||||
config: &Arc<rustls::ClientConfig>,
|
||||
) -> Result<tokio_rustls::client::TlsStream<TcpStream>, String> {
|
||||
let connector = tokio_rustls::TlsConnector::from(config.clone());
|
||||
let addr = resolve_upstream(domain).await;
|
||||
info!(domain, addr = %addr, "MITM: connecting upstream");
|
||||
|
||||
let tcp = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(15),
|
||||
TcpStream::connect(&addr),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => return Err(format!("Connect to upstream {domain} ({addr}): {e}")),
|
||||
Err(_) => return Err(format!("Connect to upstream {domain} ({addr}): timed out")),
|
||||
};
|
||||
|
||||
let server_name = rustls::pki_types::ServerName::try_from(domain.to_string())
|
||||
.map_err(|e| format!("Invalid server name: {e}"))?;
|
||||
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(15),
|
||||
connector.connect(server_name, tcp),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => {
|
||||
info!(domain, "MITM: upstream TLS connected ✓");
|
||||
Ok(s)
|
||||
}
|
||||
Ok(Err(e)) => Err(format!("TLS connect to upstream {domain}: {e}")),
|
||||
Err(_) => Err(format!("TLS connect to upstream {domain}: timed out")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve upstream IP bypassing /etc/hosts.
|
||||
async fn resolve_upstream(domain: &str) -> String {
|
||||
// 1. Try dig @8.8.8.8 (bypasses /etc/hosts)
|
||||
if let Ok(output) = tokio::process::Command::new("dig")
|
||||
.args(["+short", "@8.8.8.8", domain])
|
||||
.output()
|
||||
.await
|
||||
{
|
||||
let out = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Try cached IPs file
|
||||
if let Ok(contents) = tokio::fs::read_to_string("/tmp/antigravity-mitm-real-ips").await {
|
||||
for line in contents.lines() {
|
||||
if let Some((d, ip)) = line.split_once('=') {
|
||||
if d == domain {
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to normal resolution
|
||||
format!("{domain}:443")
|
||||
}
|
||||
|
||||
/// Dispatch parsed streaming events to the channel or legacy store.
|
||||
///
|
||||
/// Deduplicates the event dispatch logic used both for initial body parsing
|
||||
/// and subsequent body chunk processing.
|
||||
async fn dispatch_stream_events(
|
||||
acc: &mut StreamingAccumulator,
|
||||
event_tx: &Option<tokio::sync::mpsc::Sender<super::store::MitmEvent>>,
|
||||
store: &MitmStore,
|
||||
cascade_hint: Option<&str>,
|
||||
) {
|
||||
if let Some(ref tx) = event_tx {
|
||||
if !acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = acc.function_calls.drain(..).collect();
|
||||
store.record_function_call(cascade_hint, calls[0].clone()).await;
|
||||
info!("MITM: sending {} function call(s) via channel", calls.len());
|
||||
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
|
||||
}
|
||||
if !acc.thinking_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(acc.thinking_text.clone())).await;
|
||||
}
|
||||
if !acc.response_text.is_empty() {
|
||||
let _ = tx.send(super::store::MitmEvent::TextDelta(acc.response_text.clone())).await;
|
||||
}
|
||||
if let Some(ref gm) = acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
|
||||
}
|
||||
if acc.is_complete {
|
||||
info!(
|
||||
response_text_len = acc.response_text.len(),
|
||||
thinking_text_len = acc.thinking_text.len(),
|
||||
"MITM: response complete — sending via channel"
|
||||
);
|
||||
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
|
||||
acc.is_complete = false;
|
||||
}
|
||||
} else {
|
||||
if !acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint, fc.clone()).await;
|
||||
}
|
||||
info!("MITM: stored {} function call(s)", calls.len());
|
||||
}
|
||||
if !acc.response_text.is_empty() {
|
||||
store.set_response_text(&acc.response_text).await;
|
||||
}
|
||||
if let Some(ref gm) = acc.grounding_metadata {
|
||||
store.set_grounding(gm.clone()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
|
||||
trace!(domain, port, "MITM: transparent tunnel");
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
//! Shared store for intercepted API usage data.
|
||||
//!
|
||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||
//! When custom tools are active, the MITM proxy sends real-time events
|
||||
//! through a channel instead of writing to shared state.
|
||||
//! Per-request state is stored in `RequestContext`, keyed by cascade ID.
|
||||
//! The MITM proxy looks up the context when intercepting LS requests,
|
||||
//! enabling concurrent request processing without global locks.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tracing::{debug, info};
|
||||
|
||||
@@ -52,6 +52,10 @@ pub struct ApiUsage {
|
||||
pub struct CapturedFunctionCall {
|
||||
pub name: String,
|
||||
pub args: serde_json::Value,
|
||||
/// Google's thought signature — required when injecting functionCall back
|
||||
/// into conversation history. Without it, Google returns INVALID_ARGUMENT.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thought_signature: Option<String>,
|
||||
pub captured_at: u64,
|
||||
}
|
||||
|
||||
@@ -128,6 +132,25 @@ pub struct GenerationParams {
|
||||
pub google_search: bool,
|
||||
}
|
||||
|
||||
/// Cached context from turn 0 of a cascade.
|
||||
///
|
||||
/// On the first turn, the MITM proxy consumes the `RequestContext` and builds
|
||||
/// a `ToolContext`. On subsequent turns (tool-call loops), the `RequestContext`
|
||||
/// is gone. This cache stores the essential fields so we can rebuild a lite
|
||||
/// `ToolContext` on every turn — ensuring the model always sees the real user
|
||||
/// text and has access to custom tools.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CascadeCache {
|
||||
/// The real user text (used to replace the "." dot prompt).
|
||||
pub user_text: String,
|
||||
/// Custom tool definitions (Gemini format).
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Custom tool config.
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Client generation parameters.
|
||||
pub generation_params: Option<GenerationParams>,
|
||||
}
|
||||
|
||||
// ─── Channel-based event pipeline ────────────────────────────────────────────
|
||||
|
||||
/// Events sent from the MITM proxy to API handlers through a per-request channel.
|
||||
@@ -146,15 +169,53 @@ pub enum MitmEvent {
|
||||
/// Google API returned an error.
|
||||
UpstreamError(UpstreamError),
|
||||
/// Grounding metadata (search results) from the response.
|
||||
#[allow(dead_code)]
|
||||
Grounding(serde_json::Value),
|
||||
/// Token usage data from the response.
|
||||
Usage(ApiUsage),
|
||||
}
|
||||
|
||||
// ─── Per-request context ─────────────────────────────────────────────────────
|
||||
|
||||
/// All per-request state. Keyed by cascade ID in `MitmStore.pending_requests`.
|
||||
///
|
||||
/// API handlers build this before `send_message`, and the MITM proxy consumes
|
||||
/// it when the LS's outbound request is intercepted.
|
||||
#[derive(Debug)]
|
||||
pub struct RequestContext {
|
||||
/// Cascade ID this context belongs to.
|
||||
pub cascade_id: String,
|
||||
/// Real user text for MITM injection (LS receives "." instead).
|
||||
pub pending_user_text: String,
|
||||
/// Event channel for real-time streaming from MITM → API handler.
|
||||
/// Only present when custom tools are active.
|
||||
pub event_channel: Option<mpsc::Sender<MitmEvent>>,
|
||||
/// Client-specified generation parameters (temperature, top_p, etc.).
|
||||
pub generation_params: Option<GenerationParams>,
|
||||
/// Image to inject into the Google API request.
|
||||
pub pending_image: Option<PendingImage>,
|
||||
/// Gemini-format tool declarations for MITM injection.
|
||||
pub tools: Option<Vec<serde_json::Value>>,
|
||||
/// Gemini-format toolConfig.
|
||||
pub tool_config: Option<serde_json::Value>,
|
||||
/// Pending tool results to inject as functionResponse.
|
||||
pub pending_tool_results: Vec<PendingToolResult>,
|
||||
/// Multi-round tool call history for history rewriting.
|
||||
pub tool_rounds: Vec<ToolRound>,
|
||||
/// Last captured function calls for history rewriting.
|
||||
pub last_function_calls: Vec<CapturedFunctionCall>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
pub call_id_to_name: HashMap<String, String>,
|
||||
/// When this context was created (for TTL cleanup).
|
||||
pub created_at: Instant,
|
||||
}
|
||||
|
||||
// ─── MitmStore ───────────────────────────────────────────────────────────────
|
||||
|
||||
/// Thread-safe store for intercepted data.
|
||||
///
|
||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||
/// In practice, we use the cascade ID + a sequence number.
|
||||
/// Per-request state lives in `pending_requests`, keyed by cascade ID.
|
||||
/// Global state (usage stats, function call capture) remains shared.
|
||||
#[derive(Clone)]
|
||||
pub struct MitmStore {
|
||||
/// Most recent usage per cascade ID.
|
||||
@@ -163,62 +224,24 @@ pub struct MitmStore {
|
||||
stats: Arc<RwLock<MitmStats>>,
|
||||
/// Pending function calls captured from Google responses.
|
||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||
/// Used by the non-tool LS path (normal sync responses).
|
||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||
/// Set when the MITM forwards the first LLM request with custom tools.
|
||||
/// Blocks ALL subsequent LS requests until the API handler clears it.
|
||||
request_in_flight: Arc<AtomicBool>,
|
||||
|
||||
// ── Channel-based event pipeline (replaces old polling) ──────────────
|
||||
/// Active channel sender for the current tool-path request.
|
||||
/// When present, the MITM proxy sends events through this instead of
|
||||
/// writing to shared state. The channel's existence = request in-flight.
|
||||
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
|
||||
// ── Per-request state (keyed by cascade ID) ──────────────────────────
|
||||
/// Active request contexts. API handlers register before send_message,
|
||||
/// MITM proxy consumes when intercepting the LS request.
|
||||
pending_requests: Arc<RwLock<HashMap<String, RequestContext>>>,
|
||||
|
||||
// ── Tool call support ────────────────────────────────────────────────
|
||||
/// Active tool definitions (Gemini format) for MITM injection.
|
||||
active_tools: Arc<RwLock<Option<Vec<serde_json::Value>>>>,
|
||||
/// Active tool config (Gemini toolConfig format).
|
||||
active_tool_config: Arc<RwLock<Option<serde_json::Value>>>,
|
||||
/// Pending tool results for MITM to inject as functionResponse.
|
||||
pending_tool_results: Arc<RwLock<Vec<PendingToolResult>>>,
|
||||
/// Mapping call_id → function name for tool result routing.
|
||||
call_id_to_name: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// Last captured function calls (for conversation history rewriting).
|
||||
last_function_calls: Arc<RwLock<Vec<CapturedFunctionCall>>>,
|
||||
/// Multi-round tool call history for correct per-turn history rewriting.
|
||||
/// Set by completions/responses handler, consumed by modify_request.
|
||||
tool_rounds: Arc<RwLock<Vec<ToolRound>>>,
|
||||
|
||||
// ── Cascade correlation ──────────────────────────────────────────────
|
||||
/// Active cascade ID set by the API layer before sending a message.
|
||||
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
|
||||
active_cascade_id: Arc<RwLock<Option<String>>>,
|
||||
/// Cached context from turn 0, keyed by cascade ID.
|
||||
/// Used to rebuild ToolContext on subsequent turns of the same cascade.
|
||||
cascade_cache: Arc<RwLock<HashMap<String, CascadeCache>>>,
|
||||
|
||||
// ── Legacy direct response capture (used by search.rs) ───────────────
|
||||
/// Captured response text from MITM. Used as fallback by search endpoint.
|
||||
captured_response_text: Arc<RwLock<Option<String>>>,
|
||||
|
||||
// ── Generation parameters for MITM injection ─────────────────────────
|
||||
/// Client-specified sampling parameters to inject into Google API requests.
|
||||
generation_params: Arc<RwLock<Option<GenerationParams>>>,
|
||||
|
||||
// ── Grounding metadata capture ──────────────────────────────────────
|
||||
/// Captured grounding metadata from Google API responses (search results).
|
||||
captured_grounding: Arc<RwLock<Option<serde_json::Value>>>,
|
||||
|
||||
// ── Pending image for MITM injection ─────────────────────────────────
|
||||
/// Image to inject into the next Google API request via MITM.
|
||||
pending_image: Arc<RwLock<Option<PendingImage>>>,
|
||||
|
||||
// ── Upstream error capture (legacy, used when no channel) ────────────
|
||||
/// Error from Google's API, captured by MITM for forwarding to client.
|
||||
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
|
||||
|
||||
// ── Standard LS input: real user text for MITM injection ─────────────
|
||||
/// The real user text to inject into the Google API request.
|
||||
/// API handlers store this before sending a dummy prompt to the LS.
|
||||
pending_user_text: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
/// Aggregate statistics across all intercepted traffic.
|
||||
@@ -251,24 +274,106 @@ impl MitmStore {
|
||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||
request_in_flight: Arc::new(AtomicBool::new(false)),
|
||||
active_channel: Arc::new(RwLock::new(None)),
|
||||
active_tools: Arc::new(RwLock::new(None)),
|
||||
active_tool_config: Arc::new(RwLock::new(None)),
|
||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||
call_id_to_name: Arc::new(RwLock::new(HashMap::new())),
|
||||
last_function_calls: Arc::new(RwLock::new(Vec::new())),
|
||||
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
||||
active_cascade_id: Arc::new(RwLock::new(None)),
|
||||
pending_requests: Arc::new(RwLock::new(HashMap::new())),
|
||||
cascade_cache: Arc::new(RwLock::new(HashMap::new())),
|
||||
captured_response_text: Arc::new(RwLock::new(None)),
|
||||
generation_params: Arc::new(RwLock::new(None)),
|
||||
captured_grounding: Arc::new(RwLock::new(None)),
|
||||
pending_image: Arc::new(RwLock::new(None)),
|
||||
upstream_error: Arc::new(RwLock::new(None)),
|
||||
pending_user_text: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Per-request context management ───────────────────────────────────
|
||||
|
||||
/// Register a request context for a cascade. Called by API handlers
|
||||
/// before `send_message` so the MITM proxy can find it.
|
||||
pub async fn register_request(&self, ctx: RequestContext) {
|
||||
let cascade_id = ctx.cascade_id.clone();
|
||||
info!(cascade = %cascade_id, "Registered request context");
|
||||
self.pending_requests.write().await.insert(cascade_id, ctx);
|
||||
}
|
||||
|
||||
/// Take (consume) the request context for a cascade.
|
||||
/// Called by the MITM proxy when intercepting the LS's outbound request.
|
||||
pub async fn take_request(&self, cascade_id: &str) -> Option<RequestContext> {
|
||||
let ctx = self.pending_requests.write().await.remove(cascade_id);
|
||||
if ctx.is_some() {
|
||||
debug!(cascade = %cascade_id, "Took request context");
|
||||
}
|
||||
ctx
|
||||
}
|
||||
|
||||
/// Take the most recently registered request context (by creation time).
|
||||
/// Fallback when cascade_id can't be extracted from the Google API request.
|
||||
pub async fn take_latest_request(&self) -> Option<RequestContext> {
|
||||
let mut pending = self.pending_requests.write().await;
|
||||
if pending.is_empty() {
|
||||
return None;
|
||||
}
|
||||
// Find the most recently created request
|
||||
let latest_key = pending
|
||||
.iter()
|
||||
.max_by_key(|(_, ctx)| ctx.created_at)
|
||||
.map(|(k, _)| k.clone());
|
||||
if let Some(key) = latest_key {
|
||||
let ctx = pending.remove(&key);
|
||||
if ctx.is_some() {
|
||||
debug!(cascade = %key, "Took latest request context (fallback)");
|
||||
}
|
||||
ctx
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Update a request context in-place. Returns false if not found.
|
||||
pub async fn update_request<F>(&self, cascade_id: &str, updater: F) -> bool
|
||||
where
|
||||
F: FnOnce(&mut RequestContext),
|
||||
{
|
||||
let mut map = self.pending_requests.write().await;
|
||||
if let Some(ctx) = map.get_mut(cascade_id) {
|
||||
updater(ctx);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a request context (cleanup after response is complete).
|
||||
pub async fn remove_request(&self, cascade_id: &str) {
|
||||
if self.pending_requests.write().await.remove(cascade_id).is_some() {
|
||||
debug!(cascade = %cascade_id, "Removed request context");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ── Cascade cache (turn 0 context for re-injection on turn 1+) ──────
|
||||
|
||||
/// Cache the essential context from turn 0 so it can be re-used on
|
||||
/// subsequent turns of the same cascade.
|
||||
pub async fn cache_cascade(&self, cascade_id: &str, cache: CascadeCache) {
|
||||
debug!(cascade = %cascade_id, user_text_len = cache.user_text.len(),
|
||||
has_tools = cache.tools.is_some(),
|
||||
"Cached cascade context for subsequent turns");
|
||||
self.cascade_cache.write().await.insert(cascade_id.to_string(), cache);
|
||||
}
|
||||
|
||||
/// Get cached context for a cascade (non-consuming — needed on every turn).
|
||||
pub async fn get_cascade_cache(&self, cascade_id: &str) -> Option<CascadeCache> {
|
||||
self.cascade_cache.read().await.get(cascade_id).cloned()
|
||||
}
|
||||
|
||||
/// Check if a cascade has been processed (turn 0 complete).
|
||||
pub async fn has_cascade_cache(&self, cascade_id: &str) -> bool {
|
||||
self.cascade_cache.read().await.contains_key(cascade_id)
|
||||
}
|
||||
|
||||
|
||||
|
||||
// ── Usage recording ──────────────────────────────────────────────────
|
||||
|
||||
/// Record a completed API exchange with usage data.
|
||||
pub async fn record_usage(&self, cascade_id: Option<&str>, usage: ApiUsage) {
|
||||
debug!(
|
||||
@@ -314,13 +419,7 @@ impl MitmStore {
|
||||
// Call 2: thinking summary text (thinking_output_tokens == 0, response_text has the summary)
|
||||
//
|
||||
// When Call 2 arrives, we merge its response_text as thinking_text into Call 1's usage.
|
||||
let key = if let Some(cid) = cascade_id {
|
||||
cid.to_string()
|
||||
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
active.clone()
|
||||
} else {
|
||||
"_latest".to_string()
|
||||
};
|
||||
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
|
||||
if let Some(existing) = latest.get_mut(&key) {
|
||||
@@ -346,7 +445,6 @@ impl MitmStore {
|
||||
// Evict old entries to prevent unbounded memory growth
|
||||
const MAX_ENTRIES: usize = 500;
|
||||
if latest.len() > MAX_ENTRIES {
|
||||
// Find the oldest entry by captured_at and remove it
|
||||
let oldest_key = latest
|
||||
.iter()
|
||||
.min_by_key(|(_, v)| v.captured_at)
|
||||
@@ -357,18 +455,13 @@ impl MitmStore {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the latest usage for a cascade, consuming it (one-shot read).
|
||||
///
|
||||
/// Peek at usage data for a cascade without consuming it.
|
||||
/// Used to check if thinking text has been merged before taking.
|
||||
pub async fn peek_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let latest = self.latest_usage.read().await;
|
||||
latest.get(cascade_id).cloned()
|
||||
}
|
||||
|
||||
/// Only returns exact cascade_id matches — no cross-cascade fallback.
|
||||
/// The `_latest` key is only consumed when the caller explicitly requests it
|
||||
/// (i.e., when the MITM couldn't identify the cascade).
|
||||
pub async fn take_usage(&self, cascade_id: &str) -> Option<ApiUsage> {
|
||||
let mut latest = self.latest_usage.write().await;
|
||||
latest.remove(cascade_id)
|
||||
@@ -379,19 +472,11 @@ impl MitmStore {
|
||||
self.stats.read().await.clone()
|
||||
}
|
||||
|
||||
// ── Function call capture ────────────────────────────────────────────
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
///
|
||||
/// Falls back to `active_cascade_id` (set by the API handler) when no
|
||||
/// cascade hint is available from the request body, matching
|
||||
/// `record_usage`'s fallback behavior for consistent correlation.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = if let Some(cid) = cascade_id {
|
||||
cid.to_string()
|
||||
} else if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
active.clone()
|
||||
} else {
|
||||
"_latest".to_string()
|
||||
};
|
||||
let key = cascade_id.map_or_else(|| "_latest".to_string(), |s| s.to_string());
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
@@ -404,9 +489,7 @@ impl MitmStore {
|
||||
|
||||
/// Take pending function calls for a specific cascade.
|
||||
///
|
||||
/// Priority: exact cascade_id → active_cascade_id → `_latest` → any key.
|
||||
/// This prevents cross-cascade contamination when multiple requests are
|
||||
/// in-flight simultaneously.
|
||||
/// Priority: exact cascade_id → `_latest` → any key.
|
||||
pub async fn take_function_calls(&self, cascade_id: &str) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
|
||||
@@ -415,21 +498,12 @@ impl MitmStore {
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 2. Active cascade (set by API handler)
|
||||
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||
if active != cascade_id {
|
||||
if let Some(result) = pending.remove(active.as_str()) {
|
||||
return Some(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fallback to _latest
|
||||
// 2. Fallback to _latest
|
||||
if let Some(result) = pending.remove("_latest") {
|
||||
return Some(result);
|
||||
}
|
||||
|
||||
// 4. Last resort: any key
|
||||
// 3. Last resort: any key
|
||||
if let Some(key) = pending.keys().next().cloned() {
|
||||
return pending.remove(&key);
|
||||
}
|
||||
@@ -438,7 +512,6 @@ impl MitmStore {
|
||||
}
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
/// Legacy method — prefer `take_function_calls(cascade_id)` for proper correlation.
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
let result = pending.remove("_latest");
|
||||
@@ -451,114 +524,24 @@ impl MitmStore {
|
||||
None
|
||||
}
|
||||
|
||||
// ── Channel-based event pipeline ─────────────────────────────────────
|
||||
|
||||
/// Install a channel sender for the current tool-path request.
|
||||
/// The MITM proxy will send events through this channel.
|
||||
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
|
||||
*self.active_channel.write().await = Some(tx);
|
||||
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
|
||||
// try_mark_request_in_flight() is the sole setter — setting it
|
||||
// here causes compare_exchange(false,true) to always fail,
|
||||
// blocking every real LS request.
|
||||
}
|
||||
|
||||
/// Take the active channel sender (used by MITM proxy to grab it).
|
||||
/// Returns None if no channel is active.
|
||||
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
|
||||
self.active_channel.write().await.take()
|
||||
}
|
||||
|
||||
|
||||
/// Drop the active channel and clear in-flight state.
|
||||
/// Called when the API handler is done with the current request.
|
||||
pub async fn drop_channel(&self) {
|
||||
*self.active_channel.write().await = None;
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// ── Tool context methods ─────────────────────────────────────────────
|
||||
|
||||
/// Set active tool definitions (already in Gemini format).
|
||||
pub async fn set_tools(&self, tools: Vec<serde_json::Value>) {
|
||||
*self.active_tools.write().await = Some(tools);
|
||||
}
|
||||
|
||||
/// Get active tool definitions.
|
||||
pub async fn get_tools(&self) -> Option<Vec<serde_json::Value>> {
|
||||
self.active_tools.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear active tool definitions.
|
||||
pub async fn clear_tools(&self) {
|
||||
*self.active_tools.write().await = None;
|
||||
*self.active_tool_config.write().await = None;
|
||||
// Also clear accumulated tool rounds to prevent stale data
|
||||
self.tool_rounds.write().await.clear();
|
||||
}
|
||||
|
||||
/// Set active tool config (Gemini toolConfig format).
|
||||
pub async fn set_tool_config(&self, config: serde_json::Value) {
|
||||
*self.active_tool_config.write().await = Some(config);
|
||||
}
|
||||
|
||||
/// Get active tool config.
|
||||
pub async fn get_tool_config(&self) -> Option<serde_json::Value> {
|
||||
self.active_tool_config.read().await.clone()
|
||||
}
|
||||
|
||||
/// Add a pending tool result for MITM injection.
|
||||
pub async fn add_tool_result(&self, result: PendingToolResult) {
|
||||
info!(name = %result.name, "Storing pending tool result");
|
||||
self.pending_tool_results.write().await.push(result);
|
||||
}
|
||||
|
||||
/// Take (consume) all pending tool results.
|
||||
pub async fn take_tool_results(&self) -> Vec<PendingToolResult> {
|
||||
std::mem::take(&mut *self.pending_tool_results.write().await)
|
||||
}
|
||||
|
||||
/// Register a call_id → function name mapping.
|
||||
pub async fn register_call_id(&self, call_id: String, name: String) {
|
||||
self.call_id_to_name.write().await.insert(call_id, name);
|
||||
}
|
||||
|
||||
/// Look up function name by call_id.
|
||||
pub async fn lookup_call_id(&self, call_id: &str) -> Option<String> {
|
||||
self.call_id_to_name.read().await.get(call_id).cloned()
|
||||
}
|
||||
|
||||
/// Save the last captured function calls (for history rewriting).
|
||||
pub async fn set_last_function_calls(&self, calls: Vec<CapturedFunctionCall>) {
|
||||
*self.last_function_calls.write().await = calls;
|
||||
}
|
||||
|
||||
/// Get the last captured function calls.
|
||||
pub async fn get_last_function_calls(&self) -> Vec<CapturedFunctionCall> {
|
||||
self.last_function_calls.read().await.clone()
|
||||
}
|
||||
|
||||
/// Store multi-round tool call history for correct per-turn history rewriting.
|
||||
pub async fn set_tool_rounds(&self, rounds: Vec<ToolRound>) {
|
||||
*self.tool_rounds.write().await = rounds;
|
||||
}
|
||||
|
||||
/// Take (consume) multi-round tool call history.
|
||||
pub async fn take_tool_rounds(&self) -> Vec<ToolRound> {
|
||||
std::mem::take(&mut *self.tool_rounds.write().await)
|
||||
}
|
||||
|
||||
/// Get (non-destructive clone) multi-round tool call history.
|
||||
/// Used by proxy.rs to read rounds without consuming them, so they
|
||||
/// persist across multiple LS requests in the same cascade.
|
||||
pub async fn get_tool_rounds(&self) -> Vec<ToolRound> {
|
||||
self.tool_rounds.read().await.clone()
|
||||
/// Peek at the thought_signatures of recently captured function calls.
|
||||
/// Returns a map of function_name → thought_signature (non-destructive).
|
||||
pub async fn peek_thought_signatures(&self) -> std::collections::HashMap<String, String> {
|
||||
let pending = self.pending_function_calls.read().await;
|
||||
let mut sigs = std::collections::HashMap::new();
|
||||
for calls in pending.values() {
|
||||
for fc in calls {
|
||||
if let Some(ref sig) = fc.thought_signature {
|
||||
sigs.insert(fc.name.clone(), sig.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
sigs
|
||||
}
|
||||
|
||||
// ── Legacy direct response capture (search.rs fallback) ──────────────
|
||||
|
||||
/// Set (replace) the captured response text.
|
||||
/// Used by MITM proxy for non-channel path (search endpoint fallback).
|
||||
pub async fn set_response_text(&self, text: &str) {
|
||||
*self.captured_response_text.write().await = Some(text.to_string());
|
||||
}
|
||||
@@ -568,71 +551,11 @@ impl MitmStore {
|
||||
self.captured_response_text.write().await.take()
|
||||
}
|
||||
|
||||
/// Clear stale state between requests.
|
||||
/// Drops any active channel and clears in-flight flags.
|
||||
/// Clear stale legacy response state.
|
||||
pub async fn clear_response_async(&self) {
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
*self.active_channel.write().await = None;
|
||||
*self.captured_response_text.write().await = None;
|
||||
}
|
||||
|
||||
/// Atomically try to mark request as in-flight.
|
||||
/// Returns true if this caller won the race (was first to set it).
|
||||
/// Returns false if already in-flight (someone else set it first).
|
||||
pub fn try_mark_request_in_flight(&self) -> bool {
|
||||
self.request_in_flight
|
||||
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
/// Check if a request is currently in-flight.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_request_in_flight(&self) -> bool {
|
||||
self.request_in_flight.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Clear the in-flight flag so the LS can make follow-up requests.
|
||||
pub fn clear_request_in_flight(&self) {
|
||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
// ── Cascade correlation ──────────────────────────────────────────────
|
||||
|
||||
/// Set the active cascade ID (called by API handlers before sending a message).
|
||||
/// The MITM proxy will use this to correlate intercepted traffic.
|
||||
pub async fn set_active_cascade(&self, cascade_id: &str) {
|
||||
*self.active_cascade_id.write().await = Some(cascade_id.to_string());
|
||||
}
|
||||
|
||||
/// Get the active cascade ID.
|
||||
#[allow(dead_code)]
|
||||
pub async fn get_active_cascade(&self) -> Option<String> {
|
||||
self.active_cascade_id.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear the active cascade ID (called after response is complete).
|
||||
#[allow(dead_code)]
|
||||
pub async fn clear_active_cascade(&self) {
|
||||
*self.active_cascade_id.write().await = None;
|
||||
}
|
||||
|
||||
// ── Generation parameters ────────────────────────────────────────────
|
||||
|
||||
/// Store client-specified generation parameters for MITM injection.
|
||||
pub async fn set_generation_params(&self, params: GenerationParams) {
|
||||
*self.generation_params.write().await = Some(params);
|
||||
}
|
||||
|
||||
/// Read current generation parameters (non-consuming).
|
||||
pub async fn get_generation_params(&self) -> Option<GenerationParams> {
|
||||
self.generation_params.read().await.clone()
|
||||
}
|
||||
|
||||
/// Clear generation parameters.
|
||||
pub async fn clear_generation_params(&self) {
|
||||
*self.generation_params.write().await = None;
|
||||
}
|
||||
|
||||
// ── Grounding metadata capture ──────────────────────────────────────
|
||||
|
||||
/// Store captured grounding metadata from API response.
|
||||
@@ -652,46 +575,35 @@ impl MitmStore {
|
||||
self.captured_grounding.read().await.clone()
|
||||
}
|
||||
|
||||
// ── Pending image for MITM injection ─────────────────────────────────
|
||||
// ── Compat shims for streaming tool-call loops ──────────────────────
|
||||
|
||||
/// Store a pending image for MITM injection.
|
||||
pub async fn set_pending_image(&self, image: PendingImage) {
|
||||
*self.pending_image.write().await = Some(image);
|
||||
/// Update the event channel on an existing request context.
|
||||
/// Used by streaming loop handlers when re-registering for a new tool round.
|
||||
pub async fn set_channel(&self, cascade_id: &str, tx: mpsc::Sender<MitmEvent>) {
|
||||
self.update_request(cascade_id, |ctx| {
|
||||
ctx.event_channel = Some(tx);
|
||||
}).await;
|
||||
}
|
||||
|
||||
/// Take (consume) pending image for injection.
|
||||
pub async fn take_pending_image(&self) -> Option<PendingImage> {
|
||||
self.pending_image.write().await.take()
|
||||
}
|
||||
|
||||
// ── Upstream error capture ───────────────────────────────────────────
|
||||
|
||||
/// Store an upstream error from Google's API.
|
||||
pub async fn set_upstream_error(&self, error: UpstreamError) {
|
||||
*self.upstream_error.write().await = Some(error);
|
||||
}
|
||||
|
||||
/// Take (consume) captured upstream error.
|
||||
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
|
||||
self.upstream_error.write().await.take()
|
||||
}
|
||||
|
||||
/// Clear any stored upstream error.
|
||||
/// No-op. Upstream errors are now delivered through the event channel.
|
||||
/// Kept for API handler compatibility.
|
||||
pub async fn clear_upstream_error(&self) {
|
||||
*self.upstream_error.write().await = None;
|
||||
// Intentionally empty — errors flow through MitmEvent::UpstreamError
|
||||
}
|
||||
|
||||
// ── Pending user text for MITM injection ─────────────────────────────
|
||||
|
||||
/// Store the real user text for MITM injection.
|
||||
/// Called by API handlers before sending a dummy prompt to the LS.
|
||||
pub async fn set_pending_user_text(&self, text: String) {
|
||||
*self.pending_user_text.write().await = Some(text);
|
||||
/// Returns None. Upstream errors are now captured and delivered via the
|
||||
/// per-request event channel rather than stored globally.
|
||||
pub async fn take_upstream_error(&self) -> Option<UpstreamError> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Take (consume) the pending user text.
|
||||
/// Called by the MITM proxy when building ToolContext.
|
||||
pub async fn take_pending_user_text(&self) -> Option<String> {
|
||||
self.pending_user_text.write().await.take()
|
||||
/// Store a call_id → function_name mapping in the request context.
|
||||
/// Used by streaming tool-call loops when the model returns function calls.
|
||||
pub async fn register_call_id(&self, cascade_id: &str, call_id: String, name: String) {
|
||||
self.update_request(cascade_id, |ctx| {
|
||||
ctx.call_id_to_name.insert(call_id, name);
|
||||
}).await;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -9,6 +9,10 @@
|
||||
//! carries the `detect_and_use_proxy` enum, model selection, and version info.
|
||||
//! See `docs/ls-binary-analysis.md` for the full proto schema reverse engineering.
|
||||
|
||||
pub mod wire;
|
||||
|
||||
|
||||
|
||||
use crate::constants::{client_version, CLIENT_NAME};
|
||||
|
||||
// ─── Wire primitives ────────────────────────────────────────────────────────
|
||||
159
src/proto/wire.rs
Normal file
159
src/proto/wire.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
//! Shared protobuf wire-format primitives — decode + encode.
|
||||
//!
|
||||
//! This module is the single source of truth for varint encoding/decoding,
|
||||
//! proto string encoding/extraction, etc. All other modules should import
|
||||
//! from here instead of rolling their own.
|
||||
|
||||
/// Decode a varint from a byte slice. Returns `(value, bytes_consumed)`.
|
||||
///
|
||||
/// This is the canonical decoder — all other modules should use this.
|
||||
pub fn decode_varint(buf: &[u8]) -> Option<(u64, usize)> {
|
||||
let mut result: u64 = 0;
|
||||
let mut shift = 0u32;
|
||||
for (i, &byte) in buf.iter().enumerate() {
|
||||
if i >= 10 {
|
||||
return None; // Too many bytes for a varint
|
||||
}
|
||||
result |= ((byte & 0x7F) as u64) << shift;
|
||||
if byte & 0x80 == 0 {
|
||||
return Some((result, i + 1));
|
||||
}
|
||||
shift += 7;
|
||||
if shift >= 64 {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Encode a varint into an existing buffer.
|
||||
pub fn encode_varint(buf: &mut Vec<u8>, mut val: u64) {
|
||||
loop {
|
||||
let byte = (val & 0x7F) as u8;
|
||||
val >>= 7;
|
||||
if val == 0 {
|
||||
buf.push(byte);
|
||||
break;
|
||||
}
|
||||
buf.push(byte | 0x80);
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a string/bytes value as a protobuf length-delimited field.
|
||||
///
|
||||
/// Produces: `[tag(field_num, wire_type=2)] [len] [data]`
|
||||
pub fn encode_proto_string(field_num: u32, data: &[u8]) -> Vec<u8> {
|
||||
let tag = (field_num << 3) | 2; // wire type 2 = length-delimited
|
||||
let mut buf = Vec::with_capacity(1 + 5 + data.len());
|
||||
encode_varint(&mut buf, tag as u64);
|
||||
encode_varint(&mut buf, data.len() as u64);
|
||||
buf.extend_from_slice(data);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Extract a string field from raw protobuf bytes by field number.
|
||||
///
|
||||
/// Walks top-level fields, skipping varints, 64-bit, 32-bit, and other
|
||||
/// length-delimited fields until the target field number is found.
|
||||
/// Only returns the first occurrence.
|
||||
pub fn extract_proto_string(buf: &[u8], target_field: u32) -> Option<String> {
|
||||
let mut i = 0;
|
||||
while i < buf.len() {
|
||||
let (tag, consumed) = decode_varint(&buf[i..])?;
|
||||
i += consumed;
|
||||
let field_num = (tag >> 3) as u32;
|
||||
let wire_type = (tag & 0x07) as u8;
|
||||
|
||||
match wire_type {
|
||||
0 => {
|
||||
// Varint — skip
|
||||
let (_, c) = decode_varint(&buf[i..])?;
|
||||
i += c;
|
||||
}
|
||||
1 => {
|
||||
// 64-bit fixed — skip 8 bytes
|
||||
if i + 8 > buf.len() {
|
||||
return None;
|
||||
}
|
||||
i += 8;
|
||||
}
|
||||
2 => {
|
||||
// Length-delimited
|
||||
let (len, c) = decode_varint(&buf[i..])?;
|
||||
i += c;
|
||||
let len = len as usize;
|
||||
if i + len > buf.len() {
|
||||
return None;
|
||||
}
|
||||
if field_num == target_field {
|
||||
return String::from_utf8(buf[i..i + len].to_vec()).ok();
|
||||
}
|
||||
i += len;
|
||||
}
|
||||
5 => {
|
||||
// 32-bit fixed — skip 4 bytes
|
||||
if i + 4 > buf.len() {
|
||||
return None;
|
||||
}
|
||||
i += 4;
|
||||
}
|
||||
_ => return None, // Unknown wire type
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decode_varint_basic() {
|
||||
assert_eq!(decode_varint(&[0x00]), Some((0, 1)));
|
||||
assert_eq!(decode_varint(&[0x01]), Some((1, 1)));
|
||||
assert_eq!(decode_varint(&[0x7F]), Some((127, 1)));
|
||||
assert_eq!(decode_varint(&[0x80, 0x01]), Some((128, 2)));
|
||||
assert_eq!(decode_varint(&[0x96, 0x01]), Some((150, 2)));
|
||||
assert_eq!(decode_varint(&[0xAC, 0x02]), Some((300, 2)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_encode_decode_roundtrip() {
|
||||
for val in [0u64, 1, 127, 128, 300, 1026, u32::MAX as u64, u64::MAX] {
|
||||
let mut buf = Vec::new();
|
||||
encode_varint(&mut buf, val);
|
||||
let (decoded, consumed) = decode_varint(&buf).unwrap();
|
||||
assert_eq!(decoded, val, "roundtrip failed for {val}");
|
||||
assert_eq!(consumed, buf.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_encode_proto_string() {
|
||||
let result = encode_proto_string(1, b"hello");
|
||||
// tag(1,2) = 0x0A, len=5, h,e,l,l,o
|
||||
assert_eq!(result[0], 0x0A);
|
||||
assert_eq!(result[1], 0x05);
|
||||
assert_eq!(&result[2..], b"hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_proto_string() {
|
||||
// Build: field 1 = "abc", field 2 (varint) = 42, field 3 = "xyz"
|
||||
let mut buf = Vec::new();
|
||||
buf.extend_from_slice(&encode_proto_string(1, b"abc"));
|
||||
// field 2 varint 42: tag = (2<<3)|0 = 0x10, value = 0x2A
|
||||
buf.push(0x10);
|
||||
buf.push(0x2A);
|
||||
buf.extend_from_slice(&encode_proto_string(3, b"xyz"));
|
||||
|
||||
assert_eq!(extract_proto_string(&buf, 1), Some("abc".to_string()));
|
||||
assert_eq!(extract_proto_string(&buf, 3), Some("xyz".to_string()));
|
||||
assert_eq!(extract_proto_string(&buf, 99), None);
|
||||
}
|
||||
}
|
||||
@@ -8,7 +8,6 @@ use std::collections::HashMap;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
const DEFAULT_SESSION: &str = "__default__";
|
||||
const SESSION_TTL_SECS: u64 = 3600 * 4; // 4 hours
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -23,10 +22,7 @@ pub struct SessionManager {
|
||||
sessions: RwLock<HashMap<String, Session>>,
|
||||
}
|
||||
|
||||
/// Result of session resolution.
|
||||
pub struct SessionResult {
|
||||
pub cascade_id: String,
|
||||
}
|
||||
|
||||
|
||||
impl SessionManager {
|
||||
pub fn new() -> Self {
|
||||
@@ -35,82 +31,7 @@ impl SessionManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get existing cascade for session, or create a new one.
|
||||
///
|
||||
/// - `session_id = None` → use default session
|
||||
/// - `session_id = Some("new")` → always create fresh cascade
|
||||
/// - `session_id = Some("my-task")` → reuse cascade for that task
|
||||
///
|
||||
/// Uses double-check locking to avoid TOCTOU races: after creating a cascade,
|
||||
/// re-acquires the lock and checks if another request raced us.
|
||||
pub async fn get_or_create<F, Fut>(
|
||||
&self,
|
||||
session_id: Option<&str>,
|
||||
create_fn: F,
|
||||
) -> Result<SessionResult, String>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<String, String>>,
|
||||
{
|
||||
// "new" always creates a fresh cascade
|
||||
if session_id == Some("new") {
|
||||
let cascade_id = create_fn().await?;
|
||||
let new_sid = format!("s-{}", &uuid::Uuid::new_v4().to_string()[..8]);
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(
|
||||
new_sid.clone(),
|
||||
Session {
|
||||
cascade_id: cascade_id.clone(),
|
||||
created: Instant::now(),
|
||||
last_used: Instant::now(),
|
||||
msg_count: 0,
|
||||
},
|
||||
);
|
||||
return Ok(SessionResult { cascade_id });
|
||||
}
|
||||
|
||||
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
|
||||
|
||||
// Check existing — only need write lock for cleanup + mutation
|
||||
{
|
||||
let mut sessions = self.sessions.write().await;
|
||||
cleanup_expired(&mut sessions);
|
||||
if let Some(sess) = sessions.get_mut(&sid) {
|
||||
sess.last_used = Instant::now();
|
||||
sess.msg_count += 1;
|
||||
return Ok(SessionResult {
|
||||
cascade_id: sess.cascade_id.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Lock released before async create_fn
|
||||
|
||||
// Create new cascade (this may take a while — lock is NOT held)
|
||||
let cascade_id = create_fn().await?;
|
||||
|
||||
// Double-check: another request may have raced us and created the same session
|
||||
{
|
||||
let mut sessions = self.sessions.write().await;
|
||||
if let Some(existing) = sessions.get_mut(&sid) {
|
||||
// Another request won the race — use their cascade, discard ours
|
||||
existing.last_used = Instant::now();
|
||||
existing.msg_count += 1;
|
||||
return Ok(SessionResult {
|
||||
cascade_id: existing.cascade_id.clone(),
|
||||
});
|
||||
}
|
||||
sessions.insert(
|
||||
sid.clone(),
|
||||
Session {
|
||||
cascade_id: cascade_id.clone(),
|
||||
created: Instant::now(),
|
||||
last_used: Instant::now(),
|
||||
msg_count: 1,
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok(SessionResult { cascade_id })
|
||||
}
|
||||
|
||||
/// List all active sessions.
|
||||
pub async fn list_sessions(&self) -> serde_json::Value {
|
||||
|
||||
1375
src/standalone.rs
1375
src/standalone.rs
File diff suppressed because it is too large
Load Diff
340
src/standalone/discovery.rs
Normal file
340
src/standalone/discovery.rs
Normal file
@@ -0,0 +1,340 @@
|
||||
//! LS process discovery — finding, inspecting, and managing LS processes.
|
||||
|
||||
use super::{MainLSConfig, LS_USER};
|
||||
use crate::proto::wire::extract_proto_string;
|
||||
use std::net::TcpListener;
|
||||
use std::process::{Command, Stdio};
|
||||
use tracing::info;
|
||||
|
||||
/// Discover only the extension_server_port and csrf_token from the running main LS.
|
||||
///
|
||||
/// This does NOT discover the HTTPS port — we don't need to talk to the real LS,
|
||||
/// only steal its extension server connection info.
|
||||
pub fn discover_main_ls_config() -> Result<MainLSConfig, String> {
|
||||
let pid = find_main_ls_pid()?;
|
||||
|
||||
let cmdline = std::fs::read(format!("/proc/{pid}/cmdline"))
|
||||
.map_err(|e| format!("Can't read cmdline for PID {pid}: {e}"))?;
|
||||
let args: Vec<&[u8]> = cmdline.split(|&b| b == 0).collect();
|
||||
|
||||
let mut csrf = String::new();
|
||||
let mut ext_port = String::new();
|
||||
|
||||
for (i, arg) in args.iter().enumerate() {
|
||||
if let Ok(s) = std::str::from_utf8(arg) {
|
||||
match s {
|
||||
"--csrf_token" | "-csrf_token" => {
|
||||
if let Some(next) = args.get(i + 1) {
|
||||
if let Ok(val) = std::str::from_utf8(next) {
|
||||
csrf = val.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
"--extension_server_port" | "-extension_server_port" => {
|
||||
if let Some(next) = args.get(i + 1) {
|
||||
if let Ok(val) = std::str::from_utf8(next) {
|
||||
ext_port = val.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if csrf.is_empty() {
|
||||
return Err("Could not find CSRF token from main LS".to_string());
|
||||
}
|
||||
if ext_port.is_empty() {
|
||||
return Err("Could not find extension_server_port from main LS".to_string());
|
||||
}
|
||||
|
||||
info!(
|
||||
pid,
|
||||
ext_port,
|
||||
csrf_len = csrf.len(),
|
||||
"Discovered main LS config"
|
||||
);
|
||||
|
||||
Ok(MainLSConfig {
|
||||
extension_server_port: ext_port,
|
||||
csrf,
|
||||
})
|
||||
}
|
||||
|
||||
/// Find the PID of the main (real) LS process.
|
||||
///
|
||||
/// Checks `/proc/<pid>/exe` to ensure we find the actual LS binary,
|
||||
/// not bash scripts that happen to mention `language_server_linux` in their args.
|
||||
pub(super) fn find_main_ls_pid() -> Result<String, String> {
|
||||
let proc = std::path::Path::new("/proc");
|
||||
if !proc.exists() {
|
||||
return Err("No /proc filesystem".to_string());
|
||||
}
|
||||
|
||||
let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
// Only numeric dirs (PIDs)
|
||||
if !name_str.chars().all(|c| c.is_ascii_digit()) {
|
||||
continue;
|
||||
}
|
||||
let exe_link = entry.path().join("exe");
|
||||
if let Ok(target) = std::fs::read_link(&exe_link) {
|
||||
let target_str = target.to_string_lossy().to_string();
|
||||
let target_clean = target_str.trim_end_matches(" (deleted)");
|
||||
// Must be the actual LS binary, not a bash script
|
||||
if target_clean.contains("language_server_linux")
|
||||
|| target_clean.contains("antigravity-language-server")
|
||||
{
|
||||
return Ok(name_str.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err("No main LS process found — Antigravity must be running".to_string())
|
||||
}
|
||||
|
||||
/// Find a free TCP port by binding to port 0.
|
||||
pub(super) fn find_free_port() -> Result<u16, String> {
|
||||
let listener =
|
||||
TcpListener::bind("127.0.0.1:0").map_err(|e| format!("Failed to bind for port: {e}"))?;
|
||||
listener
|
||||
.local_addr()
|
||||
.map(|a| a.port())
|
||||
.map_err(|e| format!("Failed to get port: {e}"))
|
||||
}
|
||||
|
||||
/// Check if the dedicated LS system user exists.
|
||||
///
|
||||
/// When the user exists, the proxy spawns the LS as that UID so iptables
|
||||
/// can scope the :443 redirect to only the standalone LS process.
|
||||
pub(super) fn has_ls_user() -> bool {
|
||||
Command::new("id")
|
||||
.args(["-u", LS_USER])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Find the PID of a language_server process owned by a specific user.
|
||||
///
|
||||
/// Used to discover the actual LS process after sudo spawns it as a different user.
|
||||
pub(super) fn find_ls_pid_for_user(user: &str) -> Result<u32, String> {
|
||||
let output = Command::new("pgrep")
|
||||
.args(["-u", user, "-f", "language_server_linux"])
|
||||
.output()
|
||||
.map_err(|e| format!("pgrep failed: {e}"))?;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
stdout
|
||||
.lines()
|
||||
.next()
|
||||
.and_then(|line| line.trim().parse::<u32>().ok())
|
||||
.ok_or_else(|| format!("No LS process found for user {user}"))
|
||||
}
|
||||
|
||||
/// Kill any orphaned standalone LS processes from previous runs.
|
||||
///
|
||||
/// This handles the case where the proxy crashed or was killed without
|
||||
/// properly cleaning up the sudo-spawned LS process.
|
||||
///
|
||||
/// Key insight: the sudoers rule allows running commands AS antigravity-ls
|
||||
/// (`ALL=(antigravity-ls) NOPASSWD: ALL`). A process can send signals to
|
||||
/// other processes with the same UID, so we run `kill` as antigravity-ls
|
||||
/// rather than as root.
|
||||
pub(super) fn cleanup_orphaned_ls() {
|
||||
if !has_ls_user() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Find all LS processes owned by antigravity-ls user
|
||||
let output = Command::new("pgrep")
|
||||
.args(["-u", LS_USER, "-f", "language_server_linux"])
|
||||
.output();
|
||||
|
||||
let pids: Vec<u32> = match output {
|
||||
Ok(out) => String::from_utf8_lossy(&out.stdout)
|
||||
.lines()
|
||||
.filter_map(|l| l.trim().parse().ok())
|
||||
.collect(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
if pids.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!(
|
||||
count = pids.len(),
|
||||
?pids,
|
||||
"Cleaning up orphaned standalone LS processes"
|
||||
);
|
||||
|
||||
// Kill each PID by running `kill` AS the antigravity-ls user.
|
||||
// This works because same-UID processes can signal each other,
|
||||
// and the sudoers rule allows ALL commands as antigravity-ls.
|
||||
for pid in &pids {
|
||||
let ok = Command::new("sudo")
|
||||
.args(["-n", "-u", LS_USER, "kill", "-TERM", &pid.to_string()])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if ok {
|
||||
info!(pid, "Killed orphaned LS process");
|
||||
} else {
|
||||
// Fallback: try as root (needs separate sudoers entry)
|
||||
let _ = Command::new("sudo")
|
||||
.args(["-n", "kill", "-TERM", &pid.to_string()])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status();
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for graceful exit
|
||||
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||
|
||||
// Force-kill any survivors
|
||||
let still_alive = Command::new("pgrep")
|
||||
.args(["-u", LS_USER, "-f", "language_server_linux"])
|
||||
.output()
|
||||
.map(|o| !o.stdout.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
if still_alive {
|
||||
info!("Orphaned LS still alive, force killing");
|
||||
for pid in &pids {
|
||||
let _ = Command::new("sudo")
|
||||
.args(["-n", "-u", LS_USER, "kill", "-KILL", &pid.to_string()])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status();
|
||||
}
|
||||
std::thread::sleep(std::time::Duration::from_millis(300));
|
||||
|
||||
// Final check
|
||||
let still_alive = Command::new("pgrep")
|
||||
.args(["-u", LS_USER, "-f", "language_server_linux"])
|
||||
.output()
|
||||
.map(|o| !o.stdout.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
if still_alive {
|
||||
eprintln!("\n \x1b[1;31m⚠ Cannot kill orphaned LS process\x1b[0m");
|
||||
eprintln!(" Run: \x1b[1msudo pkill -u {LS_USER} -f language_server_linux\x1b[0m\n");
|
||||
}
|
||||
} else {
|
||||
info!("Orphaned LS processes cleaned up");
|
||||
}
|
||||
}
|
||||
|
||||
/// Read OAuth token state directly from Antigravity's state.vscdb.
|
||||
///
|
||||
/// The DB stores the exact Topic proto bytes under key `antigravityUnifiedStateSync.oauthToken`.
|
||||
/// This includes access_token + refresh_token + expiry, allowing the LS to auto-refresh.
|
||||
/// Returns (access_token, topic_proto_bytes) or None if unavailable.
|
||||
pub(super) fn read_oauth_from_state_db() -> Option<(String, Vec<u8>)> {
|
||||
use base64::Engine;
|
||||
|
||||
let home = std::env::var("HOME").ok()?;
|
||||
let db_path = format!("{home}/.config/Antigravity/User/globalStorage/state.vscdb");
|
||||
|
||||
// Check the DB file exists
|
||||
if !std::path::Path::new(&db_path).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Read the Topic proto (base64-encoded in the DB)
|
||||
let output = std::process::Command::new("sqlite3")
|
||||
.args([
|
||||
&db_path,
|
||||
"SELECT value FROM ItemTable WHERE key='antigravityUnifiedStateSync.oauthToken'",
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let b64_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if b64_str.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Decode the base64 to get the raw Topic proto bytes
|
||||
let topic_bytes = base64::engine::general_purpose::STANDARD
|
||||
.decode(&b64_str)
|
||||
.ok()?;
|
||||
|
||||
if topic_bytes.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract the access_token from the OAuthTokenInfo inside the Topic proto.
|
||||
// The inner value (Row.value) is also base64, containing a serialized OAuthTokenInfo.
|
||||
// For the access_token (used by GetSecretValue), we can read it from the authStatus.
|
||||
let access_token = read_access_token_from_auth_status(&db_path)
|
||||
.or_else(|| extract_access_token_from_topic(&topic_bytes))
|
||||
.unwrap_or_default();
|
||||
|
||||
Some((access_token, topic_bytes))
|
||||
}
|
||||
|
||||
/// Read the current access token from `antigravityAuthStatus` in state.vscdb.
|
||||
/// This JSON object has an `apiKey` field with the latest access token.
|
||||
fn read_access_token_from_auth_status(db_path: &str) -> Option<String> {
|
||||
let output = std::process::Command::new("sqlite3")
|
||||
.args([
|
||||
db_path,
|
||||
"SELECT value FROM ItemTable WHERE key='antigravityAuthStatus'",
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let json_str = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
// Simple extraction: find "apiKey":"..." pattern
|
||||
let marker = "\"apiKey\":\"";
|
||||
let start = json_str.find(marker)? + marker.len();
|
||||
let end = json_str[start..].find('"')? + start;
|
||||
let api_key = &json_str[start..end];
|
||||
if api_key.starts_with("ya29.") {
|
||||
Some(api_key.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract access_token from the Topic proto bytes by finding the inner
|
||||
/// base64-encoded OAuthTokenInfo and decoding its first string field.
|
||||
fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
|
||||
use base64::Engine;
|
||||
|
||||
let as_str = String::from_utf8_lossy(topic_bytes);
|
||||
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
|
||||
{
|
||||
if segment.len() > 50 {
|
||||
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
|
||||
// Use shared proto decoder
|
||||
if let Some(token) = extract_proto_string(&decoded, 1) {
|
||||
if token.starts_with("ya29.") {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
137
src/standalone/mod.rs
Normal file
137
src/standalone/mod.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
//! Standalone Language Server — spawn and lifecycle management.
|
||||
//!
|
||||
//! Launches an isolated LS instance as a child process that the proxy fully owns.
|
||||
//! In **headless** mode, the LS runs completely independently — no running
|
||||
//! Antigravity app required. Extension server is disabled (`port=0`), CSRF is
|
||||
//! self-generated, and MITM uses `HTTPS_PROXY` instead of iptables.
|
||||
|
||||
mod discovery;
|
||||
mod spawn;
|
||||
mod stub;
|
||||
|
||||
use std::process::Command;
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
|
||||
// Re-export public API
|
||||
pub use spawn::StandaloneLS;
|
||||
|
||||
/// Default path to the LS binary.
|
||||
const LS_BINARY_PATH: &str =
|
||||
"/usr/share/antigravity/resources/app/extensions/antigravity/bin/language_server_linux_x64";
|
||||
|
||||
/// App root for ANTIGRAVITY_EDITOR_APP_ROOT env var.
|
||||
const APP_ROOT: &str = "/usr/share/antigravity/resources/app";
|
||||
|
||||
/// Data directory for the standalone LS.
|
||||
const DATA_DIR: &str = "/tmp/antigravity-standalone";
|
||||
|
||||
/// System user for UID-scoped iptables isolation.
|
||||
const LS_USER: &str = "antigravity-ls";
|
||||
|
||||
/// Path for the compiled dns_redirect.so preload library.
|
||||
const DNS_REDIRECT_SO_PATH: &str = "/tmp/antigravity-dns-redirect.so";
|
||||
|
||||
/// Source file for the DNS redirect preload library (relative to binary).
|
||||
const DNS_REDIRECT_C_SOURCE: &str = include_str!("../mitm/dns_redirect.c");
|
||||
|
||||
/// Config needed to bootstrap the standalone LS.
|
||||
///
|
||||
/// In normal mode, discovered from the running main LS.
|
||||
/// In headless mode, generated entirely by the proxy.
|
||||
pub struct MainLSConfig {
|
||||
pub extension_server_port: String,
|
||||
pub csrf: String,
|
||||
}
|
||||
|
||||
/// Optional MITM proxy config for the standalone LS.
|
||||
pub struct StandaloneMitmConfig {
|
||||
pub proxy_addr: String, // Full URL with scheme, e.g. "http://127.0.0.1:8742"
|
||||
pub ca_cert_path: String, // path to MITM CA .pem
|
||||
}
|
||||
|
||||
/// Generate a fully self-contained config for headless mode.
|
||||
///
|
||||
/// No running Antigravity instance needed — extension server is disabled
|
||||
/// and CSRF is a random UUID.
|
||||
pub fn generate_standalone_config() -> MainLSConfig {
|
||||
let csrf = Uuid::new_v4().to_string();
|
||||
info!(
|
||||
csrf_len = csrf.len(),
|
||||
"Generated standalone config (headless)"
|
||||
);
|
||||
MainLSConfig {
|
||||
extension_server_port: "0".to_string(), // disables extension server
|
||||
csrf,
|
||||
}
|
||||
}
|
||||
|
||||
/// Discover only the extension_server_port and csrf_token from the running main LS.
|
||||
///
|
||||
/// This does NOT discover the HTTPS port — we don't need to talk to the real LS,
|
||||
/// only steal its extension server connection info.
|
||||
pub fn discover_main_ls_config() -> Result<MainLSConfig, String> {
|
||||
discovery::discover_main_ls_config()
|
||||
}
|
||||
|
||||
/// Build the dns_redirect.so preload library if it doesn't already exist.
|
||||
///
|
||||
/// The library hooks `getaddrinfo()` via LD_PRELOAD to redirect Google API
|
||||
/// domain lookups to 127.0.0.1. This is needed because the LS binary uses
|
||||
/// CGO for DNS resolution (libc getaddrinfo) but raw syscalls for connect(),
|
||||
/// so only DNS can be intercepted via LD_PRELOAD.
|
||||
///
|
||||
/// Returns the path to the .so on success, None on failure.
|
||||
fn build_dns_redirect_so() -> Option<String> {
|
||||
let so_path = DNS_REDIRECT_SO_PATH;
|
||||
|
||||
// Skip rebuild if already exists
|
||||
if std::path::Path::new(so_path).exists() {
|
||||
return Some(so_path.to_string());
|
||||
}
|
||||
|
||||
// Write C source to a temp file
|
||||
let c_path = format!("{so_path}.c");
|
||||
if let Err(e) = std::fs::write(&c_path, DNS_REDIRECT_C_SOURCE) {
|
||||
tracing::warn!("Failed to write dns_redirect.c: {e}");
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compile: gcc -shared -fPIC -o dns_redirect.so dns_redirect.c -ldl
|
||||
let output = Command::new("gcc")
|
||||
.args(["-shared", "-fPIC", "-o", so_path, &c_path, "-ldl"])
|
||||
.output();
|
||||
|
||||
match output {
|
||||
Ok(out) if out.status.success() => {
|
||||
info!("Built dns_redirect.so at {so_path}");
|
||||
// Clean up source
|
||||
let _ = std::fs::remove_file(&c_path);
|
||||
Some(so_path.to_string())
|
||||
}
|
||||
Ok(out) => {
|
||||
let stderr = String::from_utf8_lossy(&out.stderr);
|
||||
tracing::warn!("Failed to compile dns_redirect.so: {stderr}");
|
||||
None
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("gcc not found, cannot build dns_redirect.so: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::discovery::find_free_port;
|
||||
use std::net::TcpListener;
|
||||
|
||||
#[test]
|
||||
fn test_find_free_port() {
|
||||
let port = find_free_port().unwrap();
|
||||
assert!(port > 0);
|
||||
// Port should be available — try binding to it
|
||||
let listener = TcpListener::bind(format!("127.0.0.1:{port}"));
|
||||
assert!(listener.is_ok(), "Port {port} should be free");
|
||||
}
|
||||
}
|
||||
464
src/standalone/spawn.rs
Normal file
464
src/standalone/spawn.rs
Normal file
@@ -0,0 +1,464 @@
|
||||
//! StandaloneLS — process lifecycle (spawn, wait, kill).
|
||||
|
||||
use super::discovery::{cleanup_orphaned_ls, find_free_port, find_ls_pid_for_user, has_ls_user, read_oauth_from_state_db};
|
||||
use super::stub::stub_handle_connection;
|
||||
use super::{build_dns_redirect_so, MainLSConfig, StandaloneMitmConfig, APP_ROOT, DATA_DIR, LS_BINARY_PATH, LS_USER};
|
||||
use crate::constants;
|
||||
use crate::proto;
|
||||
use std::io::Write;
|
||||
use std::net::TcpListener;
|
||||
use std::process::{Child, Command, Stdio};
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// A running standalone LS process.
|
||||
pub struct StandaloneLS {
|
||||
child: Child,
|
||||
/// The actual LS process PID (may differ from child PID when spawned via sudo).
|
||||
ls_pid: Option<u32>,
|
||||
/// Whether the LS was spawned via sudo (needs sudo kill).
|
||||
use_sudo: bool,
|
||||
/// Whether kill() has already been called.
|
||||
killed: bool,
|
||||
pub port: u16,
|
||||
pub csrf: String,
|
||||
}
|
||||
|
||||
impl StandaloneLS {
|
||||
/// Spawn a standalone LS process.
|
||||
///
|
||||
/// Discovers the main LS's extension server port and CSRF token,
|
||||
/// picks a free port, builds init metadata, and launches the binary.
|
||||
///
|
||||
/// If `mitm_config` is provided, sets HTTPS_PROXY and SSL_CERT_FILE
|
||||
/// so the LS routes LLM API calls through the MITM proxy.
|
||||
pub fn spawn(
|
||||
main_config: &MainLSConfig,
|
||||
mitm_config: Option<&StandaloneMitmConfig>,
|
||||
headless: bool,
|
||||
) -> Result<Self, String> {
|
||||
// Kill any orphaned LS processes from previous runs
|
||||
cleanup_orphaned_ls();
|
||||
let port = find_free_port()?;
|
||||
let lsp_port = find_free_port()?;
|
||||
let ts = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
// Build init metadata protobuf
|
||||
let api_key = format!("standalone-api-key-{ts}");
|
||||
let session_id = format!("standalone-session-{ts}");
|
||||
let metadata = proto::build_init_metadata(
|
||||
&api_key,
|
||||
constants::antigravity_version(),
|
||||
constants::client_version(),
|
||||
&session_id,
|
||||
1, // DETECT_AND_USE_PROXY_ENABLED
|
||||
);
|
||||
|
||||
// Setup data dir (mode 1777 so both current user and antigravity-ls can write)
|
||||
let gemini_dir = format!("{DATA_DIR}/.gemini");
|
||||
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
|
||||
let annotations_dir = format!("{app_data_dir}/annotations");
|
||||
let brain_dir = format!("{app_data_dir}/brain");
|
||||
for dir in [
|
||||
DATA_DIR,
|
||||
&gemini_dir,
|
||||
&app_data_dir,
|
||||
&annotations_dir,
|
||||
&brain_dir,
|
||||
] {
|
||||
let _ = std::fs::create_dir_all(dir);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(dir, std::fs::Permissions::from_mode(0o1777));
|
||||
}
|
||||
}
|
||||
// Check if data dir is writable by writing a test file.
|
||||
// Old runs as `antigravity-ls` user leave dirs owned by that user.
|
||||
let test_path = format!("{app_data_dir}/.write_test");
|
||||
if std::fs::write(&test_path, b"ok").is_err() {
|
||||
eprintln!(
|
||||
"\n ⚠ Data dir {} is not writable (owned by another user from previous sudo run)\n \
|
||||
Fix with: sudo chmod -R a+rwX {}\n",
|
||||
app_data_dir, DATA_DIR
|
||||
);
|
||||
} else {
|
||||
let _ = std::fs::remove_file(&test_path);
|
||||
}
|
||||
|
||||
// Pre-seed user_settings.pb with detect_and_use_proxy = ENABLED.
|
||||
// The LS reads this at startup when creating its HTTP transport.
|
||||
// Without it, the LS ignores HTTPS_PROXY and API traffic bypasses MITM.
|
||||
// UserSettings proto: field 34 (varint) = 1 (DETECT_AND_USE_PROXY_ENABLED)
|
||||
// Tag: (34 << 3) | 0 = 272 → varint [0x90, 0x02]
|
||||
// Value: 1 → varint [0x01]
|
||||
let settings_path = format!("{app_data_dir}/user_settings.pb");
|
||||
let settings_bytes: &[u8] = &[0x90, 0x02, 0x01];
|
||||
if let Err(e) = std::fs::write(&settings_path, settings_bytes) {
|
||||
tracing::warn!("Failed to pre-seed user_settings.pb: {e}");
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(
|
||||
&settings_path,
|
||||
std::fs::Permissions::from_mode(0o0666),
|
||||
);
|
||||
}
|
||||
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
|
||||
}
|
||||
|
||||
// In headless mode, spawn a stub TCP listener to serve as the extension server.
|
||||
// The LS connects to this port and calls LanguageServerStarted — without it,
|
||||
// the LS never fully initializes and won't accept connections on its server_port.
|
||||
let _stub_listener = if headless {
|
||||
let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
|
||||
if stub_port == 0 {
|
||||
// Create a real listener so the LS can connect
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.map_err(|e| format!("Failed to bind stub extension server: {e}"))?;
|
||||
let actual_port = listener
|
||||
.local_addr()
|
||||
.map_err(|e| format!("Failed to get stub port: {e}"))?
|
||||
.port();
|
||||
info!(
|
||||
port = actual_port,
|
||||
"Stub extension server listening (headless)"
|
||||
);
|
||||
// Read OAuth state from Antigravity's state.vscdb if available.
|
||||
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
|
||||
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
|
||||
let (oauth_token, oauth_topic_bytes) = read_oauth_from_state_db()
|
||||
.map(|(token, topic)| {
|
||||
info!("Loaded OAuth token from Antigravity state.vscdb");
|
||||
(token, Some(topic))
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to env var / token file
|
||||
let token = std::env::var("ANTIGRAVITY_OAUTH_TOKEN")
|
||||
.ok()
|
||||
.filter(|s| !s.is_empty())
|
||||
.or_else(|| {
|
||||
let home = std::env::var("HOME").unwrap_or_default();
|
||||
let path = format!("{home}/.config/antigravity-proxy-token");
|
||||
std::fs::read_to_string(&path)
|
||||
.ok()
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if !token.is_empty() {
|
||||
info!("Loaded OAuth token from env/file (no refresh token — manual refresh needed)");
|
||||
} else {
|
||||
eprintln!("[headless] ⚠ No OAuth token found. Login to Antigravity first, or set ANTIGRAVITY_OAUTH_TOKEN");
|
||||
}
|
||||
(token, None)
|
||||
});
|
||||
let oauth_arc = std::sync::Arc::new(oauth_token);
|
||||
let topic_arc = std::sync::Arc::new(oauth_topic_bytes);
|
||||
// Spawn a thread to accept connections (just hold them open)
|
||||
let listener_clone = listener
|
||||
.try_clone()
|
||||
.map_err(|e| format!("Failed to clone stub listener: {e}"))?;
|
||||
std::thread::spawn(move || {
|
||||
for stream in listener_clone.incoming() {
|
||||
match stream {
|
||||
Ok(conn) => {
|
||||
let token = std::sync::Arc::clone(&oauth_arc);
|
||||
let topic = std::sync::Arc::clone(&topic_arc);
|
||||
// Handle each connection in its own thread
|
||||
std::thread::spawn(move || {
|
||||
stub_handle_connection(conn, &token, &topic);
|
||||
});
|
||||
}
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
});
|
||||
// Update the extension_server_port to the stub's port
|
||||
// (we need to use this in args below)
|
||||
Some((listener, actual_port))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Determine the actual extension_server_port to use
|
||||
let ext_port = if let Some((_, stub_port)) = &_stub_listener {
|
||||
stub_port.to_string()
|
||||
} else {
|
||||
main_config.extension_server_port.clone()
|
||||
};
|
||||
|
||||
// LS args — NO -standalone flag (it disables TCP listeners entirely)
|
||||
// NOTE: do NOT use -random_port — it overrides -server_port and the LS
|
||||
// would listen on a random port we can't discover.
|
||||
let args = vec![
|
||||
"-enable_lsp".to_string(),
|
||||
format!("-lsp_port={}", lsp_port),
|
||||
"-extension_server_port".to_string(),
|
||||
ext_port,
|
||||
"-csrf_token".to_string(),
|
||||
main_config.csrf.clone(),
|
||||
"-server_port".to_string(),
|
||||
port.to_string(),
|
||||
"-workspace_id".to_string(),
|
||||
format!("standalone_{ts}"),
|
||||
"-cloud_code_endpoint".to_string(),
|
||||
// When MITM is active, append the MITM port to the endpoint URL.
|
||||
// The LS's CodeAssistClient ignores HTTPS_PROXY (hardcoded Proxy:nil),
|
||||
// so we redirect at the DNS+port level instead:
|
||||
// 1. LD_PRELOAD hooks getaddrinfo() → 127.0.0.1 for API domains
|
||||
// 2. Custom port in URL → LS connects to 127.0.0.1:MITM_PORT
|
||||
// 3. MITM proxy intercepts the transparent TLS connection via SNI
|
||||
if let Some(mitm) = mitm_config {
|
||||
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
|
||||
let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
|
||||
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
|
||||
} else {
|
||||
"https://daily-cloudcode-pa.googleapis.com".to_string()
|
||||
},
|
||||
"-app_data_dir".to_string(),
|
||||
"antigravity-standalone".to_string(),
|
||||
"-gemini_dir".to_string(),
|
||||
gemini_dir,
|
||||
];
|
||||
|
||||
info!(port, "Spawning standalone LS");
|
||||
debug!(?args, "LS args");
|
||||
|
||||
// Build env vars for the LS process
|
||||
let mut env_vars: Vec<(String, String)> =
|
||||
vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
|
||||
|
||||
// If MITM is enabled, add SSL + proxy env vars
|
||||
if let Some(mitm) = mitm_config {
|
||||
// Go's SSL_CERT_FILE replaces the entire system cert pool, so we
|
||||
// need a combined bundle: system CAs + our MITM CA
|
||||
// Write to /tmp — accessible by antigravity-ls user
|
||||
// (user's ~/.config/ is not traversable by other UIDs)
|
||||
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
|
||||
let system_ca =
|
||||
std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
|
||||
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
|
||||
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
|
||||
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
|
||||
.map_err(|e| format!("Failed to write combined CA bundle: {e}"))?;
|
||||
// Make readable by antigravity-ls user
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(
|
||||
&combined_ca_path,
|
||||
std::fs::Permissions::from_mode(0o644),
|
||||
);
|
||||
}
|
||||
|
||||
info!(
|
||||
proxy = %mitm.proxy_addr,
|
||||
ca = %combined_ca_path,
|
||||
"Setting MITM env vars on standalone LS (combined CA bundle)"
|
||||
);
|
||||
env_vars.push(("SSL_CERT_FILE".into(), combined_ca_path));
|
||||
env_vars.push(("SSL_CERT_DIR".into(), "/dev/null".into()));
|
||||
env_vars.push(("NODE_EXTRA_CA_CERTS".into(), mitm.ca_cert_path.clone()));
|
||||
// Only set HTTPS_PROXY when iptables UID isolation is NOT available
|
||||
// OR when running in headless mode (no sudo at all).
|
||||
// With iptables, all outbound traffic is transparently redirected at the
|
||||
// kernel level — setting HTTPS_PROXY on top causes double-proxying.
|
||||
if headless || !has_ls_user() {
|
||||
// proxy_addr already includes the scheme (e.g. "http://127.0.0.1:8742")
|
||||
env_vars.push(("HTTPS_PROXY".into(), mitm.proxy_addr.clone()));
|
||||
env_vars.push(("HTTP_PROXY".into(), mitm.proxy_addr.clone()));
|
||||
|
||||
// LD_PRELOAD DNS redirect: hooks getaddrinfo() so Google API domains
|
||||
// resolve to 127.0.0.1. Combined with the port-modified endpoint URL,
|
||||
// this makes the LS connect to our MITM proxy for ALL API calls —
|
||||
// even the CodeAssistClient which has Proxy:nil hardcoded.
|
||||
let so_path = build_dns_redirect_so();
|
||||
if let Some(so) = so_path {
|
||||
info!(path = %so, "Enabling LD_PRELOAD DNS redirect for headless MITM");
|
||||
env_vars.push(("LD_PRELOAD".into(), so));
|
||||
env_vars.push((
|
||||
"DNS_REDIRECT_LOG".into(),
|
||||
"/tmp/antigravity-dns-redirect.log".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In headless mode, never use sudo — run as current user
|
||||
// In normal mode, use sudo if 'antigravity-ls' user exists
|
||||
let use_sudo = !headless && has_ls_user();
|
||||
|
||||
let mut cmd = if use_sudo {
|
||||
info!("Using UID isolation: spawning LS as 'antigravity-ls' user");
|
||||
let mut c = Command::new("sudo");
|
||||
c.args(["-n", "-u", LS_USER, "--", "/usr/bin/env"]);
|
||||
for (k, v) in &env_vars {
|
||||
c.arg(format!("{k}={v}"));
|
||||
}
|
||||
c.arg(LS_BINARY_PATH);
|
||||
c.args(&args);
|
||||
c
|
||||
} else {
|
||||
debug!("Spawning LS as current user");
|
||||
let mut c = Command::new(LS_BINARY_PATH);
|
||||
c.args(&args);
|
||||
for (k, v) in &env_vars {
|
||||
c.env(k, v);
|
||||
}
|
||||
c
|
||||
};
|
||||
|
||||
// Capture stderr for debugging — logs to /tmp so we can diagnose LS failures
|
||||
let stderr_file = std::fs::File::create("/tmp/antigravity-ls-debug.log")
|
||||
.map_err(|e| format!("Failed to create LS debug log: {e}"))?;
|
||||
cmd.stdin(Stdio::piped())
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::from(stderr_file));
|
||||
|
||||
let mut child = cmd
|
||||
.spawn()
|
||||
.map_err(|e| format!("Failed to spawn LS binary: {e}"))?;
|
||||
|
||||
// Feed init metadata via stdin, then close it
|
||||
if let Some(mut stdin) = child.stdin.take() {
|
||||
stdin
|
||||
.write_all(&metadata)
|
||||
.map_err(|e| format!("Failed to write init metadata to stdin: {e}"))?;
|
||||
// stdin drops here → EOF (LS handles this fine in non-standalone mode)
|
||||
}
|
||||
|
||||
info!(pid = child.id(), port, "Standalone LS spawned");
|
||||
|
||||
// When spawned via sudo, the child is the sudo process which exits after
|
||||
// launching the LS as the target user. We need the actual LS PID for cleanup.
|
||||
let ls_pid = if use_sudo {
|
||||
// Give sudo a moment to spawn the real process
|
||||
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||
// Find the LS process owned by antigravity-ls user
|
||||
find_ls_pid_for_user(LS_USER).ok()
|
||||
} else {
|
||||
Some(child.id())
|
||||
};
|
||||
|
||||
if let Some(pid) = ls_pid {
|
||||
info!(
|
||||
ls_pid = pid,
|
||||
sudo = use_sudo,
|
||||
"Discovered actual LS process"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(StandaloneLS {
|
||||
child,
|
||||
ls_pid,
|
||||
use_sudo,
|
||||
killed: false,
|
||||
port,
|
||||
csrf: main_config.csrf.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Wait for the standalone LS to be ready (accepting TCP connections).
|
||||
///
|
||||
/// Retries up to `max_attempts` times with a 1-second delay between each.
|
||||
pub async fn wait_ready(&mut self, max_attempts: u32) -> Result<(), String> {
|
||||
info!(port = self.port, "Waiting for standalone LS to be ready...");
|
||||
|
||||
for attempt in 1..=max_attempts {
|
||||
sleep(Duration::from_secs(1)).await;
|
||||
|
||||
// Check if the process is still alive
|
||||
match self.child.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
return Err(format!(
|
||||
"Standalone LS exited prematurely with status: {status}"
|
||||
));
|
||||
}
|
||||
Ok(None) => {} // still running
|
||||
Err(e) => {
|
||||
return Err(format!("Failed to check LS process status: {e}"));
|
||||
}
|
||||
}
|
||||
|
||||
// Simple TCP connect check — if the LS is listening, it's ready
|
||||
match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", self.port)).await {
|
||||
Ok(_) => {
|
||||
info!(attempt, "Standalone LS is ready (accepting connections)");
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
debug!(attempt, error = %e, "LS not ready yet");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(format!(
|
||||
"Standalone LS failed to become ready after {max_attempts} attempts on port {}",
|
||||
self.port
|
||||
))
|
||||
}
|
||||
|
||||
/// Check if the child process is still running.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_alive(&mut self) -> bool {
|
||||
matches!(self.child.try_wait(), Ok(None))
|
||||
}
|
||||
|
||||
/// Kill the standalone LS process.
|
||||
pub fn kill(&mut self) {
|
||||
if self.killed {
|
||||
return;
|
||||
}
|
||||
self.killed = true;
|
||||
info!("Killing standalone LS");
|
||||
|
||||
if self.use_sudo {
|
||||
// The child is sudo which already exited. Kill the actual LS.
|
||||
if let Some(pid) = self.ls_pid {
|
||||
info!(pid, "Killing LS process via sudo -u {}", LS_USER);
|
||||
// Run kill AS the antigravity-ls user (same UID can signal)
|
||||
let ok = std::process::Command::new("sudo")
|
||||
.args(["-n", "-u", LS_USER, "kill", "-TERM", &pid.to_string()])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status()
|
||||
.map(|s| s.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if ok {
|
||||
std::thread::sleep(std::time::Duration::from_millis(500));
|
||||
// Force kill if still alive
|
||||
let _ = std::process::Command::new("sudo")
|
||||
.args(["-n", "-u", LS_USER, "kill", "-KILL", &pid.to_string()])
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.status();
|
||||
} else {
|
||||
// Fallback: try with root sudo, then cleanup
|
||||
info!("sudo -u kill failed, trying fallback cleanup");
|
||||
cleanup_orphaned_ls();
|
||||
}
|
||||
} else {
|
||||
// No PID recorded, try blanket cleanup
|
||||
cleanup_orphaned_ls();
|
||||
}
|
||||
} else {
|
||||
let _ = self.child.kill();
|
||||
let _ = self.child.wait();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for StandaloneLS {
|
||||
fn drop(&mut self) {
|
||||
self.kill();
|
||||
}
|
||||
}
|
||||
330
src/standalone/stub.rs
Normal file
330
src/standalone/stub.rs
Normal file
@@ -0,0 +1,330 @@
|
||||
//! Stub extension server — handles LS connections in headless mode.
|
||||
|
||||
use crate::proto::wire::{encode_proto_string, encode_varint, extract_proto_string};
|
||||
use std::io::{BufRead, BufReader, Read, Write};
|
||||
|
||||
/// Handle a single connection from the LS to the stub extension server.
|
||||
///
|
||||
/// The LS uses Connect RPC (HTTP/1.1, ServerStream) to call ExtensionServerService methods.
|
||||
/// ALL methods are ServerStream — responses use Connect streaming envelope framing:
|
||||
/// [0x00 | len(4) | protobuf_data]... (0+ data messages)
|
||||
/// [0x02 | len(4) | json_trailer] (exactly 1 end-of-stream)
|
||||
///
|
||||
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
|
||||
/// If we immediately close it, the LS reconnects in a tight loop and never
|
||||
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
|
||||
pub fn stub_handle_connection(
|
||||
conn: std::net::TcpStream,
|
||||
oauth_token: &str,
|
||||
oauth_topic_bytes: &Option<Vec<u8>>,
|
||||
) {
|
||||
let mut reader = BufReader::new(match conn.try_clone() {
|
||||
Ok(c) => c,
|
||||
Err(_) => return,
|
||||
});
|
||||
let mut writer = conn;
|
||||
|
||||
// Read the HTTP request line
|
||||
let mut request_line = String::new();
|
||||
match reader.read_line(&mut request_line) {
|
||||
Ok(0) | Err(_) => return,
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Extract method path for logging
|
||||
let path = request_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.unwrap_or("/unknown")
|
||||
.to_string();
|
||||
|
||||
// Read headers
|
||||
let mut content_len: usize = 0;
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
if reader.read_line(&mut line).unwrap_or(0) == 0 {
|
||||
return;
|
||||
}
|
||||
if line.trim().is_empty() {
|
||||
break;
|
||||
}
|
||||
if line.to_lowercase().starts_with("content-length:") {
|
||||
content_len = line
|
||||
.split(':')
|
||||
.nth(1)
|
||||
.and_then(|v| v.trim().parse().ok())
|
||||
.unwrap_or(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Read body
|
||||
let mut body = Vec::new();
|
||||
if content_len > 0 {
|
||||
body.resize(content_len, 0u8);
|
||||
if Read::read_exact(&mut reader, &mut body).is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Long-lived streams ──────────────────────────────────────────────
|
||||
// SubscribeToUnifiedStateSyncTopic must stay open — the LS subscribes
|
||||
// once and expects updates (OAuth, settings) delivered over this stream.
|
||||
// If we close immediately, the LS reconnects in a tight loop (~30/sec).
|
||||
if path.contains("SubscribeToUnifiedStateSyncTopic") {
|
||||
handle_subscribe_stream(&mut writer, &body, &path, oauth_token, oauth_topic_bytes);
|
||||
return;
|
||||
}
|
||||
|
||||
// ─── Short-lived methods (everything else) ───────────────────────────
|
||||
handle_short_lived(&mut writer, &body, &path, oauth_token);
|
||||
}
|
||||
|
||||
/// Handle the long-lived SubscribeToUnifiedStateSyncTopic stream.
|
||||
fn handle_subscribe_stream(
|
||||
writer: &mut std::net::TcpStream,
|
||||
body: &[u8],
|
||||
path: &str,
|
||||
oauth_token: &str,
|
||||
oauth_topic_bytes: &Option<Vec<u8>>,
|
||||
) {
|
||||
// Parse the request body to extract the topic name.
|
||||
// Connect envelope: [flag(1)] [len(4)] [proto(N)]
|
||||
let proto_body = if body.len() > 5 {
|
||||
&body[5..]
|
||||
} else {
|
||||
&body[..]
|
||||
};
|
||||
|
||||
// SubscribeToUnifiedStateSyncTopicRequest { string topic = 1; }
|
||||
let mut topic_name = String::new();
|
||||
let mut i = 0;
|
||||
while i < proto_body.len() {
|
||||
let tag_byte = proto_body[i];
|
||||
let field_num = tag_byte >> 3;
|
||||
let wire_type = tag_byte & 0x07;
|
||||
i += 1;
|
||||
if wire_type == 2 && i < proto_body.len() {
|
||||
let len = proto_body[i] as usize;
|
||||
i += 1;
|
||||
if i + len <= proto_body.len() {
|
||||
if field_num == 1 {
|
||||
topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
|
||||
}
|
||||
i += len;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!("[stub-ext] STREAM → {path} topic={topic_name:?}");
|
||||
|
||||
// Build initial_state bytes
|
||||
let initial_state_bytes = build_initial_state(&topic_name, oauth_token, oauth_topic_bytes);
|
||||
|
||||
// Helper: wrap protobuf bytes in a Connect data envelope
|
||||
let make_envelope = |proto: &[u8]| -> Vec<u8> {
|
||||
let mut env = Vec::with_capacity(5 + proto.len());
|
||||
env.push(0x00u8); // data flag
|
||||
env.extend_from_slice(&(proto.len() as u32).to_be_bytes());
|
||||
env.extend_from_slice(proto);
|
||||
env
|
||||
};
|
||||
|
||||
// Helper: write a chunk
|
||||
let send_chunk = |w: &mut std::net::TcpStream, data: &[u8]| -> bool {
|
||||
let hdr = format!("{:x}\r\n", data.len());
|
||||
w.write_all(hdr.as_bytes()).is_ok()
|
||||
&& w.write_all(data).is_ok()
|
||||
&& w.write_all(b"\r\n").is_ok()
|
||||
&& w.flush().is_ok()
|
||||
};
|
||||
|
||||
// Build UnifiedStateSyncUpdate { initial_state = initial_state_bytes }
|
||||
let mut initial_proto = Vec::new();
|
||||
initial_proto.push(0x0A); // field 1 (initial_state), LEN
|
||||
encode_varint(&mut initial_proto, initial_state_bytes.len() as u64);
|
||||
initial_proto.extend_from_slice(&initial_state_bytes);
|
||||
|
||||
let initial_env = make_envelope(&initial_proto);
|
||||
|
||||
let header = format!(
|
||||
"HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: application/connect+proto\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n"
|
||||
);
|
||||
if writer.write_all(header.as_bytes()).is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
if !send_chunk(writer, &initial_env) {
|
||||
return;
|
||||
}
|
||||
eprintln!(
|
||||
"[stub-ext] STREAM → sent initial_state ({} bytes)",
|
||||
initial_state_bytes.len()
|
||||
);
|
||||
|
||||
// Keep the stream alive with periodic valid messages.
|
||||
// The LS has a ~10s read timeout on streams. After the initial_state,
|
||||
// the LS only accepts AppliedUpdate (field 2 in the oneof).
|
||||
// We send an empty AppliedUpdate {} every 5s as keepalive.
|
||||
let keepalive_proto: &[u8] = &[0x12, 0x00]; // field 2 (applied_update), LEN=0
|
||||
let keepalive_env = make_envelope(keepalive_proto);
|
||||
loop {
|
||||
std::thread::sleep(std::time::Duration::from_secs(5));
|
||||
if !send_chunk(writer, &keepalive_env) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the initial_state bytes for a USS topic subscription.
|
||||
fn build_initial_state(
|
||||
topic_name: &str,
|
||||
oauth_token: &str,
|
||||
oauth_topic_bytes: &Option<Vec<u8>>,
|
||||
) -> Vec<u8> {
|
||||
let mut initial_state_bytes = Vec::new();
|
||||
|
||||
if topic_name == "uss-oauth" {
|
||||
if let Some(topic_bytes) = oauth_topic_bytes {
|
||||
// Use the exact Topic proto from Antigravity's state.vscdb.
|
||||
initial_state_bytes = topic_bytes.clone();
|
||||
eprintln!(
|
||||
"[stub-ext] using state.vscdb topic ({} bytes)",
|
||||
topic_bytes.len()
|
||||
);
|
||||
} else if !oauth_token.is_empty() {
|
||||
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
|
||||
let mut oauth_proto = Vec::new();
|
||||
// field 1 (access_token), LEN
|
||||
oauth_proto.push(0x0A);
|
||||
encode_varint(&mut oauth_proto, oauth_token.len() as u64);
|
||||
oauth_proto.extend_from_slice(oauth_token.as_bytes());
|
||||
// field 2 (token_type), LEN
|
||||
let token_type = b"Bearer";
|
||||
oauth_proto.push(0x12);
|
||||
encode_varint(&mut oauth_proto, token_type.len() as u64);
|
||||
oauth_proto.extend_from_slice(token_type);
|
||||
// field 4 (expiry) = Timestamp { seconds = 4_102_444_800 } (year 2099-12-31)
|
||||
let mut ts_proto = Vec::new();
|
||||
ts_proto.push(0x08); // field 1 (seconds), varint
|
||||
encode_varint(&mut ts_proto, 4_102_444_800u64);
|
||||
oauth_proto.push(0x22); // field 4 (expiry), LEN
|
||||
encode_varint(&mut oauth_proto, ts_proto.len() as u64);
|
||||
oauth_proto.extend_from_slice(&ts_proto);
|
||||
|
||||
use base64::Engine;
|
||||
let b64_value = base64::engine::general_purpose::STANDARD.encode(&oauth_proto);
|
||||
|
||||
// Build Row { value = b64_value, e_tag = 1 }
|
||||
let mut row = Vec::new();
|
||||
row.push(0x0A); // field 1 (value), LEN
|
||||
encode_varint(&mut row, b64_value.len() as u64);
|
||||
row.extend_from_slice(b64_value.as_bytes());
|
||||
row.push(0x10); // field 2 (e_tag), varint
|
||||
row.push(0x01);
|
||||
|
||||
// Build map entry: { key = "oauthTokenInfoSentinelKey", value = row }
|
||||
let key_str = b"oauthTokenInfoSentinelKey";
|
||||
let mut map_entry = Vec::new();
|
||||
map_entry.push(0x0A); // field 1 (key), LEN
|
||||
encode_varint(&mut map_entry, key_str.len() as u64);
|
||||
map_entry.extend_from_slice(key_str);
|
||||
map_entry.push(0x12); // field 2 (value = Row), LEN
|
||||
encode_varint(&mut map_entry, row.len() as u64);
|
||||
map_entry.extend_from_slice(&row);
|
||||
|
||||
// Build Topic { data = [map_entry] }
|
||||
initial_state_bytes.push(0x0A); // field 1 (data map), LEN
|
||||
encode_varint(&mut initial_state_bytes, map_entry.len() as u64);
|
||||
initial_state_bytes.extend_from_slice(&map_entry);
|
||||
}
|
||||
}
|
||||
|
||||
initial_state_bytes
|
||||
}
|
||||
|
||||
/// Handle short-lived extension server methods.
|
||||
fn handle_short_lived(
|
||||
writer: &mut std::net::TcpStream,
|
||||
body: &[u8],
|
||||
path: &str,
|
||||
oauth_token: &str,
|
||||
) {
|
||||
let is_noisy = path.contains("GetChromeDevtoolsMcpUrl")
|
||||
|| path.contains("FetchMCPAuthToken")
|
||||
|| path.contains("PushUnifiedStateSyncUpdate");
|
||||
if !is_noisy {
|
||||
eprintln!("[stub-ext] 200 OK → {path}");
|
||||
}
|
||||
|
||||
// Build Connect streaming response body with proper envelope framing.
|
||||
let mut envelope = Vec::new();
|
||||
|
||||
if path.contains("GetSecretValue") {
|
||||
// Parse request body to extract the key (protobuf: field 1 = key, string)
|
||||
let key = extract_proto_string(body, 1).unwrap_or_default();
|
||||
eprintln!("[stub-ext] ← GetSecretValue key={key:?}");
|
||||
|
||||
if !oauth_token.is_empty() {
|
||||
// Build protobuf: GetSecretValueResponse { string value = 1 }
|
||||
let proto = encode_proto_string(1, oauth_token.as_bytes());
|
||||
eprintln!(
|
||||
"[stub-ext] → serving token ({} bytes) for key={key:?}",
|
||||
oauth_token.len()
|
||||
);
|
||||
|
||||
// Data envelope: flag=0x00, length, data
|
||||
envelope.push(0x00u8);
|
||||
envelope.extend_from_slice(&(proto.len() as u32).to_be_bytes());
|
||||
envelope.extend_from_slice(&proto);
|
||||
} else {
|
||||
eprintln!("[stub-ext] ⚠ no OAuth token available for key={key:?}");
|
||||
}
|
||||
} else if path.contains("StoreSecretValue") {
|
||||
// Parse and log what the LS is storing (for debugging)
|
||||
let key = extract_proto_string(body, 1).unwrap_or_default();
|
||||
let value = extract_proto_string(body, 2).unwrap_or_default();
|
||||
let val_preview = if value.len() > 32 {
|
||||
format!("{}...", &value[..32])
|
||||
} else {
|
||||
value
|
||||
};
|
||||
eprintln!("[stub-ext] ← StoreSecretValue key={key:?} value={val_preview:?}");
|
||||
}
|
||||
|
||||
if path.contains("PushUnifiedStateSyncUpdate") {
|
||||
// Unary proto — respond with empty PushUnifiedStateSyncUpdateResponse (0 bytes body)
|
||||
let header = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: application/proto\r\n\
|
||||
Content-Length: 0\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n";
|
||||
let _ = writer.write_all(header.as_bytes());
|
||||
let _ = writer.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// End-of-stream envelope: flag=0x02, length=2, data="{}"
|
||||
envelope.push(0x02u8);
|
||||
envelope.extend_from_slice(&2u32.to_be_bytes());
|
||||
envelope.extend_from_slice(b"{}");
|
||||
|
||||
// Respond with 200 OK + Connection: close (one request per connection)
|
||||
let header = format!(
|
||||
"HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: application/connect+proto\r\n\
|
||||
Content-Length: {}\r\n\
|
||||
Connection: close\r\n\
|
||||
\r\n",
|
||||
envelope.len()
|
||||
);
|
||||
let _ = writer.write_all(header.as_bytes());
|
||||
let _ = writer.write_all(&envelope);
|
||||
let _ = writer.flush();
|
||||
}
|
||||
Reference in New Issue
Block a user