fix: block ALL LS follow-up requests across connections
Move the in-flight blocking check to the top of the LLM request flow, BEFORE request modification. This catches follow-ups on ALL connections (the LS opens multiple parallel TLS connections). Only the very first modified request reaches Google — all others get fake STOP responses. Previously, each new connection independently allowed one request through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
@@ -10,9 +10,11 @@ use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
|
||||
use super::polling::{
|
||||
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
|
||||
};
|
||||
use super::types::*;
|
||||
use super::util::{err_response, upstream_err_response, now_unix};
|
||||
use super::util::{err_response, now_unix, upstream_err_response};
|
||||
use super::AppState;
|
||||
|
||||
/// Extract a conversation/session ID from a flexible JSON value.
|
||||
@@ -33,7 +35,8 @@ fn system_fingerprint() -> String {
|
||||
/// Build a streaming chunk JSON with all required OpenAI fields.
|
||||
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
|
||||
fn chunk_json(
|
||||
id: &str, model: &str,
|
||||
id: &str,
|
||||
model: &str,
|
||||
choices: serde_json::Value,
|
||||
usage: Option<serde_json::Value>,
|
||||
) -> String {
|
||||
@@ -53,7 +56,11 @@ fn chunk_json(
|
||||
}
|
||||
|
||||
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
|
||||
fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str>) -> serde_json::Value {
|
||||
fn chunk_choice(
|
||||
index: u32,
|
||||
delta: serde_json::Value,
|
||||
finish_reason: Option<&str>,
|
||||
) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"index": index,
|
||||
"delta": delta,
|
||||
@@ -70,7 +77,9 @@ fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str
|
||||
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
||||
match stop_reason {
|
||||
Some("MAX_TOKENS") => "length",
|
||||
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => "content_filter",
|
||||
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
|
||||
"content_filter"
|
||||
}
|
||||
_ => "stop",
|
||||
}
|
||||
}
|
||||
@@ -84,7 +93,9 @@ fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
||||
/// sends the entire messages array to the model.
|
||||
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
|
||||
// Extract image from last user message content array
|
||||
let image = messages.iter().rev()
|
||||
let image = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|m| m.role == "user")
|
||||
.and_then(|m| super::util::extract_first_image(&m.content));
|
||||
// Always build the full conversation
|
||||
@@ -141,10 +152,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
|
||||
if let Some(func) = tc.get("function") {
|
||||
let name = func["name"].as_str().unwrap_or("unknown");
|
||||
let args = func["arguments"].as_str().unwrap_or("{}");
|
||||
parts.push(format!(
|
||||
"[Tool call: {}({})]",
|
||||
name, args
|
||||
));
|
||||
parts.push(format!("[Tool call: {}({})]", name, args));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -153,10 +161,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
|
||||
let text = extract_message_text(&msg.content);
|
||||
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
|
||||
if !text.is_empty() {
|
||||
parts.push(format!(
|
||||
"[Tool result ({})]:\n{}",
|
||||
tool_id, text
|
||||
));
|
||||
parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
@@ -202,7 +207,10 @@ pub(crate) async fn handle_completions(
|
||||
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
|
||||
state.mitm_store.set_tool_config(gemini_config).await;
|
||||
}
|
||||
info!(count = tools.len(), "Completions: stored client tools for MITM injection");
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Completions: stored client tools for MITM injection"
|
||||
);
|
||||
} else {
|
||||
state.mitm_store.clear_tools().await;
|
||||
}
|
||||
@@ -239,10 +247,15 @@ pub(crate) async fn handle_completions(
|
||||
google_search: body.web_search,
|
||||
};
|
||||
// Only store if at least one param is set
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
||||
|| gp.frequency_penalty.is_some() || gp.presence_penalty.is_some()
|
||||
|| gp.reasoning_effort.is_some() || gp.stop_sequences.is_some()
|
||||
|| gp.response_mime_type.is_some() || gp.response_schema.is_some()
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.frequency_penalty.is_some()
|
||||
|| gp.presence_penalty.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
@@ -306,12 +319,13 @@ pub(crate) async fn handle_completions(
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state.mitm_store.set_pending_image(
|
||||
crate::mitm::store::PendingImage {
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
match state
|
||||
.backend
|
||||
@@ -346,7 +360,10 @@ pub(crate) async fn handle_completions(
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
|
||||
let include_usage = body.stream_options.as_ref().map_or(false, |o| o.include_usage);
|
||||
let include_usage = body
|
||||
.stream_options
|
||||
.as_ref()
|
||||
.map_or(false, |o| o.include_usage);
|
||||
|
||||
if body.stream {
|
||||
chat_completions_stream(
|
||||
@@ -374,11 +391,17 @@ pub(crate) async fn handle_completions(
|
||||
match state.backend.create_cascade().await {
|
||||
Ok(cid) => {
|
||||
// Send the same message on each extra cascade
|
||||
match state.backend.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()).await {
|
||||
match state
|
||||
.backend
|
||||
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
|
||||
.await
|
||||
{
|
||||
Ok((200, _)) => {
|
||||
let bg = Arc::clone(&state.backend);
|
||||
let cid2 = cid.clone();
|
||||
tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; });
|
||||
tokio::spawn(async move {
|
||||
let _ = bg.update_annotations(&cid2).await;
|
||||
});
|
||||
extra_cascade_ids.push(cid);
|
||||
}
|
||||
_ => {} // Skip failed cascades
|
||||
@@ -420,7 +443,12 @@ pub(crate) async fn handle_completions(
|
||||
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
|
||||
);
|
||||
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
|
||||
(mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens)
|
||||
(
|
||||
mu.input_tokens,
|
||||
mu.output_tokens,
|
||||
mu.cache_read_input_tokens,
|
||||
mu.thinking_output_tokens,
|
||||
)
|
||||
} else if let Some(u) = &result.usage {
|
||||
(u.input_tokens, u.output_tokens, 0, 0)
|
||||
} else {
|
||||
@@ -874,15 +902,22 @@ async fn chat_completions_sync(
|
||||
None => state.mitm_store.take_usage("_latest").await,
|
||||
};
|
||||
|
||||
let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
||||
let finish_reason =
|
||||
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
||||
|
||||
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm {
|
||||
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens)
|
||||
} else if let Some(u) = &result.usage {
|
||||
(u.input_tokens, u.output_tokens, 0, 0)
|
||||
} else {
|
||||
(0, 0, 0, 0)
|
||||
};
|
||||
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
|
||||
if let Some(ref mitm_usage) = mitm {
|
||||
(
|
||||
mitm_usage.input_tokens,
|
||||
mitm_usage.output_tokens,
|
||||
mitm_usage.cache_read_input_tokens,
|
||||
mitm_usage.thinking_output_tokens,
|
||||
)
|
||||
} else if let Some(u) = &result.usage {
|
||||
(u.input_tokens, u.output_tokens, 0, 0)
|
||||
} else {
|
||||
(0, 0, 0, 0)
|
||||
};
|
||||
|
||||
// Build message object, including reasoning_content if thinking is present
|
||||
let mut message = serde_json::json!({
|
||||
|
||||
@@ -15,7 +15,9 @@ use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
|
||||
use super::polling::{
|
||||
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
|
||||
};
|
||||
use super::util::{err_response, upstream_err_response};
|
||||
use super::AppState;
|
||||
use crate::mitm::store::PendingToolResult;
|
||||
@@ -84,7 +86,9 @@ async fn build_usage_metadata(
|
||||
store: &crate::mitm::store::MitmStore,
|
||||
cascade_id: &str,
|
||||
) -> serde_json::Value {
|
||||
let usage = store.take_usage(cascade_id).await
|
||||
let usage = store
|
||||
.take_usage(cascade_id)
|
||||
.await
|
||||
.or(store.take_usage("_latest").await);
|
||||
if let Some(usage) = usage {
|
||||
serde_json::json!({
|
||||
@@ -152,13 +156,12 @@ pub(crate) async fn handle_gemini(
|
||||
// Gemini-native inlineData format
|
||||
if image.is_none() {
|
||||
if let Some(inline) = obj.get("inlineData") {
|
||||
if let (Some(mime), Some(b64)) = (
|
||||
inline["mimeType"].as_str(),
|
||||
inline["data"].as_str(),
|
||||
) {
|
||||
if let Some(img) = super::util::parse_data_uri(
|
||||
&format!("data:{mime};base64,{b64}")
|
||||
) {
|
||||
if let (Some(mime), Some(b64)) =
|
||||
(inline["mimeType"].as_str(), inline["data"].as_str())
|
||||
{
|
||||
if let Some(img) = super::util::parse_data_uri(&format!(
|
||||
"data:{mime};base64,{b64}"
|
||||
)) {
|
||||
image = Some(img);
|
||||
}
|
||||
}
|
||||
@@ -194,7 +197,10 @@ pub(crate) async fn handle_gemini(
|
||||
if let Some(ref tools) = body.tools {
|
||||
if !tools.is_empty() {
|
||||
state.mitm_store.set_tools(tools.clone()).await;
|
||||
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Stored Gemini-native tools for MITM injection"
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(ref config) = body.tool_config {
|
||||
@@ -207,13 +213,19 @@ pub(crate) async fn handle_gemini(
|
||||
if let Some(fr) = r.get("functionResponse") {
|
||||
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
||||
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
||||
state.mitm_store.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: response,
|
||||
}).await;
|
||||
state
|
||||
.mitm_store
|
||||
.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: response,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
|
||||
info!(
|
||||
count = results.len(),
|
||||
"Stored Gemini-native tool results for MITM injection"
|
||||
);
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
@@ -232,9 +244,13 @@ pub(crate) async fn handle_gemini(
|
||||
response_schema: None,
|
||||
google_search: body.google_search,
|
||||
};
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.top_k.is_some()
|
||||
|| gp.max_output_tokens.is_some() || gp.stop_sequences.is_some()
|
||||
|| gp.reasoning_effort.is_some() || gp.google_search
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.top_k.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.stop_sequences.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
@@ -277,12 +293,13 @@ pub(crate) async fn handle_gemini(
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state.mitm_store.set_pending_image(
|
||||
crate::mitm::store::PendingImage {
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
match state
|
||||
.backend
|
||||
@@ -372,7 +389,11 @@ async fn gemini_sync(
|
||||
|
||||
// Check for completed text response
|
||||
if state.mitm_store.is_response_complete() {
|
||||
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
||||
let text = state
|
||||
.mitm_store
|
||||
.take_response_text()
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let thinking = state.mitm_store.take_thinking_text().await;
|
||||
|
||||
// Guard against stale response_complete with no data
|
||||
|
||||
@@ -44,7 +44,6 @@ pub fn router(state: Arc<AppState>) -> Router {
|
||||
post(completions::handle_completions),
|
||||
)
|
||||
.route("/v1/gemini", post(gemini::handle_gemini))
|
||||
|
||||
.route("/v1/models", get(handle_models))
|
||||
.route("/v1/sessions", get(handle_list_sessions))
|
||||
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||
@@ -106,9 +105,7 @@ async fn handle_models() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({"object": "list", "data": models}))
|
||||
}
|
||||
|
||||
async fn handle_list_sessions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
async fn handle_list_sessions(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||
let sessions = state.sessions.list_sessions().await;
|
||||
Json(serde_json::json!({"sessions": sessions}))
|
||||
}
|
||||
@@ -155,9 +152,7 @@ async fn handle_set_token(
|
||||
)
|
||||
}
|
||||
|
||||
async fn handle_usage(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
async fn handle_usage(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||
let stats = state.mitm_store.stats().await;
|
||||
Json(serde_json::json!({
|
||||
"mitm": {
|
||||
@@ -174,9 +169,7 @@ async fn handle_usage(
|
||||
}))
|
||||
}
|
||||
|
||||
async fn handle_quota(
|
||||
State(state): State<Arc<AppState>>,
|
||||
) -> Json<serde_json::Value> {
|
||||
async fn handle_quota(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||
let snap = state.quota_store.snapshot().await;
|
||||
Json(serde_json::to_value(snap).unwrap_or_default())
|
||||
}
|
||||
|
||||
@@ -84,14 +84,8 @@ pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUs
|
||||
return Some(ModelUsage {
|
||||
input_tokens: input,
|
||||
output_tokens: output,
|
||||
api_provider: usage["apiProvider"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
model: usage["model"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
api_provider: usage["apiProvider"].as_str().unwrap_or("").to_string(),
|
||||
model: usage["model"].as_str().unwrap_or("").to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -263,23 +257,36 @@ pub(crate) async fn poll_for_response(
|
||||
} else {
|
||||
info!(
|
||||
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
|
||||
elapsed, text.len(),
|
||||
thinking.as_ref().map_or(String::new(), |t| format!(", thinking: {} chars", t.len())),
|
||||
if thinking_signature.is_some() { ", has sig" } else { "" }
|
||||
elapsed,
|
||||
text.len(),
|
||||
thinking.as_ref().map_or(String::new(), |t| format!(
|
||||
", thinking: {} chars",
|
||||
t.len()
|
||||
)),
|
||||
if thinking_signature.is_some() {
|
||||
", has sig"
|
||||
} else {
|
||||
""
|
||||
}
|
||||
);
|
||||
}
|
||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
|
||||
return PollResult {
|
||||
text,
|
||||
usage,
|
||||
thinking_signature,
|
||||
thinking,
|
||||
thinking_duration,
|
||||
upstream_error: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: check trajectory IDLE status (catches edge cases)
|
||||
// Only check every 5th poll to reduce network calls
|
||||
if step_count > 4 && step_count % 5 == 0 {
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
|
||||
{
|
||||
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await {
|
||||
if ts == 200 {
|
||||
let run_status =
|
||||
td["status"].as_str().unwrap_or("");
|
||||
let run_status = td["status"].as_str().unwrap_or("");
|
||||
if run_status.contains("IDLE") {
|
||||
let text = extract_response_text(steps);
|
||||
if !text.is_empty() {
|
||||
@@ -293,7 +300,14 @@ pub(crate) async fn poll_for_response(
|
||||
elapsed,
|
||||
text.len()
|
||||
);
|
||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
|
||||
return PollResult {
|
||||
text,
|
||||
usage,
|
||||
thinking_signature,
|
||||
thinking,
|
||||
thinking_duration,
|
||||
upstream_error: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -14,12 +14,15 @@ use std::sync::Arc;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content};
|
||||
use super::polling::{
|
||||
extract_model_usage, extract_response_text, extract_thinking_content,
|
||||
extract_thinking_signature, is_response_done, poll_for_response,
|
||||
};
|
||||
use super::types::*;
|
||||
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event};
|
||||
use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
|
||||
use super::AppState;
|
||||
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
|
||||
use crate::mitm::store::PendingToolResult;
|
||||
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
|
||||
|
||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||
|
||||
@@ -35,7 +38,11 @@ struct ToolResultInput {
|
||||
fn extract_responses_input(
|
||||
input: &serde_json::Value,
|
||||
instructions: Option<&str>,
|
||||
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) {
|
||||
) -> (
|
||||
String,
|
||||
Vec<ToolResultInput>,
|
||||
Option<crate::proto::ImageData>,
|
||||
) {
|
||||
let mut tool_results: Vec<ToolResultInput> = Vec::new();
|
||||
let mut image: Option<crate::proto::ImageData> = None;
|
||||
|
||||
@@ -45,10 +52,9 @@ fn extract_responses_input(
|
||||
// Check for function_call_output items
|
||||
for item in items {
|
||||
if item["type"].as_str() == Some("function_call_output") {
|
||||
if let (Some(call_id), Some(output)) = (
|
||||
item["call_id"].as_str(),
|
||||
item["output"].as_str(),
|
||||
) {
|
||||
if let (Some(call_id), Some(output)) =
|
||||
(item["call_id"].as_str(), item["output"].as_str())
|
||||
{
|
||||
tool_results.push(ToolResultInput {
|
||||
call_id: call_id.to_string(),
|
||||
output: output.to_string(),
|
||||
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
|
||||
);
|
||||
}
|
||||
|
||||
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
let (user_text, tool_results, image) =
|
||||
extract_responses_input(&body.input, body.instructions.as_deref());
|
||||
|
||||
// Handle tool result submission (function_call_output in input)
|
||||
let is_tool_result_turn = !tool_results.is_empty();
|
||||
if is_tool_result_turn {
|
||||
for tr in &tool_results {
|
||||
// Look up function name from call_id
|
||||
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
|
||||
let name = state
|
||||
.mitm_store
|
||||
.lookup_call_id(&tr.call_id)
|
||||
.await
|
||||
.unwrap_or_else(|| "unknown_function".to_string());
|
||||
|
||||
// Parse the output as JSON, fall back to string wrapper
|
||||
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
||||
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
||||
|
||||
state.mitm_store.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
}).await;
|
||||
state
|
||||
.mitm_store
|
||||
.add_tool_result(PendingToolResult {
|
||||
name,
|
||||
result: result_value,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
info!(
|
||||
count = tool_results.len(),
|
||||
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
|
||||
let gemini_tools = openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
state.mitm_store.set_tools(gemini_tools).await;
|
||||
info!(count = tools.len(), "Stored client tools for MITM injection");
|
||||
info!(
|
||||
count = tools.len(),
|
||||
"Stored client tools for MITM injection"
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(ref choice) = body.tool_choice {
|
||||
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
|
||||
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
||||
if fmt_type == "json_schema" {
|
||||
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
|
||||
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
|
||||
let schema = text_val["format"]["schema"]
|
||||
.as_object()
|
||||
.map(|o| serde_json::Value::Object(o.clone()));
|
||||
let strict = text_val["format"]["strict"].as_bool();
|
||||
let tf = TextFormat {
|
||||
format: TextFormatInner {
|
||||
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
|
||||
response_schema,
|
||||
google_search: has_web_search,
|
||||
};
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some() || gp.google_search
|
||||
if gp.temperature.is_some()
|
||||
|| gp.top_p.is_some()
|
||||
|| gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some()
|
||||
|| gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some()
|
||||
|| gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
|
||||
}
|
||||
}
|
||||
|
||||
let response_id = format!(
|
||||
"resp_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Session/conversation management
|
||||
let session_id_str = extract_conversation_id(&body.conversation);
|
||||
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
|
||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||
if let Some(ref img) = image {
|
||||
use base64::Engine;
|
||||
state.mitm_store.set_pending_image(
|
||||
crate::mitm::store::PendingImage {
|
||||
state
|
||||
.mitm_store
|
||||
.set_pending_image(crate::mitm::store::PendingImage {
|
||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||
mime_type: img.mime_type.clone(),
|
||||
}
|
||||
).await;
|
||||
})
|
||||
.await;
|
||||
}
|
||||
match state
|
||||
.backend
|
||||
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
|
||||
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||
max_tool_calls: body.max_tool_calls,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
|
||||
tool_choice: body
|
||||
.tool_choice
|
||||
.clone()
|
||||
.unwrap_or(serde_json::json!("auto")),
|
||||
tools: body.tools.clone().unwrap_or_default(),
|
||||
text_format,
|
||||
};
|
||||
|
||||
if body.stream {
|
||||
handle_responses_stream(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
state,
|
||||
response_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
req_params,
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
handle_responses_sync(
|
||||
state, response_id, model_name.to_string(), cascade_id,
|
||||
body.timeout, req_params,
|
||||
state,
|
||||
response_id,
|
||||
model_name.to_string(),
|
||||
cascade_id,
|
||||
body.timeout,
|
||||
req_params,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -485,7 +516,9 @@ async fn usage_from_poll(
|
||||
if let Some(u) = mitm_store.peek_usage(key).await {
|
||||
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
|
||||
// Call 2 hasn't arrived yet — wait briefly for the merge
|
||||
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge...");
|
||||
tracing::debug!(
|
||||
"MITM: thinking tokens found but no text, waiting for summary merge..."
|
||||
);
|
||||
for _ in 0..10 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
if let Some(u2) = mitm_store.peek_usage(key).await {
|
||||
@@ -526,13 +559,18 @@ async fn usage_from_poll(
|
||||
|
||||
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
||||
if let Some(u) = model_usage {
|
||||
return (Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: u.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
}, None);
|
||||
return (
|
||||
Usage {
|
||||
input_tokens: u.input_tokens,
|
||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||
output_tokens: u.output_tokens,
|
||||
output_tokens_details: OutputTokensDetails {
|
||||
reasoning_tokens: 0,
|
||||
},
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
},
|
||||
None,
|
||||
);
|
||||
}
|
||||
|
||||
// Priority 3: Estimate from text lengths
|
||||
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state
|
||||
.mitm_store
|
||||
.register_call_id(call_id.clone(), fc.name.clone())
|
||||
.await;
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
output_items
|
||||
.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||
}
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None,
|
||||
¶ms.user_text, "",
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&None,
|
||||
¶ms.user_text,
|
||||
"",
|
||||
)
|
||||
.await;
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
id: response_id,
|
||||
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
|
||||
|
||||
// Check for completed text response
|
||||
if state.mitm_store.is_response_complete() {
|
||||
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
||||
let text = state
|
||||
.mitm_store
|
||||
.take_response_text()
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let thinking = state.mitm_store.take_thinking_text().await;
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &None,
|
||||
¶ms.user_text, &text,
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&None,
|
||||
¶ms.user_text,
|
||||
&text,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
if let Some(ref t) = thinking {
|
||||
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
|
||||
return upstream_err_response(err);
|
||||
}
|
||||
let completed_at = now_unix();
|
||||
let msg_id = format!(
|
||||
"msg_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||
);
|
||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
);
|
||||
// Register call_id → name mapping for tool result routing
|
||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||
state
|
||||
.mitm_store
|
||||
.register_call_id(call_id.clone(), fc.name.clone())
|
||||
.await;
|
||||
|
||||
// Stringify args (OpenAI sends arguments as JSON string)
|
||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
|
||||
}
|
||||
|
||||
let (usage, _) = usage_from_poll(
|
||||
&state.mitm_store, &cascade_id, &poll_result.usage,
|
||||
¶ms.user_text, &poll_result.text,
|
||||
).await;
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&poll_result.usage,
|
||||
¶ms.user_text,
|
||||
&poll_result.text,
|
||||
)
|
||||
.await;
|
||||
|
||||
let resp = build_response_object(
|
||||
ResponseData {
|
||||
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
|
||||
}
|
||||
|
||||
// Normal text response (no tool calls)
|
||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
||||
let (usage, mitm_thinking) = usage_from_poll(
|
||||
&state.mitm_store,
|
||||
&cascade_id,
|
||||
&poll_result.usage,
|
||||
¶ms.user_text,
|
||||
&poll_result.text,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||
let thinking_text = mitm_thinking.or(poll_result.thinking);
|
||||
@@ -1560,4 +1625,3 @@ fn completion_events(
|
||||
|
||||
events
|
||||
}
|
||||
|
||||
|
||||
@@ -126,7 +126,9 @@ pub(crate) struct CompletionRequest {
|
||||
pub web_search: bool,
|
||||
}
|
||||
|
||||
fn default_n() -> u32 { 1 }
|
||||
fn default_n() -> u32 {
|
||||
1
|
||||
}
|
||||
|
||||
/// Stop sequence can be a single string or array of strings (OpenAI accepts both).
|
||||
#[derive(Deserialize, Clone)]
|
||||
@@ -254,8 +256,7 @@ pub(crate) struct OutputTokensDetails {
|
||||
pub reasoning_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[derive(Default)]
|
||||
#[derive(Serialize, Clone, Default)]
|
||||
pub(crate) struct Reasoning {
|
||||
pub effort: Option<String>,
|
||||
pub summary: Option<String>,
|
||||
@@ -313,7 +314,6 @@ impl Default for Usage {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Default for TextFormat {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
|
||||
@@ -27,7 +27,9 @@ pub(crate) fn err_response(
|
||||
|
||||
/// Convert a MITM-captured upstream error from Google into an HTTP response.
|
||||
/// Maps Google's HTTP status codes and preserves the error message.
|
||||
pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response {
|
||||
pub(crate) fn upstream_err_response(
|
||||
err: &crate::mitm::store::UpstreamError,
|
||||
) -> axum::response::Response {
|
||||
// Map Google's status code to HTTP status
|
||||
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||
|
||||
@@ -41,7 +43,9 @@ pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) ->
|
||||
_ => "upstream_error",
|
||||
};
|
||||
|
||||
let message = err.message.clone()
|
||||
let message = err
|
||||
.message
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("Google API returned HTTP {}", err.status));
|
||||
|
||||
err_response(status, message, error_type)
|
||||
@@ -99,7 +103,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
|
||||
}
|
||||
// OpenAI Responses API format
|
||||
"input_image" => {
|
||||
let url = item["image_url"].as_str()
|
||||
let url = item["image_url"]
|
||||
.as_str()
|
||||
.or_else(|| item["url"].as_str())?;
|
||||
parse_data_uri(url)
|
||||
}
|
||||
@@ -109,5 +114,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
|
||||
|
||||
/// Extract the first image from a content array (Value::Array of content parts).
|
||||
pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> {
|
||||
content.as_array()?.iter().find_map(extract_image_from_content)
|
||||
content
|
||||
.as_array()?
|
||||
.iter()
|
||||
.find_map(extract_image_from_content)
|
||||
}
|
||||
|
||||
@@ -48,10 +48,7 @@ static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
|
||||
*CHROME_MAJOR,
|
||||
)),
|
||||
);
|
||||
h.insert(
|
||||
HeaderName::from_static("sec-ch-ua-mobile"),
|
||||
hv("?0"),
|
||||
);
|
||||
h.insert(HeaderName::from_static("sec-ch-ua-mobile"), hv("?0"));
|
||||
h.insert(
|
||||
HeaderName::from_static("sec-ch-ua-platform"),
|
||||
hv("\"Linux\""),
|
||||
@@ -72,7 +69,7 @@ impl Backend {
|
||||
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
|
||||
let client = wreq::Client::builder()
|
||||
.emulation(wreq_util::Emulation::Chrome142)
|
||||
.cert_verification(false) // LS uses self-signed cert
|
||||
.cert_verification(false) // LS uses self-signed cert
|
||||
.verify_hostname(false)
|
||||
.build()
|
||||
.map_err(|e| format!("wreq client build failed: {e}"))?;
|
||||
@@ -86,11 +83,7 @@ impl Backend {
|
||||
/// Create a Backend with known connection details (for standalone LS).
|
||||
///
|
||||
/// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token.
|
||||
pub fn new_with_config(
|
||||
port: u16,
|
||||
csrf: String,
|
||||
oauth_token: String,
|
||||
) -> Result<Self, String> {
|
||||
pub fn new_with_config(port: u16, csrf: String, oauth_token: String) -> Result<Self, String> {
|
||||
let inner = BackendInner {
|
||||
pid: "standalone".to_string(),
|
||||
csrf,
|
||||
@@ -212,10 +205,7 @@ impl Backend {
|
||||
fn common_headers(csrf: &str) -> HeaderMap {
|
||||
let mut h = STATIC_HEADERS.clone();
|
||||
if let Ok(val) = HeaderValue::from_str(csrf) {
|
||||
h.insert(
|
||||
HeaderName::from_static("x-codeium-csrf-token"),
|
||||
val,
|
||||
);
|
||||
h.insert(HeaderName::from_static("x-codeium-csrf-token"), val);
|
||||
} else {
|
||||
warn!("CSRF token contains invalid header characters, omitting");
|
||||
}
|
||||
@@ -239,8 +229,8 @@ impl Backend {
|
||||
let mut headers = Self::common_headers(&csrf);
|
||||
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||
|
||||
let body_bytes = serde_json::to_vec(body)
|
||||
.map_err(|e| format!("JSON serialize error: {e}"))?;
|
||||
let body_bytes =
|
||||
serde_json::to_vec(body).map_err(|e| format!("JSON serialize error: {e}"))?;
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
@@ -258,7 +248,9 @@ impl Backend {
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let raw = resp.bytes().await
|
||||
let raw = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| format!("Read body error: {e}"))?;
|
||||
let resp_bytes = decompress(method, &raw, &encoding);
|
||||
// High-frequency polling methods → trace; everything else → debug
|
||||
@@ -288,11 +280,7 @@ impl Backend {
|
||||
}
|
||||
|
||||
/// Call a binary protobuf RPC method.
|
||||
pub async fn call_proto(
|
||||
&self,
|
||||
method: &str,
|
||||
body: Vec<u8>,
|
||||
) -> Result<(u16, Vec<u8>), String> {
|
||||
pub async fn call_proto(&self, method: &str, body: Vec<u8>) -> Result<(u16, Vec<u8>), String> {
|
||||
let (base, csrf) = {
|
||||
let guard = self.inner.read().await;
|
||||
(
|
||||
@@ -302,7 +290,10 @@ impl Backend {
|
||||
};
|
||||
let url = format!("{base}/{LS_SERVICE}/{method}");
|
||||
let mut headers = Self::common_headers(&csrf);
|
||||
headers.insert("Content-Type", HeaderValue::from_static("application/proto"));
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
HeaderValue::from_static("application/proto"),
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
@@ -350,7 +341,8 @@ impl Backend {
|
||||
text: &str,
|
||||
model_enum: u32,
|
||||
) -> Result<(u16, Vec<u8>), String> {
|
||||
self.send_message_with_image(cascade_id, text, model_enum, None).await
|
||||
self.send_message_with_image(cascade_id, text, model_enum, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// SendUserCascadeMessage with optional image attachment.
|
||||
@@ -365,7 +357,8 @@ impl Backend {
|
||||
if token.is_empty() {
|
||||
return Err("No OAuth token available".to_string());
|
||||
}
|
||||
let proto = crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
|
||||
let proto =
|
||||
crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
|
||||
if image.is_some() {
|
||||
tracing::info!(
|
||||
proto_size = proto.len(),
|
||||
@@ -376,10 +369,7 @@ impl Backend {
|
||||
}
|
||||
|
||||
/// GetCascadeTrajectorySteps → JSON with steps array.
|
||||
pub async fn get_steps(
|
||||
&self,
|
||||
cascade_id: &str,
|
||||
) -> Result<(u16, serde_json::Value), String> {
|
||||
pub async fn get_steps(&self, cascade_id: &str) -> Result<(u16, serde_json::Value), String> {
|
||||
let body = serde_json::json!({"cascadeId": cascade_id});
|
||||
self.call_json("GetCascadeTrajectorySteps", &body).await
|
||||
}
|
||||
@@ -415,7 +405,10 @@ impl Backend {
|
||||
});
|
||||
|
||||
let mut headers = Self::common_headers(&csrf);
|
||||
headers.insert("Content-Type", HeaderValue::from_static("application/connect+json"));
|
||||
headers.insert(
|
||||
"Content-Type",
|
||||
HeaderValue::from_static("application/connect+json"),
|
||||
);
|
||||
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
|
||||
|
||||
// Connect protocol envelope: [flags:1][length:4][payload]
|
||||
@@ -441,7 +434,8 @@ impl Backend {
|
||||
return Err(format!("{rpc_method} failed: {status} — {err_text}"));
|
||||
}
|
||||
|
||||
let resp_ct = resp.headers()
|
||||
let resp_ct = resp
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("unknown")
|
||||
@@ -495,7 +489,8 @@ impl Backend {
|
||||
&self,
|
||||
cascade_id: &str,
|
||||
) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> {
|
||||
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id).await
|
||||
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,7 +501,10 @@ fn discover() -> Result<BackendInner, String> {
|
||||
// the wrapper is a shell script named language_server_linux_x64, while
|
||||
// the real binary is language_server_linux_x64.real)
|
||||
let pid_output = Command::new("sh")
|
||||
.args(["-c", "pgrep -f 'language_server_linux_x64\\.real' | head -1"])
|
||||
.args([
|
||||
"-c",
|
||||
"pgrep -f 'language_server_linux_x64\\.real' | head -1",
|
||||
])
|
||||
.output()
|
||||
.map_err(|e| format!("pgrep failed: {e}"))?;
|
||||
|
||||
@@ -564,9 +562,8 @@ fn discover() -> Result<BackendInner, String> {
|
||||
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
|
||||
|
||||
for d in &dirs {
|
||||
let log_path = format!(
|
||||
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"
|
||||
);
|
||||
let log_path =
|
||||
format!("{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log");
|
||||
if let Ok(contents) = fs::read_to_string(&log_path) {
|
||||
for line in contents.lines() {
|
||||
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
|
||||
@@ -584,10 +581,7 @@ fn discover() -> Result<BackendInner, String> {
|
||||
|
||||
if https_port.is_empty() {
|
||||
// Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up)
|
||||
if let Ok(output) = std::process::Command::new("ss")
|
||||
.args(["-tlnp"])
|
||||
.output()
|
||||
{
|
||||
if let Ok(output) = std::process::Command::new("ss").args(["-tlnp"]).output() {
|
||||
let ss_out = String::from_utf8_lossy(&output.stdout);
|
||||
// Find listening ports for this PID — typically the first is HTTPS
|
||||
for line in ss_out.lines() {
|
||||
@@ -653,7 +647,11 @@ fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
|
||||
Err(e) => {
|
||||
if !encoding.is_empty() {
|
||||
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
|
||||
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview);
|
||||
warn!(
|
||||
"{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}",
|
||||
data.len(),
|
||||
preview
|
||||
);
|
||||
}
|
||||
data.to_vec()
|
||||
}
|
||||
|
||||
@@ -115,9 +115,7 @@ fn detect_versions() -> DetectedVersions {
|
||||
const FALLBACK_CLIENT: &str = "1.16.5";
|
||||
|
||||
let Some(install_dir) = find_install_dir() else {
|
||||
tracing::warn!(
|
||||
"Could not find Antigravity install — using fallback versions"
|
||||
);
|
||||
tracing::warn!("Could not find Antigravity install — using fallback versions");
|
||||
return DetectedVersions {
|
||||
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
|
||||
chrome: FALLBACK_CHROME.to_string(),
|
||||
|
||||
70
src/main.rs
70
src/main.rs
@@ -24,7 +24,10 @@ use tracing::{info, warn};
|
||||
use mitm::store::MitmStore;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")]
|
||||
#[command(
|
||||
name = "antigravity-proxy",
|
||||
about = "Antigravity OpenAI Proxy (stealth)"
|
||||
)]
|
||||
struct Cli {
|
||||
/// Port to listen on
|
||||
#[arg(long, default_value_t = 8741)]
|
||||
@@ -93,15 +96,12 @@ async fn main() {
|
||||
};
|
||||
|
||||
let filter = if log_level.is_empty() {
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "warn".into())
|
||||
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into())
|
||||
} else {
|
||||
tracing_subscriber::EnvFilter::new(log_level)
|
||||
};
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(filter)
|
||||
.init();
|
||||
tracing_subscriber::fmt().with_env_filter(filter).init();
|
||||
|
||||
// ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
|
||||
let addr = format!("127.0.0.1:{}", cli.port);
|
||||
@@ -111,7 +111,10 @@ async fn main() {
|
||||
// Port in use — try to kill whatever's holding it
|
||||
eprintln!(" Port {} in use, killing stale process...", cli.port);
|
||||
let _ = std::process::Command::new("sh")
|
||||
.args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)])
|
||||
.args([
|
||||
"-c",
|
||||
&format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port),
|
||||
])
|
||||
.status();
|
||||
// Also kill any leftover standalone LS processes
|
||||
let _ = std::process::Command::new("pkill")
|
||||
@@ -180,7 +183,9 @@ async fn main() {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
eprintln!("Fatal: {e}");
|
||||
eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode");
|
||||
eprintln!(
|
||||
"Hint: start Antigravity first, or remove --classic to use headless mode"
|
||||
);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
@@ -199,13 +204,14 @@ async fn main() {
|
||||
None
|
||||
};
|
||||
|
||||
let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
|
||||
Ok(ls) => ls,
|
||||
Err(e) => {
|
||||
eprintln!("Fatal: failed to spawn standalone LS: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let mut ls =
|
||||
match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
|
||||
Ok(ls) => ls,
|
||||
Err(e) => {
|
||||
eprintln!("Fatal: failed to spawn standalone LS: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
// Wait for it to be ready
|
||||
let rt_ls_port = ls.port;
|
||||
let rt_ls_csrf = ls.csrf.clone();
|
||||
@@ -294,7 +300,15 @@ async fn main() {
|
||||
// ── Step 5: Start serving ─────────────────────────────────────────────────
|
||||
let app = api::router(state.clone());
|
||||
|
||||
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual, is_standalone);
|
||||
print_banner(
|
||||
cli.port,
|
||||
&pid,
|
||||
&https_port,
|
||||
&csrf,
|
||||
&token,
|
||||
&mitm_port_actual,
|
||||
is_standalone,
|
||||
);
|
||||
info!("Listening on http://{addr}");
|
||||
|
||||
axum::serve(listener, app)
|
||||
@@ -349,7 +363,15 @@ async fn shutdown_signal() {
|
||||
}
|
||||
}
|
||||
|
||||
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>, is_standalone: bool) {
|
||||
fn print_banner(
|
||||
port: u16,
|
||||
pid: &str,
|
||||
https_port: &str,
|
||||
csrf: &str,
|
||||
token: &str,
|
||||
mitm: &Option<(u16, String)>,
|
||||
is_standalone: bool,
|
||||
) {
|
||||
let chrome_major = &*constants::CHROME_MAJOR;
|
||||
let ver = crate::constants::antigravity_version();
|
||||
|
||||
@@ -401,7 +423,11 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
|
||||
println!();
|
||||
|
||||
// Status line
|
||||
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" };
|
||||
let mitm_tag = if mitm.is_some() {
|
||||
"\x1b[32mmitm\x1b[0m"
|
||||
} else {
|
||||
"\x1b[31mmitm\x1b[0m"
|
||||
};
|
||||
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
|
||||
println!();
|
||||
|
||||
@@ -421,7 +447,9 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
|
||||
if token == "NOT SET" {
|
||||
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
|
||||
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
|
||||
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'");
|
||||
println!(
|
||||
" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"
|
||||
);
|
||||
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
|
||||
println!();
|
||||
}
|
||||
@@ -476,5 +504,7 @@ fn find_ls_binary_path() -> Option<String> {
|
||||
/// Get the data directory for storing MITM CA cert/key.
|
||||
fn dirs_data_dir() -> std::path::PathBuf {
|
||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy")
|
||||
std::path::PathBuf::from(home)
|
||||
.join(".config")
|
||||
.join("antigravity-proxy")
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
//! Dynamically generates per-domain leaf certificates signed by this CA.
|
||||
|
||||
use rcgen::{
|
||||
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
|
||||
IsCa, KeyPair, KeyUsagePurpose, SanType,
|
||||
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, IsCa,
|
||||
KeyPair, KeyUsagePurpose, SanType,
|
||||
};
|
||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||
use std::collections::HashMap;
|
||||
@@ -45,15 +45,16 @@ impl MitmCa {
|
||||
let key_pem = std::fs::read_to_string(&key_path)
|
||||
.map_err(|e| format!("Failed to read CA key: {e}"))?;
|
||||
|
||||
let ca_key = KeyPair::from_pem(&key_pem)
|
||||
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
||||
let ca_key =
|
||||
KeyPair::from_pem(&key_pem).map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
||||
|
||||
// Re-create params and self-sign to get the rcgen Certificate object
|
||||
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
|
||||
// The re-signed cert will have a different serial/notBefore, but that's fine
|
||||
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
|
||||
let params = Self::ca_params();
|
||||
let ca_signed = params.self_signed(&ca_key)
|
||||
let ca_signed = params
|
||||
.self_signed(&ca_key)
|
||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||
|
||||
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
|
||||
@@ -76,11 +77,12 @@ impl MitmCa {
|
||||
std::fs::create_dir_all(data_dir)
|
||||
.map_err(|e| format!("Failed to create data dir: {e}"))?;
|
||||
|
||||
let ca_key = KeyPair::generate()
|
||||
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
||||
let ca_key =
|
||||
KeyPair::generate().map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
||||
|
||||
let params = Self::ca_params();
|
||||
let ca_signed = params.self_signed(&ca_key)
|
||||
let ca_signed = params
|
||||
.self_signed(&ca_key)
|
||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||
|
||||
// Write cert and key to disk
|
||||
@@ -117,10 +119,7 @@ impl MitmCa {
|
||||
params.distinguished_name = dn;
|
||||
|
||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::KeyCertSign,
|
||||
KeyUsagePurpose::CrlSign,
|
||||
];
|
||||
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
|
||||
|
||||
// Valid for 10 years
|
||||
let now = time::OffsetDateTime::now_utc();
|
||||
@@ -151,12 +150,17 @@ impl MitmCa {
|
||||
return None;
|
||||
}
|
||||
use base64::Engine;
|
||||
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
|
||||
let der = base64::engine::general_purpose::STANDARD
|
||||
.decode(&b64)
|
||||
.ok()?;
|
||||
Some(CertificateDer::from(der))
|
||||
}
|
||||
|
||||
/// Get or create a TLS ServerConfig for the given domain.
|
||||
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
|
||||
pub async fn server_config_for_domain(
|
||||
&self,
|
||||
domain: &str,
|
||||
) -> Result<Arc<rustls::ServerConfig>, String> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.domain_cache.read().await;
|
||||
@@ -172,7 +176,11 @@ impl MitmCa {
|
||||
dn.push(DnType::CommonName, domain);
|
||||
params.distinguished_name = dn;
|
||||
|
||||
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
|
||||
params.subject_alt_names = vec![SanType::DnsName(
|
||||
domain
|
||||
.try_into()
|
||||
.map_err(|e| format!("Invalid domain: {e}"))?,
|
||||
)];
|
||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||
params.key_usages = vec![
|
||||
KeyUsagePurpose::DigitalSignature,
|
||||
@@ -184,10 +192,11 @@ impl MitmCa {
|
||||
params.not_before = now;
|
||||
params.not_after = now + time::Duration::days(365);
|
||||
|
||||
let leaf_key = KeyPair::generate()
|
||||
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
||||
let leaf_key =
|
||||
KeyPair::generate().map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
||||
|
||||
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
||||
let leaf_cert = params
|
||||
.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
||||
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
|
||||
|
||||
// Build rustls ServerConfig
|
||||
@@ -196,10 +205,7 @@ impl MitmCa {
|
||||
|
||||
let mut config = rustls::ServerConfig::builder()
|
||||
.with_no_client_auth()
|
||||
.with_single_cert(
|
||||
vec![leaf_cert_der, self.ca_cert_der.clone()],
|
||||
leaf_key_der,
|
||||
)
|
||||
.with_single_cert(vec![leaf_cert_der, self.ca_cert_der.clone()], leaf_key_der)
|
||||
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
|
||||
|
||||
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
|
||||
|
||||
@@ -92,11 +92,10 @@ impl UpstreamPool {
|
||||
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
||||
|
||||
let upstream_io = TokioIo::new(upstream_tls);
|
||||
let (sender, conn) =
|
||||
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.handshake(upstream_io)
|
||||
.await
|
||||
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||
let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||
.handshake(upstream_io)
|
||||
.await
|
||||
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||
|
||||
let domain = self.domain.clone();
|
||||
tokio::spawn(async move {
|
||||
@@ -215,12 +214,10 @@ async fn handle_h2_request(
|
||||
.unwrap_or(false);
|
||||
|
||||
// Check if this method carries usage data
|
||||
let is_usage_method = is_grpc
|
||||
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||
let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||
|
||||
// Check if this is a streaming method
|
||||
let is_streaming = is_grpc
|
||||
&& (path.contains("Stream") || path.contains("stream"));
|
||||
let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream"));
|
||||
|
||||
debug!(
|
||||
domain,
|
||||
@@ -249,9 +246,9 @@ async fn handle_h2_request(
|
||||
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream connect failed: {e}")),
|
||||
)))
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
format!("upstream connect failed: {e}"),
|
||||
))))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
@@ -261,17 +258,11 @@ async fn handle_h2_request(
|
||||
let upstream_uri = http::Uri::builder()
|
||||
.scheme("https")
|
||||
.authority(domain)
|
||||
.path_and_query(
|
||||
uri.path_and_query()
|
||||
.map(|pq| pq.as_str())
|
||||
.unwrap_or("/"),
|
||||
)
|
||||
.path_and_query(uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"))
|
||||
.build()
|
||||
.unwrap_or(uri);
|
||||
|
||||
let mut upstream_req = Request::builder()
|
||||
.method(parts.method)
|
||||
.uri(upstream_uri);
|
||||
let mut upstream_req = Request::builder().method(parts.method).uri(upstream_uri);
|
||||
|
||||
// Copy headers, skip hop-by-hop
|
||||
for (name, value) in &parts.headers {
|
||||
@@ -287,9 +278,9 @@ async fn handle_h2_request(
|
||||
Err(e) => {
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("build request failed: {e}")),
|
||||
)))
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
format!("build request failed: {e}"),
|
||||
))))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
@@ -302,9 +293,9 @@ async fn handle_h2_request(
|
||||
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
||||
let resp = Response::builder()
|
||||
.status(502)
|
||||
.body(http_body_util::Either::Left(Full::new(
|
||||
Bytes::from(format!("upstream request failed: {e}")),
|
||||
)))
|
||||
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||
format!("upstream request failed: {e}"),
|
||||
))))
|
||||
.unwrap();
|
||||
return Ok(resp);
|
||||
}
|
||||
@@ -326,13 +317,18 @@ async fn handle_h2_request(
|
||||
|
||||
// Spawn a task to forward body chunks and tee for usage extraction
|
||||
tokio::spawn(async move {
|
||||
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
||||
let mut tee_buffer = if should_track_usage {
|
||||
Some(Vec::new())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut body = resp_body;
|
||||
|
||||
loop {
|
||||
match body.frame().await {
|
||||
Some(Ok(frame)) => {
|
||||
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
||||
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref())
|
||||
{
|
||||
buf.extend_from_slice(data);
|
||||
}
|
||||
if tx.send(Ok(frame)).await.is_err() {
|
||||
@@ -354,7 +350,9 @@ async fn handle_h2_request(
|
||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
||||
let usage = grpc_usage.into_api_usage(path_clone.clone());
|
||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
||||
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
||||
store_clone
|
||||
.record_usage(cascade_hint.as_deref(), usage)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,15 +78,21 @@ impl StreamingAccumulator {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Process a single SSE event.
|
||||
/// Process a single SSE event.
|
||||
pub fn process_event(&mut self, event: &Value) {
|
||||
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
|
||||
if let Some(response) = event.get("response") {
|
||||
// Extract usage metadata (each event has cumulative counts)
|
||||
if let Some(usage) = response.get("usageMetadata") {
|
||||
self.input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(self.input_tokens);
|
||||
self.output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(self.output_tokens);
|
||||
self.thinking_tokens = usage["thoughtsTokenCount"].as_u64().unwrap_or(self.thinking_tokens);
|
||||
self.input_tokens = usage["promptTokenCount"]
|
||||
.as_u64()
|
||||
.unwrap_or(self.input_tokens);
|
||||
self.output_tokens = usage["candidatesTokenCount"]
|
||||
.as_u64()
|
||||
.unwrap_or(self.output_tokens);
|
||||
self.thinking_tokens = usage["thoughtsTokenCount"]
|
||||
.as_u64()
|
||||
.unwrap_or(self.thinking_tokens);
|
||||
}
|
||||
if let Some(model) = response["modelVersion"].as_str() {
|
||||
self.model = Some(model.to_string());
|
||||
@@ -170,8 +176,10 @@ impl StreamingAccumulator {
|
||||
"message_start" => {
|
||||
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
||||
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_creation_input_tokens =
|
||||
usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
||||
self.cache_read_input_tokens =
|
||||
usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
||||
}
|
||||
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
||||
self.model = Some(model.to_string());
|
||||
@@ -181,7 +189,9 @@ impl StreamingAccumulator {
|
||||
}
|
||||
"message_delta" => {
|
||||
if let Some(usage) = event.get("usage") {
|
||||
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
|
||||
self.output_tokens = usage["output_tokens"]
|
||||
.as_u64()
|
||||
.unwrap_or(self.output_tokens);
|
||||
}
|
||||
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
||||
self.stop_reason = Some(reason.to_string());
|
||||
@@ -235,7 +245,10 @@ impl StreamingAccumulator {
|
||||
response_output_tokens: 0,
|
||||
model: self.model,
|
||||
stop_reason: self.stop_reason,
|
||||
api_provider: self.api_provider.unwrap_or_else(|| "unknown".to_string()).into(),
|
||||
api_provider: self
|
||||
.api_provider
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
.into(),
|
||||
grpc_method: None,
|
||||
captured_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
|
||||
@@ -68,14 +68,14 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
"system instruction: keep <identity> only ({original_len} → {} chars, -{stripped})",
|
||||
new_sys.len()
|
||||
));
|
||||
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
||||
Value::String(new_sys);
|
||||
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(new_sys);
|
||||
}
|
||||
} else {
|
||||
// No identity tag found — clear the whole thing
|
||||
changes.push(format!("system instruction: cleared ({original_len} chars)"));
|
||||
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
||||
Value::String(String::new());
|
||||
changes.push(format!(
|
||||
"system instruction: cleared ({original_len} chars)"
|
||||
));
|
||||
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,7 +125,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
let mut modified = text.clone();
|
||||
|
||||
// Strip conversation summaries block
|
||||
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") {
|
||||
if let Some(cleaned) = strip_between(
|
||||
&modified,
|
||||
"# Conversation History\n",
|
||||
"</conversation_summaries>",
|
||||
) {
|
||||
modified = cleaned;
|
||||
}
|
||||
|
||||
@@ -147,7 +151,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
}
|
||||
|
||||
// Strip knowledge item blocks
|
||||
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") {
|
||||
if let Some(cleaned) =
|
||||
strip_between(&modified, "Here are the ", "</knowledge_item>")
|
||||
{
|
||||
// Only strip if it's about knowledge items
|
||||
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
|
||||
modified = cleaned;
|
||||
@@ -202,7 +208,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
// Inject client-provided tools from ToolContext
|
||||
if let Some(ref ctx) = tool_ctx {
|
||||
if let Some(ref custom_tools) = ctx.tools {
|
||||
let total_decls: usize = custom_tools.iter()
|
||||
let total_decls: usize = custom_tools
|
||||
.iter()
|
||||
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
|
||||
.map(|a| a.len())
|
||||
.sum();
|
||||
@@ -210,7 +217,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
tools.push(tool.clone());
|
||||
}
|
||||
has_custom_tools = true;
|
||||
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
|
||||
changes.push(format!(
|
||||
"inject {} custom tool group(s)",
|
||||
custom_tools.len()
|
||||
));
|
||||
|
||||
// Override LS's VALIDATED toolConfig → AUTO for custom tools.
|
||||
// VALIDATED mode forces Google to validate function calls against a
|
||||
@@ -218,16 +228,20 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
// that list, so they'd be rejected. AUTO lets the model freely choose
|
||||
// between text and function calls.
|
||||
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||
let has_validated = req.get("toolConfig")
|
||||
let has_validated = req
|
||||
.get("toolConfig")
|
||||
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
|
||||
.and_then(|m| m.as_str())
|
||||
.map_or(false, |m| m == "VALIDATED");
|
||||
if has_validated {
|
||||
req.insert("toolConfig".to_string(), serde_json::json!({
|
||||
"functionCallingConfig": {
|
||||
"mode": "AUTO"
|
||||
}
|
||||
}));
|
||||
req.insert(
|
||||
"toolConfig".to_string(),
|
||||
serde_json::json!({
|
||||
"functionCallingConfig": {
|
||||
"mode": "AUTO"
|
||||
}
|
||||
}),
|
||||
);
|
||||
changes.push("override toolConfig VALIDATED → AUTO".to_string());
|
||||
}
|
||||
}
|
||||
@@ -243,7 +257,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
if STRIP_ALL_TOOLS && !has_custom_tools {
|
||||
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||
// Remove the empty tools array entirely
|
||||
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) {
|
||||
if req
|
||||
.get("tools")
|
||||
.and_then(|v| v.as_array())
|
||||
.map_or(false, |a| a.is_empty())
|
||||
{
|
||||
req.remove("tools");
|
||||
changes.push("remove empty tools array".to_string());
|
||||
}
|
||||
@@ -266,7 +284,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
.as_ref()
|
||||
.and_then(|ctx| ctx.tools.as_ref())
|
||||
.map(|tools| {
|
||||
tools.iter()
|
||||
tools
|
||||
.iter()
|
||||
.filter_map(|t| t["functionDeclarations"].as_array())
|
||||
.flatten()
|
||||
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
|
||||
@@ -309,7 +328,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
.map_or(true, |parts| !parts.is_empty())
|
||||
});
|
||||
if stripped_fc > 0 {
|
||||
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history"));
|
||||
changes.push(format!(
|
||||
"strip {stripped_fc} functionCall/Response parts from history"
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -336,16 +357,22 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
for msg in contents.iter_mut() {
|
||||
if msg["role"].as_str() == Some("model") {
|
||||
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
||||
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
|
||||
if text.contains("Tool call completed")
|
||||
|| text.contains("Awaiting external tool result")
|
||||
{
|
||||
// Replace with functionCall parts
|
||||
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
let fc_parts: Vec<Value> = ctx
|
||||
.last_calls
|
||||
.iter()
|
||||
.map(|fc| {
|
||||
serde_json::json!({
|
||||
"functionCall": {
|
||||
"name": fc.name,
|
||||
"args": fc.args,
|
||||
}
|
||||
})
|
||||
})
|
||||
}).collect();
|
||||
.collect();
|
||||
msg["parts"] = Value::Array(fc_parts);
|
||||
changes.push("rewrite model turn with functionCall".to_string());
|
||||
break;
|
||||
@@ -355,29 +382,36 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
}
|
||||
|
||||
// Add functionResponse as a user turn before the last user message
|
||||
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
|
||||
serde_json::json!({
|
||||
"functionResponse": {
|
||||
"name": r.name,
|
||||
"response": r.result,
|
||||
}
|
||||
let fn_response_parts: Vec<Value> = ctx
|
||||
.pending_results
|
||||
.iter()
|
||||
.map(|r| {
|
||||
serde_json::json!({
|
||||
"functionResponse": {
|
||||
"name": r.name,
|
||||
"response": r.result,
|
||||
}
|
||||
})
|
||||
})
|
||||
}).collect();
|
||||
.collect();
|
||||
let fn_response_turn = serde_json::json!({
|
||||
"role": "user",
|
||||
"parts": fn_response_parts,
|
||||
});
|
||||
|
||||
// Insert before the last user message
|
||||
let last_user_idx = contents.iter().rposition(|msg| {
|
||||
msg["role"].as_str() == Some("user")
|
||||
});
|
||||
let last_user_idx = contents
|
||||
.iter()
|
||||
.rposition(|msg| msg["role"].as_str() == Some("user"));
|
||||
if let Some(idx) = last_user_idx {
|
||||
contents.insert(idx, fn_response_turn);
|
||||
} else {
|
||||
contents.push(fn_response_turn);
|
||||
}
|
||||
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
|
||||
changes.push(format!(
|
||||
"inject {} functionResponse(s)",
|
||||
ctx.pending_results.len()
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -420,8 +454,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
} else {
|
||||
// Not wrapped in request — try top-level (public API format)
|
||||
let gen_config = json.as_object_mut().and_then(|o| {
|
||||
Some(o.entry("generationConfig")
|
||||
.or_insert_with(|| serde_json::json!({})))
|
||||
Some(
|
||||
o.entry("generationConfig")
|
||||
.or_insert_with(|| serde_json::json!({})),
|
||||
)
|
||||
});
|
||||
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
|
||||
let thinking_config = gc
|
||||
@@ -449,8 +485,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
if let Some(ref gp) = ctx.generation_params {
|
||||
// Find or create generationConfig (same path as above)
|
||||
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||
Some(req.entry("generationConfig")
|
||||
.or_insert_with(|| serde_json::json!({})))
|
||||
Some(
|
||||
req.entry("generationConfig")
|
||||
.or_insert_with(|| serde_json::json!({})),
|
||||
)
|
||||
} else {
|
||||
json.as_object_mut().map(|o| {
|
||||
o.entry("generationConfig")
|
||||
@@ -564,8 +602,6 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
||||
changes.join(", ")
|
||||
);
|
||||
|
||||
|
||||
|
||||
Some(modified_bytes)
|
||||
}
|
||||
|
||||
@@ -832,8 +868,10 @@ mod tests {
|
||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||
|
||||
// With no ToolContext, tools should be removed entirely
|
||||
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
||||
"tools should be removed when no custom tools provided");
|
||||
assert!(
|
||||
result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
||||
"tools should be removed when no custom tools provided"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -892,13 +930,23 @@ mod tests {
|
||||
let contents = result["request"]["contents"].as_array().unwrap();
|
||||
// Should have removed user_information, user_rules, workflows (3 messages)
|
||||
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
|
||||
assert_eq!(contents.len(), 2, "should keep only user request + model response");
|
||||
assert_eq!(
|
||||
contents.len(),
|
||||
2,
|
||||
"should keep only user request + model response"
|
||||
);
|
||||
|
||||
// Check USER_REQUEST message had metadata stripped
|
||||
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
|
||||
assert!(user_msg.contains("Say hello"), "should keep user request");
|
||||
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata");
|
||||
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info");
|
||||
assert!(
|
||||
!user_msg.contains("ADDITIONAL_METADATA"),
|
||||
"should strip metadata"
|
||||
);
|
||||
assert!(
|
||||
!user_msg.contains("cursor stuff"),
|
||||
"should strip cursor info"
|
||||
);
|
||||
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
|
||||
|
||||
// Model response kept intact
|
||||
@@ -921,8 +969,14 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_strip_between() {
|
||||
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
||||
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap();
|
||||
let text =
|
||||
"keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
||||
let result = strip_between(
|
||||
text,
|
||||
"# Conversation History\n",
|
||||
"</conversation_summaries>",
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(result, "keep this and this");
|
||||
}
|
||||
}
|
||||
@@ -977,7 +1031,9 @@ pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
|
||||
// Replace the JSON in the result string
|
||||
result.replace_range(json_start..json_start + json_end, &new_json);
|
||||
changed = true;
|
||||
info!("MITM: rewrote functionCall in response → text placeholder for LS");
|
||||
info!(
|
||||
"MITM: rewrote functionCall in response → text placeholder for LS"
|
||||
);
|
||||
search_from = json_start + new_json.len();
|
||||
continue;
|
||||
}
|
||||
@@ -1117,7 +1173,10 @@ fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
|
||||
}
|
||||
|
||||
// Try nested "response.candidates"
|
||||
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) {
|
||||
if let Some(candidates) = json
|
||||
.pointer_mut("/response/candidates")
|
||||
.and_then(|v| v.as_array_mut())
|
||||
{
|
||||
changed |= rewrite_candidates(candidates);
|
||||
}
|
||||
|
||||
|
||||
@@ -251,7 +251,10 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
|
||||
// (e.g., a long string that happened to have a valid first-field prefix)
|
||||
if fields.len() == 1 && original_len > 100 {
|
||||
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
||||
matches!(&fields[0].value, ProtoValue::Bytes(_) | ProtoValue::Message(_))
|
||||
matches!(
|
||||
&fields[0].value,
|
||||
ProtoValue::Bytes(_) | ProtoValue::Message(_)
|
||||
)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
@@ -328,7 +331,9 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
if let ProtoValue::Bytes(ref b) = f.value {
|
||||
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
|
||||
std::str::from_utf8(b)
|
||||
.ok()
|
||||
.map(|s| (f.number, s.to_string()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -361,14 +366,23 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
||||
// can contain model names, or model enum values map to known names)
|
||||
let has_model_string = string_fields.iter().any(|(_, s)| {
|
||||
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|
||||
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|
||||
|| s.contains("flash") || s.contains("pro")
|
||||
s.contains("claude")
|
||||
|| s.contains("gemini")
|
||||
|| s.contains("gpt")
|
||||
|| s.starts_with("models/")
|
||||
|| s.contains("sonnet")
|
||||
|| s.contains("opus")
|
||||
|| s.contains("flash")
|
||||
|| s.contains("pro")
|
||||
});
|
||||
|
||||
// Check for fields at the known ModelUsageStats field numbers
|
||||
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
let has_field_2 = fields
|
||||
.iter()
|
||||
.any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
let has_field_3 = fields
|
||||
.iter()
|
||||
.any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||
|
||||
// Strong signal: has both input and output token fields
|
||||
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
||||
@@ -392,8 +406,8 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
||||
// field 1 = model enum (varint, not string!)
|
||||
2 => usage.input_tokens = v,
|
||||
3 => usage.output_tokens = v,
|
||||
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||
// field 6 = api_provider enum (varint)
|
||||
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
||||
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
||||
@@ -486,11 +500,11 @@ pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
|
||||
fn model_enum_name(enum_val: u64) -> &'static str {
|
||||
match enum_val {
|
||||
// Placeholder models (1000 + N)
|
||||
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
|
||||
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
|
||||
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
|
||||
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
|
||||
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
|
||||
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
|
||||
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
|
||||
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
|
||||
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
|
||||
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
|
||||
|
||||
// Claude models (named)
|
||||
281 => "claude-4-sonnet",
|
||||
@@ -629,13 +643,13 @@ mod tests {
|
||||
data.push(v as u8);
|
||||
}
|
||||
|
||||
encode_varint_field(&mut data, 1, 5); // model enum
|
||||
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||
encode_varint_field(&mut data, 1, 5); // model enum
|
||||
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||
|
||||
let fields = decode_proto(&data);
|
||||
let usage = try_extract_usage(&fields).expect("should extract usage");
|
||||
|
||||
@@ -11,8 +11,7 @@
|
||||
|
||||
use super::ca::MitmCa;
|
||||
use super::intercept::{
|
||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
|
||||
StreamingAccumulator,
|
||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
|
||||
};
|
||||
use super::store::MitmStore;
|
||||
use std::sync::Arc;
|
||||
@@ -54,7 +53,6 @@ pub struct MitmConfig {
|
||||
pub modify_requests: bool,
|
||||
}
|
||||
|
||||
|
||||
/// Run the MITM proxy server.
|
||||
///
|
||||
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
||||
@@ -84,7 +82,8 @@ pub async fn run(
|
||||
let ca = ca.clone();
|
||||
let store = store.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
|
||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
|
||||
{
|
||||
warn!(error = %e, "MITM connection error");
|
||||
}
|
||||
});
|
||||
@@ -131,8 +130,7 @@ async fn handle_connection(
|
||||
.await
|
||||
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
||||
|
||||
let domain = extract_sni(&hello_buf[..n])
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
|
||||
|
||||
info!(domain, "MITM: transparent redirect (iptables)");
|
||||
|
||||
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
|
||||
let mut pos = 34; // skip version + random
|
||||
|
||||
// Session ID
|
||||
if pos >= body.len() { return None; }
|
||||
if pos >= body.len() {
|
||||
return None;
|
||||
}
|
||||
let sid_len = body[pos] as usize;
|
||||
pos += 1 + sid_len;
|
||||
|
||||
// Cipher suites
|
||||
if pos + 2 > body.len() { return None; }
|
||||
if pos + 2 > body.len() {
|
||||
return None;
|
||||
}
|
||||
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2 + cs_len;
|
||||
|
||||
// Compression methods
|
||||
if pos >= body.len() { return None; }
|
||||
if pos >= body.len() {
|
||||
return None;
|
||||
}
|
||||
let cm_len = body[pos] as usize;
|
||||
pos += 1 + cm_len;
|
||||
|
||||
// Extensions
|
||||
if pos + 2 > body.len() { return None; }
|
||||
if pos + 2 > body.len() {
|
||||
return None;
|
||||
}
|
||||
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||
pos += 2;
|
||||
let ext_end = pos + ext_len.min(body.len() - pos);
|
||||
@@ -304,32 +310,32 @@ async fn handle_intercepted(
|
||||
info!(domain, "MITM: intercepting TLS");
|
||||
|
||||
// Get or create server TLS config for this domain
|
||||
let server_config = ca
|
||||
.server_config_for_domain(domain)
|
||||
.await?;
|
||||
let server_config = ca.server_config_for_domain(domain).await?;
|
||||
|
||||
let acceptor = TlsAcceptor::from(server_config);
|
||||
|
||||
// Perform TLS handshake with the client (LS) — 10s timeout
|
||||
let tls_stream = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(10),
|
||||
acceptor.accept(stream),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
||||
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
||||
return Err(format!("TLS handshake timed out for {domain}"));
|
||||
}
|
||||
};
|
||||
let tls_stream =
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
|
||||
.await
|
||||
{
|
||||
Ok(Ok(s)) => s,
|
||||
Ok(Err(e)) => {
|
||||
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
||||
return Err(format!(
|
||||
"TLS handshake with client failed for {domain}: {e}"
|
||||
));
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
||||
return Err(format!("TLS handshake timed out for {domain}"));
|
||||
}
|
||||
};
|
||||
|
||||
// Check negotiated ALPN protocol
|
||||
let alpn = tls_stream.get_ref().1
|
||||
let alpn = tls_stream
|
||||
.get_ref()
|
||||
.1
|
||||
.alpn_protocol()
|
||||
.map(|p| String::from_utf8_lossy(p).to_string());
|
||||
|
||||
@@ -339,12 +345,7 @@ async fn handle_intercepted(
|
||||
Some("h2") => {
|
||||
// HTTP/2 — use the hyper-based gRPC handler
|
||||
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
||||
super::h2_handler::handle_h2_connection(
|
||||
tls_stream,
|
||||
domain.to_string(),
|
||||
store,
|
||||
)
|
||||
.await
|
||||
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
|
||||
}
|
||||
_ => {
|
||||
// HTTP/1.1 or no ALPN — use the existing handler
|
||||
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
|
||||
.await
|
||||
{
|
||||
let out = String::from_utf8_lossy(&output.stdout);
|
||||
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
|
||||
if let Some(ip) = out
|
||||
.lines()
|
||||
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
|
||||
{
|
||||
return format!("{ip}:443");
|
||||
}
|
||||
}
|
||||
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
|
||||
loop {
|
||||
// ── Read the HTTP request from the client ─────────────────────────
|
||||
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||
let mut is_our_request = false;
|
||||
|
||||
// 60s timeout on initial read (LS may open connection without sending immediately)
|
||||
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
||||
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
// Parse the HTTP request to find headers and body
|
||||
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf);
|
||||
let (headers_end, content_length, _is_streaming_request) =
|
||||
parse_http_request_meta(&request_buf);
|
||||
|
||||
// Try to extract cascade hint from request body
|
||||
let cascade_hint = if headers_end < request_buf.len() {
|
||||
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
|
||||
"MITM: forwarding LLM request"
|
||||
);
|
||||
|
||||
// ── Block ALL requests when one is already in-flight ─────────
|
||||
// The LS opens multiple connections and sends parallel requests.
|
||||
// When custom tools are active, only the FIRST request should reach
|
||||
// Google. Block everything else with a fake response.
|
||||
if store.is_request_in_flight() {
|
||||
info!("MITM: blocking LS request — another request already in-flight");
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue;
|
||||
}
|
||||
|
||||
// ── Request modification ─────────────────────────────────────
|
||||
// Dechunk body → check if agent request → modify → rechunk
|
||||
if modify_requests && body_len > 0 {
|
||||
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
|
||||
let generation_params = store.get_generation_params().await;
|
||||
let pending_image = store.take_pending_image().await;
|
||||
|
||||
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() {
|
||||
let tool_ctx = if tools.is_some()
|
||||
|| !pending_results.is_empty()
|
||||
|| generation_params.is_some()
|
||||
|| pending_image.is_some()
|
||||
{
|
||||
Some(super::modify::ToolContext {
|
||||
tools,
|
||||
tool_config,
|
||||
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
|
||||
if let Some(modified_body) =
|
||||
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
|
||||
{
|
||||
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
|
||||
let new_chunked = super::modify::rechunk(&modified_body);
|
||||
|
||||
@@ -589,38 +620,11 @@ async fn handle_http_over_tls(
|
||||
new_buf.extend_from_slice(&new_chunked);
|
||||
request_buf = new_buf;
|
||||
|
||||
// Mark this as our modified request and set in-flight flag
|
||||
is_our_request = true;
|
||||
// Mark in-flight IMMEDIATELY — blocks all subsequent requests
|
||||
store.mark_request_in_flight();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Block ALL LS follow-up requests once first is in-flight ──
|
||||
// When custom tools are active, we only need ONE request to Google.
|
||||
// The LS tries to send multiple requests (its own agentic loop +
|
||||
// internal requests on gemini-2.5-flash-lite). Block them ALL
|
||||
// immediately — don't wait for response_complete.
|
||||
let has_tools = store.get_tools().await.is_some();
|
||||
if has_tools && store.is_request_in_flight() && !is_our_request {
|
||||
info!(
|
||||
"MITM: blocking LS follow-up — request already in-flight"
|
||||
);
|
||||
// Return a fake SSE response that makes the LS stop
|
||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||
Content-Type: text/event-stream\r\n\
|
||||
Transfer-Encoding: chunked\r\n\
|
||||
\r\n";
|
||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||
let mut response = fake_response.as_bytes().to_vec();
|
||||
response.extend_from_slice(&chunked_body);
|
||||
if let Err(e) = client.write_all(&response).await {
|
||||
warn!(error = %e, "MITM: failed to write fake response");
|
||||
}
|
||||
let _ = client.flush().await;
|
||||
continue; // Skip the real upstream call
|
||||
}
|
||||
} else {
|
||||
debug!(
|
||||
domain,
|
||||
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
|
||||
};
|
||||
|
||||
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
|
||||
Ok(Ok(0)) => { upstream_ok = false; break; }
|
||||
Ok(Ok(0)) => {
|
||||
upstream_ok = false;
|
||||
break;
|
||||
}
|
||||
Ok(Ok(n)) => n,
|
||||
Ok(Err(e)) => {
|
||||
debug!(domain, error = %e, "MITM: upstream read ended");
|
||||
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
|
||||
if header.name.eq_ignore_ascii_case("content-type") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
content_type = v.to_string();
|
||||
if v.contains("text/event-stream") { is_streaming_response = true; }
|
||||
if v.contains("text/event-stream") {
|
||||
is_streaming_response = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("content-length") {
|
||||
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("connection") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
|
||||
if v.trim().eq_ignore_ascii_case("close") {
|
||||
upstream_ok = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
|
||||
if v.trim().eq_ignore_ascii_case("chunked") {
|
||||
is_chunked = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
|
||||
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
|
||||
|
||||
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
|
||||
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
let err = v.get("error")?;
|
||||
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string());
|
||||
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string());
|
||||
Some((msg, status))
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
let (message, error_status) =
|
||||
serde_json::from_str::<serde_json::Value>(&body_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
let err = v.get("error")?;
|
||||
let msg = err
|
||||
.get("message")
|
||||
.and_then(|m| m.as_str())
|
||||
.map(|s| s.to_string());
|
||||
let status = err
|
||||
.get("status")
|
||||
.and_then(|s| s.as_str())
|
||||
.map(|s| s.to_string());
|
||||
Some((msg, status))
|
||||
})
|
||||
.unwrap_or((None, None));
|
||||
|
||||
store.set_upstream_error(super::store::UpstreamError {
|
||||
status: http_status,
|
||||
body: body_str,
|
||||
message,
|
||||
error_status,
|
||||
}).await;
|
||||
store
|
||||
.set_upstream_error(super::store::UpstreamError {
|
||||
status: http_status,
|
||||
body: body_str,
|
||||
message,
|
||||
error_status,
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
// Save body for usage parsing
|
||||
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
store
|
||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||
.await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from initial body", calls.len());
|
||||
info!(
|
||||
"MITM: stored {} function call(s) from initial body",
|
||||
calls.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Capture response + thinking text + grounding into MitmStore
|
||||
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
if response_body_buf.len() >= cl { break; }
|
||||
if response_body_buf.len() >= cl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Check chunked terminator in initial body
|
||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
|
||||
if !streaming_acc.function_calls.is_empty() {
|
||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||
for fc in &calls {
|
||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||
store
|
||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||
.await;
|
||||
}
|
||||
store.set_last_function_calls(calls.clone()).await;
|
||||
info!("MITM: stored {} function call(s) from body chunk", calls.len());
|
||||
info!(
|
||||
"MITM: stored {} function call(s) from body chunk",
|
||||
calls.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Capture response + thinking text + grounding into MitmStore
|
||||
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
|
||||
response_body_buf.extend_from_slice(chunk);
|
||||
|
||||
if let Some(cl) = response_content_length {
|
||||
if response_body_buf.len() >= cl { break; }
|
||||
if response_body_buf.len() >= cl {
|
||||
break;
|
||||
}
|
||||
}
|
||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||
debug!(domain, "MITM: chunked response complete");
|
||||
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
|
||||
}
|
||||
|
||||
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||
async fn handle_passthrough(
|
||||
mut client: TcpStream,
|
||||
domain: &str,
|
||||
port: u16,
|
||||
) -> Result<(), String> {
|
||||
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
|
||||
trace!(domain, port, "MITM: transparent tunnel");
|
||||
|
||||
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
||||
@@ -926,7 +958,12 @@ async fn handle_passthrough(
|
||||
// Bidirectional copy
|
||||
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
||||
Ok((client_to_server, server_to_client)) => {
|
||||
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
|
||||
trace!(
|
||||
domain,
|
||||
client_to_server,
|
||||
server_to_client,
|
||||
"MITM: tunnel closed"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
||||
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
|
||||
return false;
|
||||
}
|
||||
// Check last 7 bytes to account for possible trailing whitespace
|
||||
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
|
||||
let tail = if body.len() > 7 {
|
||||
&body[body.len() - 7..]
|
||||
} else {
|
||||
body
|
||||
};
|
||||
// Look for \r\n0\r\n\r\n anywhere in the tail
|
||||
tail.windows(5).any(|w| w == b"0\r\n\r\n")
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
//!
|
||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::sync::RwLock;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info};
|
||||
|
||||
/// Token usage from an intercepted API response.
|
||||
@@ -342,7 +342,9 @@ impl MitmStore {
|
||||
|
||||
/// Record a captured function call from Google's response.
|
||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
||||
let key = cascade_id
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| "_latest".to_string());
|
||||
info!(
|
||||
cascade = %key,
|
||||
tool = %fc.name,
|
||||
@@ -377,7 +379,6 @@ impl MitmStore {
|
||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
|
||||
/// Take any pending function calls (ignoring cascade ID).
|
||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||
let mut pending = self.pending_function_calls.write().await;
|
||||
@@ -457,8 +458,6 @@ impl MitmStore {
|
||||
|
||||
// ── Direct response capture (bypass LS) ──────────────────────────────
|
||||
|
||||
|
||||
|
||||
/// Set (replace) the captured response text.
|
||||
pub async fn set_response_text(&self, text: &str) {
|
||||
*self.captured_response_text.write().await = Some(text.to_string());
|
||||
@@ -484,8 +483,6 @@ impl MitmStore {
|
||||
self.response_complete.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
|
||||
|
||||
/// Async version of clear_response.
|
||||
pub async fn clear_response_async(&self) {
|
||||
self.response_complete.store(false, Ordering::SeqCst);
|
||||
|
||||
@@ -293,8 +293,7 @@ mod tests {
|
||||
|
||||
let cascade_bytes = b"test-cascade-id";
|
||||
assert!(
|
||||
msg.windows(cascade_bytes.len())
|
||||
.any(|w| w == cascade_bytes),
|
||||
msg.windows(cascade_bytes.len()).any(|w| w == cascade_bytes),
|
||||
"cascade_id must appear in output"
|
||||
);
|
||||
|
||||
|
||||
17
src/quota.rs
17
src/quota.rs
@@ -93,9 +93,8 @@ impl QuotaStore {
|
||||
// Initial poll immediately.
|
||||
self.poll_once(&backend).await;
|
||||
|
||||
let mut interval = tokio::time::interval(
|
||||
std::time::Duration::from_secs(POLL_INTERVAL_SECS),
|
||||
);
|
||||
let mut interval =
|
||||
tokio::time::interval(std::time::Duration::from_secs(POLL_INTERVAL_SECS));
|
||||
interval.tick().await; // consume the first immediate tick
|
||||
|
||||
loop {
|
||||
@@ -125,7 +124,9 @@ impl QuotaStore {
|
||||
// Profile picture fetch fails through iptables — harmless, suppress
|
||||
let data_str = data.to_string();
|
||||
if data_str.contains("profile picture") {
|
||||
tracing::debug!("GetUserStatus: profile picture fetch failed (expected with iptables)");
|
||||
tracing::debug!(
|
||||
"GetUserStatus: profile picture fetch failed (expected with iptables)"
|
||||
);
|
||||
} else {
|
||||
warn!("GetUserStatus returned {status}: {data_str}");
|
||||
}
|
||||
@@ -172,9 +173,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let frac = m["quotaInfo"]["remainingFraction"]
|
||||
.as_f64()
|
||||
.unwrap_or(0.0);
|
||||
let frac = m["quotaInfo"]["remainingFraction"].as_f64().unwrap_or(0.0);
|
||||
let reset_str = m["quotaInfo"]["resetTime"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
@@ -224,9 +223,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
|
||||
flow_available: flow_avail,
|
||||
flow_total,
|
||||
flow_used_pct,
|
||||
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"]
|
||||
.as_i64()
|
||||
.unwrap_or(0),
|
||||
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"].as_i64().unwrap_or(0),
|
||||
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
|
||||
},
|
||||
models,
|
||||
|
||||
@@ -66,9 +66,7 @@ impl SessionManager {
|
||||
msg_count: 0,
|
||||
},
|
||||
);
|
||||
return Ok(SessionResult {
|
||||
cascade_id,
|
||||
});
|
||||
return Ok(SessionResult { cascade_id });
|
||||
}
|
||||
|
||||
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
|
||||
@@ -111,9 +109,7 @@ impl SessionManager {
|
||||
},
|
||||
);
|
||||
}
|
||||
Ok(SessionResult {
|
||||
cascade_id,
|
||||
})
|
||||
Ok(SessionResult { cascade_id })
|
||||
}
|
||||
|
||||
/// List all active sessions.
|
||||
@@ -146,7 +142,5 @@ impl SessionManager {
|
||||
|
||||
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
|
||||
let now = Instant::now();
|
||||
sessions.retain(|_, s| {
|
||||
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
|
||||
});
|
||||
sessions.retain(|_, s| now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS);
|
||||
}
|
||||
|
||||
210
src/snapshot.rs
210
src/snapshot.rs
@@ -10,16 +10,44 @@ use std::io::{self, Read};
|
||||
// ── Domain metadata ──────────────────────────────────────────────────────────
|
||||
|
||||
const DOMAIN_INFO: &[(&str, &str, &str)] = &[
|
||||
("antigravity-unleash.goog", "Feature Flags", "Unleash SDK — controls A/B tests and feature rollouts"),
|
||||
("daily-cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Primary Gemini/Claude API endpoint"),
|
||||
("cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Production Gemini/Claude API endpoint"),
|
||||
("api.anthropic.com", "Claude API", "Direct Anthropic API calls"),
|
||||
("lh3.googleusercontent.com", "Profile Picture", "User avatar"),
|
||||
(
|
||||
"antigravity-unleash.goog",
|
||||
"Feature Flags",
|
||||
"Unleash SDK — controls A/B tests and feature rollouts",
|
||||
),
|
||||
(
|
||||
"daily-cloudcode-pa.googleapis.com",
|
||||
"LLM API (gRPC)",
|
||||
"Primary Gemini/Claude API endpoint",
|
||||
),
|
||||
(
|
||||
"cloudcode-pa.googleapis.com",
|
||||
"LLM API (gRPC)",
|
||||
"Production Gemini/Claude API endpoint",
|
||||
),
|
||||
(
|
||||
"api.anthropic.com",
|
||||
"Claude API",
|
||||
"Direct Anthropic API calls",
|
||||
),
|
||||
(
|
||||
"lh3.googleusercontent.com",
|
||||
"Profile Picture",
|
||||
"User avatar",
|
||||
),
|
||||
("play.googleapis.com", "Telemetry", "Google Play telemetry"),
|
||||
("firebaseinstallations.googleapis.com", "Firebase", "Installation tracking"),
|
||||
(
|
||||
"firebaseinstallations.googleapis.com",
|
||||
"Firebase",
|
||||
"Installation tracking",
|
||||
),
|
||||
("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"),
|
||||
("speech.googleapis.com", "Speech", "Voice input processing"),
|
||||
("modelarmor.googleapis.com", "Safety", "Content safety/filtering"),
|
||||
(
|
||||
"modelarmor.googleapis.com",
|
||||
"Safety",
|
||||
"Content safety/filtering",
|
||||
),
|
||||
];
|
||||
|
||||
fn domain_label(domain: &str) -> (&str, &str) {
|
||||
@@ -57,8 +85,8 @@ struct HttpExchange {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
enum Direction {
|
||||
Outgoing, // LS → upstream
|
||||
Incoming, // external → LS (our curl calls)
|
||||
Outgoing, // LS → upstream
|
||||
Incoming, // external → LS (our curl calls)
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -101,10 +129,12 @@ impl Snapshot {
|
||||
|
||||
// LS process logs
|
||||
if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E'))
|
||||
&& line.len() > 4 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) {
|
||||
snap.ls_logs.push(line.to_string());
|
||||
continue;
|
||||
}
|
||||
&& line.len() > 4
|
||||
&& line.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
|
||||
{
|
||||
snap.ls_logs.push(line.to_string());
|
||||
continue;
|
||||
}
|
||||
if line.contains("maxprocs:") {
|
||||
snap.ls_logs.push(line.to_string());
|
||||
continue;
|
||||
@@ -128,8 +158,15 @@ impl Snapshot {
|
||||
if let Some((key, val)) = extract_header(line, "Transport encoding header") {
|
||||
if key == ":method" {
|
||||
// Finalize previous exchange
|
||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone());
|
||||
if current_pseudo.contains_key(":path")
|
||||
|| current_pseudo.contains_key(":method")
|
||||
{
|
||||
snap.finalize_exchange(
|
||||
¤t_pseudo,
|
||||
¤t_headers,
|
||||
current_direction,
|
||||
current_stream.clone(),
|
||||
);
|
||||
}
|
||||
current_headers.clear();
|
||||
current_pseudo.clear();
|
||||
@@ -147,8 +184,15 @@ impl Snapshot {
|
||||
// Incoming / server-received headers
|
||||
if let Some((key, val)) = extract_header(line, "decoded hpack field header field") {
|
||||
if key == ":authority" && !line.contains("server read frame") {
|
||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone());
|
||||
if current_pseudo.contains_key(":path")
|
||||
|| current_pseudo.contains_key(":method")
|
||||
{
|
||||
snap.finalize_exchange(
|
||||
¤t_pseudo,
|
||||
¤t_headers,
|
||||
current_direction,
|
||||
current_stream.clone(),
|
||||
);
|
||||
}
|
||||
current_headers.clear();
|
||||
current_pseudo.clear();
|
||||
@@ -167,8 +211,15 @@ impl Snapshot {
|
||||
if line.contains("wrote HEADERS") {
|
||||
if let Some(stream) = extract_stream_id(line) {
|
||||
current_stream = Some(stream.clone());
|
||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
||||
let ex = snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, Some(stream));
|
||||
if current_pseudo.contains_key(":path")
|
||||
|| current_pseudo.contains_key(":method")
|
||||
{
|
||||
let ex = snap.finalize_exchange(
|
||||
¤t_pseudo,
|
||||
¤t_headers,
|
||||
current_direction,
|
||||
Some(stream),
|
||||
);
|
||||
if ex.is_some() {
|
||||
current_headers.clear();
|
||||
current_pseudo.clear();
|
||||
@@ -179,10 +230,13 @@ impl Snapshot {
|
||||
}
|
||||
|
||||
// DATA frames
|
||||
if (line.contains("wrote DATA") || line.contains("read DATA") || line.contains("server read frame DATA"))
|
||||
if (line.contains("wrote DATA")
|
||||
|| line.contains("read DATA")
|
||||
|| line.contains("server read frame DATA"))
|
||||
&& line.contains("data=\"")
|
||||
{
|
||||
let is_outgoing = line.contains("wrote DATA") || line.contains("server read frame DATA");
|
||||
let is_outgoing =
|
||||
line.contains("wrote DATA") || line.contains("server read frame DATA");
|
||||
if let Some(stream) = extract_stream_id(line) {
|
||||
if let Some(data_str) = extract_data(line) {
|
||||
let raw = decode_go_escaped(&data_str);
|
||||
@@ -203,7 +257,12 @@ impl Snapshot {
|
||||
|
||||
// Finalize remaining
|
||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream);
|
||||
snap.finalize_exchange(
|
||||
¤t_pseudo,
|
||||
¤t_headers,
|
||||
current_direction,
|
||||
current_stream,
|
||||
);
|
||||
}
|
||||
|
||||
snap
|
||||
@@ -226,7 +285,11 @@ impl Snapshot {
|
||||
|
||||
self.exchanges.push(HttpExchange {
|
||||
authority,
|
||||
method: if method.is_empty() { "GET".into() } else { method },
|
||||
method: if method.is_empty() {
|
||||
"GET".into()
|
||||
} else {
|
||||
method
|
||||
},
|
||||
path,
|
||||
headers: headers.to_vec(),
|
||||
body: Vec::new(),
|
||||
@@ -245,7 +308,9 @@ impl Snapshot {
|
||||
let sep = "═".repeat(70);
|
||||
let sep_thin = "─".repeat(60);
|
||||
out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n"));
|
||||
out.push_str(&format!("{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"));
|
||||
out.push_str(&format!(
|
||||
"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"
|
||||
));
|
||||
out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n"));
|
||||
|
||||
// LS Logs
|
||||
@@ -265,7 +330,9 @@ impl Snapshot {
|
||||
for target in &self.connections {
|
||||
let domain = target.split(':').next().unwrap_or(target);
|
||||
let (label, desc) = domain_label(domain);
|
||||
out.push_str(&format!(" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"));
|
||||
out.push_str(&format!(
|
||||
" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"
|
||||
));
|
||||
if !desc.is_empty() {
|
||||
out.push_str(&format!(" {DIM}{desc}{NC}\n"));
|
||||
}
|
||||
@@ -276,7 +343,10 @@ impl Snapshot {
|
||||
// Group by domain
|
||||
let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new();
|
||||
for ex in &self.exchanges {
|
||||
if let Some(entry) = by_domain.iter_mut().find(|(d, _)| *d == ex.authority.as_str()) {
|
||||
if let Some(entry) = by_domain
|
||||
.iter_mut()
|
||||
.find(|(d, _)| *d == ex.authority.as_str())
|
||||
{
|
||||
entry.1.push(ex);
|
||||
} else {
|
||||
by_domain.push((&ex.authority, vec![ex]));
|
||||
@@ -293,12 +363,17 @@ impl Snapshot {
|
||||
let color = if label.contains("API") { YELLOW } else { CYAN };
|
||||
|
||||
out.push_str(&format!("\n{BOLD}{sep}{NC}\n"));
|
||||
out.push_str(&format!("{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n"));
|
||||
out.push_str(&format!(
|
||||
"{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n"
|
||||
));
|
||||
out.push_str(&format!("{BOLD}{sep}{NC}\n"));
|
||||
|
||||
for ex in exchanges {
|
||||
let method_color = if ex.method == "GET" { GREEN } else { YELLOW };
|
||||
out.push_str(&format!("\n {BOLD}→ {method_color}{}{NC} {}\n", ex.method, ex.path));
|
||||
out.push_str(&format!(
|
||||
"\n {BOLD}→ {method_color}{}{NC} {}\n",
|
||||
ex.method, ex.path
|
||||
));
|
||||
|
||||
// Interesting headers
|
||||
for (key, val) in &ex.headers {
|
||||
@@ -342,7 +417,10 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
||||
out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n"));
|
||||
for (i, line) in pretty.lines().enumerate() {
|
||||
if i >= 40 {
|
||||
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 40));
|
||||
out.push_str(&format!(
|
||||
" {DIM}... ({} more lines){NC}\n",
|
||||
pretty.lines().count() - 40
|
||||
));
|
||||
break;
|
||||
}
|
||||
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
||||
@@ -357,10 +435,16 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
||||
if let Ok(text) = std::str::from_utf8(&decompressed) {
|
||||
if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
|
||||
let pretty = serde_json::to_string_pretty(&val).unwrap_or_default();
|
||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", decompressed.len()));
|
||||
out.push_str(&format!(
|
||||
" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n",
|
||||
decompressed.len()
|
||||
));
|
||||
for (i, line) in pretty.lines().enumerate() {
|
||||
if i >= 50 {
|
||||
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 50));
|
||||
out.push_str(&format!(
|
||||
" {DIM}... ({} more lines){NC}\n",
|
||||
pretty.lines().count() - 50
|
||||
));
|
||||
break;
|
||||
}
|
||||
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
||||
@@ -368,14 +452,20 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
||||
return out;
|
||||
}
|
||||
// Plain text
|
||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", decompressed.len()));
|
||||
out.push_str(&format!(
|
||||
" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n",
|
||||
decompressed.len()
|
||||
));
|
||||
for line in text.lines().take(20) {
|
||||
out.push_str(&format!(" {line}\n"));
|
||||
}
|
||||
return out;
|
||||
}
|
||||
// Binary gzip
|
||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", decompressed.len()));
|
||||
out.push_str(&format!(
|
||||
" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n",
|
||||
decompressed.len()
|
||||
));
|
||||
let strings = extract_strings(&decompressed);
|
||||
for s in strings.iter().take(15) {
|
||||
out.push_str(&format!(" {MAGENTA}{s}{NC}\n"));
|
||||
@@ -393,7 +483,11 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
||||
// Protobuf / binary with string extraction
|
||||
let strings = extract_strings(data);
|
||||
if !strings.is_empty() {
|
||||
let kind = if !data.is_empty() && matches!(data[0], 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22) {
|
||||
let kind = if !data.is_empty()
|
||||
&& matches!(
|
||||
data[0],
|
||||
0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22
|
||||
) {
|
||||
"protobuf"
|
||||
} else {
|
||||
"binary"
|
||||
@@ -448,7 +542,9 @@ fn extract_header(line: &str, pattern: &str) -> Option<(String, String)> {
|
||||
fn extract_stream_id(line: &str) -> Option<String> {
|
||||
let pos = line.find("stream=")?;
|
||||
let rest = &line[pos + 7..];
|
||||
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
|
||||
let end = rest
|
||||
.find(|c: char| !c.is_ascii_digit())
|
||||
.unwrap_or(rest.len());
|
||||
Some(rest[..end].to_string())
|
||||
}
|
||||
|
||||
@@ -470,7 +566,9 @@ fn extract_data(line: &str) -> Option<String> {
|
||||
fn extract_data_len(line: &str) -> Option<usize> {
|
||||
let pos = line.find("len=")?;
|
||||
let rest = &line[pos + 4..];
|
||||
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
|
||||
let end = rest
|
||||
.find(|c: char| !c.is_ascii_digit())
|
||||
.unwrap_or(rest.len());
|
||||
rest[..end].parse().ok()
|
||||
}
|
||||
|
||||
@@ -482,17 +580,40 @@ fn decode_go_escaped(s: &str) -> Vec<u8> {
|
||||
if bytes[i] == b'\\' && i + 1 < bytes.len() {
|
||||
match bytes[i + 1] {
|
||||
b'x' if i + 3 < bytes.len() => {
|
||||
if let Ok(b) = u8::from_str_radix(std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), 16) {
|
||||
if let Ok(b) = u8::from_str_radix(
|
||||
std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""),
|
||||
16,
|
||||
) {
|
||||
result.push(b);
|
||||
i += 4;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
b'n' => { result.push(b'\n'); i += 2; continue; }
|
||||
b'r' => { result.push(b'\r'); i += 2; continue; }
|
||||
b't' => { result.push(b'\t'); i += 2; continue; }
|
||||
b'\\' => { result.push(b'\\'); i += 2; continue; }
|
||||
b'"' => { result.push(b'"'); i += 2; continue; }
|
||||
b'n' => {
|
||||
result.push(b'\n');
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
b'r' => {
|
||||
result.push(b'\r');
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
b't' => {
|
||||
result.push(b'\t');
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
b'\\' => {
|
||||
result.push(b'\\');
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
b'"' => {
|
||||
result.push(b'"');
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
@@ -562,7 +683,10 @@ pub fn run_cli() {
|
||||
})
|
||||
} else {
|
||||
let mut buf = String::new();
|
||||
io::stdin().lock().read_to_string(&mut buf).expect("Failed to read stdin");
|
||||
io::stdin()
|
||||
.lock()
|
||||
.read_to_string(&mut buf)
|
||||
.expect("Failed to read stdin");
|
||||
buf
|
||||
};
|
||||
|
||||
|
||||
@@ -108,7 +108,10 @@ pub struct MainLSConfig {
|
||||
/// and CSRF is a random UUID.
|
||||
pub fn generate_standalone_config() -> MainLSConfig {
|
||||
let csrf = Uuid::new_v4().to_string();
|
||||
info!(csrf_len = csrf.len(), "Generated standalone config (headless)");
|
||||
info!(
|
||||
csrf_len = csrf.len(),
|
||||
"Generated standalone config (headless)"
|
||||
);
|
||||
MainLSConfig {
|
||||
extension_server_port: "0".to_string(), // disables extension server
|
||||
csrf,
|
||||
@@ -159,7 +162,13 @@ impl StandaloneLS {
|
||||
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
|
||||
let annotations_dir = format!("{app_data_dir}/annotations");
|
||||
let brain_dir = format!("{app_data_dir}/brain");
|
||||
for dir in [DATA_DIR, &gemini_dir, &app_data_dir, &annotations_dir, &brain_dir] {
|
||||
for dir in [
|
||||
DATA_DIR,
|
||||
&gemini_dir,
|
||||
&app_data_dir,
|
||||
&annotations_dir,
|
||||
&brain_dir,
|
||||
] {
|
||||
let _ = std::fs::create_dir_all(dir);
|
||||
#[cfg(unix)]
|
||||
{
|
||||
@@ -194,7 +203,10 @@ impl StandaloneLS {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(&settings_path, std::fs::Permissions::from_mode(0o0666));
|
||||
let _ = std::fs::set_permissions(
|
||||
&settings_path,
|
||||
std::fs::Permissions::from_mode(0o0666),
|
||||
);
|
||||
}
|
||||
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
|
||||
}
|
||||
@@ -203,10 +215,7 @@ impl StandaloneLS {
|
||||
// The LS connects to this port and calls LanguageServerStarted — without it,
|
||||
// the LS never fully initializes and won't accept connections on its server_port.
|
||||
let _stub_listener = if headless {
|
||||
let stub_port: u16 = main_config
|
||||
.extension_server_port
|
||||
.parse()
|
||||
.unwrap_or(0);
|
||||
let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
|
||||
if stub_port == 0 {
|
||||
// Create a real listener so the LS can connect
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
@@ -215,7 +224,10 @@ impl StandaloneLS {
|
||||
.local_addr()
|
||||
.map_err(|e| format!("Failed to get stub port: {e}"))?
|
||||
.port();
|
||||
info!(port = actual_port, "Stub extension server listening (headless)");
|
||||
info!(
|
||||
port = actual_port,
|
||||
"Stub extension server listening (headless)"
|
||||
);
|
||||
// Read OAuth state from Antigravity's state.vscdb if available.
|
||||
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
|
||||
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
|
||||
@@ -306,10 +318,7 @@ impl StandaloneLS {
|
||||
// 3. MITM proxy intercepts the transparent TLS connection via SNI
|
||||
if let Some(mitm) = mitm_config {
|
||||
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
|
||||
let mitm_port = mitm.proxy_addr
|
||||
.rsplit(':')
|
||||
.next()
|
||||
.unwrap_or("8742");
|
||||
let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
|
||||
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
|
||||
} else {
|
||||
"https://daily-cloudcode-pa.googleapis.com".to_string()
|
||||
@@ -324,9 +333,8 @@ impl StandaloneLS {
|
||||
debug!(?args, "LS args");
|
||||
|
||||
// Build env vars for the LS process
|
||||
let mut env_vars: Vec<(String, String)> = vec![
|
||||
("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into()),
|
||||
];
|
||||
let mut env_vars: Vec<(String, String)> =
|
||||
vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
|
||||
|
||||
// If MITM is enabled, add SSL + proxy env vars
|
||||
if let Some(mitm) = mitm_config {
|
||||
@@ -335,8 +343,8 @@ impl StandaloneLS {
|
||||
// Write to /tmp — accessible by antigravity-ls user
|
||||
// (user's ~/.config/ is not traversable by other UIDs)
|
||||
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
|
||||
let system_ca = std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt")
|
||||
.unwrap_or_default();
|
||||
let system_ca =
|
||||
std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
|
||||
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
|
||||
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
|
||||
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
|
||||
@@ -441,7 +449,11 @@ impl StandaloneLS {
|
||||
};
|
||||
|
||||
if let Some(pid) = ls_pid {
|
||||
info!(ls_pid = pid, sudo = use_sudo, "Discovered actual LS process");
|
||||
info!(
|
||||
ls_pid = pid,
|
||||
sudo = use_sudo,
|
||||
"Discovered actual LS process"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(StandaloneLS {
|
||||
@@ -617,8 +629,7 @@ fn find_main_ls_pid() -> Result<String, String> {
|
||||
return Err("No /proc filesystem".to_string());
|
||||
}
|
||||
|
||||
let entries = std::fs::read_dir(proc)
|
||||
.map_err(|e| format!("Cannot read /proc: {e}"))?;
|
||||
let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
|
||||
|
||||
for entry in entries.flatten() {
|
||||
let name = entry.file_name();
|
||||
@@ -704,12 +715,10 @@ fn cleanup_orphaned_ls() {
|
||||
.output();
|
||||
|
||||
let pids: Vec<u32> = match output {
|
||||
Ok(out) => {
|
||||
String::from_utf8_lossy(&out.stdout)
|
||||
.lines()
|
||||
.filter_map(|l| l.trim().parse().ok())
|
||||
.collect()
|
||||
}
|
||||
Ok(out) => String::from_utf8_lossy(&out.stdout)
|
||||
.lines()
|
||||
.filter_map(|l| l.trim().parse().ok())
|
||||
.collect(),
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
@@ -717,7 +726,11 @@ fn cleanup_orphaned_ls() {
|
||||
return;
|
||||
}
|
||||
|
||||
info!(count = pids.len(), ?pids, "Cleaning up orphaned standalone LS processes");
|
||||
info!(
|
||||
count = pids.len(),
|
||||
?pids,
|
||||
"Cleaning up orphaned standalone LS processes"
|
||||
);
|
||||
|
||||
// Kill each PID by running `kill` AS the antigravity-ls user.
|
||||
// This works because same-UID processes can signal each other,
|
||||
@@ -870,7 +883,8 @@ fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
|
||||
// Simple approach: convert to string and find base64 pattern
|
||||
let as_str = String::from_utf8_lossy(topic_bytes);
|
||||
// The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited)
|
||||
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') {
|
||||
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
|
||||
{
|
||||
if segment.len() > 50 {
|
||||
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
|
||||
// Try to extract field 1 (access_token) from the OAuthTokenInfo proto
|
||||
@@ -951,7 +965,11 @@ fn decode_varint_at(buf: &[u8], offset: usize) -> Option<(u64, usize)> {
|
||||
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
|
||||
/// If we immediately close it, the LS reconnects in a tight loop and never
|
||||
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
|
||||
fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_topic_bytes: &Option<Vec<u8>>) {
|
||||
fn stub_handle_connection(
|
||||
conn: std::net::TcpStream,
|
||||
oauth_token: &str,
|
||||
oauth_topic_bytes: &Option<Vec<u8>>,
|
||||
) {
|
||||
use std::io::{BufRead, BufReader, Read, Write};
|
||||
|
||||
let mut reader = BufReader::new(match conn.try_clone() {
|
||||
@@ -1028,7 +1046,7 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
||||
i += 1;
|
||||
if i + len <= proto_body.len() {
|
||||
if field_num == 1 {
|
||||
topic_name = String::from_utf8_lossy(&proto_body[i..i+len]).to_string();
|
||||
topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
|
||||
}
|
||||
i += len;
|
||||
} else {
|
||||
@@ -1084,7 +1102,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
||||
// This includes access_token + refresh_token + expiry, so the
|
||||
// LS can auto-refresh tokens via its built-in Google OAuth2 client.
|
||||
initial_state_bytes = topic_bytes.clone();
|
||||
eprintln!("[stub-ext] using state.vscdb topic ({} bytes)", topic_bytes.len());
|
||||
eprintln!(
|
||||
"[stub-ext] using state.vscdb topic ({} bytes)",
|
||||
topic_bytes.len()
|
||||
);
|
||||
} else if !oauth_token.is_empty() {
|
||||
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
|
||||
// (no refresh_token, so the LS can't auto-refresh)
|
||||
@@ -1155,7 +1176,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
||||
if !send_chunk(&mut writer, &initial_env) {
|
||||
return;
|
||||
}
|
||||
eprintln!("[stub-ext] STREAM → sent initial_state ({} bytes)", initial_state_bytes.len());
|
||||
eprintln!(
|
||||
"[stub-ext] STREAM → sent initial_state ({} bytes)",
|
||||
initial_state_bytes.len()
|
||||
);
|
||||
|
||||
// (applied_update removed — data is in initial_state)
|
||||
|
||||
@@ -1197,7 +1221,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
||||
if !oauth_token.is_empty() {
|
||||
// Build protobuf: GetSecretValueResponse { string value = 1 }
|
||||
let proto = encode_proto_string(1, oauth_token.as_bytes());
|
||||
eprintln!("[stub-ext] → serving token ({} bytes) for key={key:?}", oauth_token.len());
|
||||
eprintln!(
|
||||
"[stub-ext] → serving token ({} bytes) for key={key:?}",
|
||||
oauth_token.len()
|
||||
);
|
||||
|
||||
// Data envelope: flag=0x00, length, data
|
||||
envelope.push(0x00u8);
|
||||
|
||||
@@ -34,7 +34,9 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"),
|
||||
Ok(Ok((status, _))) => {
|
||||
info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}")
|
||||
}
|
||||
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
|
||||
Err(_) => warn!("SetUserSettings timed out"),
|
||||
}
|
||||
@@ -59,12 +61,7 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
|
||||
for (method, body) in calls {
|
||||
// Timeout per call — in headless mode, the LS can't reach Google's API
|
||||
// so these would hang forever without a timeout. Warmup is best-effort.
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
backend.call_json(method, body),
|
||||
)
|
||||
.await
|
||||
{
|
||||
match tokio::time::timeout(Duration::from_secs(5), backend.call_json(method, body)).await {
|
||||
Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"),
|
||||
Ok(Err(e)) => warn!("Warmup {method} failed: {e}"),
|
||||
Err(_) => warn!("Warmup {method} timed out"),
|
||||
@@ -87,10 +84,7 @@ pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
|
||||
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
|
||||
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
|
||||
|
||||
match backend
|
||||
.call_json("Heartbeat", &serde_json::json!({}))
|
||||
.await
|
||||
{
|
||||
match backend.call_json("Heartbeat", &serde_json::json!({})).await {
|
||||
Ok((status, _)) => debug!("Heartbeat: {status}"),
|
||||
Err(e) => warn!("Heartbeat failed: {e}"),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user