refactor: endpoint parity and proxy improvements

Mixed changes from recent sessions: endpoint feature parity
improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
Nikketryhard
2026-02-16 21:47:00 -06:00
parent 86675fd960
commit 637fbc0e54
5 changed files with 763 additions and 692 deletions

View File

@@ -223,7 +223,6 @@ pub(crate) async fn handle_completions(
} else {
state.mitm_store.clear_tools().await;
}
state.mitm_store.clear_active_function_call();
// ── Extract tool results from messages for MITM injection ──────────
// When OpenCode sends back tool results, the messages array contains:
@@ -340,7 +339,6 @@ pub(crate) async fn handle_completions(
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
}
state.mitm_store.set_tool_rounds(rounds).await;
state.mitm_store.clear_awaiting_tool_result();
}
}
@@ -441,6 +439,8 @@ pub(crate) async fn handle_completions(
// Send message on primary cascade
state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
@@ -452,9 +452,25 @@ pub(crate) async fn handle_completions(
})
.await;
}
// Pre-flight: install channel BEFORE send_message so the MITM proxy
// can grab it when the LS fires its API call.
// Only for streaming — sync paths use poll_for_response (legacy store).
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let mitm_rx = if has_custom_tools && body.stream {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await;
Some(rx)
} else {
None
};
match state
.backend
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref())
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
@@ -465,6 +481,7 @@ pub(crate) async fn handle_completions(
});
}
Ok((status, _)) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"),
@@ -472,6 +489,7 @@ pub(crate) async fn handle_completions(
);
}
Err(e) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"),
@@ -498,6 +516,7 @@ pub(crate) async fn handle_completions(
cascade_id,
body.timeout,
include_usage,
mitm_rx,
)
.await
} else if n <= 1 {
@@ -518,7 +537,7 @@ pub(crate) async fn handle_completions(
// Send the same message on each extra cascade
match state
.backend
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
.send_message_with_image(&cid, ".", model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
@@ -635,18 +654,18 @@ async fn chat_completions_stream(
cascade_id: String,
timeout: u64,
include_usage: bool,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response {
let stream = async_stream::stream! {
let start = std::time::Instant::now();
let mut last_text = String::new();
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let has_custom_tools = mitm_rx.is_some();
// Clear ALL stale state from previous requests
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
state.mitm_store.clear_active_function_call();
// Drain any stale function calls from previous requests
let _ = state.mitm_store.take_any_function_calls().await;
if !has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
}
// Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
@@ -659,7 +678,6 @@ async fn chat_completions_stream(
let mut last_thinking_len: usize = 0;
let mut complete_polls: u32 = 0;
let mut did_unblock_ls = false; // Prevents infinite unblock loops
let mut my_generation = state.mitm_store.current_generation();
// Helper: build usage JSON from MITM tokens
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
@@ -672,212 +690,219 @@ async fn chat_completions_stream(
})
};
while start.elapsed().as_secs() < timeout {
// Check for upstream errors from MITM (Google API errors)
if let Some(err) = state.mitm_store.take_upstream_error().await {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]".to_string()));
break;
}
// Take ownership of the pre-installed channel receiver
let mut rx_opt = mitm_rx;
// Bail if another completions handler has superseded us
if state.mitm_store.current_generation() != my_generation {
debug!("Completions: generation changed (superseded), ending stream");
while start.elapsed().as_secs() < timeout {
if let Some(ref mut rx) = rx_opt {
// ── Channel-based MITM pipeline ──
// Track accumulated text for delta computation
let mut acc_text = String::new();
let mut acc_thinking = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
'channel_loop: while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(full_thinking) => {
if full_thinking.len() > acc_thinking.len() {
let delta = full_thinking[acc_thinking.len()..].to_string();
acc_thinking = full_thinking;
last_thinking_len = acc_thinking.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > acc_text.len() {
let delta = full_text[acc_text.len()..].to_string();
acc_text = full_text;
last_text = acc_text.clone();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
}
}
MitmEvent::FunctionCall(calls) => {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else if let Some(ref u) = last_usage {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::ResponseComplete => {
if !acc_text.is_empty() {
// Have response text — done
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
acc_text.len(), acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
} else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate.
did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
*rx = new_rx;
debug!(
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.len()
);
continue 'channel_loop;
} else if !acc_thinking.is_empty() && did_unblock_ls {
// Already unblocked once, still thinking-only.
// Wait a bit for potential follow-up events.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
// Don't break — wait for more channel events
continue 'channel_loop;
} else {
// Empty response (no text, no thinking, no tools)
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: channel response complete but empty, ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
continue 'channel_loop;
}
}
MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::Usage(u) => {
last_usage = Some(u);
}
MitmEvent::Grounding(_) => {
// Grounding metadata handled by store directly
}
}
}
// Channel closed or timeout — clean up
state.mitm_store.drop_channel().await;
// If we got here from timeout with content, emit what we have
if !last_text.is_empty() || last_thinking_len > 0 {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
// ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
// Clear in-flight flag so the next turn's requests can get through
state.mitm_store.clear_response_async().await;
return;
}
}
// ── Primary: MITM-captured response (when custom tools are active) ──
// The MITM intercepts the real Google SSE stream and captures text,
// thinking, and function calls. This is the authoritative data source.
// The LS only gets rewritten responses (function calls → text placeholders)
// so it doesn't provide useful streaming data when MITM is active.
if has_custom_tools {
// Stream thinking text as reasoning_content deltas
if let Some(tc) = state.mitm_store.peek_thinking_text().await {
if tc.len() > last_thinking_len {
let delta = &tc[last_thinking_len..];
last_thinking_len = tc.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
// Stream response text as content deltas
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !delta.is_empty() {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
last_text = text;
}
}
}
// Check if MITM response is complete
if state.mitm_store.is_response_complete() {
if !last_text.is_empty() {
// Have actual response text — done
complete_polls += 1;
if complete_polls >= 2 {
debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else if last_thinking_len > 0 && !did_unblock_ls {
// Thinking-only response. The LS needs follow-up API calls
// to get actual function calls or text. Unblock once.
did_unblock_ls = true;
complete_polls = 0;
// Bump generation FIRST — invalidates old MITM connection's store writes
my_generation = state.mitm_store.bump_generation();
state.mitm_store.clear_request_in_flight();
state.mitm_store.clear_response_complete();
// Drain store so leaked connections can't produce stale content
state.mitm_store.set_response_text("").await;
state.mitm_store.set_thinking_text("").await;
let _ = state.mitm_store.take_any_function_calls().await;
debug!(
"Completions: thinking-only — unblocking LS for follow-up, thinking_len={}, new_gen={}",
last_thinking_len, my_generation
);
} else if last_thinking_len > 0 && did_unblock_ls {
// Already unblocked once. Still only thinking after follow-up.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout after ~10s, thinking_len={}", last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else {
// response_complete but no text AND no thinking — might be
// a function-call-only response that was already consumed,
// or empty response. Wait a bit then give up.
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
} else {
complete_polls = 0; // Reset — not complete yet
}
} else {
// ── Fallback: LS steps (no MITM capture active) ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
@@ -1001,7 +1026,7 @@ async fn chat_completions_stream(
}
})).unwrap()));
// Always clear in-flight flag when stream ends
state.mitm_store.clear_response_async().await;
state.mitm_store.drop_channel().await;
yield Ok(Event::default().data("[DONE]"));
};

View File

@@ -309,7 +309,11 @@ pub(crate) async fn handle_responses(
count = tools.len(),
"Stored client tools for MITM injection"
);
} else {
state.mitm_store.clear_tools().await;
}
} else {
state.mitm_store.clear_tools().await;
}
if let Some(ref choice) = body.tool_choice {
let gemini_config = openai_tool_choice_to_gemini(choice);
@@ -404,6 +408,8 @@ pub(crate) async fn handle_responses(
// Send message
state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image {
use base64::Engine;
@@ -415,9 +421,24 @@ pub(crate) async fn handle_responses(
})
.await;
}
// Pre-flight: install channel BEFORE send_message so the MITM proxy
// can grab it when the LS fires its API call.
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let mitm_rx = if has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await;
Some(rx)
} else {
None
};
match state
.backend
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref())
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
.await
{
Ok((200, _)) => {
@@ -428,6 +449,7 @@ pub(crate) async fn handle_responses(
});
}
Ok((status, _)) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"),
@@ -435,6 +457,7 @@ pub(crate) async fn handle_responses(
);
}
Err(e) => {
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"),
@@ -472,6 +495,7 @@ pub(crate) async fn handle_responses(
cascade_id,
body.timeout,
req_params,
mitm_rx,
)
.await
} else {
@@ -482,6 +506,7 @@ pub(crate) async fn handle_responses(
cascade_id,
body.timeout,
req_params,
mitm_rx,
)
.await
}
@@ -603,54 +628,54 @@ async fn handle_responses_sync(
cascade_id: String,
timeout: u64,
params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response {
let created_at = now_unix();
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear stale captured response and upstream errors
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
// Clear stale captured response and upstream errors (only if no pre-installed channel)
if mitm_rx.is_none() {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
}
// ── MITM bypass: poll MitmStore directly when custom tools active ──
if has_custom_tools {
// ── MITM bypass: channel-based pipeline when custom tools active ──
if let Some(mut rx) = mitm_rx {
let start = std::time::Instant::now();
while start.elapsed().as_secs() < timeout {
// Check for upstream errors from MITM (Google API errors)
if let Some(err) = state.mitm_store.take_upstream_error().await {
return upstream_err_response(&err);
}
// Check for function calls
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref raw_calls) = captured {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
if !calls.is_empty() {
let mut acc_text = String::new();
let mut acc_thinking: Option<String> = None;
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
MitmEvent::TextDelta(t) => { acc_text = t; }
MitmEvent::Usage(u) => { last_usage = Some(u); }
MitmEvent::Grounding(_) => {} // stored by proxy directly
MitmEvent::FunctionCall(raw_calls) => {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
let mut output_items: Vec<serde_json::Value> = Vec::new();
for fc in &calls {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
state
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items
.push(build_function_call_output(&call_id, &fc.name, &arguments));
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
}
let (usage, _) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
"",
)
.await;
&state.mitm_store, &cascade_id, &None, &params.user_text, "",
).await;
state.mitm_store.drop_channel().await;
let resp = build_response_object(
ResponseData {
id: response_id,
@@ -666,52 +691,61 @@ async fn handle_responses_sync(
);
return Json(resp).into_response();
}
}
MitmEvent::ResponseComplete => {
if acc_text.is_empty() && acc_thinking.is_none() {
// Empty response — continue waiting
continue;
}
if acc_text.is_empty() && acc_thinking.is_some() {
// Thinking-only — LS needs to make a follow-up request.
// Reinstall channel and unblock gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx;
debug!(
"Responses sync: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.as_ref().map(|t| t.len()).unwrap_or(0)
);
continue;
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &params.user_text, &acc_text,
).await;
state.mitm_store.drop_channel().await;
// Check for completed text response
if state.mitm_store.is_response_complete() {
let text = state
.mitm_store
.take_response_text()
.await
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await;
let (usage, _) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
&text,
)
.await;
let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = acc_thinking {
output_items.push(build_reasoning_output(t));
}
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
output_items.push(build_message_output(&msg_id, &acc_text));
let mut output_items: Vec<serde_json::Value> = Vec::new();
if let Some(ref t) = thinking {
output_items.push(build_reasoning_output(t));
let resp = build_response_object(
ResponseData {
id: response_id,
model: model_name,
status: "completed",
created_at,
completed_at: Some(now_unix()),
output: output_items,
usage: Some(usage),
thinking_signature: None,
},
&params,
);
return Json(resp).into_response();
}
MitmEvent::UpstreamError(err) => {
state.mitm_store.drop_channel().await;
return upstream_err_response(&err);
}
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
output_items.push(build_message_output(&msg_id, &text));
let resp = build_response_object(
ResponseData {
id: response_id,
model: model_name,
status: "completed",
created_at,
completed_at: Some(now_unix()),
output: output_items,
usage: Some(usage),
thinking_signature: None,
},
&params,
);
return Json(resp).into_response();
}
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
// Timeout — return proper error, not fake incomplete response
// Timeout
state.mitm_store.drop_channel().await;
return err_response(
StatusCode::GATEWAY_TIMEOUT,
format!("Timeout: no response from Google API after {timeout}s"),
@@ -835,6 +869,7 @@ async fn handle_responses_stream(
cascade_id: String,
timeout: u64,
params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response {
let stream = async_stream::stream! {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
@@ -886,50 +921,170 @@ async fn handle_responses_stream(
let mut thinking_text: Option<String> = None;
let mut message_started = false;
let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear stale captured response and upstream errors
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
// Clear stale response (only if no pre-installed channel)
if mitm_rx.is_none() {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
}
// ── MITM bypass mode (when custom tools are active) ──
// Skip LS entirely — read text, thinking, and tool calls directly from MitmStore.
if has_custom_tools {
// Channel-based pipeline: read events directly from MITM proxy.
// Channel is pre-installed before send_message to avoid race conditions.
if let Some(mut rx) = mitm_rx {
let mut last_thinking = String::new();
while start.elapsed().as_secs() < timeout {
// Check for upstream errors from MITM (Google API errors)
if let Some(err) = state.mitm_store.take_upstream_error().await {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(responses_sse_event(
"response.failed",
serde_json::json!({
"type": "response.failed",
"sequence_number": next_seq(),
"response": {
"id": &response_id,
"status": "failed",
"error": {
"type": error_type,
"message": error_msg,
"code": err.status,
},
},
}),
));
break;
}
while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(full_thinking) => {
if !thinking_emitted && full_thinking.len() > last_thinking.len() {
// First thinking text — emit reasoning output_item.added
if last_thinking.is_empty() {
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [],
},
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.added",
serde_json::json!({
"type": "response.reasoning_summary_part.added",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": "" },
}),
));
}
let delta = &full_thinking[last_thinking.len()..];
if !delta.is_empty() {
yield Ok(responses_sse_event(
"response.reasoning_summary_text.delta",
serde_json::json!({
"type": "response.reasoning_summary_text.delta",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"delta": delta,
}),
));
}
last_thinking = full_thinking;
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > last_text.len() {
// Finalize thinking if started but not done
if !thinking_emitted && !last_thinking.is_empty() {
thinking_emitted = true;
thinking_text = Some(last_thinking.clone());
yield Ok(responses_sse_event(
"response.reasoning_summary_text.done",
serde_json::json!({
"type": "response.reasoning_summary_text.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"text": &last_thinking,
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.done",
serde_json::json!({
"type": "response.reasoning_summary_part.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": &last_thinking },
}),
));
yield Ok(responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [{
"type": "summary_text",
"text": &last_thinking,
}],
},
}),
));
}
// Check for function calls first
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref raw_calls) = captured {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
if !calls.is_empty() {
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
if !message_started {
message_started = true;
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"item": build_message_output_in_progress(&msg_id),
}),
));
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
}
let delta = &full_text[last_text.len()..];
if !delta.is_empty() {
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"delta": delta,
}),
));
last_text = full_text;
}
}
}
MitmEvent::FunctionCall(raw_calls) => {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
@@ -1011,194 +1166,71 @@ async fn handle_responses_stream(
"response": response_to_json(&final_resp),
}),
));
state.mitm_store.drop_channel().await;
return;
}
}
// Stream thinking text in real-time
if !thinking_emitted {
if let Some(thinking) = state.mitm_store.peek_thinking_text().await {
if !thinking.is_empty() && thinking != last_thinking {
// First thinking text — emit reasoning output_item.added
if last_thinking.is_empty() {
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [],
},
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.added",
serde_json::json!({
"type": "response.reasoning_summary_part.added",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": "" },
}),
));
MitmEvent::ResponseComplete => {
if !last_text.is_empty() {
let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, &last_text,
).await;
let tc = thinking_text.clone();
for evt in completion_events(
&response_id, &model_name, &msg_id, &reasoning_id,
msg_idx, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, None, tc,
) {
yield Ok(evt);
}
// Delta of new thinking text
let delta = if thinking.len() > last_thinking.len()
&& thinking.starts_with(&*last_thinking)
{
thinking[last_thinking.len()..].to_string()
} else {
thinking.clone()
};
if !delta.is_empty() {
yield Ok(responses_sse_event(
"response.reasoning_summary_text.delta",
serde_json::json!({
"type": "response.reasoning_summary_text.delta",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"delta": &delta,
}),
));
}
last_thinking = thinking;
state.mitm_store.drop_channel().await;
return;
} else if !last_thinking.is_empty() {
// Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx;
debug!(
"Responses stream: thinking-only — new channel for follow-up, thinking_len={}",
last_thinking.len()
);
}
// ResponseComplete with no text and no thinking — continue waiting
}
}
// Stream response text
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text {
// Finalize thinking if started but not done
if !thinking_emitted && !last_thinking.is_empty() {
thinking_emitted = true;
thinking_text = Some(last_thinking.clone());
yield Ok(responses_sse_event(
"response.reasoning_summary_text.done",
serde_json::json!({
"type": "response.reasoning_summary_text.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"text": &last_thinking,
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.done",
serde_json::json!({
"type": "response.reasoning_summary_part.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": &last_thinking },
}),
));
yield Ok(responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [{
"type": "summary_text",
"text": &last_thinking,
}],
MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(responses_sse_event(
"response.failed",
serde_json::json!({
"type": "response.failed",
"sequence_number": next_seq(),
"response": {
"id": &response_id,
"status": "failed",
"error": {
"type": error_type,
"message": error_msg,
"code": err.status,
},
}),
));
}
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
if !message_started {
message_started = true;
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"item": build_message_output_in_progress(&msg_id),
}),
));
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
}
let new_content = if text.len() > last_text.len()
&& text.starts_with(&*last_text)
{
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !new_content.is_empty() {
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"delta": &new_content,
}),
));
last_text = text;
}
}
// Check if response is complete
if state.mitm_store.is_response_complete() && !last_text.is_empty() {
let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, &last_text,
).await;
let tc = thinking_text.clone();
for evt in completion_events(
&response_id, &model_name, &msg_id, &reasoning_id,
msg_idx, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, None, tc,
) {
yield Ok(evt);
}
},
}),
));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {
// Usage/grounding stored by proxy, consumed via usage_from_poll
}
}
// Poll interval
let poll_ms: u64 = rand::thread_rng().gen_range(150..300);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
}
// Timeout in bypass mode — emit error, not fake incomplete
// Timeout in channel mode
state.mitm_store.drop_channel().await;
yield Ok(responses_sse_event(
"response.failed",
serde_json::json!({

View File

@@ -163,11 +163,13 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
// Set active cascade for MITM correlation
state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real search prompt for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(search_prompt.clone()).await;
// Send the search message
if let Err(e) = state
.backend
.send_message(&cascade_id, &search_prompt, model.model_enum)
.send_message(&cascade_id, ".", model.model_enum)
.await
{
state.mitm_store.clear_active_cascade().await;

View File

@@ -538,9 +538,10 @@ async fn handle_http_over_tls(
}
};
// Generation tracking for store write guards
let mut won_gate = false;
let mut conn_generation = store.current_generation();
// Channel-based event pipeline: grab the channel sender if one exists.
// If the API handler set up a channel, we send events through it.
// Otherwise, we fall back to legacy store writes (search endpoint, etc.).
let mut event_tx: Option<tokio::sync::mpsc::Sender<super::store::MitmEvent>> = None;
// Log LLM calls at info, everything else at debug
if req_path.contains("streamGenerateContent") {
@@ -558,7 +559,7 @@ async fn handle_http_over_tls(
// When custom tools are active, only the FIRST request wins the
// atomic compare_exchange. All others get fake STOP responses.
let has_tools = store.get_tools().await.is_some();
won_gate = if has_tools {
if has_tools {
if !store.try_mark_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\
@@ -575,13 +576,11 @@ async fn handle_http_over_tls(
let _ = client.flush().await;
continue;
}
true
} else {
false
};
// Snapshot the generation at gate-win time. If it changes later,
// another completions turn started and our data is stale.
conn_generation = store.current_generation();
// Grab the channel sender — the API handler installed it before
// sending the LS message. If it's gone, we still proceed but
// fall back to legacy store writes.
event_tx = store.take_channel().await;
}
// ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk
@@ -603,12 +602,14 @@ async fn handle_http_over_tls(
let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await;
let tool_rounds = store.get_tool_rounds().await;
let pending_user_text = store.take_pending_user_text().await;
let tool_ctx = if tools.is_some()
|| !pending_results.is_empty()
|| !tool_rounds.is_empty()
|| generation_params.is_some()
|| pending_image.is_some()
|| pending_user_text.is_some()
{
Some(super::modify::ToolContext {
tools,
@@ -618,6 +619,7 @@ async fn handle_http_over_tls(
generation_params,
pending_image,
tool_rounds,
pending_user_text,
})
} else {
None
@@ -794,14 +796,18 @@ async fn handle_http_over_tls(
})
.unwrap_or((None, None));
store
.set_upstream_error(super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
})
.await;
let upstream_err = super::store::UpstreamError {
status: http_status,
body: body_str,
message,
error_status,
};
// Send through channel if available, otherwise store for legacy consumers
if let Some(ref tx) = event_tx {
let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await;
} else {
store.set_upstream_error(upstream_err).await;
}
}
// Save body for usage parsing
@@ -812,46 +818,59 @@ async fn handle_http_over_tls(
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc);
// Only write to store if our generation is still current.
// If another completions turn started, our data is stale.
let gen_valid = !won_gate || store.current_generation() == conn_generation;
if gen_valid {
// Store captured function calls (drain to avoid re-storing on next chunk)
// Send events through channel if available, otherwise use legacy store
if let Some(ref tx) = event_tx {
// Function calls → channel event
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> =
streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
store.set_last_function_calls(calls.clone()).await;
info!(
"MITM: stored {} function call(s) from initial body",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
info!("MITM: sending {} function call(s) via channel (initial body)", calls.len());
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
}
// Thinking delta → channel event
if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await;
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
}
// Text delta → channel event
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
}
// Response complete → channel event
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
"MITM: response complete (initial body) — marking store"
"MITM: response complete (initial body) — sending via channel"
);
store.mark_response_complete();
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers (search, etc.)
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (initial body)");
}
}
@@ -890,45 +909,60 @@ async fn handle_http_over_tls(
let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc);
// Only write to store if our generation is still current.
let gen_valid = !won_gate || store.current_generation() == conn_generation;
if gen_valid {
// Store captured function calls (drain to avoid re-storing on next chunk)
// Send events through channel if available, otherwise use legacy store
if let Some(ref tx) = event_tx {
// Function calls → channel event
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await;
info!(
"MITM: stored {} function call(s) from body chunk",
calls.len()
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len());
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
}
// Thinking delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await;
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
}
// Text delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
}
// Response complete → channel event
if streaming_acc.is_complete {
info!(
response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(),
function_calls = streaming_acc.function_calls.len(),
"MITM: response complete — marking store"
"MITM: response complete — sending via channel"
);
store.mark_response_complete();
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
}
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (body chunk)");
}
}
@@ -969,9 +1003,11 @@ async fn handle_http_over_tls(
store.set_grounding(gm.clone()).await;
}
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
// Function calls are stored immediately when detected (above),
// so no need to store them again here.
let usage = streaming_acc.into_usage();
// Send usage through channel if available
if let Some(ref tx) = event_tx {
let _ = tx.send(super::store::MitmEvent::Usage(usage.clone())).await;
}
store.record_usage(cascade_hint.as_deref(), usage).await;
}
} else if !response_body_buf.is_empty() {

View File

@@ -1,12 +1,14 @@
//! Shared store for intercepted API usage data.
//!
//! The MITM proxy writes usage data here; the API handlers read from it.
//! When custom tools are active, the MITM proxy sends real-time events
//! through a channel instead of writing to shared state.
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info};
/// Token usage from an intercepted API response.
@@ -126,6 +128,29 @@ pub struct GenerationParams {
pub google_search: bool,
}
// ─── Channel-based event pipeline ────────────────────────────────────────────
/// Events sent from the MITM proxy to API handlers through a per-request channel.
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
/// instant, race-free delivery.
#[derive(Debug, Clone)]
pub enum MitmEvent {
/// Incremental thinking/reasoning text from the model.
ThinkingDelta(String),
/// Incremental response text from the model.
TextDelta(String),
/// Model requested function call(s).
FunctionCall(Vec<CapturedFunctionCall>),
/// Response streaming is complete (finishReason received).
ResponseComplete,
/// Google API returned an error.
UpstreamError(UpstreamError),
/// Grounding metadata (search results) from the response.
Grounding(serde_json::Value),
/// Token usage data from the response.
Usage(ApiUsage),
}
/// Thread-safe store for intercepted data.
///
/// Keyed by a unique request ID that we can correlate with cascade operations.
@@ -138,20 +163,17 @@ pub struct MitmStore {
stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls.
/// Used by the non-tool LS path (normal sync responses).
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>,
/// Persistent flag: set when a function call is captured, cleared ONLY when
/// a tool result is submitted. Prevents the LS from making follow-up API
/// calls during the entire tool execution cycle.
awaiting_tool_result: Arc<AtomicBool>,
/// Set when the MITM forwards the first LLM request with custom tools.
/// Blocks ALL subsequent LS requests until the API handler clears it.
request_in_flight: Arc<AtomicBool>,
/// Generation counter — incremented each time a new completions turn starts.
/// Used to discard stale data from leaked LS connections.
request_generation: Arc<AtomicU64>,
// ── Channel-based event pipeline (replaces old polling) ──────────────
/// Active channel sender for the current tool-path request.
/// When present, the MITM proxy sends events through this instead of
/// writing to shared state. The channel's existence = request in-flight.
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
// ── Tool call support ────────────────────────────────────────────────
/// Active tool definitions (Gemini format) for MITM injection.
@@ -173,14 +195,9 @@ pub struct MitmStore {
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
active_cascade_id: Arc<RwLock<Option<String>>>,
// ── Direct response capture (bypasses LS) ────────────────────────────
/// Captured response text from MITM when custom tools are active.
/// The completions/responses handler reads this instead of polling LS steps.
// ── Legacy direct response capture (used by search.rs) ───────────────
/// Captured response text from MITM. Used as fallback by search endpoint.
captured_response_text: Arc<RwLock<Option<String>>>,
/// Captured thinking/reasoning text from MITM (for real-time streaming).
captured_thinking_text: Arc<RwLock<Option<String>>>,
/// Whether the captured response is complete (finishReason received).
response_complete: Arc<AtomicBool>,
// ── Generation parameters for MITM injection ─────────────────────────
/// Client-specified sampling parameters to inject into Google API requests.
@@ -194,9 +211,14 @@ pub struct MitmStore {
/// Image to inject into the next Google API request via MITM.
pending_image: Arc<RwLock<Option<PendingImage>>>,
// ── Upstream error capture ───────────────────────────────────────────
// ── Upstream error capture (legacy, used when no channel) ────────────
/// Error from Google's API, captured by MITM for forwarding to client.
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
// ── Standard LS input: real user text for MITM injection ─────────────
/// The real user text to inject into the Google API request.
/// API handlers store this before sending a dummy prompt to the LS.
pending_user_text: Arc<RwLock<Option<String>>>,
}
/// Aggregate statistics across all intercepted traffic.
@@ -229,10 +251,8 @@ impl MitmStore {
latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)),
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
request_in_flight: Arc::new(AtomicBool::new(false)),
request_generation: Arc::new(AtomicU64::new(0)),
active_channel: Arc::new(RwLock::new(None)),
active_tools: Arc::new(RwLock::new(None)),
active_tool_config: Arc::new(RwLock::new(None)),
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
@@ -241,12 +261,11 @@ impl MitmStore {
tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)),
captured_thinking_text: Arc::new(RwLock::new(None)),
response_complete: Arc::new(AtomicBool::new(false)),
generation_params: Arc::new(RwLock::new(None)),
captured_grounding: Arc::new(RwLock::new(None)),
pending_image: Arc::new(RwLock::new(None)),
upstream_error: Arc::new(RwLock::new(None)),
pending_user_text: Arc::new(RwLock::new(None)),
}
}
@@ -381,30 +400,6 @@ impl MitmStore {
);
let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst);
self.awaiting_tool_result.store(true, Ordering::SeqCst);
}
/// Check if there's an active (unclaimed) function call.
pub fn has_active_function_call(&self) -> bool {
self.has_active_function_call.load(Ordering::SeqCst)
}
/// Force-clear the active function call flag (used to reset stale state).
pub fn clear_active_function_call(&self) {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
/// This persists across function call consumption — only cleared when
/// actual tool results are submitted.
pub fn is_awaiting_tool_result(&self) -> bool {
self.awaiting_tool_result.load(Ordering::SeqCst)
}
/// Clear the awaiting-tool-result flag (called when tool results arrive).
pub fn clear_awaiting_tool_result(&self) {
self.awaiting_tool_result.store(false, Ordering::SeqCst);
}
/// Take pending function calls for a specific cascade.
@@ -417,7 +412,6 @@ impl MitmStore {
// 1. Exact cascade match
if let Some(result) = pending.remove(cascade_id) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
@@ -425,7 +419,6 @@ impl MitmStore {
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
if active != cascade_id {
if let Some(result) = pending.remove(active.as_str()) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
}
@@ -433,17 +426,12 @@ impl MitmStore {
// 3. Fallback to _latest
if let Some(result) = pending.remove("_latest") {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result);
}
// 4. Last resort: any key
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
return pending.remove(&key);
}
None
@@ -455,19 +443,40 @@ impl MitmStore {
let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest");
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
return result;
}
if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
return pending.remove(&key);
}
None
}
// ── Channel-based event pipeline ─────────────────────────────────────
/// Install a channel sender for the current tool-path request.
/// The MITM proxy will send events through this channel.
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
*self.active_channel.write().await = Some(tx);
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
// try_mark_request_in_flight() is the sole setter — setting it
// here causes compare_exchange(false,true) to always fail,
// blocking every real LS request.
}
/// Take the active channel sender (used by MITM proxy to grab it).
/// Returns None if no channel is active.
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
self.active_channel.write().await.take()
}
/// Drop the active channel and clear in-flight state.
/// Called when the API handler is done with the current request.
pub async fn drop_channel(&self) {
*self.active_channel.write().await = None;
self.request_in_flight.store(false, Ordering::SeqCst);
}
// ── Tool context methods ─────────────────────────────────────────────
/// Set active tool definitions (already in Gemini format).
@@ -546,10 +555,10 @@ impl MitmStore {
self.tool_rounds.read().await.clone()
}
// ── Direct response capture (bypass LS) ──────────────────────────────
// ── Legacy direct response capture (search.rs fallback) ──────────────
/// Set (replace) the captured response text.
/// Used by MITM proxy for non-channel path (search endpoint fallback).
pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string());
}
@@ -559,28 +568,12 @@ impl MitmStore {
self.captured_response_text.write().await.take()
}
/// Peek at the captured response text without consuming it.
pub async fn peek_response_text(&self) -> Option<String> {
self.captured_response_text.read().await.clone()
}
/// Mark the response as complete.
pub fn mark_response_complete(&self) {
self.response_complete.store(true, Ordering::SeqCst);
}
/// Check if the response is complete.
pub fn is_response_complete(&self) -> bool {
self.response_complete.load(Ordering::SeqCst)
}
/// Async version of clear_response. Bumps generation counter.
/// Clear stale state between requests.
/// Drops any active channel and clears in-flight flags.
pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst);
self.request_in_flight.store(false, Ordering::SeqCst);
self.request_generation.fetch_add(1, Ordering::SeqCst);
*self.active_channel.write().await = None;
*self.captured_response_text.write().await = None;
*self.captured_thinking_text.write().await = None;
}
/// Atomically try to mark request as in-flight.
@@ -593,6 +586,7 @@ impl MitmStore {
}
/// Check if a request is currently in-flight.
#[allow(dead_code)]
pub fn is_request_in_flight(&self) -> bool {
self.request_in_flight.load(Ordering::SeqCst)
}
@@ -602,38 +596,6 @@ impl MitmStore {
self.request_in_flight.store(false, Ordering::SeqCst);
}
/// Reset response_complete so we can wait for the next response.
pub fn clear_response_complete(&self) {
self.response_complete.store(false, Ordering::SeqCst);
}
/// Get current generation number.
pub fn current_generation(&self) -> u64 {
self.request_generation.load(Ordering::SeqCst)
}
/// Bump generation counter (invalidates all pending data from old generation).
pub fn bump_generation(&self) -> u64 {
self.request_generation.fetch_add(1, Ordering::SeqCst) + 1
}
// ── Thinking text capture ────────────────────────────────────────────
/// Set (replace) the captured thinking text.
pub async fn set_thinking_text(&self, text: &str) {
*self.captured_thinking_text.write().await = Some(text.to_string());
}
/// Peek at the captured thinking text without consuming it.
pub async fn peek_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.read().await.clone()
}
/// Take the captured thinking text (consumes it).
pub async fn take_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.write().await.take()
}
// ── Cascade correlation ──────────────────────────────────────────────
/// Set the active cascade ID (called by API handlers before sending a message).
@@ -718,4 +680,18 @@ impl MitmStore {
pub async fn clear_upstream_error(&self) {
*self.upstream_error.write().await = None;
}
// ── Pending user text for MITM injection ─────────────────────────────
/// Store the real user text for MITM injection.
/// Called by API handlers before sending a dummy prompt to the LS.
pub async fn set_pending_user_text(&self, text: String) {
*self.pending_user_text.write().await = Some(text);
}
/// Take (consume) the pending user text.
/// Called by the MITM proxy when building ToolContext.
pub async fn take_pending_user_text(&self) -> Option<String> {
self.pending_user_text.write().await.take()
}
}