fix: block ALL LS follow-up requests across connections
Move the in-flight blocking check to the top of the LLM request flow, BEFORE request modification. This catches follow-ups on ALL connections (the LS opens multiple parallel TLS connections). Only the very first modified request reaches Google — all others get fake STOP responses. Previously, each new connection independently allowed one request through before blocking, letting 4-5 requests leak per turn.
This commit is contained in:
@@ -10,9 +10,11 @@ use std::sync::Arc;
|
|||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||||
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
|
use super::polling::{
|
||||||
|
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
|
||||||
|
};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
use super::util::{err_response, upstream_err_response, now_unix};
|
use super::util::{err_response, now_unix, upstream_err_response};
|
||||||
use super::AppState;
|
use super::AppState;
|
||||||
|
|
||||||
/// Extract a conversation/session ID from a flexible JSON value.
|
/// Extract a conversation/session ID from a flexible JSON value.
|
||||||
@@ -33,7 +35,8 @@ fn system_fingerprint() -> String {
|
|||||||
/// Build a streaming chunk JSON with all required OpenAI fields.
|
/// Build a streaming chunk JSON with all required OpenAI fields.
|
||||||
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
|
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
|
||||||
fn chunk_json(
|
fn chunk_json(
|
||||||
id: &str, model: &str,
|
id: &str,
|
||||||
|
model: &str,
|
||||||
choices: serde_json::Value,
|
choices: serde_json::Value,
|
||||||
usage: Option<serde_json::Value>,
|
usage: Option<serde_json::Value>,
|
||||||
) -> String {
|
) -> String {
|
||||||
@@ -53,7 +56,11 @@ fn chunk_json(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
|
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
|
||||||
fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str>) -> serde_json::Value {
|
fn chunk_choice(
|
||||||
|
index: u32,
|
||||||
|
delta: serde_json::Value,
|
||||||
|
finish_reason: Option<&str>,
|
||||||
|
) -> serde_json::Value {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
"index": index,
|
"index": index,
|
||||||
"delta": delta,
|
"delta": delta,
|
||||||
@@ -70,7 +77,9 @@ fn chunk_choice(index: u32, delta: serde_json::Value, finish_reason: Option<&str
|
|||||||
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
||||||
match stop_reason {
|
match stop_reason {
|
||||||
Some("MAX_TOKENS") => "length",
|
Some("MAX_TOKENS") => "length",
|
||||||
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => "content_filter",
|
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
|
||||||
|
"content_filter"
|
||||||
|
}
|
||||||
_ => "stop",
|
_ => "stop",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -84,7 +93,9 @@ fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
|||||||
/// sends the entire messages array to the model.
|
/// sends the entire messages array to the model.
|
||||||
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
|
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
|
||||||
// Extract image from last user message content array
|
// Extract image from last user message content array
|
||||||
let image = messages.iter().rev()
|
let image = messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
.find(|m| m.role == "user")
|
.find(|m| m.role == "user")
|
||||||
.and_then(|m| super::util::extract_first_image(&m.content));
|
.and_then(|m| super::util::extract_first_image(&m.content));
|
||||||
// Always build the full conversation
|
// Always build the full conversation
|
||||||
@@ -141,10 +152,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
|
|||||||
if let Some(func) = tc.get("function") {
|
if let Some(func) = tc.get("function") {
|
||||||
let name = func["name"].as_str().unwrap_or("unknown");
|
let name = func["name"].as_str().unwrap_or("unknown");
|
||||||
let args = func["arguments"].as_str().unwrap_or("{}");
|
let args = func["arguments"].as_str().unwrap_or("{}");
|
||||||
parts.push(format!(
|
parts.push(format!("[Tool call: {}({})]", name, args));
|
||||||
"[Tool call: {}({})]",
|
|
||||||
name, args
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -153,10 +161,7 @@ fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
|
|||||||
let text = extract_message_text(&msg.content);
|
let text = extract_message_text(&msg.content);
|
||||||
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
|
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
parts.push(format!(
|
parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
|
||||||
"[Tool result ({})]:\n{}",
|
|
||||||
tool_id, text
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
@@ -202,7 +207,10 @@ pub(crate) async fn handle_completions(
|
|||||||
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
|
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
|
||||||
state.mitm_store.set_tool_config(gemini_config).await;
|
state.mitm_store.set_tool_config(gemini_config).await;
|
||||||
}
|
}
|
||||||
info!(count = tools.len(), "Completions: stored client tools for MITM injection");
|
info!(
|
||||||
|
count = tools.len(),
|
||||||
|
"Completions: stored client tools for MITM injection"
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
state.mitm_store.clear_tools().await;
|
state.mitm_store.clear_tools().await;
|
||||||
}
|
}
|
||||||
@@ -239,10 +247,15 @@ pub(crate) async fn handle_completions(
|
|||||||
google_search: body.web_search,
|
google_search: body.web_search,
|
||||||
};
|
};
|
||||||
// Only store if at least one param is set
|
// Only store if at least one param is set
|
||||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
if gp.temperature.is_some()
|
||||||
|| gp.frequency_penalty.is_some() || gp.presence_penalty.is_some()
|
|| gp.top_p.is_some()
|
||||||
|| gp.reasoning_effort.is_some() || gp.stop_sequences.is_some()
|
|| gp.max_output_tokens.is_some()
|
||||||
|| gp.response_mime_type.is_some() || gp.response_schema.is_some()
|
|| gp.frequency_penalty.is_some()
|
||||||
|
|| gp.presence_penalty.is_some()
|
||||||
|
|| gp.reasoning_effort.is_some()
|
||||||
|
|| gp.stop_sequences.is_some()
|
||||||
|
|| gp.response_mime_type.is_some()
|
||||||
|
|| gp.response_schema.is_some()
|
||||||
|| gp.google_search
|
|| gp.google_search
|
||||||
{
|
{
|
||||||
state.mitm_store.set_generation_params(gp).await;
|
state.mitm_store.set_generation_params(gp).await;
|
||||||
@@ -306,12 +319,13 @@ pub(crate) async fn handle_completions(
|
|||||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||||
if let Some(ref img) = image {
|
if let Some(ref img) = image {
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
state.mitm_store.set_pending_image(
|
state
|
||||||
crate::mitm::store::PendingImage {
|
.mitm_store
|
||||||
|
.set_pending_image(crate::mitm::store::PendingImage {
|
||||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||||
mime_type: img.mime_type.clone(),
|
mime_type: img.mime_type.clone(),
|
||||||
}
|
})
|
||||||
).await;
|
.await;
|
||||||
}
|
}
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
@@ -346,7 +360,10 @@ pub(crate) async fn handle_completions(
|
|||||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
||||||
);
|
);
|
||||||
|
|
||||||
let include_usage = body.stream_options.as_ref().map_or(false, |o| o.include_usage);
|
let include_usage = body
|
||||||
|
.stream_options
|
||||||
|
.as_ref()
|
||||||
|
.map_or(false, |o| o.include_usage);
|
||||||
|
|
||||||
if body.stream {
|
if body.stream {
|
||||||
chat_completions_stream(
|
chat_completions_stream(
|
||||||
@@ -374,11 +391,17 @@ pub(crate) async fn handle_completions(
|
|||||||
match state.backend.create_cascade().await {
|
match state.backend.create_cascade().await {
|
||||||
Ok(cid) => {
|
Ok(cid) => {
|
||||||
// Send the same message on each extra cascade
|
// Send the same message on each extra cascade
|
||||||
match state.backend.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()).await {
|
match state
|
||||||
|
.backend
|
||||||
|
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok((200, _)) => {
|
Ok((200, _)) => {
|
||||||
let bg = Arc::clone(&state.backend);
|
let bg = Arc::clone(&state.backend);
|
||||||
let cid2 = cid.clone();
|
let cid2 = cid.clone();
|
||||||
tokio::spawn(async move { let _ = bg.update_annotations(&cid2).await; });
|
tokio::spawn(async move {
|
||||||
|
let _ = bg.update_annotations(&cid2).await;
|
||||||
|
});
|
||||||
extra_cascade_ids.push(cid);
|
extra_cascade_ids.push(cid);
|
||||||
}
|
}
|
||||||
_ => {} // Skip failed cascades
|
_ => {} // Skip failed cascades
|
||||||
@@ -420,7 +443,12 @@ pub(crate) async fn handle_completions(
|
|||||||
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
|
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
|
||||||
);
|
);
|
||||||
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
|
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
|
||||||
(mu.input_tokens, mu.output_tokens, mu.cache_read_input_tokens, mu.thinking_output_tokens)
|
(
|
||||||
|
mu.input_tokens,
|
||||||
|
mu.output_tokens,
|
||||||
|
mu.cache_read_input_tokens,
|
||||||
|
mu.thinking_output_tokens,
|
||||||
|
)
|
||||||
} else if let Some(u) = &result.usage {
|
} else if let Some(u) = &result.usage {
|
||||||
(u.input_tokens, u.output_tokens, 0, 0)
|
(u.input_tokens, u.output_tokens, 0, 0)
|
||||||
} else {
|
} else {
|
||||||
@@ -874,15 +902,22 @@ async fn chat_completions_sync(
|
|||||||
None => state.mitm_store.take_usage("_latest").await,
|
None => state.mitm_store.take_usage("_latest").await,
|
||||||
};
|
};
|
||||||
|
|
||||||
let finish_reason = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
let finish_reason =
|
||||||
|
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
||||||
|
|
||||||
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) = if let Some(ref mitm_usage) = mitm {
|
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
|
||||||
(mitm_usage.input_tokens, mitm_usage.output_tokens, mitm_usage.cache_read_input_tokens, mitm_usage.thinking_output_tokens)
|
if let Some(ref mitm_usage) = mitm {
|
||||||
} else if let Some(u) = &result.usage {
|
(
|
||||||
(u.input_tokens, u.output_tokens, 0, 0)
|
mitm_usage.input_tokens,
|
||||||
} else {
|
mitm_usage.output_tokens,
|
||||||
(0, 0, 0, 0)
|
mitm_usage.cache_read_input_tokens,
|
||||||
};
|
mitm_usage.thinking_output_tokens,
|
||||||
|
)
|
||||||
|
} else if let Some(u) = &result.usage {
|
||||||
|
(u.input_tokens, u.output_tokens, 0, 0)
|
||||||
|
} else {
|
||||||
|
(0, 0, 0, 0)
|
||||||
|
};
|
||||||
|
|
||||||
// Build message object, including reasoning_content if thinking is present
|
// Build message object, including reasoning_content if thinking is present
|
||||||
let mut message = serde_json::json!({
|
let mut message = serde_json::json!({
|
||||||
|
|||||||
@@ -15,7 +15,9 @@ use std::sync::Arc;
|
|||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||||
use super::polling::{extract_response_text, extract_thinking_content, is_response_done, poll_for_response};
|
use super::polling::{
|
||||||
|
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
|
||||||
|
};
|
||||||
use super::util::{err_response, upstream_err_response};
|
use super::util::{err_response, upstream_err_response};
|
||||||
use super::AppState;
|
use super::AppState;
|
||||||
use crate::mitm::store::PendingToolResult;
|
use crate::mitm::store::PendingToolResult;
|
||||||
@@ -84,7 +86,9 @@ async fn build_usage_metadata(
|
|||||||
store: &crate::mitm::store::MitmStore,
|
store: &crate::mitm::store::MitmStore,
|
||||||
cascade_id: &str,
|
cascade_id: &str,
|
||||||
) -> serde_json::Value {
|
) -> serde_json::Value {
|
||||||
let usage = store.take_usage(cascade_id).await
|
let usage = store
|
||||||
|
.take_usage(cascade_id)
|
||||||
|
.await
|
||||||
.or(store.take_usage("_latest").await);
|
.or(store.take_usage("_latest").await);
|
||||||
if let Some(usage) = usage {
|
if let Some(usage) = usage {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
@@ -152,13 +156,12 @@ pub(crate) async fn handle_gemini(
|
|||||||
// Gemini-native inlineData format
|
// Gemini-native inlineData format
|
||||||
if image.is_none() {
|
if image.is_none() {
|
||||||
if let Some(inline) = obj.get("inlineData") {
|
if let Some(inline) = obj.get("inlineData") {
|
||||||
if let (Some(mime), Some(b64)) = (
|
if let (Some(mime), Some(b64)) =
|
||||||
inline["mimeType"].as_str(),
|
(inline["mimeType"].as_str(), inline["data"].as_str())
|
||||||
inline["data"].as_str(),
|
{
|
||||||
) {
|
if let Some(img) = super::util::parse_data_uri(&format!(
|
||||||
if let Some(img) = super::util::parse_data_uri(
|
"data:{mime};base64,{b64}"
|
||||||
&format!("data:{mime};base64,{b64}")
|
)) {
|
||||||
) {
|
|
||||||
image = Some(img);
|
image = Some(img);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -194,7 +197,10 @@ pub(crate) async fn handle_gemini(
|
|||||||
if let Some(ref tools) = body.tools {
|
if let Some(ref tools) = body.tools {
|
||||||
if !tools.is_empty() {
|
if !tools.is_empty() {
|
||||||
state.mitm_store.set_tools(tools.clone()).await;
|
state.mitm_store.set_tools(tools.clone()).await;
|
||||||
info!(count = tools.len(), "Stored Gemini-native tools for MITM injection");
|
info!(
|
||||||
|
count = tools.len(),
|
||||||
|
"Stored Gemini-native tools for MITM injection"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(ref config) = body.tool_config {
|
if let Some(ref config) = body.tool_config {
|
||||||
@@ -207,13 +213,19 @@ pub(crate) async fn handle_gemini(
|
|||||||
if let Some(fr) = r.get("functionResponse") {
|
if let Some(fr) = r.get("functionResponse") {
|
||||||
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
let name = fr["name"].as_str().unwrap_or("unknown").to_string();
|
||||||
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
let response = fr.get("response").cloned().unwrap_or(serde_json::json!({}));
|
||||||
state.mitm_store.add_tool_result(PendingToolResult {
|
state
|
||||||
name,
|
.mitm_store
|
||||||
result: response,
|
.add_tool_result(PendingToolResult {
|
||||||
}).await;
|
name,
|
||||||
|
result: response,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
info!(count = results.len(), "Stored Gemini-native tool results for MITM injection");
|
info!(
|
||||||
|
count = results.len(),
|
||||||
|
"Stored Gemini-native tool results for MITM injection"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store generation parameters for MITM injection
|
// Store generation parameters for MITM injection
|
||||||
@@ -232,9 +244,13 @@ pub(crate) async fn handle_gemini(
|
|||||||
response_schema: None,
|
response_schema: None,
|
||||||
google_search: body.google_search,
|
google_search: body.google_search,
|
||||||
};
|
};
|
||||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.top_k.is_some()
|
if gp.temperature.is_some()
|
||||||
|| gp.max_output_tokens.is_some() || gp.stop_sequences.is_some()
|
|| gp.top_p.is_some()
|
||||||
|| gp.reasoning_effort.is_some() || gp.google_search
|
|| gp.top_k.is_some()
|
||||||
|
|| gp.max_output_tokens.is_some()
|
||||||
|
|| gp.stop_sequences.is_some()
|
||||||
|
|| gp.reasoning_effort.is_some()
|
||||||
|
|| gp.google_search
|
||||||
{
|
{
|
||||||
state.mitm_store.set_generation_params(gp).await;
|
state.mitm_store.set_generation_params(gp).await;
|
||||||
} else {
|
} else {
|
||||||
@@ -277,12 +293,13 @@ pub(crate) async fn handle_gemini(
|
|||||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||||
if let Some(ref img) = image {
|
if let Some(ref img) = image {
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
state.mitm_store.set_pending_image(
|
state
|
||||||
crate::mitm::store::PendingImage {
|
.mitm_store
|
||||||
|
.set_pending_image(crate::mitm::store::PendingImage {
|
||||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||||
mime_type: img.mime_type.clone(),
|
mime_type: img.mime_type.clone(),
|
||||||
}
|
})
|
||||||
).await;
|
.await;
|
||||||
}
|
}
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
@@ -372,7 +389,11 @@ async fn gemini_sync(
|
|||||||
|
|
||||||
// Check for completed text response
|
// Check for completed text response
|
||||||
if state.mitm_store.is_response_complete() {
|
if state.mitm_store.is_response_complete() {
|
||||||
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
let text = state
|
||||||
|
.mitm_store
|
||||||
|
.take_response_text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
let thinking = state.mitm_store.take_thinking_text().await;
|
let thinking = state.mitm_store.take_thinking_text().await;
|
||||||
|
|
||||||
// Guard against stale response_complete with no data
|
// Guard against stale response_complete with no data
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ pub fn router(state: Arc<AppState>) -> Router {
|
|||||||
post(completions::handle_completions),
|
post(completions::handle_completions),
|
||||||
)
|
)
|
||||||
.route("/v1/gemini", post(gemini::handle_gemini))
|
.route("/v1/gemini", post(gemini::handle_gemini))
|
||||||
|
|
||||||
.route("/v1/models", get(handle_models))
|
.route("/v1/models", get(handle_models))
|
||||||
.route("/v1/sessions", get(handle_list_sessions))
|
.route("/v1/sessions", get(handle_list_sessions))
|
||||||
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
.route("/v1/sessions/{id}", delete(handle_delete_session))
|
||||||
@@ -106,9 +105,7 @@ async fn handle_models() -> Json<serde_json::Value> {
|
|||||||
Json(serde_json::json!({"object": "list", "data": models}))
|
Json(serde_json::json!({"object": "list", "data": models}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_list_sessions(
|
async fn handle_list_sessions(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
) -> Json<serde_json::Value> {
|
|
||||||
let sessions = state.sessions.list_sessions().await;
|
let sessions = state.sessions.list_sessions().await;
|
||||||
Json(serde_json::json!({"sessions": sessions}))
|
Json(serde_json::json!({"sessions": sessions}))
|
||||||
}
|
}
|
||||||
@@ -155,9 +152,7 @@ async fn handle_set_token(
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_usage(
|
async fn handle_usage(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
) -> Json<serde_json::Value> {
|
|
||||||
let stats = state.mitm_store.stats().await;
|
let stats = state.mitm_store.stats().await;
|
||||||
Json(serde_json::json!({
|
Json(serde_json::json!({
|
||||||
"mitm": {
|
"mitm": {
|
||||||
@@ -174,9 +169,7 @@ async fn handle_usage(
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_quota(
|
async fn handle_quota(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
) -> Json<serde_json::Value> {
|
|
||||||
let snap = state.quota_store.snapshot().await;
|
let snap = state.quota_store.snapshot().await;
|
||||||
Json(serde_json::to_value(snap).unwrap_or_default())
|
Json(serde_json::to_value(snap).unwrap_or_default())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,14 +84,8 @@ pub(crate) fn extract_model_usage(steps: &[serde_json::Value]) -> Option<ModelUs
|
|||||||
return Some(ModelUsage {
|
return Some(ModelUsage {
|
||||||
input_tokens: input,
|
input_tokens: input,
|
||||||
output_tokens: output,
|
output_tokens: output,
|
||||||
api_provider: usage["apiProvider"]
|
api_provider: usage["apiProvider"].as_str().unwrap_or("").to_string(),
|
||||||
.as_str()
|
model: usage["model"].as_str().unwrap_or("").to_string(),
|
||||||
.unwrap_or("")
|
|
||||||
.to_string(),
|
|
||||||
model: usage["model"]
|
|
||||||
.as_str()
|
|
||||||
.unwrap_or("")
|
|
||||||
.to_string(),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -263,23 +257,36 @@ pub(crate) async fn poll_for_response(
|
|||||||
} else {
|
} else {
|
||||||
info!(
|
info!(
|
||||||
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
|
"Response done ({short_id}), {:.1}s, {} chars (no usage){}{}",
|
||||||
elapsed, text.len(),
|
elapsed,
|
||||||
thinking.as_ref().map_or(String::new(), |t| format!(", thinking: {} chars", t.len())),
|
text.len(),
|
||||||
if thinking_signature.is_some() { ", has sig" } else { "" }
|
thinking.as_ref().map_or(String::new(), |t| format!(
|
||||||
|
", thinking: {} chars",
|
||||||
|
t.len()
|
||||||
|
)),
|
||||||
|
if thinking_signature.is_some() {
|
||||||
|
", has sig"
|
||||||
|
} else {
|
||||||
|
""
|
||||||
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
|
return PollResult {
|
||||||
|
text,
|
||||||
|
usage,
|
||||||
|
thinking_signature,
|
||||||
|
thinking,
|
||||||
|
thinking_duration,
|
||||||
|
upstream_error: None,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback: check trajectory IDLE status (catches edge cases)
|
// Fallback: check trajectory IDLE status (catches edge cases)
|
||||||
// Only check every 5th poll to reduce network calls
|
// Only check every 5th poll to reduce network calls
|
||||||
if step_count > 4 && step_count % 5 == 0 {
|
if step_count > 4 && step_count % 5 == 0 {
|
||||||
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await
|
if let Ok((ts, td)) = state.backend.get_trajectory(cascade_id).await {
|
||||||
{
|
|
||||||
if ts == 200 {
|
if ts == 200 {
|
||||||
let run_status =
|
let run_status = td["status"].as_str().unwrap_or("");
|
||||||
td["status"].as_str().unwrap_or("");
|
|
||||||
if run_status.contains("IDLE") {
|
if run_status.contains("IDLE") {
|
||||||
let text = extract_response_text(steps);
|
let text = extract_response_text(steps);
|
||||||
if !text.is_empty() {
|
if !text.is_empty() {
|
||||||
@@ -293,7 +300,14 @@ pub(crate) async fn poll_for_response(
|
|||||||
elapsed,
|
elapsed,
|
||||||
text.len()
|
text.len()
|
||||||
);
|
);
|
||||||
return PollResult { text, usage, thinking_signature, thinking, thinking_duration, upstream_error: None };
|
return PollResult {
|
||||||
|
text,
|
||||||
|
usage,
|
||||||
|
thinking_signature,
|
||||||
|
thinking,
|
||||||
|
thinking_duration,
|
||||||
|
upstream_error: None,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,12 +14,15 @@ use std::sync::Arc;
|
|||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
||||||
use super::polling::{extract_response_text, is_response_done, poll_for_response, extract_model_usage, extract_thinking_signature, extract_thinking_content};
|
use super::polling::{
|
||||||
|
extract_model_usage, extract_response_text, extract_thinking_content,
|
||||||
|
extract_thinking_signature, is_response_done, poll_for_response,
|
||||||
|
};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
use super::util::{err_response, upstream_err_response, now_unix, responses_sse_event};
|
use super::util::{err_response, now_unix, responses_sse_event, upstream_err_response};
|
||||||
use super::AppState;
|
use super::AppState;
|
||||||
|
use crate::mitm::modify::{openai_tool_choice_to_gemini, openai_tools_to_gemini};
|
||||||
use crate::mitm::store::PendingToolResult;
|
use crate::mitm::store::PendingToolResult;
|
||||||
use crate::mitm::modify::{openai_tools_to_gemini, openai_tool_choice_to_gemini};
|
|
||||||
|
|
||||||
// ─── Input extraction ────────────────────────────────────────────────────────
|
// ─── Input extraction ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -35,7 +38,11 @@ struct ToolResultInput {
|
|||||||
fn extract_responses_input(
|
fn extract_responses_input(
|
||||||
input: &serde_json::Value,
|
input: &serde_json::Value,
|
||||||
instructions: Option<&str>,
|
instructions: Option<&str>,
|
||||||
) -> (String, Vec<ToolResultInput>, Option<crate::proto::ImageData>) {
|
) -> (
|
||||||
|
String,
|
||||||
|
Vec<ToolResultInput>,
|
||||||
|
Option<crate::proto::ImageData>,
|
||||||
|
) {
|
||||||
let mut tool_results: Vec<ToolResultInput> = Vec::new();
|
let mut tool_results: Vec<ToolResultInput> = Vec::new();
|
||||||
let mut image: Option<crate::proto::ImageData> = None;
|
let mut image: Option<crate::proto::ImageData> = None;
|
||||||
|
|
||||||
@@ -45,10 +52,9 @@ fn extract_responses_input(
|
|||||||
// Check for function_call_output items
|
// Check for function_call_output items
|
||||||
for item in items {
|
for item in items {
|
||||||
if item["type"].as_str() == Some("function_call_output") {
|
if item["type"].as_str() == Some("function_call_output") {
|
||||||
if let (Some(call_id), Some(output)) = (
|
if let (Some(call_id), Some(output)) =
|
||||||
item["call_id"].as_str(),
|
(item["call_id"].as_str(), item["output"].as_str())
|
||||||
item["output"].as_str(),
|
{
|
||||||
) {
|
|
||||||
tool_results.push(ToolResultInput {
|
tool_results.push(ToolResultInput {
|
||||||
call_id: call_id.to_string(),
|
call_id: call_id.to_string(),
|
||||||
output: output.to_string(),
|
output: output.to_string(),
|
||||||
@@ -230,24 +236,31 @@ pub(crate) async fn handle_responses(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let (user_text, tool_results, image) = extract_responses_input(&body.input, body.instructions.as_deref());
|
let (user_text, tool_results, image) =
|
||||||
|
extract_responses_input(&body.input, body.instructions.as_deref());
|
||||||
|
|
||||||
// Handle tool result submission (function_call_output in input)
|
// Handle tool result submission (function_call_output in input)
|
||||||
let is_tool_result_turn = !tool_results.is_empty();
|
let is_tool_result_turn = !tool_results.is_empty();
|
||||||
if is_tool_result_turn {
|
if is_tool_result_turn {
|
||||||
for tr in &tool_results {
|
for tr in &tool_results {
|
||||||
// Look up function name from call_id
|
// Look up function name from call_id
|
||||||
let name = state.mitm_store.lookup_call_id(&tr.call_id).await
|
let name = state
|
||||||
|
.mitm_store
|
||||||
|
.lookup_call_id(&tr.call_id)
|
||||||
|
.await
|
||||||
.unwrap_or_else(|| "unknown_function".to_string());
|
.unwrap_or_else(|| "unknown_function".to_string());
|
||||||
|
|
||||||
// Parse the output as JSON, fall back to string wrapper
|
// Parse the output as JSON, fall back to string wrapper
|
||||||
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
let result_value = serde_json::from_str::<serde_json::Value>(&tr.output)
|
||||||
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
.unwrap_or_else(|_| serde_json::json!({"result": tr.output}));
|
||||||
|
|
||||||
state.mitm_store.add_tool_result(PendingToolResult {
|
state
|
||||||
name,
|
.mitm_store
|
||||||
result: result_value,
|
.add_tool_result(PendingToolResult {
|
||||||
}).await;
|
name,
|
||||||
|
result: result_value,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
info!(
|
info!(
|
||||||
count = tool_results.len(),
|
count = tool_results.len(),
|
||||||
@@ -275,7 +288,10 @@ pub(crate) async fn handle_responses(
|
|||||||
let gemini_tools = openai_tools_to_gemini(tools);
|
let gemini_tools = openai_tools_to_gemini(tools);
|
||||||
if !gemini_tools.is_empty() {
|
if !gemini_tools.is_empty() {
|
||||||
state.mitm_store.set_tools(gemini_tools).await;
|
state.mitm_store.set_tools(gemini_tools).await;
|
||||||
info!(count = tools.len(), "Stored client tools for MITM injection");
|
info!(
|
||||||
|
count = tools.len(),
|
||||||
|
"Stored client tools for MITM injection"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if let Some(ref choice) = body.tool_choice {
|
if let Some(ref choice) = body.tool_choice {
|
||||||
@@ -289,7 +305,9 @@ pub(crate) async fn handle_responses(
|
|||||||
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
||||||
if fmt_type == "json_schema" {
|
if fmt_type == "json_schema" {
|
||||||
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
|
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
|
||||||
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
|
let schema = text_val["format"]["schema"]
|
||||||
|
.as_object()
|
||||||
|
.map(|o| serde_json::Value::Object(o.clone()));
|
||||||
let strict = text_val["format"]["strict"].as_bool();
|
let strict = text_val["format"]["strict"].as_bool();
|
||||||
let tf = TextFormat {
|
let tf = TextFormat {
|
||||||
format: TextFormatInner {
|
format: TextFormatInner {
|
||||||
@@ -321,9 +339,13 @@ pub(crate) async fn handle_responses(
|
|||||||
response_schema,
|
response_schema,
|
||||||
google_search: has_web_search,
|
google_search: has_web_search,
|
||||||
};
|
};
|
||||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
if gp.temperature.is_some()
|
||||||
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|
|| gp.top_p.is_some()
|
||||||
|| gp.response_schema.is_some() || gp.google_search
|
|| gp.max_output_tokens.is_some()
|
||||||
|
|| gp.reasoning_effort.is_some()
|
||||||
|
|| gp.response_mime_type.is_some()
|
||||||
|
|| gp.response_schema.is_some()
|
||||||
|
|| gp.google_search
|
||||||
{
|
{
|
||||||
state.mitm_store.set_generation_params(gp).await;
|
state.mitm_store.set_generation_params(gp).await;
|
||||||
} else {
|
} else {
|
||||||
@@ -331,10 +353,7 @@ pub(crate) async fn handle_responses(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let response_id = format!(
|
let response_id = format!("resp_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
"resp_{}",
|
|
||||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
|
||||||
);
|
|
||||||
|
|
||||||
// Session/conversation management
|
// Session/conversation management
|
||||||
let session_id_str = extract_conversation_id(&body.conversation);
|
let session_id_str = extract_conversation_id(&body.conversation);
|
||||||
@@ -371,12 +390,13 @@ pub(crate) async fn handle_responses(
|
|||||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||||
if let Some(ref img) = image {
|
if let Some(ref img) = image {
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
state.mitm_store.set_pending_image(
|
state
|
||||||
crate::mitm::store::PendingImage {
|
.mitm_store
|
||||||
|
.set_pending_image(crate::mitm::store::PendingImage {
|
||||||
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
||||||
mime_type: img.mime_type.clone(),
|
mime_type: img.mime_type.clone(),
|
||||||
}
|
})
|
||||||
).await;
|
.await;
|
||||||
}
|
}
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
@@ -419,21 +439,32 @@ pub(crate) async fn handle_responses(
|
|||||||
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||||
max_tool_calls: body.max_tool_calls,
|
max_tool_calls: body.max_tool_calls,
|
||||||
reasoning_effort: body.reasoning_effort.clone(),
|
reasoning_effort: body.reasoning_effort.clone(),
|
||||||
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
|
tool_choice: body
|
||||||
|
.tool_choice
|
||||||
|
.clone()
|
||||||
|
.unwrap_or(serde_json::json!("auto")),
|
||||||
tools: body.tools.clone().unwrap_or_default(),
|
tools: body.tools.clone().unwrap_or_default(),
|
||||||
text_format,
|
text_format,
|
||||||
};
|
};
|
||||||
|
|
||||||
if body.stream {
|
if body.stream {
|
||||||
handle_responses_stream(
|
handle_responses_stream(
|
||||||
state, response_id, model_name.to_string(), cascade_id,
|
state,
|
||||||
body.timeout, req_params,
|
response_id,
|
||||||
|
model_name.to_string(),
|
||||||
|
cascade_id,
|
||||||
|
body.timeout,
|
||||||
|
req_params,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
handle_responses_sync(
|
handle_responses_sync(
|
||||||
state, response_id, model_name.to_string(), cascade_id,
|
state,
|
||||||
body.timeout, req_params,
|
response_id,
|
||||||
|
model_name.to_string(),
|
||||||
|
cascade_id,
|
||||||
|
body.timeout,
|
||||||
|
req_params,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@@ -485,7 +516,9 @@ async fn usage_from_poll(
|
|||||||
if let Some(u) = mitm_store.peek_usage(key).await {
|
if let Some(u) = mitm_store.peek_usage(key).await {
|
||||||
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
|
if u.thinking_output_tokens > 0 && u.thinking_text.is_none() {
|
||||||
// Call 2 hasn't arrived yet — wait briefly for the merge
|
// Call 2 hasn't arrived yet — wait briefly for the merge
|
||||||
tracing::debug!("MITM: thinking tokens found but no text, waiting for summary merge...");
|
tracing::debug!(
|
||||||
|
"MITM: thinking tokens found but no text, waiting for summary merge..."
|
||||||
|
);
|
||||||
for _ in 0..10 {
|
for _ in 0..10 {
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
if let Some(u2) = mitm_store.peek_usage(key).await {
|
if let Some(u2) = mitm_store.peek_usage(key).await {
|
||||||
@@ -526,13 +559,18 @@ async fn usage_from_poll(
|
|||||||
|
|
||||||
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
// Priority 2: LS trajectory data (from CHECKPOINT/metadata steps)
|
||||||
if let Some(u) = model_usage {
|
if let Some(u) = model_usage {
|
||||||
return (Usage {
|
return (
|
||||||
input_tokens: u.input_tokens,
|
Usage {
|
||||||
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
input_tokens: u.input_tokens,
|
||||||
output_tokens: u.output_tokens,
|
input_tokens_details: InputTokensDetails { cached_tokens: 0 },
|
||||||
output_tokens_details: OutputTokensDetails { reasoning_tokens: 0 },
|
output_tokens: u.output_tokens,
|
||||||
total_tokens: u.input_tokens + u.output_tokens,
|
output_tokens_details: OutputTokensDetails {
|
||||||
}, None);
|
reasoning_tokens: 0,
|
||||||
|
},
|
||||||
|
total_tokens: u.input_tokens + u.output_tokens,
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Priority 3: Estimate from text lengths
|
// Priority 3: Estimate from text lengths
|
||||||
@@ -575,14 +613,22 @@ async fn handle_responses_sync(
|
|||||||
"call_{}",
|
"call_{}",
|
||||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
);
|
);
|
||||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
state
|
||||||
|
.mitm_store
|
||||||
|
.register_call_id(call_id.clone(), fc.name.clone())
|
||||||
|
.await;
|
||||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
output_items
|
||||||
|
.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||||
}
|
}
|
||||||
let (usage, _) = usage_from_poll(
|
let (usage, _) = usage_from_poll(
|
||||||
&state.mitm_store, &cascade_id, &None,
|
&state.mitm_store,
|
||||||
¶ms.user_text, "",
|
&cascade_id,
|
||||||
).await;
|
&None,
|
||||||
|
¶ms.user_text,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
.await;
|
||||||
let resp = build_response_object(
|
let resp = build_response_object(
|
||||||
ResponseData {
|
ResponseData {
|
||||||
id: response_id,
|
id: response_id,
|
||||||
@@ -602,12 +648,20 @@ async fn handle_responses_sync(
|
|||||||
|
|
||||||
// Check for completed text response
|
// Check for completed text response
|
||||||
if state.mitm_store.is_response_complete() {
|
if state.mitm_store.is_response_complete() {
|
||||||
let text = state.mitm_store.take_response_text().await.unwrap_or_default();
|
let text = state
|
||||||
|
.mitm_store
|
||||||
|
.take_response_text()
|
||||||
|
.await
|
||||||
|
.unwrap_or_default();
|
||||||
let thinking = state.mitm_store.take_thinking_text().await;
|
let thinking = state.mitm_store.take_thinking_text().await;
|
||||||
let (usage, _) = usage_from_poll(
|
let (usage, _) = usage_from_poll(
|
||||||
&state.mitm_store, &cascade_id, &None,
|
&state.mitm_store,
|
||||||
¶ms.user_text, &text,
|
&cascade_id,
|
||||||
).await;
|
&None,
|
||||||
|
¶ms.user_text,
|
||||||
|
&text,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||||
if let Some(ref t) = thinking {
|
if let Some(ref t) = thinking {
|
||||||
@@ -658,10 +712,7 @@ async fn handle_responses_sync(
|
|||||||
return upstream_err_response(err);
|
return upstream_err_response(err);
|
||||||
}
|
}
|
||||||
let completed_at = now_unix();
|
let completed_at = now_unix();
|
||||||
let msg_id = format!(
|
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
"msg_{}",
|
|
||||||
uuid::Uuid::new_v4().to_string().replace('-', "")
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check for captured function calls from MITM (clears the active flag)
|
// Check for captured function calls from MITM (clears the active flag)
|
||||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||||
@@ -689,7 +740,10 @@ async fn handle_responses_sync(
|
|||||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
);
|
);
|
||||||
// Register call_id → name mapping for tool result routing
|
// Register call_id → name mapping for tool result routing
|
||||||
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
state
|
||||||
|
.mitm_store
|
||||||
|
.register_call_id(call_id.clone(), fc.name.clone())
|
||||||
|
.await;
|
||||||
|
|
||||||
// Stringify args (OpenAI sends arguments as JSON string)
|
// Stringify args (OpenAI sends arguments as JSON string)
|
||||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
@@ -697,9 +751,13 @@ async fn handle_responses_sync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (usage, _) = usage_from_poll(
|
let (usage, _) = usage_from_poll(
|
||||||
&state.mitm_store, &cascade_id, &poll_result.usage,
|
&state.mitm_store,
|
||||||
¶ms.user_text, &poll_result.text,
|
&cascade_id,
|
||||||
).await;
|
&poll_result.usage,
|
||||||
|
¶ms.user_text,
|
||||||
|
&poll_result.text,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
let resp = build_response_object(
|
let resp = build_response_object(
|
||||||
ResponseData {
|
ResponseData {
|
||||||
@@ -719,7 +777,14 @@ async fn handle_responses_sync(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Normal text response (no tool calls)
|
// Normal text response (no tool calls)
|
||||||
let (usage, mitm_thinking) = usage_from_poll(&state.mitm_store, &cascade_id, &poll_result.usage, ¶ms.user_text, &poll_result.text).await;
|
let (usage, mitm_thinking) = usage_from_poll(
|
||||||
|
&state.mitm_store,
|
||||||
|
&cascade_id,
|
||||||
|
&poll_result.usage,
|
||||||
|
¶ms.user_text,
|
||||||
|
&poll_result.text,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
// Thinking text priority: MITM-captured (raw API) > LS-extracted (steps)
|
||||||
let thinking_text = mitm_thinking.or(poll_result.thinking);
|
let thinking_text = mitm_thinking.or(poll_result.thinking);
|
||||||
@@ -1560,4 +1625,3 @@ fn completion_events(
|
|||||||
|
|
||||||
events
|
events
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -126,7 +126,9 @@ pub(crate) struct CompletionRequest {
|
|||||||
pub web_search: bool,
|
pub web_search: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_n() -> u32 { 1 }
|
fn default_n() -> u32 {
|
||||||
|
1
|
||||||
|
}
|
||||||
|
|
||||||
/// Stop sequence can be a single string or array of strings (OpenAI accepts both).
|
/// Stop sequence can be a single string or array of strings (OpenAI accepts both).
|
||||||
#[derive(Deserialize, Clone)]
|
#[derive(Deserialize, Clone)]
|
||||||
@@ -254,8 +256,7 @@ pub(crate) struct OutputTokensDetails {
|
|||||||
pub reasoning_tokens: u64,
|
pub reasoning_tokens: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
#[derive(Serialize, Clone, Default)]
|
||||||
#[derive(Default)]
|
|
||||||
pub(crate) struct Reasoning {
|
pub(crate) struct Reasoning {
|
||||||
pub effort: Option<String>,
|
pub effort: Option<String>,
|
||||||
pub summary: Option<String>,
|
pub summary: Option<String>,
|
||||||
@@ -313,7 +314,6 @@ impl Default for Usage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl Default for TextFormat {
|
impl Default for TextFormat {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|||||||
@@ -27,7 +27,9 @@ pub(crate) fn err_response(
|
|||||||
|
|
||||||
/// Convert a MITM-captured upstream error from Google into an HTTP response.
|
/// Convert a MITM-captured upstream error from Google into an HTTP response.
|
||||||
/// Maps Google's HTTP status codes and preserves the error message.
|
/// Maps Google's HTTP status codes and preserves the error message.
|
||||||
pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) -> axum::response::Response {
|
pub(crate) fn upstream_err_response(
|
||||||
|
err: &crate::mitm::store::UpstreamError,
|
||||||
|
) -> axum::response::Response {
|
||||||
// Map Google's status code to HTTP status
|
// Map Google's status code to HTTP status
|
||||||
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY);
|
let status = StatusCode::from_u16(err.status).unwrap_or(StatusCode::BAD_GATEWAY);
|
||||||
|
|
||||||
@@ -41,7 +43,9 @@ pub(crate) fn upstream_err_response(err: &crate::mitm::store::UpstreamError) ->
|
|||||||
_ => "upstream_error",
|
_ => "upstream_error",
|
||||||
};
|
};
|
||||||
|
|
||||||
let message = err.message.clone()
|
let message = err
|
||||||
|
.message
|
||||||
|
.clone()
|
||||||
.unwrap_or_else(|| format!("Google API returned HTTP {}", err.status));
|
.unwrap_or_else(|| format!("Google API returned HTTP {}", err.status));
|
||||||
|
|
||||||
err_response(status, message, error_type)
|
err_response(status, message, error_type)
|
||||||
@@ -99,7 +103,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
|
|||||||
}
|
}
|
||||||
// OpenAI Responses API format
|
// OpenAI Responses API format
|
||||||
"input_image" => {
|
"input_image" => {
|
||||||
let url = item["image_url"].as_str()
|
let url = item["image_url"]
|
||||||
|
.as_str()
|
||||||
.or_else(|| item["url"].as_str())?;
|
.or_else(|| item["url"].as_str())?;
|
||||||
parse_data_uri(url)
|
parse_data_uri(url)
|
||||||
}
|
}
|
||||||
@@ -109,5 +114,8 @@ pub(crate) fn extract_image_from_content(item: &serde_json::Value) -> Option<Ima
|
|||||||
|
|
||||||
/// Extract the first image from a content array (Value::Array of content parts).
|
/// Extract the first image from a content array (Value::Array of content parts).
|
||||||
pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> {
|
pub(crate) fn extract_first_image(content: &serde_json::Value) -> Option<ImageData> {
|
||||||
content.as_array()?.iter().find_map(extract_image_from_content)
|
content
|
||||||
|
.as_array()?
|
||||||
|
.iter()
|
||||||
|
.find_map(extract_image_from_content)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -48,10 +48,7 @@ static STATIC_HEADERS: LazyLock<HeaderMap> = LazyLock::new(|| {
|
|||||||
*CHROME_MAJOR,
|
*CHROME_MAJOR,
|
||||||
)),
|
)),
|
||||||
);
|
);
|
||||||
h.insert(
|
h.insert(HeaderName::from_static("sec-ch-ua-mobile"), hv("?0"));
|
||||||
HeaderName::from_static("sec-ch-ua-mobile"),
|
|
||||||
hv("?0"),
|
|
||||||
);
|
|
||||||
h.insert(
|
h.insert(
|
||||||
HeaderName::from_static("sec-ch-ua-platform"),
|
HeaderName::from_static("sec-ch-ua-platform"),
|
||||||
hv("\"Linux\""),
|
hv("\"Linux\""),
|
||||||
@@ -72,7 +69,7 @@ impl Backend {
|
|||||||
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
|
// wreq with Chrome impersonation: BoringSSL + Chrome JA3/JA4 + H2 fingerprint
|
||||||
let client = wreq::Client::builder()
|
let client = wreq::Client::builder()
|
||||||
.emulation(wreq_util::Emulation::Chrome142)
|
.emulation(wreq_util::Emulation::Chrome142)
|
||||||
.cert_verification(false) // LS uses self-signed cert
|
.cert_verification(false) // LS uses self-signed cert
|
||||||
.verify_hostname(false)
|
.verify_hostname(false)
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("wreq client build failed: {e}"))?;
|
.map_err(|e| format!("wreq client build failed: {e}"))?;
|
||||||
@@ -86,11 +83,7 @@ impl Backend {
|
|||||||
/// Create a Backend with known connection details (for standalone LS).
|
/// Create a Backend with known connection details (for standalone LS).
|
||||||
///
|
///
|
||||||
/// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token.
|
/// Skips auto-discovery — the caller provides the port, CSRF, and OAuth token.
|
||||||
pub fn new_with_config(
|
pub fn new_with_config(port: u16, csrf: String, oauth_token: String) -> Result<Self, String> {
|
||||||
port: u16,
|
|
||||||
csrf: String,
|
|
||||||
oauth_token: String,
|
|
||||||
) -> Result<Self, String> {
|
|
||||||
let inner = BackendInner {
|
let inner = BackendInner {
|
||||||
pid: "standalone".to_string(),
|
pid: "standalone".to_string(),
|
||||||
csrf,
|
csrf,
|
||||||
@@ -212,10 +205,7 @@ impl Backend {
|
|||||||
fn common_headers(csrf: &str) -> HeaderMap {
|
fn common_headers(csrf: &str) -> HeaderMap {
|
||||||
let mut h = STATIC_HEADERS.clone();
|
let mut h = STATIC_HEADERS.clone();
|
||||||
if let Ok(val) = HeaderValue::from_str(csrf) {
|
if let Ok(val) = HeaderValue::from_str(csrf) {
|
||||||
h.insert(
|
h.insert(HeaderName::from_static("x-codeium-csrf-token"), val);
|
||||||
HeaderName::from_static("x-codeium-csrf-token"),
|
|
||||||
val,
|
|
||||||
);
|
|
||||||
} else {
|
} else {
|
||||||
warn!("CSRF token contains invalid header characters, omitting");
|
warn!("CSRF token contains invalid header characters, omitting");
|
||||||
}
|
}
|
||||||
@@ -239,8 +229,8 @@ impl Backend {
|
|||||||
let mut headers = Self::common_headers(&csrf);
|
let mut headers = Self::common_headers(&csrf);
|
||||||
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||||
|
|
||||||
let body_bytes = serde_json::to_vec(body)
|
let body_bytes =
|
||||||
.map_err(|e| format!("JSON serialize error: {e}"))?;
|
serde_json::to_vec(body).map_err(|e| format!("JSON serialize error: {e}"))?;
|
||||||
|
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
@@ -258,7 +248,9 @@ impl Backend {
|
|||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_string();
|
.to_string();
|
||||||
let raw = resp.bytes().await
|
let raw = resp
|
||||||
|
.bytes()
|
||||||
|
.await
|
||||||
.map_err(|e| format!("Read body error: {e}"))?;
|
.map_err(|e| format!("Read body error: {e}"))?;
|
||||||
let resp_bytes = decompress(method, &raw, &encoding);
|
let resp_bytes = decompress(method, &raw, &encoding);
|
||||||
// High-frequency polling methods → trace; everything else → debug
|
// High-frequency polling methods → trace; everything else → debug
|
||||||
@@ -288,11 +280,7 @@ impl Backend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Call a binary protobuf RPC method.
|
/// Call a binary protobuf RPC method.
|
||||||
pub async fn call_proto(
|
pub async fn call_proto(&self, method: &str, body: Vec<u8>) -> Result<(u16, Vec<u8>), String> {
|
||||||
&self,
|
|
||||||
method: &str,
|
|
||||||
body: Vec<u8>,
|
|
||||||
) -> Result<(u16, Vec<u8>), String> {
|
|
||||||
let (base, csrf) = {
|
let (base, csrf) = {
|
||||||
let guard = self.inner.read().await;
|
let guard = self.inner.read().await;
|
||||||
(
|
(
|
||||||
@@ -302,7 +290,10 @@ impl Backend {
|
|||||||
};
|
};
|
||||||
let url = format!("{base}/{LS_SERVICE}/{method}");
|
let url = format!("{base}/{LS_SERVICE}/{method}");
|
||||||
let mut headers = Self::common_headers(&csrf);
|
let mut headers = Self::common_headers(&csrf);
|
||||||
headers.insert("Content-Type", HeaderValue::from_static("application/proto"));
|
headers.insert(
|
||||||
|
"Content-Type",
|
||||||
|
HeaderValue::from_static("application/proto"),
|
||||||
|
);
|
||||||
|
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
@@ -350,7 +341,8 @@ impl Backend {
|
|||||||
text: &str,
|
text: &str,
|
||||||
model_enum: u32,
|
model_enum: u32,
|
||||||
) -> Result<(u16, Vec<u8>), String> {
|
) -> Result<(u16, Vec<u8>), String> {
|
||||||
self.send_message_with_image(cascade_id, text, model_enum, None).await
|
self.send_message_with_image(cascade_id, text, model_enum, None)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SendUserCascadeMessage with optional image attachment.
|
/// SendUserCascadeMessage with optional image attachment.
|
||||||
@@ -365,7 +357,8 @@ impl Backend {
|
|||||||
if token.is_empty() {
|
if token.is_empty() {
|
||||||
return Err("No OAuth token available".to_string());
|
return Err("No OAuth token available".to_string());
|
||||||
}
|
}
|
||||||
let proto = crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
|
let proto =
|
||||||
|
crate::proto::build_request_with_image(cascade_id, text, &token, model_enum, image);
|
||||||
if image.is_some() {
|
if image.is_some() {
|
||||||
tracing::info!(
|
tracing::info!(
|
||||||
proto_size = proto.len(),
|
proto_size = proto.len(),
|
||||||
@@ -376,10 +369,7 @@ impl Backend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// GetCascadeTrajectorySteps → JSON with steps array.
|
/// GetCascadeTrajectorySteps → JSON with steps array.
|
||||||
pub async fn get_steps(
|
pub async fn get_steps(&self, cascade_id: &str) -> Result<(u16, serde_json::Value), String> {
|
||||||
&self,
|
|
||||||
cascade_id: &str,
|
|
||||||
) -> Result<(u16, serde_json::Value), String> {
|
|
||||||
let body = serde_json::json!({"cascadeId": cascade_id});
|
let body = serde_json::json!({"cascadeId": cascade_id});
|
||||||
self.call_json("GetCascadeTrajectorySteps", &body).await
|
self.call_json("GetCascadeTrajectorySteps", &body).await
|
||||||
}
|
}
|
||||||
@@ -415,7 +405,10 @@ impl Backend {
|
|||||||
});
|
});
|
||||||
|
|
||||||
let mut headers = Self::common_headers(&csrf);
|
let mut headers = Self::common_headers(&csrf);
|
||||||
headers.insert("Content-Type", HeaderValue::from_static("application/connect+json"));
|
headers.insert(
|
||||||
|
"Content-Type",
|
||||||
|
HeaderValue::from_static("application/connect+json"),
|
||||||
|
);
|
||||||
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
|
headers.insert("Connect-Protocol-Version", HeaderValue::from_static("1"));
|
||||||
|
|
||||||
// Connect protocol envelope: [flags:1][length:4][payload]
|
// Connect protocol envelope: [flags:1][length:4][payload]
|
||||||
@@ -441,7 +434,8 @@ impl Backend {
|
|||||||
return Err(format!("{rpc_method} failed: {status} — {err_text}"));
|
return Err(format!("{rpc_method} failed: {status} — {err_text}"));
|
||||||
}
|
}
|
||||||
|
|
||||||
let resp_ct = resp.headers()
|
let resp_ct = resp
|
||||||
|
.headers()
|
||||||
.get("content-type")
|
.get("content-type")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
@@ -495,7 +489,8 @@ impl Backend {
|
|||||||
&self,
|
&self,
|
||||||
cascade_id: &str,
|
cascade_id: &str,
|
||||||
) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> {
|
) -> Result<tokio::sync::mpsc::Receiver<serde_json::Value>, String> {
|
||||||
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id).await
|
self.stream_reactive_rpc("StreamCascadeReactiveUpdates", cascade_id)
|
||||||
|
.await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,7 +501,10 @@ fn discover() -> Result<BackendInner, String> {
|
|||||||
// the wrapper is a shell script named language_server_linux_x64, while
|
// the wrapper is a shell script named language_server_linux_x64, while
|
||||||
// the real binary is language_server_linux_x64.real)
|
// the real binary is language_server_linux_x64.real)
|
||||||
let pid_output = Command::new("sh")
|
let pid_output = Command::new("sh")
|
||||||
.args(["-c", "pgrep -f 'language_server_linux_x64\\.real' | head -1"])
|
.args([
|
||||||
|
"-c",
|
||||||
|
"pgrep -f 'language_server_linux_x64\\.real' | head -1",
|
||||||
|
])
|
||||||
.output()
|
.output()
|
||||||
.map_err(|e| format!("pgrep failed: {e}"))?;
|
.map_err(|e| format!("pgrep failed: {e}"))?;
|
||||||
|
|
||||||
@@ -564,9 +562,8 @@ fn discover() -> Result<BackendInner, String> {
|
|||||||
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
|
LazyLock::new(|| regex::Regex::new(r"port at (\d+) for HTTPS").unwrap());
|
||||||
|
|
||||||
for d in &dirs {
|
for d in &dirs {
|
||||||
let log_path = format!(
|
let log_path =
|
||||||
"{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log"
|
format!("{log_base}/{d}/window1/exthost/google.antigravity/Antigravity.log");
|
||||||
);
|
|
||||||
if let Ok(contents) = fs::read_to_string(&log_path) {
|
if let Ok(contents) = fs::read_to_string(&log_path) {
|
||||||
for line in contents.lines() {
|
for line in contents.lines() {
|
||||||
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
|
if line.contains(&pid) && line.contains("listening") && line.contains("HTTPS") {
|
||||||
@@ -584,10 +581,7 @@ fn discover() -> Result<BackendInner, String> {
|
|||||||
|
|
||||||
if https_port.is_empty() {
|
if https_port.is_empty() {
|
||||||
// Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up)
|
// Fallback: find the LS HTTPS port via `ss` (when log file hasn't caught up)
|
||||||
if let Ok(output) = std::process::Command::new("ss")
|
if let Ok(output) = std::process::Command::new("ss").args(["-tlnp"]).output() {
|
||||||
.args(["-tlnp"])
|
|
||||||
.output()
|
|
||||||
{
|
|
||||||
let ss_out = String::from_utf8_lossy(&output.stdout);
|
let ss_out = String::from_utf8_lossy(&output.stdout);
|
||||||
// Find listening ports for this PID — typically the first is HTTPS
|
// Find listening ports for this PID — typically the first is HTTPS
|
||||||
for line in ss_out.lines() {
|
for line in ss_out.lines() {
|
||||||
@@ -653,7 +647,11 @@ fn decompress(method: &str, data: &[u8], encoding: &str) -> Vec<u8> {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
if !encoding.is_empty() {
|
if !encoding.is_empty() {
|
||||||
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
|
let preview = String::from_utf8_lossy(&data[..data.len().min(100)]);
|
||||||
warn!("{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}", data.len(), preview);
|
warn!(
|
||||||
|
"{method}: {encoding} decompress failed ({} bytes): {e}. Raw: {}",
|
||||||
|
data.len(),
|
||||||
|
preview
|
||||||
|
);
|
||||||
}
|
}
|
||||||
data.to_vec()
|
data.to_vec()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,9 +115,7 @@ fn detect_versions() -> DetectedVersions {
|
|||||||
const FALLBACK_CLIENT: &str = "1.16.5";
|
const FALLBACK_CLIENT: &str = "1.16.5";
|
||||||
|
|
||||||
let Some(install_dir) = find_install_dir() else {
|
let Some(install_dir) = find_install_dir() else {
|
||||||
tracing::warn!(
|
tracing::warn!("Could not find Antigravity install — using fallback versions");
|
||||||
"Could not find Antigravity install — using fallback versions"
|
|
||||||
);
|
|
||||||
return DetectedVersions {
|
return DetectedVersions {
|
||||||
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
|
antigravity: FALLBACK_ANTIGRAVITY.to_string(),
|
||||||
chrome: FALLBACK_CHROME.to_string(),
|
chrome: FALLBACK_CHROME.to_string(),
|
||||||
|
|||||||
70
src/main.rs
70
src/main.rs
@@ -24,7 +24,10 @@ use tracing::{info, warn};
|
|||||||
use mitm::store::MitmStore;
|
use mitm::store::MitmStore;
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
#[command(name = "antigravity-proxy", about = "Antigravity OpenAI Proxy (stealth)")]
|
#[command(
|
||||||
|
name = "antigravity-proxy",
|
||||||
|
about = "Antigravity OpenAI Proxy (stealth)"
|
||||||
|
)]
|
||||||
struct Cli {
|
struct Cli {
|
||||||
/// Port to listen on
|
/// Port to listen on
|
||||||
#[arg(long, default_value_t = 8741)]
|
#[arg(long, default_value_t = 8741)]
|
||||||
@@ -93,15 +96,12 @@ async fn main() {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let filter = if log_level.is_empty() {
|
let filter = if log_level.is_empty() {
|
||||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "warn".into())
|
||||||
.unwrap_or_else(|_| "warn".into())
|
|
||||||
} else {
|
} else {
|
||||||
tracing_subscriber::EnvFilter::new(log_level)
|
tracing_subscriber::EnvFilter::new(log_level)
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing_subscriber::fmt()
|
tracing_subscriber::fmt().with_env_filter(filter).init();
|
||||||
.with_env_filter(filter)
|
|
||||||
.init();
|
|
||||||
|
|
||||||
// ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
|
// ── Step 1: Bind main port (auto-kill stale process if needed) ─────────────
|
||||||
let addr = format!("127.0.0.1:{}", cli.port);
|
let addr = format!("127.0.0.1:{}", cli.port);
|
||||||
@@ -111,7 +111,10 @@ async fn main() {
|
|||||||
// Port in use — try to kill whatever's holding it
|
// Port in use — try to kill whatever's holding it
|
||||||
eprintln!(" Port {} in use, killing stale process...", cli.port);
|
eprintln!(" Port {} in use, killing stale process...", cli.port);
|
||||||
let _ = std::process::Command::new("sh")
|
let _ = std::process::Command::new("sh")
|
||||||
.args(["-c", &format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port)])
|
.args([
|
||||||
|
"-c",
|
||||||
|
&format!("kill $(lsof -ti:{}) 2>/dev/null; sleep 0.3", cli.port),
|
||||||
|
])
|
||||||
.status();
|
.status();
|
||||||
// Also kill any leftover standalone LS processes
|
// Also kill any leftover standalone LS processes
|
||||||
let _ = std::process::Command::new("pkill")
|
let _ = std::process::Command::new("pkill")
|
||||||
@@ -180,7 +183,9 @@ async fn main() {
|
|||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("Fatal: {e}");
|
eprintln!("Fatal: {e}");
|
||||||
eprintln!("Hint: start Antigravity first, or remove --classic to use headless mode");
|
eprintln!(
|
||||||
|
"Hint: start Antigravity first, or remove --classic to use headless mode"
|
||||||
|
);
|
||||||
std::process::exit(1);
|
std::process::exit(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -199,13 +204,14 @@ async fn main() {
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut ls = match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
|
let mut ls =
|
||||||
Ok(ls) => ls,
|
match standalone::StandaloneLS::spawn(&main_config, mitm_cfg.as_ref(), headless) {
|
||||||
Err(e) => {
|
Ok(ls) => ls,
|
||||||
eprintln!("Fatal: failed to spawn standalone LS: {e}");
|
Err(e) => {
|
||||||
std::process::exit(1);
|
eprintln!("Fatal: failed to spawn standalone LS: {e}");
|
||||||
}
|
std::process::exit(1);
|
||||||
};
|
}
|
||||||
|
};
|
||||||
// Wait for it to be ready
|
// Wait for it to be ready
|
||||||
let rt_ls_port = ls.port;
|
let rt_ls_port = ls.port;
|
||||||
let rt_ls_csrf = ls.csrf.clone();
|
let rt_ls_csrf = ls.csrf.clone();
|
||||||
@@ -294,7 +300,15 @@ async fn main() {
|
|||||||
// ── Step 5: Start serving ─────────────────────────────────────────────────
|
// ── Step 5: Start serving ─────────────────────────────────────────────────
|
||||||
let app = api::router(state.clone());
|
let app = api::router(state.clone());
|
||||||
|
|
||||||
print_banner(cli.port, &pid, &https_port, &csrf, &token, &mitm_port_actual, is_standalone);
|
print_banner(
|
||||||
|
cli.port,
|
||||||
|
&pid,
|
||||||
|
&https_port,
|
||||||
|
&csrf,
|
||||||
|
&token,
|
||||||
|
&mitm_port_actual,
|
||||||
|
is_standalone,
|
||||||
|
);
|
||||||
info!("Listening on http://{addr}");
|
info!("Listening on http://{addr}");
|
||||||
|
|
||||||
axum::serve(listener, app)
|
axum::serve(listener, app)
|
||||||
@@ -349,7 +363,15 @@ async fn shutdown_signal() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str, mitm: &Option<(u16, String)>, is_standalone: bool) {
|
fn print_banner(
|
||||||
|
port: u16,
|
||||||
|
pid: &str,
|
||||||
|
https_port: &str,
|
||||||
|
csrf: &str,
|
||||||
|
token: &str,
|
||||||
|
mitm: &Option<(u16, String)>,
|
||||||
|
is_standalone: bool,
|
||||||
|
) {
|
||||||
let chrome_major = &*constants::CHROME_MAJOR;
|
let chrome_major = &*constants::CHROME_MAJOR;
|
||||||
let ver = crate::constants::antigravity_version();
|
let ver = crate::constants::antigravity_version();
|
||||||
|
|
||||||
@@ -401,7 +423,11 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
|
|||||||
println!();
|
println!();
|
||||||
|
|
||||||
// Status line
|
// Status line
|
||||||
let mitm_tag = if mitm.is_some() { "\x1b[32mmitm\x1b[0m" } else { "\x1b[31mmitm\x1b[0m" };
|
let mitm_tag = if mitm.is_some() {
|
||||||
|
"\x1b[32mmitm\x1b[0m"
|
||||||
|
} else {
|
||||||
|
"\x1b[31mmitm\x1b[0m"
|
||||||
|
};
|
||||||
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
|
println!(" \x1b[2mstealth:\x1b[0m \x1b[32mwarmup\x1b[0m \x1b[32mheartbeat\x1b[0m \x1b[32mjitter\x1b[0m {mitm_tag}");
|
||||||
println!();
|
println!();
|
||||||
|
|
||||||
@@ -421,7 +447,9 @@ fn print_banner(port: u16, pid: &str, https_port: &str, csrf: &str, token: &str,
|
|||||||
if token == "NOT SET" {
|
if token == "NOT SET" {
|
||||||
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
|
println!(" \x1b[1;33m[!]\x1b[0m no oauth token");
|
||||||
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
|
println!(" export ANTIGRAVITY_OAUTH_TOKEN=ya29.xxx");
|
||||||
println!(" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'");
|
println!(
|
||||||
|
" curl -X POST http://127.0.0.1:{port}/v1/token -d '{{\"token\":\"ya29.xxx\"}}'"
|
||||||
|
);
|
||||||
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
|
println!(" echo 'ya29.xxx' > ~/.config/antigravity-proxy-token");
|
||||||
println!();
|
println!();
|
||||||
}
|
}
|
||||||
@@ -476,5 +504,7 @@ fn find_ls_binary_path() -> Option<String> {
|
|||||||
/// Get the data directory for storing MITM CA cert/key.
|
/// Get the data directory for storing MITM CA cert/key.
|
||||||
fn dirs_data_dir() -> std::path::PathBuf {
|
fn dirs_data_dir() -> std::path::PathBuf {
|
||||||
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
let home = std::env::var("HOME").unwrap_or_else(|_| "/tmp".to_string());
|
||||||
std::path::PathBuf::from(home).join(".config").join("antigravity-proxy")
|
std::path::PathBuf::from(home)
|
||||||
|
.join(".config")
|
||||||
|
.join("antigravity-proxy")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,8 +4,8 @@
|
|||||||
//! Dynamically generates per-domain leaf certificates signed by this CA.
|
//! Dynamically generates per-domain leaf certificates signed by this CA.
|
||||||
|
|
||||||
use rcgen::{
|
use rcgen::{
|
||||||
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose,
|
BasicConstraints, CertificateParams, DistinguishedName, DnType, ExtendedKeyUsagePurpose, IsCa,
|
||||||
IsCa, KeyPair, KeyUsagePurpose, SanType,
|
KeyPair, KeyUsagePurpose, SanType,
|
||||||
};
|
};
|
||||||
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@@ -45,15 +45,16 @@ impl MitmCa {
|
|||||||
let key_pem = std::fs::read_to_string(&key_path)
|
let key_pem = std::fs::read_to_string(&key_path)
|
||||||
.map_err(|e| format!("Failed to read CA key: {e}"))?;
|
.map_err(|e| format!("Failed to read CA key: {e}"))?;
|
||||||
|
|
||||||
let ca_key = KeyPair::from_pem(&key_pem)
|
let ca_key =
|
||||||
.map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
KeyPair::from_pem(&key_pem).map_err(|e| format!("Failed to parse CA key: {e}"))?;
|
||||||
|
|
||||||
// Re-create params and self-sign to get the rcgen Certificate object
|
// Re-create params and self-sign to get the rcgen Certificate object
|
||||||
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
|
// (needed for signing leaf certs — rcgen 0.13 doesn't have from_ca_cert_pem).
|
||||||
// The re-signed cert will have a different serial/notBefore, but that's fine
|
// The re-signed cert will have a different serial/notBefore, but that's fine
|
||||||
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
|
// because we only use it for the rcgen signing API, NOT for the on-disk PEM.
|
||||||
let params = Self::ca_params();
|
let params = Self::ca_params();
|
||||||
let ca_signed = params.self_signed(&ca_key)
|
let ca_signed = params
|
||||||
|
.self_signed(&ca_key)
|
||||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||||
|
|
||||||
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
|
// Use the ORIGINAL on-disk PEM cert for DER — this is what the LS trusts
|
||||||
@@ -76,11 +77,12 @@ impl MitmCa {
|
|||||||
std::fs::create_dir_all(data_dir)
|
std::fs::create_dir_all(data_dir)
|
||||||
.map_err(|e| format!("Failed to create data dir: {e}"))?;
|
.map_err(|e| format!("Failed to create data dir: {e}"))?;
|
||||||
|
|
||||||
let ca_key = KeyPair::generate()
|
let ca_key =
|
||||||
.map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
KeyPair::generate().map_err(|e| format!("Failed to generate CA key: {e}"))?;
|
||||||
|
|
||||||
let params = Self::ca_params();
|
let params = Self::ca_params();
|
||||||
let ca_signed = params.self_signed(&ca_key)
|
let ca_signed = params
|
||||||
|
.self_signed(&ca_key)
|
||||||
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
.map_err(|e| format!("Failed to self-sign CA: {e}"))?;
|
||||||
|
|
||||||
// Write cert and key to disk
|
// Write cert and key to disk
|
||||||
@@ -117,10 +119,7 @@ impl MitmCa {
|
|||||||
params.distinguished_name = dn;
|
params.distinguished_name = dn;
|
||||||
|
|
||||||
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
|
||||||
params.key_usages = vec![
|
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
|
||||||
KeyUsagePurpose::KeyCertSign,
|
|
||||||
KeyUsagePurpose::CrlSign,
|
|
||||||
];
|
|
||||||
|
|
||||||
// Valid for 10 years
|
// Valid for 10 years
|
||||||
let now = time::OffsetDateTime::now_utc();
|
let now = time::OffsetDateTime::now_utc();
|
||||||
@@ -151,12 +150,17 @@ impl MitmCa {
|
|||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
let der = base64::engine::general_purpose::STANDARD.decode(&b64).ok()?;
|
let der = base64::engine::general_purpose::STANDARD
|
||||||
|
.decode(&b64)
|
||||||
|
.ok()?;
|
||||||
Some(CertificateDer::from(der))
|
Some(CertificateDer::from(der))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get or create a TLS ServerConfig for the given domain.
|
/// Get or create a TLS ServerConfig for the given domain.
|
||||||
pub async fn server_config_for_domain(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>, String> {
|
pub async fn server_config_for_domain(
|
||||||
|
&self,
|
||||||
|
domain: &str,
|
||||||
|
) -> Result<Arc<rustls::ServerConfig>, String> {
|
||||||
// Check cache first
|
// Check cache first
|
||||||
{
|
{
|
||||||
let cache = self.domain_cache.read().await;
|
let cache = self.domain_cache.read().await;
|
||||||
@@ -172,7 +176,11 @@ impl MitmCa {
|
|||||||
dn.push(DnType::CommonName, domain);
|
dn.push(DnType::CommonName, domain);
|
||||||
params.distinguished_name = dn;
|
params.distinguished_name = dn;
|
||||||
|
|
||||||
params.subject_alt_names = vec![SanType::DnsName(domain.try_into().map_err(|e| format!("Invalid domain: {e}"))?)];
|
params.subject_alt_names = vec![SanType::DnsName(
|
||||||
|
domain
|
||||||
|
.try_into()
|
||||||
|
.map_err(|e| format!("Invalid domain: {e}"))?,
|
||||||
|
)];
|
||||||
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
|
||||||
params.key_usages = vec![
|
params.key_usages = vec![
|
||||||
KeyUsagePurpose::DigitalSignature,
|
KeyUsagePurpose::DigitalSignature,
|
||||||
@@ -184,10 +192,11 @@ impl MitmCa {
|
|||||||
params.not_before = now;
|
params.not_before = now;
|
||||||
params.not_after = now + time::Duration::days(365);
|
params.not_after = now + time::Duration::days(365);
|
||||||
|
|
||||||
let leaf_key = KeyPair::generate()
|
let leaf_key =
|
||||||
.map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
KeyPair::generate().map_err(|e| format!("Failed to generate leaf key: {e}"))?;
|
||||||
|
|
||||||
let leaf_cert = params.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
let leaf_cert = params
|
||||||
|
.signed_by(&leaf_key, &self.ca_signed, &self.ca_key)
|
||||||
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
|
.map_err(|e| format!("Failed to sign leaf cert for {domain}: {e}"))?;
|
||||||
|
|
||||||
// Build rustls ServerConfig
|
// Build rustls ServerConfig
|
||||||
@@ -196,10 +205,7 @@ impl MitmCa {
|
|||||||
|
|
||||||
let mut config = rustls::ServerConfig::builder()
|
let mut config = rustls::ServerConfig::builder()
|
||||||
.with_no_client_auth()
|
.with_no_client_auth()
|
||||||
.with_single_cert(
|
.with_single_cert(vec![leaf_cert_der, self.ca_cert_der.clone()], leaf_key_der)
|
||||||
vec![leaf_cert_der, self.ca_cert_der.clone()],
|
|
||||||
leaf_key_der,
|
|
||||||
)
|
|
||||||
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
|
.map_err(|e| format!("Failed to build ServerConfig for {domain}: {e}"))?;
|
||||||
|
|
||||||
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
|
// Advertise both h2 and http/1.1 so gRPC clients can negotiate HTTP/2
|
||||||
|
|||||||
@@ -92,11 +92,10 @@ impl UpstreamPool {
|
|||||||
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
.map_err(|e| format!("upstream TLS to {} failed: {e}", self.domain))?;
|
||||||
|
|
||||||
let upstream_io = TokioIo::new(upstream_tls);
|
let upstream_io = TokioIo::new(upstream_tls);
|
||||||
let (sender, conn) =
|
let (sender, conn) = hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
||||||
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
|
.handshake(upstream_io)
|
||||||
.handshake(upstream_io)
|
.await
|
||||||
.await
|
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
||||||
.map_err(|e| format!("upstream h2 handshake to {} failed: {e}", self.domain))?;
|
|
||||||
|
|
||||||
let domain = self.domain.clone();
|
let domain = self.domain.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
@@ -215,12 +214,10 @@ async fn handle_h2_request(
|
|||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
// Check if this method carries usage data
|
// Check if this method carries usage data
|
||||||
let is_usage_method = is_grpc
|
let is_usage_method = is_grpc && USAGE_METHODS.iter().any(|m| path.contains(m));
|
||||||
&& USAGE_METHODS.iter().any(|m| path.contains(m));
|
|
||||||
|
|
||||||
// Check if this is a streaming method
|
// Check if this is a streaming method
|
||||||
let is_streaming = is_grpc
|
let is_streaming = is_grpc && (path.contains("Stream") || path.contains("stream"));
|
||||||
&& (path.contains("Stream") || path.contains("stream"));
|
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
domain,
|
domain,
|
||||||
@@ -249,9 +246,9 @@ async fn handle_h2_request(
|
|||||||
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
warn!(error = %e, domain, "MITM H2: upstream connect failed");
|
||||||
let resp = Response::builder()
|
let resp = Response::builder()
|
||||||
.status(502)
|
.status(502)
|
||||||
.body(http_body_util::Either::Left(Full::new(
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||||
Bytes::from(format!("upstream connect failed: {e}")),
|
format!("upstream connect failed: {e}"),
|
||||||
)))
|
))))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
@@ -261,17 +258,11 @@ async fn handle_h2_request(
|
|||||||
let upstream_uri = http::Uri::builder()
|
let upstream_uri = http::Uri::builder()
|
||||||
.scheme("https")
|
.scheme("https")
|
||||||
.authority(domain)
|
.authority(domain)
|
||||||
.path_and_query(
|
.path_and_query(uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/"))
|
||||||
uri.path_and_query()
|
|
||||||
.map(|pq| pq.as_str())
|
|
||||||
.unwrap_or("/"),
|
|
||||||
)
|
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or(uri);
|
.unwrap_or(uri);
|
||||||
|
|
||||||
let mut upstream_req = Request::builder()
|
let mut upstream_req = Request::builder().method(parts.method).uri(upstream_uri);
|
||||||
.method(parts.method)
|
|
||||||
.uri(upstream_uri);
|
|
||||||
|
|
||||||
// Copy headers, skip hop-by-hop
|
// Copy headers, skip hop-by-hop
|
||||||
for (name, value) in &parts.headers {
|
for (name, value) in &parts.headers {
|
||||||
@@ -287,9 +278,9 @@ async fn handle_h2_request(
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
let resp = Response::builder()
|
let resp = Response::builder()
|
||||||
.status(502)
|
.status(502)
|
||||||
.body(http_body_util::Either::Left(Full::new(
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||||
Bytes::from(format!("build request failed: {e}")),
|
format!("build request failed: {e}"),
|
||||||
)))
|
))))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
@@ -302,9 +293,9 @@ async fn handle_h2_request(
|
|||||||
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
warn!(error = %e, domain, path = %path, "MITM H2: upstream request failed");
|
||||||
let resp = Response::builder()
|
let resp = Response::builder()
|
||||||
.status(502)
|
.status(502)
|
||||||
.body(http_body_util::Either::Left(Full::new(
|
.body(http_body_util::Either::Left(Full::new(Bytes::from(
|
||||||
Bytes::from(format!("upstream request failed: {e}")),
|
format!("upstream request failed: {e}"),
|
||||||
)))
|
))))
|
||||||
.unwrap();
|
.unwrap();
|
||||||
return Ok(resp);
|
return Ok(resp);
|
||||||
}
|
}
|
||||||
@@ -326,13 +317,18 @@ async fn handle_h2_request(
|
|||||||
|
|
||||||
// Spawn a task to forward body chunks and tee for usage extraction
|
// Spawn a task to forward body chunks and tee for usage extraction
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut tee_buffer = if should_track_usage { Some(Vec::new()) } else { None };
|
let mut tee_buffer = if should_track_usage {
|
||||||
|
Some(Vec::new())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let mut body = resp_body;
|
let mut body = resp_body;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
match body.frame().await {
|
match body.frame().await {
|
||||||
Some(Ok(frame)) => {
|
Some(Ok(frame)) => {
|
||||||
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref()) {
|
if let (Some(ref mut buf), Some(data)) = (&mut tee_buffer, frame.data_ref())
|
||||||
|
{
|
||||||
buf.extend_from_slice(data);
|
buf.extend_from_slice(data);
|
||||||
}
|
}
|
||||||
if tx.send(Ok(frame)).await.is_err() {
|
if tx.send(Ok(frame)).await.is_err() {
|
||||||
@@ -354,7 +350,9 @@ async fn handle_h2_request(
|
|||||||
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
if let Some(grpc_usage) = parse_grpc_response_for_usage(&tee_buffer) {
|
||||||
let usage = grpc_usage.into_api_usage(path_clone.clone());
|
let usage = grpc_usage.into_api_usage(path_clone.clone());
|
||||||
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
let cascade_hint = extract_cascade_from_grpc_request(&request_body_clone);
|
||||||
store_clone.record_usage(cascade_hint.as_deref(), usage).await;
|
store_clone
|
||||||
|
.record_usage(cascade_hint.as_deref(), usage)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -78,15 +78,21 @@ impl StreamingAccumulator {
|
|||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Process a single SSE event.
|
/// Process a single SSE event.
|
||||||
pub fn process_event(&mut self, event: &Value) {
|
pub fn process_event(&mut self, event: &Value) {
|
||||||
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
|
// ── Google format: {"response": {"usageMetadata": {...}, "modelVersion": "..."}} ──
|
||||||
if let Some(response) = event.get("response") {
|
if let Some(response) = event.get("response") {
|
||||||
// Extract usage metadata (each event has cumulative counts)
|
// Extract usage metadata (each event has cumulative counts)
|
||||||
if let Some(usage) = response.get("usageMetadata") {
|
if let Some(usage) = response.get("usageMetadata") {
|
||||||
self.input_tokens = usage["promptTokenCount"].as_u64().unwrap_or(self.input_tokens);
|
self.input_tokens = usage["promptTokenCount"]
|
||||||
self.output_tokens = usage["candidatesTokenCount"].as_u64().unwrap_or(self.output_tokens);
|
.as_u64()
|
||||||
self.thinking_tokens = usage["thoughtsTokenCount"].as_u64().unwrap_or(self.thinking_tokens);
|
.unwrap_or(self.input_tokens);
|
||||||
|
self.output_tokens = usage["candidatesTokenCount"]
|
||||||
|
.as_u64()
|
||||||
|
.unwrap_or(self.output_tokens);
|
||||||
|
self.thinking_tokens = usage["thoughtsTokenCount"]
|
||||||
|
.as_u64()
|
||||||
|
.unwrap_or(self.thinking_tokens);
|
||||||
}
|
}
|
||||||
if let Some(model) = response["modelVersion"].as_str() {
|
if let Some(model) = response["modelVersion"].as_str() {
|
||||||
self.model = Some(model.to_string());
|
self.model = Some(model.to_string());
|
||||||
@@ -170,8 +176,10 @@ impl StreamingAccumulator {
|
|||||||
"message_start" => {
|
"message_start" => {
|
||||||
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
if let Some(usage) = event.get("message").and_then(|m| m.get("usage")) {
|
||||||
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
self.input_tokens = usage["input_tokens"].as_u64().unwrap_or(0);
|
||||||
self.cache_creation_input_tokens = usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
self.cache_creation_input_tokens =
|
||||||
self.cache_read_input_tokens = usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
usage["cache_creation_input_tokens"].as_u64().unwrap_or(0);
|
||||||
|
self.cache_read_input_tokens =
|
||||||
|
usage["cache_read_input_tokens"].as_u64().unwrap_or(0);
|
||||||
}
|
}
|
||||||
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
if let Some(model) = event.get("message").and_then(|m| m["model"].as_str()) {
|
||||||
self.model = Some(model.to_string());
|
self.model = Some(model.to_string());
|
||||||
@@ -181,7 +189,9 @@ impl StreamingAccumulator {
|
|||||||
}
|
}
|
||||||
"message_delta" => {
|
"message_delta" => {
|
||||||
if let Some(usage) = event.get("usage") {
|
if let Some(usage) = event.get("usage") {
|
||||||
self.output_tokens = usage["output_tokens"].as_u64().unwrap_or(self.output_tokens);
|
self.output_tokens = usage["output_tokens"]
|
||||||
|
.as_u64()
|
||||||
|
.unwrap_or(self.output_tokens);
|
||||||
}
|
}
|
||||||
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
if let Some(reason) = event["delta"]["stop_reason"].as_str() {
|
||||||
self.stop_reason = Some(reason.to_string());
|
self.stop_reason = Some(reason.to_string());
|
||||||
@@ -235,7 +245,10 @@ impl StreamingAccumulator {
|
|||||||
response_output_tokens: 0,
|
response_output_tokens: 0,
|
||||||
model: self.model,
|
model: self.model,
|
||||||
stop_reason: self.stop_reason,
|
stop_reason: self.stop_reason,
|
||||||
api_provider: self.api_provider.unwrap_or_else(|| "unknown".to_string()).into(),
|
api_provider: self
|
||||||
|
.api_provider
|
||||||
|
.unwrap_or_else(|| "unknown".to_string())
|
||||||
|
.into(),
|
||||||
grpc_method: None,
|
grpc_method: None,
|
||||||
captured_at: std::time::SystemTime::now()
|
captured_at: std::time::SystemTime::now()
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
|||||||
@@ -68,14 +68,14 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
"system instruction: keep <identity> only ({original_len} → {} chars, -{stripped})",
|
"system instruction: keep <identity> only ({original_len} → {} chars, -{stripped})",
|
||||||
new_sys.len()
|
new_sys.len()
|
||||||
));
|
));
|
||||||
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(new_sys);
|
||||||
Value::String(new_sys);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// No identity tag found — clear the whole thing
|
// No identity tag found — clear the whole thing
|
||||||
changes.push(format!("system instruction: cleared ({original_len} chars)"));
|
changes.push(format!(
|
||||||
json["request"]["systemInstruction"]["parts"][0]["text"] =
|
"system instruction: cleared ({original_len} chars)"
|
||||||
Value::String(String::new());
|
));
|
||||||
|
json["request"]["systemInstruction"]["parts"][0]["text"] = Value::String(String::new());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -125,7 +125,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
let mut modified = text.clone();
|
let mut modified = text.clone();
|
||||||
|
|
||||||
// Strip conversation summaries block
|
// Strip conversation summaries block
|
||||||
if let Some(cleaned) = strip_between(&modified, "# Conversation History\n", "</conversation_summaries>") {
|
if let Some(cleaned) = strip_between(
|
||||||
|
&modified,
|
||||||
|
"# Conversation History\n",
|
||||||
|
"</conversation_summaries>",
|
||||||
|
) {
|
||||||
modified = cleaned;
|
modified = cleaned;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,7 +151,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Strip knowledge item blocks
|
// Strip knowledge item blocks
|
||||||
if let Some(cleaned) = strip_between(&modified, "Here are the ", "</knowledge_item>") {
|
if let Some(cleaned) =
|
||||||
|
strip_between(&modified, "Here are the ", "</knowledge_item>")
|
||||||
|
{
|
||||||
// Only strip if it's about knowledge items
|
// Only strip if it's about knowledge items
|
||||||
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
|
if cleaned.len() < modified.len() && modified.contains("knowledge item") {
|
||||||
modified = cleaned;
|
modified = cleaned;
|
||||||
@@ -202,7 +208,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
// Inject client-provided tools from ToolContext
|
// Inject client-provided tools from ToolContext
|
||||||
if let Some(ref ctx) = tool_ctx {
|
if let Some(ref ctx) = tool_ctx {
|
||||||
if let Some(ref custom_tools) = ctx.tools {
|
if let Some(ref custom_tools) = ctx.tools {
|
||||||
let total_decls: usize = custom_tools.iter()
|
let total_decls: usize = custom_tools
|
||||||
|
.iter()
|
||||||
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
|
.filter_map(|t| t.get("functionDeclarations").and_then(|d| d.as_array()))
|
||||||
.map(|a| a.len())
|
.map(|a| a.len())
|
||||||
.sum();
|
.sum();
|
||||||
@@ -210,7 +217,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
tools.push(tool.clone());
|
tools.push(tool.clone());
|
||||||
}
|
}
|
||||||
has_custom_tools = true;
|
has_custom_tools = true;
|
||||||
changes.push(format!("inject {} custom tool group(s)", custom_tools.len()));
|
changes.push(format!(
|
||||||
|
"inject {} custom tool group(s)",
|
||||||
|
custom_tools.len()
|
||||||
|
));
|
||||||
|
|
||||||
// Override LS's VALIDATED toolConfig → AUTO for custom tools.
|
// Override LS's VALIDATED toolConfig → AUTO for custom tools.
|
||||||
// VALIDATED mode forces Google to validate function calls against a
|
// VALIDATED mode forces Google to validate function calls against a
|
||||||
@@ -218,16 +228,20 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
// that list, so they'd be rejected. AUTO lets the model freely choose
|
// that list, so they'd be rejected. AUTO lets the model freely choose
|
||||||
// between text and function calls.
|
// between text and function calls.
|
||||||
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||||
let has_validated = req.get("toolConfig")
|
let has_validated = req
|
||||||
|
.get("toolConfig")
|
||||||
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
|
.and_then(|tc| tc.pointer("/functionCallingConfig/mode"))
|
||||||
.and_then(|m| m.as_str())
|
.and_then(|m| m.as_str())
|
||||||
.map_or(false, |m| m == "VALIDATED");
|
.map_or(false, |m| m == "VALIDATED");
|
||||||
if has_validated {
|
if has_validated {
|
||||||
req.insert("toolConfig".to_string(), serde_json::json!({
|
req.insert(
|
||||||
"functionCallingConfig": {
|
"toolConfig".to_string(),
|
||||||
"mode": "AUTO"
|
serde_json::json!({
|
||||||
}
|
"functionCallingConfig": {
|
||||||
}));
|
"mode": "AUTO"
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
);
|
||||||
changes.push("override toolConfig VALIDATED → AUTO".to_string());
|
changes.push("override toolConfig VALIDATED → AUTO".to_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -243,7 +257,11 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
if STRIP_ALL_TOOLS && !has_custom_tools {
|
if STRIP_ALL_TOOLS && !has_custom_tools {
|
||||||
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||||
// Remove the empty tools array entirely
|
// Remove the empty tools array entirely
|
||||||
if req.get("tools").and_then(|v| v.as_array()).map_or(false, |a| a.is_empty()) {
|
if req
|
||||||
|
.get("tools")
|
||||||
|
.and_then(|v| v.as_array())
|
||||||
|
.map_or(false, |a| a.is_empty())
|
||||||
|
{
|
||||||
req.remove("tools");
|
req.remove("tools");
|
||||||
changes.push("remove empty tools array".to_string());
|
changes.push("remove empty tools array".to_string());
|
||||||
}
|
}
|
||||||
@@ -266,7 +284,8 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|ctx| ctx.tools.as_ref())
|
.and_then(|ctx| ctx.tools.as_ref())
|
||||||
.map(|tools| {
|
.map(|tools| {
|
||||||
tools.iter()
|
tools
|
||||||
|
.iter()
|
||||||
.filter_map(|t| t["functionDeclarations"].as_array())
|
.filter_map(|t| t["functionDeclarations"].as_array())
|
||||||
.flatten()
|
.flatten()
|
||||||
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
|
.filter_map(|decl| decl["name"].as_str().map(|s| s.to_string()))
|
||||||
@@ -309,7 +328,9 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
.map_or(true, |parts| !parts.is_empty())
|
.map_or(true, |parts| !parts.is_empty())
|
||||||
});
|
});
|
||||||
if stripped_fc > 0 {
|
if stripped_fc > 0 {
|
||||||
changes.push(format!("strip {stripped_fc} functionCall/Response parts from history"));
|
changes.push(format!(
|
||||||
|
"strip {stripped_fc} functionCall/Response parts from history"
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -336,16 +357,22 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
for msg in contents.iter_mut() {
|
for msg in contents.iter_mut() {
|
||||||
if msg["role"].as_str() == Some("model") {
|
if msg["role"].as_str() == Some("model") {
|
||||||
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
if let Some(text) = msg["parts"][0]["text"].as_str() {
|
||||||
if text.contains("Tool call completed") || text.contains("Awaiting external tool result") {
|
if text.contains("Tool call completed")
|
||||||
|
|| text.contains("Awaiting external tool result")
|
||||||
|
{
|
||||||
// Replace with functionCall parts
|
// Replace with functionCall parts
|
||||||
let fc_parts: Vec<Value> = ctx.last_calls.iter().map(|fc| {
|
let fc_parts: Vec<Value> = ctx
|
||||||
serde_json::json!({
|
.last_calls
|
||||||
"functionCall": {
|
.iter()
|
||||||
"name": fc.name,
|
.map(|fc| {
|
||||||
"args": fc.args,
|
serde_json::json!({
|
||||||
}
|
"functionCall": {
|
||||||
|
"name": fc.name,
|
||||||
|
"args": fc.args,
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).collect();
|
.collect();
|
||||||
msg["parts"] = Value::Array(fc_parts);
|
msg["parts"] = Value::Array(fc_parts);
|
||||||
changes.push("rewrite model turn with functionCall".to_string());
|
changes.push("rewrite model turn with functionCall".to_string());
|
||||||
break;
|
break;
|
||||||
@@ -355,29 +382,36 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add functionResponse as a user turn before the last user message
|
// Add functionResponse as a user turn before the last user message
|
||||||
let fn_response_parts: Vec<Value> = ctx.pending_results.iter().map(|r| {
|
let fn_response_parts: Vec<Value> = ctx
|
||||||
serde_json::json!({
|
.pending_results
|
||||||
"functionResponse": {
|
.iter()
|
||||||
"name": r.name,
|
.map(|r| {
|
||||||
"response": r.result,
|
serde_json::json!({
|
||||||
}
|
"functionResponse": {
|
||||||
|
"name": r.name,
|
||||||
|
"response": r.result,
|
||||||
|
}
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}).collect();
|
.collect();
|
||||||
let fn_response_turn = serde_json::json!({
|
let fn_response_turn = serde_json::json!({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"parts": fn_response_parts,
|
"parts": fn_response_parts,
|
||||||
});
|
});
|
||||||
|
|
||||||
// Insert before the last user message
|
// Insert before the last user message
|
||||||
let last_user_idx = contents.iter().rposition(|msg| {
|
let last_user_idx = contents
|
||||||
msg["role"].as_str() == Some("user")
|
.iter()
|
||||||
});
|
.rposition(|msg| msg["role"].as_str() == Some("user"));
|
||||||
if let Some(idx) = last_user_idx {
|
if let Some(idx) = last_user_idx {
|
||||||
contents.insert(idx, fn_response_turn);
|
contents.insert(idx, fn_response_turn);
|
||||||
} else {
|
} else {
|
||||||
contents.push(fn_response_turn);
|
contents.push(fn_response_turn);
|
||||||
}
|
}
|
||||||
changes.push(format!("inject {} functionResponse(s)", ctx.pending_results.len()));
|
changes.push(format!(
|
||||||
|
"inject {} functionResponse(s)",
|
||||||
|
ctx.pending_results.len()
|
||||||
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -420,8 +454,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
} else {
|
} else {
|
||||||
// Not wrapped in request — try top-level (public API format)
|
// Not wrapped in request — try top-level (public API format)
|
||||||
let gen_config = json.as_object_mut().and_then(|o| {
|
let gen_config = json.as_object_mut().and_then(|o| {
|
||||||
Some(o.entry("generationConfig")
|
Some(
|
||||||
.or_insert_with(|| serde_json::json!({})))
|
o.entry("generationConfig")
|
||||||
|
.or_insert_with(|| serde_json::json!({})),
|
||||||
|
)
|
||||||
});
|
});
|
||||||
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
|
if let Some(gc) = gen_config.and_then(|v| v.as_object_mut()) {
|
||||||
let thinking_config = gc
|
let thinking_config = gc
|
||||||
@@ -449,8 +485,10 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
if let Some(ref gp) = ctx.generation_params {
|
if let Some(ref gp) = ctx.generation_params {
|
||||||
// Find or create generationConfig (same path as above)
|
// Find or create generationConfig (same path as above)
|
||||||
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
let gc = if let Some(req) = json.get_mut("request").and_then(|v| v.as_object_mut()) {
|
||||||
Some(req.entry("generationConfig")
|
Some(
|
||||||
.or_insert_with(|| serde_json::json!({})))
|
req.entry("generationConfig")
|
||||||
|
.or_insert_with(|| serde_json::json!({})),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
json.as_object_mut().map(|o| {
|
json.as_object_mut().map(|o| {
|
||||||
o.entry("generationConfig")
|
o.entry("generationConfig")
|
||||||
@@ -564,8 +602,6 @@ pub fn modify_request(body: &[u8], tool_ctx: Option<&ToolContext>) -> Option<Vec
|
|||||||
changes.join(", ")
|
changes.join(", ")
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Some(modified_bytes)
|
Some(modified_bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -832,8 +868,10 @@ mod tests {
|
|||||||
let result: Value = serde_json::from_slice(&modified).unwrap();
|
let result: Value = serde_json::from_slice(&modified).unwrap();
|
||||||
|
|
||||||
// With no ToolContext, tools should be removed entirely
|
// With no ToolContext, tools should be removed entirely
|
||||||
assert!(result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
assert!(
|
||||||
"tools should be removed when no custom tools provided");
|
result["request"]["tools"].is_null() || result.pointer("/request/tools").is_none(),
|
||||||
|
"tools should be removed when no custom tools provided"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -892,13 +930,23 @@ mod tests {
|
|||||||
let contents = result["request"]["contents"].as_array().unwrap();
|
let contents = result["request"]["contents"].as_array().unwrap();
|
||||||
// Should have removed user_information, user_rules, workflows (3 messages)
|
// Should have removed user_information, user_rules, workflows (3 messages)
|
||||||
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
|
// Kept: USER_REQUEST message (with ADDITIONAL_METADATA stripped) + model response
|
||||||
assert_eq!(contents.len(), 2, "should keep only user request + model response");
|
assert_eq!(
|
||||||
|
contents.len(),
|
||||||
|
2,
|
||||||
|
"should keep only user request + model response"
|
||||||
|
);
|
||||||
|
|
||||||
// Check USER_REQUEST message had metadata stripped
|
// Check USER_REQUEST message had metadata stripped
|
||||||
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
|
let user_msg = contents[0]["parts"][0]["text"].as_str().unwrap();
|
||||||
assert!(user_msg.contains("Say hello"), "should keep user request");
|
assert!(user_msg.contains("Say hello"), "should keep user request");
|
||||||
assert!(!user_msg.contains("ADDITIONAL_METADATA"), "should strip metadata");
|
assert!(
|
||||||
assert!(!user_msg.contains("cursor stuff"), "should strip cursor info");
|
!user_msg.contains("ADDITIONAL_METADATA"),
|
||||||
|
"should strip metadata"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!user_msg.contains("cursor stuff"),
|
||||||
|
"should strip cursor info"
|
||||||
|
);
|
||||||
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
|
assert!(!user_msg.starts_with("Step Id:"), "should strip step id");
|
||||||
|
|
||||||
// Model response kept intact
|
// Model response kept intact
|
||||||
@@ -921,8 +969,14 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_strip_between() {
|
fn test_strip_between() {
|
||||||
let text = "keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
let text =
|
||||||
let result = strip_between(text, "# Conversation History\n", "</conversation_summaries>").unwrap();
|
"keep this # Conversation History\nlots of stuff\n</conversation_summaries>\nand this";
|
||||||
|
let result = strip_between(
|
||||||
|
text,
|
||||||
|
"# Conversation History\n",
|
||||||
|
"</conversation_summaries>",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
assert_eq!(result, "keep this and this");
|
assert_eq!(result, "keep this and this");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -977,7 +1031,9 @@ pub fn modify_response_chunk(chunk: &[u8]) -> Option<Vec<u8>> {
|
|||||||
// Replace the JSON in the result string
|
// Replace the JSON in the result string
|
||||||
result.replace_range(json_start..json_start + json_end, &new_json);
|
result.replace_range(json_start..json_start + json_end, &new_json);
|
||||||
changed = true;
|
changed = true;
|
||||||
info!("MITM: rewrote functionCall in response → text placeholder for LS");
|
info!(
|
||||||
|
"MITM: rewrote functionCall in response → text placeholder for LS"
|
||||||
|
);
|
||||||
search_from = json_start + new_json.len();
|
search_from = json_start + new_json.len();
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -1117,7 +1173,10 @@ fn rewrite_function_calls_in_response(json: &mut Value) -> bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Try nested "response.candidates"
|
// Try nested "response.candidates"
|
||||||
if let Some(candidates) = json.pointer_mut("/response/candidates").and_then(|v| v.as_array_mut()) {
|
if let Some(candidates) = json
|
||||||
|
.pointer_mut("/response/candidates")
|
||||||
|
.and_then(|v| v.as_array_mut())
|
||||||
|
{
|
||||||
changed |= rewrite_candidates(candidates);
|
changed |= rewrite_candidates(candidates);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -251,7 +251,10 @@ fn looks_like_valid_message(fields: &[ProtoField], original_len: usize) -> bool
|
|||||||
// (e.g., a long string that happened to have a valid first-field prefix)
|
// (e.g., a long string that happened to have a valid first-field prefix)
|
||||||
if fields.len() == 1 && original_len > 100 {
|
if fields.len() == 1 && original_len > 100 {
|
||||||
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
// Single-field messages of >100 bytes are suspicious unless the field is bytes/message
|
||||||
matches!(&fields[0].value, ProtoValue::Bytes(_) | ProtoValue::Message(_))
|
matches!(
|
||||||
|
&fields[0].value,
|
||||||
|
ProtoValue::Bytes(_) | ProtoValue::Message(_)
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@@ -328,7 +331,9 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
|||||||
.iter()
|
.iter()
|
||||||
.filter_map(|f| {
|
.filter_map(|f| {
|
||||||
if let ProtoValue::Bytes(ref b) = f.value {
|
if let ProtoValue::Bytes(ref b) = f.value {
|
||||||
std::str::from_utf8(b).ok().map(|s| (f.number, s.to_string()))
|
std::str::from_utf8(b)
|
||||||
|
.ok()
|
||||||
|
.map(|s| (f.number, s.to_string()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -361,14 +366,23 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
|||||||
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
// Check if there's a model-like string (field 7 = message_id or field 11 = response_id
|
||||||
// can contain model names, or model enum values map to known names)
|
// can contain model names, or model enum values map to known names)
|
||||||
let has_model_string = string_fields.iter().any(|(_, s)| {
|
let has_model_string = string_fields.iter().any(|(_, s)| {
|
||||||
s.contains("claude") || s.contains("gemini") || s.contains("gpt")
|
s.contains("claude")
|
||||||
|| s.starts_with("models/") || s.contains("sonnet") || s.contains("opus")
|
|| s.contains("gemini")
|
||||||
|| s.contains("flash") || s.contains("pro")
|
|| s.contains("gpt")
|
||||||
|
|| s.starts_with("models/")
|
||||||
|
|| s.contains("sonnet")
|
||||||
|
|| s.contains("opus")
|
||||||
|
|| s.contains("flash")
|
||||||
|
|| s.contains("pro")
|
||||||
});
|
});
|
||||||
|
|
||||||
// Check for fields at the known ModelUsageStats field numbers
|
// Check for fields at the known ModelUsageStats field numbers
|
||||||
let has_field_2 = fields.iter().any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
let has_field_2 = fields
|
||||||
let has_field_3 = fields.iter().any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
.iter()
|
||||||
|
.any(|f| f.number == 2 && matches!(f.value, ProtoValue::Varint(_)));
|
||||||
|
let has_field_3 = fields
|
||||||
|
.iter()
|
||||||
|
.any(|f| f.number == 3 && matches!(f.value, ProtoValue::Varint(_)));
|
||||||
|
|
||||||
// Strong signal: has both input and output token fields
|
// Strong signal: has both input and output token fields
|
||||||
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
let is_likely_usage = (has_field_2 && has_field_3) || has_model_string;
|
||||||
@@ -392,8 +406,8 @@ fn try_extract_usage(fields: &[ProtoField]) -> Option<GrpcUsage> {
|
|||||||
// field 1 = model enum (varint, not string!)
|
// field 1 = model enum (varint, not string!)
|
||||||
2 => usage.input_tokens = v,
|
2 => usage.input_tokens = v,
|
||||||
3 => usage.output_tokens = v,
|
3 => usage.output_tokens = v,
|
||||||
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
4 => usage.cache_write_tokens = v, // VERIFIED: field 4
|
||||||
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
5 => usage.cache_read_tokens = v, // VERIFIED: field 5
|
||||||
// field 6 = api_provider enum (varint)
|
// field 6 = api_provider enum (varint)
|
||||||
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
9 => usage.thinking_output_tokens = v, // VERIFIED: field 9
|
||||||
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
10 => usage.response_output_tokens = v, // VERIFIED: field 10
|
||||||
@@ -486,11 +500,11 @@ pub fn parse_grpc_response_for_usage(body: &[u8]) -> Option<GrpcUsage> {
|
|||||||
fn model_enum_name(enum_val: u64) -> &'static str {
|
fn model_enum_name(enum_val: u64) -> &'static str {
|
||||||
match enum_val {
|
match enum_val {
|
||||||
// Placeholder models (1000 + N)
|
// Placeholder models (1000 + N)
|
||||||
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
|
1007 => "gemini-3-pro", // MODEL_PLACEHOLDER_M7
|
||||||
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
|
1008 => "gemini-3-pro-high", // MODEL_PLACEHOLDER_M8
|
||||||
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
|
1012 => "claude-opus-4.5", // MODEL_PLACEHOLDER_M12
|
||||||
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
|
1018 => "gemini-3-flash", // MODEL_PLACEHOLDER_M18
|
||||||
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
|
1026 => "claude-opus-4.6", // MODEL_PLACEHOLDER_M26
|
||||||
|
|
||||||
// Claude models (named)
|
// Claude models (named)
|
||||||
281 => "claude-4-sonnet",
|
281 => "claude-4-sonnet",
|
||||||
@@ -629,13 +643,13 @@ mod tests {
|
|||||||
data.push(v as u8);
|
data.push(v as u8);
|
||||||
}
|
}
|
||||||
|
|
||||||
encode_varint_field(&mut data, 1, 5); // model enum
|
encode_varint_field(&mut data, 1, 5); // model enum
|
||||||
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
encode_varint_field(&mut data, 2, 1000); // input_tokens
|
||||||
encode_varint_field(&mut data, 3, 500); // output_tokens
|
encode_varint_field(&mut data, 3, 500); // output_tokens
|
||||||
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
encode_varint_field(&mut data, 4, 100); // cache_write_tokens
|
||||||
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
encode_varint_field(&mut data, 5, 200); // cache_read_tokens
|
||||||
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
encode_varint_field(&mut data, 9, 300); // thinking_output_tokens
|
||||||
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
encode_varint_field(&mut data, 10, 200); // response_output_tokens
|
||||||
|
|
||||||
let fields = decode_proto(&data);
|
let fields = decode_proto(&data);
|
||||||
let usage = try_extract_usage(&fields).expect("should extract usage");
|
let usage = try_extract_usage(&fields).expect("should extract usage");
|
||||||
|
|||||||
@@ -11,8 +11,7 @@
|
|||||||
|
|
||||||
use super::ca::MitmCa;
|
use super::ca::MitmCa;
|
||||||
use super::intercept::{
|
use super::intercept::{
|
||||||
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk,
|
extract_cascade_hint, parse_non_streaming_response, parse_streaming_chunk, StreamingAccumulator,
|
||||||
StreamingAccumulator,
|
|
||||||
};
|
};
|
||||||
use super::store::MitmStore;
|
use super::store::MitmStore;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
@@ -54,7 +53,6 @@ pub struct MitmConfig {
|
|||||||
pub modify_requests: bool,
|
pub modify_requests: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Run the MITM proxy server.
|
/// Run the MITM proxy server.
|
||||||
///
|
///
|
||||||
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
/// Returns (port, task_handle) — port it's listening on, handle to abort on shutdown.
|
||||||
@@ -84,7 +82,8 @@ pub async fn run(
|
|||||||
let ca = ca.clone();
|
let ca = ca.clone();
|
||||||
let store = store.clone();
|
let store = store.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await {
|
if let Err(e) = handle_connection(stream, ca, store, modify_requests).await
|
||||||
|
{
|
||||||
warn!(error = %e, "MITM connection error");
|
warn!(error = %e, "MITM connection error");
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@@ -131,8 +130,7 @@ async fn handle_connection(
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
.map_err(|e| format!("Peek ClientHello: {e}"))?;
|
||||||
|
|
||||||
let domain = extract_sni(&hello_buf[..n])
|
let domain = extract_sni(&hello_buf[..n]).unwrap_or_else(|| "unknown".to_string());
|
||||||
.unwrap_or_else(|| "unknown".to_string());
|
|
||||||
|
|
||||||
info!(domain, "MITM: transparent redirect (iptables)");
|
info!(domain, "MITM: transparent redirect (iptables)");
|
||||||
|
|
||||||
@@ -224,22 +222,30 @@ fn extract_sni(buf: &[u8]) -> Option<String> {
|
|||||||
let mut pos = 34; // skip version + random
|
let mut pos = 34; // skip version + random
|
||||||
|
|
||||||
// Session ID
|
// Session ID
|
||||||
if pos >= body.len() { return None; }
|
if pos >= body.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let sid_len = body[pos] as usize;
|
let sid_len = body[pos] as usize;
|
||||||
pos += 1 + sid_len;
|
pos += 1 + sid_len;
|
||||||
|
|
||||||
// Cipher suites
|
// Cipher suites
|
||||||
if pos + 2 > body.len() { return None; }
|
if pos + 2 > body.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
let cs_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||||
pos += 2 + cs_len;
|
pos += 2 + cs_len;
|
||||||
|
|
||||||
// Compression methods
|
// Compression methods
|
||||||
if pos >= body.len() { return None; }
|
if pos >= body.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let cm_len = body[pos] as usize;
|
let cm_len = body[pos] as usize;
|
||||||
pos += 1 + cm_len;
|
pos += 1 + cm_len;
|
||||||
|
|
||||||
// Extensions
|
// Extensions
|
||||||
if pos + 2 > body.len() { return None; }
|
if pos + 2 > body.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
let ext_len = u16::from_be_bytes([body[pos], body[pos + 1]]) as usize;
|
||||||
pos += 2;
|
pos += 2;
|
||||||
let ext_end = pos + ext_len.min(body.len() - pos);
|
let ext_end = pos + ext_len.min(body.len() - pos);
|
||||||
@@ -304,32 +310,32 @@ async fn handle_intercepted(
|
|||||||
info!(domain, "MITM: intercepting TLS");
|
info!(domain, "MITM: intercepting TLS");
|
||||||
|
|
||||||
// Get or create server TLS config for this domain
|
// Get or create server TLS config for this domain
|
||||||
let server_config = ca
|
let server_config = ca.server_config_for_domain(domain).await?;
|
||||||
.server_config_for_domain(domain)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let acceptor = TlsAcceptor::from(server_config);
|
let acceptor = TlsAcceptor::from(server_config);
|
||||||
|
|
||||||
// Perform TLS handshake with the client (LS) — 10s timeout
|
// Perform TLS handshake with the client (LS) — 10s timeout
|
||||||
let tls_stream = match tokio::time::timeout(
|
let tls_stream =
|
||||||
std::time::Duration::from_secs(10),
|
match tokio::time::timeout(std::time::Duration::from_secs(10), acceptor.accept(stream))
|
||||||
acceptor.accept(stream),
|
.await
|
||||||
)
|
{
|
||||||
.await
|
Ok(Ok(s)) => s,
|
||||||
{
|
Ok(Err(e)) => {
|
||||||
Ok(Ok(s)) => s,
|
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
||||||
Ok(Err(e)) => {
|
return Err(format!(
|
||||||
warn!(domain, error = %e, "MITM: TLS handshake FAILED (client rejected cert?)");
|
"TLS handshake with client failed for {domain}: {e}"
|
||||||
return Err(format!("TLS handshake with client failed for {domain}: {e}"));
|
));
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
warn!(domain, "MITM: TLS handshake TIMED OUT after 10s");
|
||||||
return Err(format!("TLS handshake timed out for {domain}"));
|
return Err(format!("TLS handshake timed out for {domain}"));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check negotiated ALPN protocol
|
// Check negotiated ALPN protocol
|
||||||
let alpn = tls_stream.get_ref().1
|
let alpn = tls_stream
|
||||||
|
.get_ref()
|
||||||
|
.1
|
||||||
.alpn_protocol()
|
.alpn_protocol()
|
||||||
.map(|p| String::from_utf8_lossy(p).to_string());
|
.map(|p| String::from_utf8_lossy(p).to_string());
|
||||||
|
|
||||||
@@ -339,12 +345,7 @@ async fn handle_intercepted(
|
|||||||
Some("h2") => {
|
Some("h2") => {
|
||||||
// HTTP/2 — use the hyper-based gRPC handler
|
// HTTP/2 — use the hyper-based gRPC handler
|
||||||
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
info!(domain, "MITM: routing to HTTP/2 handler (gRPC)");
|
||||||
super::h2_handler::handle_h2_connection(
|
super::h2_handler::handle_h2_connection(tls_stream, domain.to_string(), store).await
|
||||||
tls_stream,
|
|
||||||
domain.to_string(),
|
|
||||||
store,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// HTTP/1.1 or no ALPN — use the existing handler
|
// HTTP/1.1 or no ALPN — use the existing handler
|
||||||
@@ -434,7 +435,10 @@ async fn handle_http_over_tls(
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
let out = String::from_utf8_lossy(&output.stdout);
|
let out = String::from_utf8_lossy(&output.stdout);
|
||||||
if let Some(ip) = out.lines().find(|l| l.parse::<std::net::Ipv4Addr>().is_ok()) {
|
if let Some(ip) = out
|
||||||
|
.lines()
|
||||||
|
.find(|l| l.parse::<std::net::Ipv4Addr>().is_ok())
|
||||||
|
{
|
||||||
return format!("{ip}:443");
|
return format!("{ip}:443");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -458,7 +462,6 @@ async fn handle_http_over_tls(
|
|||||||
loop {
|
loop {
|
||||||
// ── Read the HTTP request from the client ─────────────────────────
|
// ── Read the HTTP request from the client ─────────────────────────
|
||||||
let mut request_buf = Vec::with_capacity(1024 * 64);
|
let mut request_buf = Vec::with_capacity(1024 * 64);
|
||||||
let mut is_our_request = false;
|
|
||||||
|
|
||||||
// 60s timeout on initial read (LS may open connection without sending immediately)
|
// 60s timeout on initial read (LS may open connection without sending immediately)
|
||||||
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
const IDLE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60);
|
||||||
@@ -513,7 +516,8 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Parse the HTTP request to find headers and body
|
// Parse the HTTP request to find headers and body
|
||||||
let (headers_end, content_length, _is_streaming_request) = parse_http_request_meta(&request_buf);
|
let (headers_end, content_length, _is_streaming_request) =
|
||||||
|
parse_http_request_meta(&request_buf);
|
||||||
|
|
||||||
// Try to extract cascade hint from request body
|
// Try to extract cascade hint from request body
|
||||||
let cascade_hint = if headers_end < request_buf.len() {
|
let cascade_hint = if headers_end < request_buf.len() {
|
||||||
@@ -545,6 +549,27 @@ async fn handle_http_over_tls(
|
|||||||
"MITM: forwarding LLM request"
|
"MITM: forwarding LLM request"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// ── Block ALL requests when one is already in-flight ─────────
|
||||||
|
// The LS opens multiple connections and sends parallel requests.
|
||||||
|
// When custom tools are active, only the FIRST request should reach
|
||||||
|
// Google. Block everything else with a fake response.
|
||||||
|
if store.is_request_in_flight() {
|
||||||
|
info!("MITM: blocking LS request — another request already in-flight");
|
||||||
|
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||||
|
Content-Type: text/event-stream\r\n\
|
||||||
|
Transfer-Encoding: chunked\r\n\
|
||||||
|
\r\n";
|
||||||
|
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
||||||
|
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
||||||
|
let mut response = fake_response.as_bytes().to_vec();
|
||||||
|
response.extend_from_slice(&chunked_body);
|
||||||
|
if let Err(e) = client.write_all(&response).await {
|
||||||
|
warn!(error = %e, "MITM: failed to write fake response");
|
||||||
|
}
|
||||||
|
let _ = client.flush().await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// ── Request modification ─────────────────────────────────────
|
// ── Request modification ─────────────────────────────────────
|
||||||
// Dechunk body → check if agent request → modify → rechunk
|
// Dechunk body → check if agent request → modify → rechunk
|
||||||
if modify_requests && body_len > 0 {
|
if modify_requests && body_len > 0 {
|
||||||
@@ -565,7 +590,11 @@ async fn handle_http_over_tls(
|
|||||||
let generation_params = store.get_generation_params().await;
|
let generation_params = store.get_generation_params().await;
|
||||||
let pending_image = store.take_pending_image().await;
|
let pending_image = store.take_pending_image().await;
|
||||||
|
|
||||||
let tool_ctx = if tools.is_some() || !pending_results.is_empty() || generation_params.is_some() || pending_image.is_some() {
|
let tool_ctx = if tools.is_some()
|
||||||
|
|| !pending_results.is_empty()
|
||||||
|
|| generation_params.is_some()
|
||||||
|
|| pending_image.is_some()
|
||||||
|
{
|
||||||
Some(super::modify::ToolContext {
|
Some(super::modify::ToolContext {
|
||||||
tools,
|
tools,
|
||||||
tool_config,
|
tool_config,
|
||||||
@@ -578,7 +607,9 @@ async fn handle_http_over_tls(
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(modified_body) = super::modify::modify_request(&raw_body, tool_ctx.as_ref()) {
|
if let Some(modified_body) =
|
||||||
|
super::modify::modify_request(&raw_body, tool_ctx.as_ref())
|
||||||
|
{
|
||||||
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
|
// Rebuild request_buf: headers (with updated Content-Length) + rechunked modified body
|
||||||
let new_chunked = super::modify::rechunk(&modified_body);
|
let new_chunked = super::modify::rechunk(&modified_body);
|
||||||
|
|
||||||
@@ -589,38 +620,11 @@ async fn handle_http_over_tls(
|
|||||||
new_buf.extend_from_slice(&new_chunked);
|
new_buf.extend_from_slice(&new_chunked);
|
||||||
request_buf = new_buf;
|
request_buf = new_buf;
|
||||||
|
|
||||||
// Mark this as our modified request and set in-flight flag
|
// Mark in-flight IMMEDIATELY — blocks all subsequent requests
|
||||||
is_our_request = true;
|
|
||||||
store.mark_request_in_flight();
|
store.mark_request_in_flight();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Block ALL LS follow-up requests once first is in-flight ──
|
|
||||||
// When custom tools are active, we only need ONE request to Google.
|
|
||||||
// The LS tries to send multiple requests (its own agentic loop +
|
|
||||||
// internal requests on gemini-2.5-flash-lite). Block them ALL
|
|
||||||
// immediately — don't wait for response_complete.
|
|
||||||
let has_tools = store.get_tools().await.is_some();
|
|
||||||
if has_tools && store.is_request_in_flight() && !is_our_request {
|
|
||||||
info!(
|
|
||||||
"MITM: blocking LS follow-up — request already in-flight"
|
|
||||||
);
|
|
||||||
// Return a fake SSE response that makes the LS stop
|
|
||||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
|
||||||
Content-Type: text/event-stream\r\n\
|
|
||||||
Transfer-Encoding: chunked\r\n\
|
|
||||||
\r\n";
|
|
||||||
let fake_sse = "data: {\"response\":{\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"Request handled.\"}],\"role\":\"model\"},\"finishReason\":\"STOP\"}],\"usageMetadata\":{\"promptTokenCount\":0,\"candidatesTokenCount\":1,\"totalTokenCount\":1}}}\n\ndata: [DONE]\n\n";
|
|
||||||
let chunked_body = super::modify::rechunk(fake_sse.as_bytes());
|
|
||||||
let mut response = fake_response.as_bytes().to_vec();
|
|
||||||
response.extend_from_slice(&chunked_body);
|
|
||||||
if let Err(e) = client.write_all(&response).await {
|
|
||||||
warn!(error = %e, "MITM: failed to write fake response");
|
|
||||||
}
|
|
||||||
let _ = client.flush().await;
|
|
||||||
continue; // Skip the real upstream call
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
debug!(
|
debug!(
|
||||||
domain,
|
domain,
|
||||||
@@ -674,7 +678,10 @@ async fn handle_http_over_tls(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
|
let n = match tokio::time::timeout(timeout, conn.read(&mut tmp)).await {
|
||||||
Ok(Ok(0)) => { upstream_ok = false; break; }
|
Ok(Ok(0)) => {
|
||||||
|
upstream_ok = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
Ok(Ok(n)) => n,
|
Ok(Ok(n)) => n,
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
debug!(domain, error = %e, "MITM: upstream read ended");
|
debug!(domain, error = %e, "MITM: upstream read ended");
|
||||||
@@ -711,7 +718,9 @@ async fn handle_http_over_tls(
|
|||||||
if header.name.eq_ignore_ascii_case("content-type") {
|
if header.name.eq_ignore_ascii_case("content-type") {
|
||||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||||
content_type = v.to_string();
|
content_type = v.to_string();
|
||||||
if v.contains("text/event-stream") { is_streaming_response = true; }
|
if v.contains("text/event-stream") {
|
||||||
|
is_streaming_response = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if header.name.eq_ignore_ascii_case("content-length") {
|
if header.name.eq_ignore_ascii_case("content-length") {
|
||||||
@@ -721,12 +730,16 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
if header.name.eq_ignore_ascii_case("connection") {
|
if header.name.eq_ignore_ascii_case("connection") {
|
||||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||||
if v.trim().eq_ignore_ascii_case("close") { upstream_ok = false; }
|
if v.trim().eq_ignore_ascii_case("close") {
|
||||||
|
upstream_ok = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
if header.name.eq_ignore_ascii_case("transfer-encoding") {
|
||||||
if let Ok(v) = std::str::from_utf8(header.value) {
|
if let Ok(v) = std::str::from_utf8(header.value) {
|
||||||
if v.trim().eq_ignore_ascii_case("chunked") { is_chunked = true; }
|
if v.trim().eq_ignore_ascii_case("chunked") {
|
||||||
|
is_chunked = true;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -749,22 +762,31 @@ async fn handle_http_over_tls(
|
|||||||
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
|
warn!(domain, status = http_status, body = %body_str, "MITM: upstream error response");
|
||||||
|
|
||||||
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
|
// Parse Google's error JSON: {"error": {"code": N, "message": "...", "status": "..."}}
|
||||||
let (message, error_status) = serde_json::from_str::<serde_json::Value>(&body_str)
|
let (message, error_status) =
|
||||||
.ok()
|
serde_json::from_str::<serde_json::Value>(&body_str)
|
||||||
.and_then(|v| {
|
.ok()
|
||||||
let err = v.get("error")?;
|
.and_then(|v| {
|
||||||
let msg = err.get("message").and_then(|m| m.as_str()).map(|s| s.to_string());
|
let err = v.get("error")?;
|
||||||
let status = err.get("status").and_then(|s| s.as_str()).map(|s| s.to_string());
|
let msg = err
|
||||||
Some((msg, status))
|
.get("message")
|
||||||
})
|
.and_then(|m| m.as_str())
|
||||||
.unwrap_or((None, None));
|
.map(|s| s.to_string());
|
||||||
|
let status = err
|
||||||
|
.get("status")
|
||||||
|
.and_then(|s| s.as_str())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
Some((msg, status))
|
||||||
|
})
|
||||||
|
.unwrap_or((None, None));
|
||||||
|
|
||||||
store.set_upstream_error(super::store::UpstreamError {
|
store
|
||||||
status: http_status,
|
.set_upstream_error(super::store::UpstreamError {
|
||||||
body: body_str,
|
status: http_status,
|
||||||
message,
|
body: body_str,
|
||||||
error_status,
|
message,
|
||||||
}).await;
|
error_status,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save body for usage parsing
|
// Save body for usage parsing
|
||||||
@@ -779,10 +801,15 @@ async fn handle_http_over_tls(
|
|||||||
if !streaming_acc.function_calls.is_empty() {
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
for fc in &calls {
|
for fc in &calls {
|
||||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
store
|
||||||
|
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
store.set_last_function_calls(calls.clone()).await;
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
info!("MITM: stored {} function call(s) from initial body", calls.len());
|
info!(
|
||||||
|
"MITM: stored {} function call(s) from initial body",
|
||||||
|
calls.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture response + thinking text + grounding into MitmStore
|
// Capture response + thinking text + grounding into MitmStore
|
||||||
@@ -816,7 +843,9 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(cl) = response_content_length {
|
if let Some(cl) = response_content_length {
|
||||||
if response_body_buf.len() >= cl { break; }
|
if response_body_buf.len() >= cl {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// Check chunked terminator in initial body
|
// Check chunked terminator in initial body
|
||||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||||
@@ -837,10 +866,15 @@ async fn handle_http_over_tls(
|
|||||||
if !streaming_acc.function_calls.is_empty() {
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
for fc in &calls {
|
for fc in &calls {
|
||||||
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
store
|
||||||
|
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
store.set_last_function_calls(calls.clone()).await;
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
info!("MITM: stored {} function call(s) from body chunk", calls.len());
|
info!(
|
||||||
|
"MITM: stored {} function call(s) from body chunk",
|
||||||
|
calls.len()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture response + thinking text + grounding into MitmStore
|
// Capture response + thinking text + grounding into MitmStore
|
||||||
@@ -875,7 +909,9 @@ async fn handle_http_over_tls(
|
|||||||
response_body_buf.extend_from_slice(chunk);
|
response_body_buf.extend_from_slice(chunk);
|
||||||
|
|
||||||
if let Some(cl) = response_content_length {
|
if let Some(cl) = response_content_length {
|
||||||
if response_body_buf.len() >= cl { break; }
|
if response_body_buf.len() >= cl {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
if is_chunked && has_chunked_terminator(&response_body_buf) {
|
||||||
debug!(domain, "MITM: chunked response complete");
|
debug!(domain, "MITM: chunked response complete");
|
||||||
@@ -912,11 +948,7 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
/// Handle a passthrough connection: transparent TCP tunnel to upstream.
|
||||||
async fn handle_passthrough(
|
async fn handle_passthrough(mut client: TcpStream, domain: &str, port: u16) -> Result<(), String> {
|
||||||
mut client: TcpStream,
|
|
||||||
domain: &str,
|
|
||||||
port: u16,
|
|
||||||
) -> Result<(), String> {
|
|
||||||
trace!(domain, port, "MITM: transparent tunnel");
|
trace!(domain, port, "MITM: transparent tunnel");
|
||||||
|
|
||||||
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
let mut upstream = TcpStream::connect(format!("{domain}:{port}"))
|
||||||
@@ -926,7 +958,12 @@ async fn handle_passthrough(
|
|||||||
// Bidirectional copy
|
// Bidirectional copy
|
||||||
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
match tokio::io::copy_bidirectional(&mut client, &mut upstream).await {
|
||||||
Ok((client_to_server, server_to_client)) => {
|
Ok((client_to_server, server_to_client)) => {
|
||||||
trace!(domain, client_to_server, server_to_client, "MITM: tunnel closed");
|
trace!(
|
||||||
|
domain,
|
||||||
|
client_to_server,
|
||||||
|
server_to_client,
|
||||||
|
"MITM: tunnel closed"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
trace!(domain, error = %e, "MITM: tunnel error (likely clean close)");
|
||||||
@@ -945,7 +982,11 @@ fn has_chunked_terminator(body: &[u8]) -> bool {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Check last 7 bytes to account for possible trailing whitespace
|
// Check last 7 bytes to account for possible trailing whitespace
|
||||||
let tail = if body.len() > 7 { &body[body.len() - 7..] } else { body };
|
let tail = if body.len() > 7 {
|
||||||
|
&body[body.len() - 7..]
|
||||||
|
} else {
|
||||||
|
body
|
||||||
|
};
|
||||||
// Look for \r\n0\r\n\r\n anywhere in the tail
|
// Look for \r\n0\r\n\r\n anywhere in the tail
|
||||||
tail.windows(5).any(|w| w == b"0\r\n\r\n")
|
tail.windows(5).any(|w| w == b"0\r\n\r\n")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,11 @@
|
|||||||
//!
|
//!
|
||||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
/// Token usage from an intercepted API response.
|
/// Token usage from an intercepted API response.
|
||||||
@@ -342,7 +342,9 @@ impl MitmStore {
|
|||||||
|
|
||||||
/// Record a captured function call from Google's response.
|
/// Record a captured function call from Google's response.
|
||||||
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
pub async fn record_function_call(&self, cascade_id: Option<&str>, fc: CapturedFunctionCall) {
|
||||||
let key = cascade_id.map(|s| s.to_string()).unwrap_or_else(|| "_latest".to_string());
|
let key = cascade_id
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| "_latest".to_string());
|
||||||
info!(
|
info!(
|
||||||
cascade = %key,
|
cascade = %key,
|
||||||
tool = %fc.name,
|
tool = %fc.name,
|
||||||
@@ -377,7 +379,6 @@ impl MitmStore {
|
|||||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Take any pending function calls (ignoring cascade ID).
|
/// Take any pending function calls (ignoring cascade ID).
|
||||||
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
pub async fn take_any_function_calls(&self) -> Option<Vec<CapturedFunctionCall>> {
|
||||||
let mut pending = self.pending_function_calls.write().await;
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
@@ -457,8 +458,6 @@ impl MitmStore {
|
|||||||
|
|
||||||
// ── Direct response capture (bypass LS) ──────────────────────────────
|
// ── Direct response capture (bypass LS) ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/// Set (replace) the captured response text.
|
/// Set (replace) the captured response text.
|
||||||
pub async fn set_response_text(&self, text: &str) {
|
pub async fn set_response_text(&self, text: &str) {
|
||||||
*self.captured_response_text.write().await = Some(text.to_string());
|
*self.captured_response_text.write().await = Some(text.to_string());
|
||||||
@@ -484,8 +483,6 @@ impl MitmStore {
|
|||||||
self.response_complete.load(Ordering::SeqCst)
|
self.response_complete.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/// Async version of clear_response.
|
/// Async version of clear_response.
|
||||||
pub async fn clear_response_async(&self) {
|
pub async fn clear_response_async(&self) {
|
||||||
self.response_complete.store(false, Ordering::SeqCst);
|
self.response_complete.store(false, Ordering::SeqCst);
|
||||||
|
|||||||
@@ -293,8 +293,7 @@ mod tests {
|
|||||||
|
|
||||||
let cascade_bytes = b"test-cascade-id";
|
let cascade_bytes = b"test-cascade-id";
|
||||||
assert!(
|
assert!(
|
||||||
msg.windows(cascade_bytes.len())
|
msg.windows(cascade_bytes.len()).any(|w| w == cascade_bytes),
|
||||||
.any(|w| w == cascade_bytes),
|
|
||||||
"cascade_id must appear in output"
|
"cascade_id must appear in output"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
17
src/quota.rs
17
src/quota.rs
@@ -93,9 +93,8 @@ impl QuotaStore {
|
|||||||
// Initial poll immediately.
|
// Initial poll immediately.
|
||||||
self.poll_once(&backend).await;
|
self.poll_once(&backend).await;
|
||||||
|
|
||||||
let mut interval = tokio::time::interval(
|
let mut interval =
|
||||||
std::time::Duration::from_secs(POLL_INTERVAL_SECS),
|
tokio::time::interval(std::time::Duration::from_secs(POLL_INTERVAL_SECS));
|
||||||
);
|
|
||||||
interval.tick().await; // consume the first immediate tick
|
interval.tick().await; // consume the first immediate tick
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -125,7 +124,9 @@ impl QuotaStore {
|
|||||||
// Profile picture fetch fails through iptables — harmless, suppress
|
// Profile picture fetch fails through iptables — harmless, suppress
|
||||||
let data_str = data.to_string();
|
let data_str = data.to_string();
|
||||||
if data_str.contains("profile picture") {
|
if data_str.contains("profile picture") {
|
||||||
tracing::debug!("GetUserStatus: profile picture fetch failed (expected with iptables)");
|
tracing::debug!(
|
||||||
|
"GetUserStatus: profile picture fetch failed (expected with iptables)"
|
||||||
|
);
|
||||||
} else {
|
} else {
|
||||||
warn!("GetUserStatus returned {status}: {data_str}");
|
warn!("GetUserStatus returned {status}: {data_str}");
|
||||||
}
|
}
|
||||||
@@ -172,9 +173,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
|
|||||||
.as_str()
|
.as_str()
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.to_string();
|
.to_string();
|
||||||
let frac = m["quotaInfo"]["remainingFraction"]
|
let frac = m["quotaInfo"]["remainingFraction"].as_f64().unwrap_or(0.0);
|
||||||
.as_f64()
|
|
||||||
.unwrap_or(0.0);
|
|
||||||
let reset_str = m["quotaInfo"]["resetTime"]
|
let reset_str = m["quotaInfo"]["resetTime"]
|
||||||
.as_str()
|
.as_str()
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
@@ -224,9 +223,7 @@ fn parse_user_status(data: &serde_json::Value) -> QuotaSnapshot {
|
|||||||
flow_available: flow_avail,
|
flow_available: flow_avail,
|
||||||
flow_total,
|
flow_total,
|
||||||
flow_used_pct,
|
flow_used_pct,
|
||||||
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"]
|
flex_purchasable: pi["monthlyFlexCreditPurchaseAmount"].as_i64().unwrap_or(0),
|
||||||
.as_i64()
|
|
||||||
.unwrap_or(0),
|
|
||||||
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
|
can_buy_more: pi["canBuyMoreCredits"].as_bool().unwrap_or(false),
|
||||||
},
|
},
|
||||||
models,
|
models,
|
||||||
|
|||||||
@@ -66,9 +66,7 @@ impl SessionManager {
|
|||||||
msg_count: 0,
|
msg_count: 0,
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
return Ok(SessionResult {
|
return Ok(SessionResult { cascade_id });
|
||||||
cascade_id,
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
|
let sid = session_id.unwrap_or(DEFAULT_SESSION).to_string();
|
||||||
@@ -111,9 +109,7 @@ impl SessionManager {
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Ok(SessionResult {
|
Ok(SessionResult { cascade_id })
|
||||||
cascade_id,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// List all active sessions.
|
/// List all active sessions.
|
||||||
@@ -146,7 +142,5 @@ impl SessionManager {
|
|||||||
|
|
||||||
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
|
fn cleanup_expired(sessions: &mut HashMap<String, Session>) {
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
sessions.retain(|_, s| {
|
sessions.retain(|_, s| now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS);
|
||||||
now.duration_since(s.last_used).as_secs() < SESSION_TTL_SECS
|
|
||||||
});
|
|
||||||
}
|
}
|
||||||
|
|||||||
210
src/snapshot.rs
210
src/snapshot.rs
@@ -10,16 +10,44 @@ use std::io::{self, Read};
|
|||||||
// ── Domain metadata ──────────────────────────────────────────────────────────
|
// ── Domain metadata ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
const DOMAIN_INFO: &[(&str, &str, &str)] = &[
|
const DOMAIN_INFO: &[(&str, &str, &str)] = &[
|
||||||
("antigravity-unleash.goog", "Feature Flags", "Unleash SDK — controls A/B tests and feature rollouts"),
|
(
|
||||||
("daily-cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Primary Gemini/Claude API endpoint"),
|
"antigravity-unleash.goog",
|
||||||
("cloudcode-pa.googleapis.com", "LLM API (gRPC)", "Production Gemini/Claude API endpoint"),
|
"Feature Flags",
|
||||||
("api.anthropic.com", "Claude API", "Direct Anthropic API calls"),
|
"Unleash SDK — controls A/B tests and feature rollouts",
|
||||||
("lh3.googleusercontent.com", "Profile Picture", "User avatar"),
|
),
|
||||||
|
(
|
||||||
|
"daily-cloudcode-pa.googleapis.com",
|
||||||
|
"LLM API (gRPC)",
|
||||||
|
"Primary Gemini/Claude API endpoint",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"cloudcode-pa.googleapis.com",
|
||||||
|
"LLM API (gRPC)",
|
||||||
|
"Production Gemini/Claude API endpoint",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"api.anthropic.com",
|
||||||
|
"Claude API",
|
||||||
|
"Direct Anthropic API calls",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"lh3.googleusercontent.com",
|
||||||
|
"Profile Picture",
|
||||||
|
"User avatar",
|
||||||
|
),
|
||||||
("play.googleapis.com", "Telemetry", "Google Play telemetry"),
|
("play.googleapis.com", "Telemetry", "Google Play telemetry"),
|
||||||
("firebaseinstallations.googleapis.com", "Firebase", "Installation tracking"),
|
(
|
||||||
|
"firebaseinstallations.googleapis.com",
|
||||||
|
"Firebase",
|
||||||
|
"Installation tracking",
|
||||||
|
),
|
||||||
("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"),
|
("oauth2.googleapis.com", "OAuth", "Token refresh/exchange"),
|
||||||
("speech.googleapis.com", "Speech", "Voice input processing"),
|
("speech.googleapis.com", "Speech", "Voice input processing"),
|
||||||
("modelarmor.googleapis.com", "Safety", "Content safety/filtering"),
|
(
|
||||||
|
"modelarmor.googleapis.com",
|
||||||
|
"Safety",
|
||||||
|
"Content safety/filtering",
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
fn domain_label(domain: &str) -> (&str, &str) {
|
fn domain_label(domain: &str) -> (&str, &str) {
|
||||||
@@ -57,8 +85,8 @@ struct HttpExchange {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
enum Direction {
|
enum Direction {
|
||||||
Outgoing, // LS → upstream
|
Outgoing, // LS → upstream
|
||||||
Incoming, // external → LS (our curl calls)
|
Incoming, // external → LS (our curl calls)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
@@ -101,10 +129,12 @@ impl Snapshot {
|
|||||||
|
|
||||||
// LS process logs
|
// LS process logs
|
||||||
if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E'))
|
if (line.starts_with('I') || line.starts_with('W') || line.starts_with('E'))
|
||||||
&& line.len() > 4 && line.chars().nth(1).is_some_and(|c| c.is_ascii_digit()) {
|
&& line.len() > 4
|
||||||
snap.ls_logs.push(line.to_string());
|
&& line.chars().nth(1).is_some_and(|c| c.is_ascii_digit())
|
||||||
continue;
|
{
|
||||||
}
|
snap.ls_logs.push(line.to_string());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if line.contains("maxprocs:") {
|
if line.contains("maxprocs:") {
|
||||||
snap.ls_logs.push(line.to_string());
|
snap.ls_logs.push(line.to_string());
|
||||||
continue;
|
continue;
|
||||||
@@ -128,8 +158,15 @@ impl Snapshot {
|
|||||||
if let Some((key, val)) = extract_header(line, "Transport encoding header") {
|
if let Some((key, val)) = extract_header(line, "Transport encoding header") {
|
||||||
if key == ":method" {
|
if key == ":method" {
|
||||||
// Finalize previous exchange
|
// Finalize previous exchange
|
||||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
if current_pseudo.contains_key(":path")
|
||||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone());
|
|| current_pseudo.contains_key(":method")
|
||||||
|
{
|
||||||
|
snap.finalize_exchange(
|
||||||
|
¤t_pseudo,
|
||||||
|
¤t_headers,
|
||||||
|
current_direction,
|
||||||
|
current_stream.clone(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
current_headers.clear();
|
current_headers.clear();
|
||||||
current_pseudo.clear();
|
current_pseudo.clear();
|
||||||
@@ -147,8 +184,15 @@ impl Snapshot {
|
|||||||
// Incoming / server-received headers
|
// Incoming / server-received headers
|
||||||
if let Some((key, val)) = extract_header(line, "decoded hpack field header field") {
|
if let Some((key, val)) = extract_header(line, "decoded hpack field header field") {
|
||||||
if key == ":authority" && !line.contains("server read frame") {
|
if key == ":authority" && !line.contains("server read frame") {
|
||||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
if current_pseudo.contains_key(":path")
|
||||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream.clone());
|
|| current_pseudo.contains_key(":method")
|
||||||
|
{
|
||||||
|
snap.finalize_exchange(
|
||||||
|
¤t_pseudo,
|
||||||
|
¤t_headers,
|
||||||
|
current_direction,
|
||||||
|
current_stream.clone(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
current_headers.clear();
|
current_headers.clear();
|
||||||
current_pseudo.clear();
|
current_pseudo.clear();
|
||||||
@@ -167,8 +211,15 @@ impl Snapshot {
|
|||||||
if line.contains("wrote HEADERS") {
|
if line.contains("wrote HEADERS") {
|
||||||
if let Some(stream) = extract_stream_id(line) {
|
if let Some(stream) = extract_stream_id(line) {
|
||||||
current_stream = Some(stream.clone());
|
current_stream = Some(stream.clone());
|
||||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
if current_pseudo.contains_key(":path")
|
||||||
let ex = snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, Some(stream));
|
|| current_pseudo.contains_key(":method")
|
||||||
|
{
|
||||||
|
let ex = snap.finalize_exchange(
|
||||||
|
¤t_pseudo,
|
||||||
|
¤t_headers,
|
||||||
|
current_direction,
|
||||||
|
Some(stream),
|
||||||
|
);
|
||||||
if ex.is_some() {
|
if ex.is_some() {
|
||||||
current_headers.clear();
|
current_headers.clear();
|
||||||
current_pseudo.clear();
|
current_pseudo.clear();
|
||||||
@@ -179,10 +230,13 @@ impl Snapshot {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DATA frames
|
// DATA frames
|
||||||
if (line.contains("wrote DATA") || line.contains("read DATA") || line.contains("server read frame DATA"))
|
if (line.contains("wrote DATA")
|
||||||
|
|| line.contains("read DATA")
|
||||||
|
|| line.contains("server read frame DATA"))
|
||||||
&& line.contains("data=\"")
|
&& line.contains("data=\"")
|
||||||
{
|
{
|
||||||
let is_outgoing = line.contains("wrote DATA") || line.contains("server read frame DATA");
|
let is_outgoing =
|
||||||
|
line.contains("wrote DATA") || line.contains("server read frame DATA");
|
||||||
if let Some(stream) = extract_stream_id(line) {
|
if let Some(stream) = extract_stream_id(line) {
|
||||||
if let Some(data_str) = extract_data(line) {
|
if let Some(data_str) = extract_data(line) {
|
||||||
let raw = decode_go_escaped(&data_str);
|
let raw = decode_go_escaped(&data_str);
|
||||||
@@ -203,7 +257,12 @@ impl Snapshot {
|
|||||||
|
|
||||||
// Finalize remaining
|
// Finalize remaining
|
||||||
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
if current_pseudo.contains_key(":path") || current_pseudo.contains_key(":method") {
|
||||||
snap.finalize_exchange(¤t_pseudo, ¤t_headers, current_direction, current_stream);
|
snap.finalize_exchange(
|
||||||
|
¤t_pseudo,
|
||||||
|
¤t_headers,
|
||||||
|
current_direction,
|
||||||
|
current_stream,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
snap
|
snap
|
||||||
@@ -226,7 +285,11 @@ impl Snapshot {
|
|||||||
|
|
||||||
self.exchanges.push(HttpExchange {
|
self.exchanges.push(HttpExchange {
|
||||||
authority,
|
authority,
|
||||||
method: if method.is_empty() { "GET".into() } else { method },
|
method: if method.is_empty() {
|
||||||
|
"GET".into()
|
||||||
|
} else {
|
||||||
|
method
|
||||||
|
},
|
||||||
path,
|
path,
|
||||||
headers: headers.to_vec(),
|
headers: headers.to_vec(),
|
||||||
body: Vec::new(),
|
body: Vec::new(),
|
||||||
@@ -245,7 +308,9 @@ impl Snapshot {
|
|||||||
let sep = "═".repeat(70);
|
let sep = "═".repeat(70);
|
||||||
let sep_thin = "─".repeat(60);
|
let sep_thin = "─".repeat(60);
|
||||||
out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n"));
|
out.push_str(&format!("\n{BOLD}{CYAN}{sep}{NC}\n"));
|
||||||
out.push_str(&format!("{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"));
|
out.push_str(&format!(
|
||||||
|
"{BOLD}{CYAN} STANDALONE LS TRAFFIC SNAPSHOT{NC}\n"
|
||||||
|
));
|
||||||
out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n"));
|
out.push_str(&format!("{BOLD}{CYAN}{sep}{NC}\n\n"));
|
||||||
|
|
||||||
// LS Logs
|
// LS Logs
|
||||||
@@ -265,7 +330,9 @@ impl Snapshot {
|
|||||||
for target in &self.connections {
|
for target in &self.connections {
|
||||||
let domain = target.split(':').next().unwrap_or(target);
|
let domain = target.split(':').next().unwrap_or(target);
|
||||||
let (label, desc) = domain_label(domain);
|
let (label, desc) = domain_label(domain);
|
||||||
out.push_str(&format!(" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"));
|
out.push_str(&format!(
|
||||||
|
" {GREEN}→{NC} {BOLD}{target}{NC} {DIM}({label}){NC}\n"
|
||||||
|
));
|
||||||
if !desc.is_empty() {
|
if !desc.is_empty() {
|
||||||
out.push_str(&format!(" {DIM}{desc}{NC}\n"));
|
out.push_str(&format!(" {DIM}{desc}{NC}\n"));
|
||||||
}
|
}
|
||||||
@@ -276,7 +343,10 @@ impl Snapshot {
|
|||||||
// Group by domain
|
// Group by domain
|
||||||
let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new();
|
let mut by_domain: Vec<(&str, Vec<&HttpExchange>)> = Vec::new();
|
||||||
for ex in &self.exchanges {
|
for ex in &self.exchanges {
|
||||||
if let Some(entry) = by_domain.iter_mut().find(|(d, _)| *d == ex.authority.as_str()) {
|
if let Some(entry) = by_domain
|
||||||
|
.iter_mut()
|
||||||
|
.find(|(d, _)| *d == ex.authority.as_str())
|
||||||
|
{
|
||||||
entry.1.push(ex);
|
entry.1.push(ex);
|
||||||
} else {
|
} else {
|
||||||
by_domain.push((&ex.authority, vec![ex]));
|
by_domain.push((&ex.authority, vec![ex]));
|
||||||
@@ -293,12 +363,17 @@ impl Snapshot {
|
|||||||
let color = if label.contains("API") { YELLOW } else { CYAN };
|
let color = if label.contains("API") { YELLOW } else { CYAN };
|
||||||
|
|
||||||
out.push_str(&format!("\n{BOLD}{sep}{NC}\n"));
|
out.push_str(&format!("\n{BOLD}{sep}{NC}\n"));
|
||||||
out.push_str(&format!("{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n"));
|
out.push_str(&format!(
|
||||||
|
"{BOLD}{color} {domain}{NC} {DIM}— {label}{NC}\n"
|
||||||
|
));
|
||||||
out.push_str(&format!("{BOLD}{sep}{NC}\n"));
|
out.push_str(&format!("{BOLD}{sep}{NC}\n"));
|
||||||
|
|
||||||
for ex in exchanges {
|
for ex in exchanges {
|
||||||
let method_color = if ex.method == "GET" { GREEN } else { YELLOW };
|
let method_color = if ex.method == "GET" { GREEN } else { YELLOW };
|
||||||
out.push_str(&format!("\n {BOLD}→ {method_color}{}{NC} {}\n", ex.method, ex.path));
|
out.push_str(&format!(
|
||||||
|
"\n {BOLD}→ {method_color}{}{NC} {}\n",
|
||||||
|
ex.method, ex.path
|
||||||
|
));
|
||||||
|
|
||||||
// Interesting headers
|
// Interesting headers
|
||||||
for (key, val) in &ex.headers {
|
for (key, val) in &ex.headers {
|
||||||
@@ -342,7 +417,10 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
|||||||
out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n"));
|
out.push_str(&format!(" {BOLD}Body ({len} bytes, JSON):{NC}\n"));
|
||||||
for (i, line) in pretty.lines().enumerate() {
|
for (i, line) in pretty.lines().enumerate() {
|
||||||
if i >= 40 {
|
if i >= 40 {
|
||||||
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 40));
|
out.push_str(&format!(
|
||||||
|
" {DIM}... ({} more lines){NC}\n",
|
||||||
|
pretty.lines().count() - 40
|
||||||
|
));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
||||||
@@ -357,10 +435,16 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
|||||||
if let Ok(text) = std::str::from_utf8(&decompressed) {
|
if let Ok(text) = std::str::from_utf8(&decompressed) {
|
||||||
if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
|
if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
|
||||||
let pretty = serde_json::to_string_pretty(&val).unwrap_or_default();
|
let pretty = serde_json::to_string_pretty(&val).unwrap_or_default();
|
||||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n", decompressed.len()));
|
out.push_str(&format!(
|
||||||
|
" {BOLD}Body ({len} bytes gzip → {} bytes, JSON):{NC}\n",
|
||||||
|
decompressed.len()
|
||||||
|
));
|
||||||
for (i, line) in pretty.lines().enumerate() {
|
for (i, line) in pretty.lines().enumerate() {
|
||||||
if i >= 50 {
|
if i >= 50 {
|
||||||
out.push_str(&format!(" {DIM}... ({} more lines){NC}\n", pretty.lines().count() - 50));
|
out.push_str(&format!(
|
||||||
|
" {DIM}... ({} more lines){NC}\n",
|
||||||
|
pretty.lines().count() - 50
|
||||||
|
));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
out.push_str(&format!(" {GREEN}{line}{NC}\n"));
|
||||||
@@ -368,14 +452,20 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
// Plain text
|
// Plain text
|
||||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n", decompressed.len()));
|
out.push_str(&format!(
|
||||||
|
" {BOLD}Body ({len} bytes gzip → {} bytes, text):{NC}\n",
|
||||||
|
decompressed.len()
|
||||||
|
));
|
||||||
for line in text.lines().take(20) {
|
for line in text.lines().take(20) {
|
||||||
out.push_str(&format!(" {line}\n"));
|
out.push_str(&format!(" {line}\n"));
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
// Binary gzip
|
// Binary gzip
|
||||||
out.push_str(&format!(" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n", decompressed.len()));
|
out.push_str(&format!(
|
||||||
|
" {BOLD}Body ({len} bytes gzip → {} bytes, binary):{NC}\n",
|
||||||
|
decompressed.len()
|
||||||
|
));
|
||||||
let strings = extract_strings(&decompressed);
|
let strings = extract_strings(&decompressed);
|
||||||
for s in strings.iter().take(15) {
|
for s in strings.iter().take(15) {
|
||||||
out.push_str(&format!(" {MAGENTA}{s}{NC}\n"));
|
out.push_str(&format!(" {MAGENTA}{s}{NC}\n"));
|
||||||
@@ -393,7 +483,11 @@ fn render_body(data: &[u8], total_len: usize) -> String {
|
|||||||
// Protobuf / binary with string extraction
|
// Protobuf / binary with string extraction
|
||||||
let strings = extract_strings(data);
|
let strings = extract_strings(data);
|
||||||
if !strings.is_empty() {
|
if !strings.is_empty() {
|
||||||
let kind = if !data.is_empty() && matches!(data[0], 0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22) {
|
let kind = if !data.is_empty()
|
||||||
|
&& matches!(
|
||||||
|
data[0],
|
||||||
|
0x08 | 0x0a | 0x10 | 0x12 | 0x18 | 0x1a | 0x20 | 0x22
|
||||||
|
) {
|
||||||
"protobuf"
|
"protobuf"
|
||||||
} else {
|
} else {
|
||||||
"binary"
|
"binary"
|
||||||
@@ -448,7 +542,9 @@ fn extract_header(line: &str, pattern: &str) -> Option<(String, String)> {
|
|||||||
fn extract_stream_id(line: &str) -> Option<String> {
|
fn extract_stream_id(line: &str) -> Option<String> {
|
||||||
let pos = line.find("stream=")?;
|
let pos = line.find("stream=")?;
|
||||||
let rest = &line[pos + 7..];
|
let rest = &line[pos + 7..];
|
||||||
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
|
let end = rest
|
||||||
|
.find(|c: char| !c.is_ascii_digit())
|
||||||
|
.unwrap_or(rest.len());
|
||||||
Some(rest[..end].to_string())
|
Some(rest[..end].to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -470,7 +566,9 @@ fn extract_data(line: &str) -> Option<String> {
|
|||||||
fn extract_data_len(line: &str) -> Option<usize> {
|
fn extract_data_len(line: &str) -> Option<usize> {
|
||||||
let pos = line.find("len=")?;
|
let pos = line.find("len=")?;
|
||||||
let rest = &line[pos + 4..];
|
let rest = &line[pos + 4..];
|
||||||
let end = rest.find(|c: char| !c.is_ascii_digit()).unwrap_or(rest.len());
|
let end = rest
|
||||||
|
.find(|c: char| !c.is_ascii_digit())
|
||||||
|
.unwrap_or(rest.len());
|
||||||
rest[..end].parse().ok()
|
rest[..end].parse().ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -482,17 +580,40 @@ fn decode_go_escaped(s: &str) -> Vec<u8> {
|
|||||||
if bytes[i] == b'\\' && i + 1 < bytes.len() {
|
if bytes[i] == b'\\' && i + 1 < bytes.len() {
|
||||||
match bytes[i + 1] {
|
match bytes[i + 1] {
|
||||||
b'x' if i + 3 < bytes.len() => {
|
b'x' if i + 3 < bytes.len() => {
|
||||||
if let Ok(b) = u8::from_str_radix(std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""), 16) {
|
if let Ok(b) = u8::from_str_radix(
|
||||||
|
std::str::from_utf8(&bytes[i + 2..i + 4]).unwrap_or(""),
|
||||||
|
16,
|
||||||
|
) {
|
||||||
result.push(b);
|
result.push(b);
|
||||||
i += 4;
|
i += 4;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b'n' => { result.push(b'\n'); i += 2; continue; }
|
b'n' => {
|
||||||
b'r' => { result.push(b'\r'); i += 2; continue; }
|
result.push(b'\n');
|
||||||
b't' => { result.push(b'\t'); i += 2; continue; }
|
i += 2;
|
||||||
b'\\' => { result.push(b'\\'); i += 2; continue; }
|
continue;
|
||||||
b'"' => { result.push(b'"'); i += 2; continue; }
|
}
|
||||||
|
b'r' => {
|
||||||
|
result.push(b'\r');
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
b't' => {
|
||||||
|
result.push(b'\t');
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
b'\\' => {
|
||||||
|
result.push(b'\\');
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
b'"' => {
|
||||||
|
result.push(b'"');
|
||||||
|
i += 2;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
_ => {}
|
_ => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -562,7 +683,10 @@ pub fn run_cli() {
|
|||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
let mut buf = String::new();
|
let mut buf = String::new();
|
||||||
io::stdin().lock().read_to_string(&mut buf).expect("Failed to read stdin");
|
io::stdin()
|
||||||
|
.lock()
|
||||||
|
.read_to_string(&mut buf)
|
||||||
|
.expect("Failed to read stdin");
|
||||||
buf
|
buf
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -108,7 +108,10 @@ pub struct MainLSConfig {
|
|||||||
/// and CSRF is a random UUID.
|
/// and CSRF is a random UUID.
|
||||||
pub fn generate_standalone_config() -> MainLSConfig {
|
pub fn generate_standalone_config() -> MainLSConfig {
|
||||||
let csrf = Uuid::new_v4().to_string();
|
let csrf = Uuid::new_v4().to_string();
|
||||||
info!(csrf_len = csrf.len(), "Generated standalone config (headless)");
|
info!(
|
||||||
|
csrf_len = csrf.len(),
|
||||||
|
"Generated standalone config (headless)"
|
||||||
|
);
|
||||||
MainLSConfig {
|
MainLSConfig {
|
||||||
extension_server_port: "0".to_string(), // disables extension server
|
extension_server_port: "0".to_string(), // disables extension server
|
||||||
csrf,
|
csrf,
|
||||||
@@ -159,7 +162,13 @@ impl StandaloneLS {
|
|||||||
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
|
let app_data_dir = format!("{DATA_DIR}/.gemini/antigravity-standalone");
|
||||||
let annotations_dir = format!("{app_data_dir}/annotations");
|
let annotations_dir = format!("{app_data_dir}/annotations");
|
||||||
let brain_dir = format!("{app_data_dir}/brain");
|
let brain_dir = format!("{app_data_dir}/brain");
|
||||||
for dir in [DATA_DIR, &gemini_dir, &app_data_dir, &annotations_dir, &brain_dir] {
|
for dir in [
|
||||||
|
DATA_DIR,
|
||||||
|
&gemini_dir,
|
||||||
|
&app_data_dir,
|
||||||
|
&annotations_dir,
|
||||||
|
&brain_dir,
|
||||||
|
] {
|
||||||
let _ = std::fs::create_dir_all(dir);
|
let _ = std::fs::create_dir_all(dir);
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
@@ -194,7 +203,10 @@ impl StandaloneLS {
|
|||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
use std::os::unix::fs::PermissionsExt;
|
use std::os::unix::fs::PermissionsExt;
|
||||||
let _ = std::fs::set_permissions(&settings_path, std::fs::Permissions::from_mode(0o0666));
|
let _ = std::fs::set_permissions(
|
||||||
|
&settings_path,
|
||||||
|
std::fs::Permissions::from_mode(0o0666),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
|
tracing::info!("Pre-seeded user_settings.pb (detect_and_use_proxy=ENABLED)");
|
||||||
}
|
}
|
||||||
@@ -203,10 +215,7 @@ impl StandaloneLS {
|
|||||||
// The LS connects to this port and calls LanguageServerStarted — without it,
|
// The LS connects to this port and calls LanguageServerStarted — without it,
|
||||||
// the LS never fully initializes and won't accept connections on its server_port.
|
// the LS never fully initializes and won't accept connections on its server_port.
|
||||||
let _stub_listener = if headless {
|
let _stub_listener = if headless {
|
||||||
let stub_port: u16 = main_config
|
let stub_port: u16 = main_config.extension_server_port.parse().unwrap_or(0);
|
||||||
.extension_server_port
|
|
||||||
.parse()
|
|
||||||
.unwrap_or(0);
|
|
||||||
if stub_port == 0 {
|
if stub_port == 0 {
|
||||||
// Create a real listener so the LS can connect
|
// Create a real listener so the LS can connect
|
||||||
let listener = TcpListener::bind("127.0.0.1:0")
|
let listener = TcpListener::bind("127.0.0.1:0")
|
||||||
@@ -215,7 +224,10 @@ impl StandaloneLS {
|
|||||||
.local_addr()
|
.local_addr()
|
||||||
.map_err(|e| format!("Failed to get stub port: {e}"))?
|
.map_err(|e| format!("Failed to get stub port: {e}"))?
|
||||||
.port();
|
.port();
|
||||||
info!(port = actual_port, "Stub extension server listening (headless)");
|
info!(
|
||||||
|
port = actual_port,
|
||||||
|
"Stub extension server listening (headless)"
|
||||||
|
);
|
||||||
// Read OAuth state from Antigravity's state.vscdb if available.
|
// Read OAuth state from Antigravity's state.vscdb if available.
|
||||||
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
|
// The DB stores the exact Topic proto (access_token + refresh_token + expiry)
|
||||||
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
|
// which lets the LS auto-refresh tokens via its built-in Google OAuth2 client.
|
||||||
@@ -306,10 +318,7 @@ impl StandaloneLS {
|
|||||||
// 3. MITM proxy intercepts the transparent TLS connection via SNI
|
// 3. MITM proxy intercepts the transparent TLS connection via SNI
|
||||||
if let Some(mitm) = mitm_config {
|
if let Some(mitm) = mitm_config {
|
||||||
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
|
// Extract port from proxy_addr (e.g. "http://127.0.0.1:8742" → "8742")
|
||||||
let mitm_port = mitm.proxy_addr
|
let mitm_port = mitm.proxy_addr.rsplit(':').next().unwrap_or("8742");
|
||||||
.rsplit(':')
|
|
||||||
.next()
|
|
||||||
.unwrap_or("8742");
|
|
||||||
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
|
format!("https://daily-cloudcode-pa.googleapis.com:{mitm_port}")
|
||||||
} else {
|
} else {
|
||||||
"https://daily-cloudcode-pa.googleapis.com".to_string()
|
"https://daily-cloudcode-pa.googleapis.com".to_string()
|
||||||
@@ -324,9 +333,8 @@ impl StandaloneLS {
|
|||||||
debug!(?args, "LS args");
|
debug!(?args, "LS args");
|
||||||
|
|
||||||
// Build env vars for the LS process
|
// Build env vars for the LS process
|
||||||
let mut env_vars: Vec<(String, String)> = vec![
|
let mut env_vars: Vec<(String, String)> =
|
||||||
("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into()),
|
vec![("ANTIGRAVITY_EDITOR_APP_ROOT".into(), APP_ROOT.into())];
|
||||||
];
|
|
||||||
|
|
||||||
// If MITM is enabled, add SSL + proxy env vars
|
// If MITM is enabled, add SSL + proxy env vars
|
||||||
if let Some(mitm) = mitm_config {
|
if let Some(mitm) = mitm_config {
|
||||||
@@ -335,8 +343,8 @@ impl StandaloneLS {
|
|||||||
// Write to /tmp — accessible by antigravity-ls user
|
// Write to /tmp — accessible by antigravity-ls user
|
||||||
// (user's ~/.config/ is not traversable by other UIDs)
|
// (user's ~/.config/ is not traversable by other UIDs)
|
||||||
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
|
let combined_ca_path = "/tmp/antigravity-mitm-combined-ca.pem".to_string();
|
||||||
let system_ca = std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt")
|
let system_ca =
|
||||||
.unwrap_or_default();
|
std::fs::read_to_string("/etc/ssl/certs/ca-certificates.crt").unwrap_or_default();
|
||||||
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
|
let mitm_ca = std::fs::read_to_string(&mitm.ca_cert_path)
|
||||||
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
|
.map_err(|e| format!("Failed to read MITM CA cert: {e}"))?;
|
||||||
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
|
std::fs::write(&combined_ca_path, format!("{system_ca}\n{mitm_ca}"))
|
||||||
@@ -441,7 +449,11 @@ impl StandaloneLS {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(pid) = ls_pid {
|
if let Some(pid) = ls_pid {
|
||||||
info!(ls_pid = pid, sudo = use_sudo, "Discovered actual LS process");
|
info!(
|
||||||
|
ls_pid = pid,
|
||||||
|
sudo = use_sudo,
|
||||||
|
"Discovered actual LS process"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(StandaloneLS {
|
Ok(StandaloneLS {
|
||||||
@@ -617,8 +629,7 @@ fn find_main_ls_pid() -> Result<String, String> {
|
|||||||
return Err("No /proc filesystem".to_string());
|
return Err("No /proc filesystem".to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
let entries = std::fs::read_dir(proc)
|
let entries = std::fs::read_dir(proc).map_err(|e| format!("Cannot read /proc: {e}"))?;
|
||||||
.map_err(|e| format!("Cannot read /proc: {e}"))?;
|
|
||||||
|
|
||||||
for entry in entries.flatten() {
|
for entry in entries.flatten() {
|
||||||
let name = entry.file_name();
|
let name = entry.file_name();
|
||||||
@@ -704,12 +715,10 @@ fn cleanup_orphaned_ls() {
|
|||||||
.output();
|
.output();
|
||||||
|
|
||||||
let pids: Vec<u32> = match output {
|
let pids: Vec<u32> = match output {
|
||||||
Ok(out) => {
|
Ok(out) => String::from_utf8_lossy(&out.stdout)
|
||||||
String::from_utf8_lossy(&out.stdout)
|
.lines()
|
||||||
.lines()
|
.filter_map(|l| l.trim().parse().ok())
|
||||||
.filter_map(|l| l.trim().parse().ok())
|
.collect(),
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
Err(_) => return,
|
Err(_) => return,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -717,7 +726,11 @@ fn cleanup_orphaned_ls() {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(count = pids.len(), ?pids, "Cleaning up orphaned standalone LS processes");
|
info!(
|
||||||
|
count = pids.len(),
|
||||||
|
?pids,
|
||||||
|
"Cleaning up orphaned standalone LS processes"
|
||||||
|
);
|
||||||
|
|
||||||
// Kill each PID by running `kill` AS the antigravity-ls user.
|
// Kill each PID by running `kill` AS the antigravity-ls user.
|
||||||
// This works because same-UID processes can signal each other,
|
// This works because same-UID processes can signal each other,
|
||||||
@@ -870,7 +883,8 @@ fn extract_access_token_from_topic(topic_bytes: &[u8]) -> Option<String> {
|
|||||||
// Simple approach: convert to string and find base64 pattern
|
// Simple approach: convert to string and find base64 pattern
|
||||||
let as_str = String::from_utf8_lossy(topic_bytes);
|
let as_str = String::from_utf8_lossy(topic_bytes);
|
||||||
// The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited)
|
// The base64 OAuthTokenInfo starts with "Co" (0x0A = field 1, len-delimited)
|
||||||
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=') {
|
for segment in as_str.split(|c: char| !c.is_alphanumeric() && c != '+' && c != '/' && c != '=')
|
||||||
|
{
|
||||||
if segment.len() > 50 {
|
if segment.len() > 50 {
|
||||||
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
|
if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(segment) {
|
||||||
// Try to extract field 1 (access_token) from the OAuthTokenInfo proto
|
// Try to extract field 1 (access_token) from the OAuthTokenInfo proto
|
||||||
@@ -951,7 +965,11 @@ fn decode_varint_at(buf: &[u8], offset: usize) -> Option<(u64, usize)> {
|
|||||||
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
|
/// IMPORTANT: `SubscribeToUnifiedStateSyncTopic` is a long-lived stream.
|
||||||
/// If we immediately close it, the LS reconnects in a tight loop and never
|
/// If we immediately close it, the LS reconnects in a tight loop and never
|
||||||
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
|
/// proceeds to fetch OAuth tokens. We keep subscription connections OPEN.
|
||||||
fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_topic_bytes: &Option<Vec<u8>>) {
|
fn stub_handle_connection(
|
||||||
|
conn: std::net::TcpStream,
|
||||||
|
oauth_token: &str,
|
||||||
|
oauth_topic_bytes: &Option<Vec<u8>>,
|
||||||
|
) {
|
||||||
use std::io::{BufRead, BufReader, Read, Write};
|
use std::io::{BufRead, BufReader, Read, Write};
|
||||||
|
|
||||||
let mut reader = BufReader::new(match conn.try_clone() {
|
let mut reader = BufReader::new(match conn.try_clone() {
|
||||||
@@ -1028,7 +1046,7 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
|||||||
i += 1;
|
i += 1;
|
||||||
if i + len <= proto_body.len() {
|
if i + len <= proto_body.len() {
|
||||||
if field_num == 1 {
|
if field_num == 1 {
|
||||||
topic_name = String::from_utf8_lossy(&proto_body[i..i+len]).to_string();
|
topic_name = String::from_utf8_lossy(&proto_body[i..i + len]).to_string();
|
||||||
}
|
}
|
||||||
i += len;
|
i += len;
|
||||||
} else {
|
} else {
|
||||||
@@ -1084,7 +1102,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
|||||||
// This includes access_token + refresh_token + expiry, so the
|
// This includes access_token + refresh_token + expiry, so the
|
||||||
// LS can auto-refresh tokens via its built-in Google OAuth2 client.
|
// LS can auto-refresh tokens via its built-in Google OAuth2 client.
|
||||||
initial_state_bytes = topic_bytes.clone();
|
initial_state_bytes = topic_bytes.clone();
|
||||||
eprintln!("[stub-ext] using state.vscdb topic ({} bytes)", topic_bytes.len());
|
eprintln!(
|
||||||
|
"[stub-ext] using state.vscdb topic ({} bytes)",
|
||||||
|
topic_bytes.len()
|
||||||
|
);
|
||||||
} else if !oauth_token.is_empty() {
|
} else if !oauth_token.is_empty() {
|
||||||
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
|
// Manual token fallback — construct OAuthTokenInfo with far-future expiry
|
||||||
// (no refresh_token, so the LS can't auto-refresh)
|
// (no refresh_token, so the LS can't auto-refresh)
|
||||||
@@ -1155,7 +1176,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
|||||||
if !send_chunk(&mut writer, &initial_env) {
|
if !send_chunk(&mut writer, &initial_env) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
eprintln!("[stub-ext] STREAM → sent initial_state ({} bytes)", initial_state_bytes.len());
|
eprintln!(
|
||||||
|
"[stub-ext] STREAM → sent initial_state ({} bytes)",
|
||||||
|
initial_state_bytes.len()
|
||||||
|
);
|
||||||
|
|
||||||
// (applied_update removed — data is in initial_state)
|
// (applied_update removed — data is in initial_state)
|
||||||
|
|
||||||
@@ -1197,7 +1221,10 @@ fn stub_handle_connection(conn: std::net::TcpStream, oauth_token: &str, oauth_to
|
|||||||
if !oauth_token.is_empty() {
|
if !oauth_token.is_empty() {
|
||||||
// Build protobuf: GetSecretValueResponse { string value = 1 }
|
// Build protobuf: GetSecretValueResponse { string value = 1 }
|
||||||
let proto = encode_proto_string(1, oauth_token.as_bytes());
|
let proto = encode_proto_string(1, oauth_token.as_bytes());
|
||||||
eprintln!("[stub-ext] → serving token ({} bytes) for key={key:?}", oauth_token.len());
|
eprintln!(
|
||||||
|
"[stub-ext] → serving token ({} bytes) for key={key:?}",
|
||||||
|
oauth_token.len()
|
||||||
|
);
|
||||||
|
|
||||||
// Data envelope: flag=0x00, length, data
|
// Data envelope: flag=0x00, length, data
|
||||||
envelope.push(0x00u8);
|
envelope.push(0x00u8);
|
||||||
|
|||||||
@@ -34,7 +34,9 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
|
|||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(Ok((status, _))) => info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}"),
|
Ok(Ok((status, _))) => {
|
||||||
|
info!("SetUserSettings (detect_and_use_proxy=ENABLED): {status}")
|
||||||
|
}
|
||||||
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
|
Ok(Err(e)) => warn!("SetUserSettings failed: {e}"),
|
||||||
Err(_) => warn!("SetUserSettings timed out"),
|
Err(_) => warn!("SetUserSettings timed out"),
|
||||||
}
|
}
|
||||||
@@ -59,12 +61,7 @@ pub async fn warmup_sequence(backend: &Backend, headless: bool) {
|
|||||||
for (method, body) in calls {
|
for (method, body) in calls {
|
||||||
// Timeout per call — in headless mode, the LS can't reach Google's API
|
// Timeout per call — in headless mode, the LS can't reach Google's API
|
||||||
// so these would hang forever without a timeout. Warmup is best-effort.
|
// so these would hang forever without a timeout. Warmup is best-effort.
|
||||||
match tokio::time::timeout(
|
match tokio::time::timeout(Duration::from_secs(5), backend.call_json(method, body)).await {
|
||||||
Duration::from_secs(5),
|
|
||||||
backend.call_json(method, body),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"),
|
Ok(Ok((status, _))) => debug!("Warmup {method}: {status}"),
|
||||||
Ok(Err(e)) => warn!("Warmup {method} failed: {e}"),
|
Ok(Err(e)) => warn!("Warmup {method} failed: {e}"),
|
||||||
Err(_) => warn!("Warmup {method} timed out"),
|
Err(_) => warn!("Warmup {method} timed out"),
|
||||||
@@ -87,10 +84,7 @@ pub fn start_heartbeat(backend: Arc<Backend>) -> JoinHandle<()> {
|
|||||||
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
|
let interval_ms = rand::thread_rng().gen_range(29_500..30_500);
|
||||||
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
|
tokio::time::sleep(Duration::from_millis(interval_ms)).await;
|
||||||
|
|
||||||
match backend
|
match backend.call_json("Heartbeat", &serde_json::json!({})).await {
|
||||||
.call_json("Heartbeat", &serde_json::json!({}))
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok((status, _)) => debug!("Heartbeat: {status}"),
|
Ok((status, _)) => debug!("Heartbeat: {status}"),
|
||||||
Err(e) => warn!("Heartbeat failed: {e}"),
|
Err(e) => warn!("Heartbeat failed: {e}"),
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user