refactor: endpoint parity and proxy improvements

Mixed changes from recent sessions: endpoint feature parity
improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
Nikketryhard
2026-02-16 21:47:00 -06:00
parent 86675fd960
commit 637fbc0e54
5 changed files with 763 additions and 692 deletions

View File

@@ -223,7 +223,6 @@ pub(crate) async fn handle_completions(
} else { } else {
state.mitm_store.clear_tools().await; state.mitm_store.clear_tools().await;
} }
state.mitm_store.clear_active_function_call();
// ── Extract tool results from messages for MITM injection ────────── // ── Extract tool results from messages for MITM injection ──────────
// When OpenCode sends back tool results, the messages array contains: // When OpenCode sends back tool results, the messages array contains:
@@ -340,7 +339,6 @@ pub(crate) async fn handle_completions(
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await; state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
} }
state.mitm_store.set_tool_rounds(rounds).await; state.mitm_store.set_tool_rounds(rounds).await;
state.mitm_store.clear_awaiting_tool_result();
} }
} }
@@ -441,6 +439,8 @@ pub(crate) async fn handle_completions(
// Send message on primary cascade // Send message on primary cascade
state.mitm_store.set_active_cascade(&cascade_id).await; state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API) // Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image { if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
@@ -452,9 +452,25 @@ pub(crate) async fn handle_completions(
}) })
.await; .await;
} }
// Pre-flight: install channel BEFORE send_message so the MITM proxy
// can grab it when the LS fires its API call.
// Only for streaming — sync paths use poll_for_response (legacy store).
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let mitm_rx = if has_custom_tools && body.stream {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await;
Some(rx)
} else {
None
};
match state match state
.backend .backend
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref()) .send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -465,6 +481,7 @@ pub(crate) async fn handle_completions(
}); });
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.drop_channel().await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Backend returned {status}"), format!("Backend returned {status}"),
@@ -472,6 +489,7 @@ pub(crate) async fn handle_completions(
); );
} }
Err(e) => { Err(e) => {
state.mitm_store.drop_channel().await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send failed: {e}"), format!("Send failed: {e}"),
@@ -498,6 +516,7 @@ pub(crate) async fn handle_completions(
cascade_id, cascade_id,
body.timeout, body.timeout,
include_usage, include_usage,
mitm_rx,
) )
.await .await
} else if n <= 1 { } else if n <= 1 {
@@ -518,7 +537,7 @@ pub(crate) async fn handle_completions(
// Send the same message on each extra cascade // Send the same message on each extra cascade
match state match state
.backend .backend
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref()) .send_message_with_image(&cid, ".", model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -635,18 +654,18 @@ async fn chat_completions_stream(
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
include_usage: bool, include_usage: bool,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response { ) -> axum::response::Response {
let stream = async_stream::stream! { let stream = async_stream::stream! {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let mut last_text = String::new(); let mut last_text = String::new();
let has_custom_tools = state.mitm_store.get_tools().await.is_some(); let has_custom_tools = mitm_rx.is_some();
// Clear ALL stale state from previous requests if !has_custom_tools {
state.mitm_store.clear_response_async().await; state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await; state.mitm_store.clear_upstream_error().await;
state.mitm_store.clear_active_function_call(); let _ = state.mitm_store.take_any_function_calls().await;
// Drain any stale function calls from previous requests }
let _ = state.mitm_store.take_any_function_calls().await;
// Initial role chunk // Initial role chunk
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json( yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
@@ -659,7 +678,6 @@ async fn chat_completions_stream(
let mut last_thinking_len: usize = 0; let mut last_thinking_len: usize = 0;
let mut complete_polls: u32 = 0; let mut complete_polls: u32 = 0;
let mut did_unblock_ls = false; // Prevents infinite unblock loops let mut did_unblock_ls = false; // Prevents infinite unblock loops
let mut my_generation = state.mitm_store.current_generation();
// Helper: build usage JSON from MITM tokens // Helper: build usage JSON from MITM tokens
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value { let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
@@ -672,212 +690,219 @@ async fn chat_completions_stream(
}) })
}; };
while start.elapsed().as_secs() < timeout { // Take ownership of the pre-installed channel receiver
// Check for upstream errors from MITM (Google API errors) let mut rx_opt = mitm_rx;
if let Some(err) = state.mitm_store.take_upstream_error().await {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]".to_string()));
break;
}
// Bail if another completions handler has superseded us while start.elapsed().as_secs() < timeout {
if state.mitm_store.current_generation() != my_generation { if let Some(ref mut rx) = rx_opt {
debug!("Completions: generation changed (superseded), ending stream"); // ── Channel-based MITM pipeline ──
// Track accumulated text for delta computation
let mut acc_text = String::new();
let mut acc_thinking = String::new();
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
'channel_loop: while let Some(event) = tokio::time::timeout(
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
rx.recv(),
).await.ok().flatten() {
use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(full_thinking) => {
if full_thinking.len() > acc_thinking.len() {
let delta = full_thinking[acc_thinking.len()..].to_string();
acc_thinking = full_thinking;
last_thinking_len = acc_thinking.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > acc_text.len() {
let delta = full_text[acc_text.len()..].to_string();
acc_text = full_text;
last_text = acc_text.clone();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
}
}
MitmEvent::FunctionCall(calls) => {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else if let Some(ref u) = last_usage {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::ResponseComplete => {
if !acc_text.is_empty() {
// Have response text — done
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
acc_text.len(), acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
} else if !acc_thinking.is_empty() && !did_unblock_ls {
// Thinking-only response — LS needs follow-up API calls.
// Create a new channel and unblock the gate.
did_unblock_ls = true;
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
*rx = new_rx;
debug!(
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.len()
);
continue 'channel_loop;
} else if !acc_thinking.is_empty() && did_unblock_ls {
// Already unblocked once, still thinking-only.
// Wait a bit for potential follow-up events.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await)
.or(last_usage.take());
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
// Don't break — wait for more channel events
continue 'channel_loop;
} else {
// Empty response (no text, no thinking, no tools)
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: channel response complete but empty, ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
continue 'channel_loop;
}
}
MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
let error_type = super::util::upstream_error_type(&err);
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
"error": {
"message": error_msg,
"type": error_type,
"code": err.status,
}
})).unwrap()));
yield Ok(Event::default().data("[DONE]"));
state.mitm_store.drop_channel().await;
return;
}
MitmEvent::Usage(u) => {
last_usage = Some(u);
}
MitmEvent::Grounding(_) => {
// Grounding metadata handled by store directly
}
}
}
// Channel closed or timeout — clean up
state.mitm_store.drop_channel().await;
// If we got here from timeout with content, emit what we have
if !last_text.is_empty() || last_thinking_len > 0 {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
}
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
return; return;
}
// ── Check for MITM-captured function calls FIRST ──
// This runs independently of LS steps — the MITM captures tool calls
// at the proxy layer, so we don't need to wait for LS processing.
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref calls) = captured {
if !calls.is_empty() {
let mut tool_calls = Vec::new();
for (i, fc) in calls.iter().enumerate() {
let call_id = format!(
"call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
);
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
tool_calls.push(serde_json::json!({
"index": i,
"id": call_id,
"type": "function",
"function": {
"name": fc.name,
"arguments": arguments,
},
}));
}
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
None,
)));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
None,
)));
if include_usage {
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
// Clear in-flight flag so the next turn's requests can get through
state.mitm_store.clear_response_async().await;
return;
}
}
// ── Primary: MITM-captured response (when custom tools are active) ──
// The MITM intercepts the real Google SSE stream and captures text,
// thinking, and function calls. This is the authoritative data source.
// The LS only gets rewritten responses (function calls → text placeholders)
// so it doesn't provide useful streaming data when MITM is active.
if has_custom_tools {
// Stream thinking text as reasoning_content deltas
if let Some(tc) = state.mitm_store.peek_thinking_text().await {
if tc.len() > last_thinking_len {
let delta = &tc[last_thinking_len..];
last_thinking_len = tc.len();
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
None,
)));
}
}
// Stream response text as content deltas
if let Some(text) = state.mitm_store.peek_response_text().await {
if !text.is_empty() && text != last_text {
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !delta.is_empty() {
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
None,
)));
last_text = text;
}
}
}
// Check if MITM response is complete
if state.mitm_store.is_response_complete() {
if !last_text.is_empty() {
// Have actual response text — done
complete_polls += 1;
if complete_polls >= 2 {
debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else if last_thinking_len > 0 && !did_unblock_ls {
// Thinking-only response. The LS needs follow-up API calls
// to get actual function calls or text. Unblock once.
did_unblock_ls = true;
complete_polls = 0;
// Bump generation FIRST — invalidates old MITM connection's store writes
my_generation = state.mitm_store.bump_generation();
state.mitm_store.clear_request_in_flight();
state.mitm_store.clear_response_complete();
// Drain store so leaked connections can't produce stale content
state.mitm_store.set_response_text("").await;
state.mitm_store.set_thinking_text("").await;
let _ = state.mitm_store.take_any_function_calls().await;
debug!(
"Completions: thinking-only — unblocking LS for follow-up, thinking_len={}, new_gen={}",
last_thinking_len, my_generation
);
} else if last_thinking_len > 0 && did_unblock_ls {
// Already unblocked once. Still only thinking after follow-up.
complete_polls += 1;
if complete_polls >= 25 {
info!("Completions: thinking-only timeout after ~10s, thinking_len={}", last_thinking_len);
let mitm = state.mitm_store.take_usage(&cascade_id).await
.or(state.mitm_store.take_usage("_latest").await);
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
None,
)));
if include_usage {
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
} else { (0, 0, 0, 0) };
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([]),
Some(build_usage(pt, ct, crt, tt)),
)));
}
yield Ok(Event::default().data("[DONE]"));
return;
}
} else {
// response_complete but no text AND no thinking — might be
// a function-call-only response that was already consumed,
// or empty response. Wait a bit then give up.
complete_polls += 1;
if complete_polls >= 4 {
info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream");
yield Ok(Event::default().data(chunk_json(
&completion_id, &model_name,
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
None,
)));
yield Ok(Event::default().data("[DONE]"));
return;
}
}
} else {
complete_polls = 0; // Reset — not complete yet
}
} else { } else {
// ── Fallback: LS steps (no MITM capture active) ── // ── Fallback: LS steps (no MITM capture active) ──
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await { if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
@@ -1001,7 +1026,7 @@ async fn chat_completions_stream(
} }
})).unwrap())); })).unwrap()));
// Always clear in-flight flag when stream ends // Always clear in-flight flag when stream ends
state.mitm_store.clear_response_async().await; state.mitm_store.drop_channel().await;
yield Ok(Event::default().data("[DONE]")); yield Ok(Event::default().data("[DONE]"));
}; };

View File

@@ -309,7 +309,11 @@ pub(crate) async fn handle_responses(
count = tools.len(), count = tools.len(),
"Stored client tools for MITM injection" "Stored client tools for MITM injection"
); );
} else {
state.mitm_store.clear_tools().await;
} }
} else {
state.mitm_store.clear_tools().await;
} }
if let Some(ref choice) = body.tool_choice { if let Some(ref choice) = body.tool_choice {
let gemini_config = openai_tool_choice_to_gemini(choice); let gemini_config = openai_tool_choice_to_gemini(choice);
@@ -404,6 +408,8 @@ pub(crate) async fn handle_responses(
// Send message // Send message
state.mitm_store.set_active_cascade(&cascade_id).await; state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real user text for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(user_text.clone()).await;
// Store image for MITM injection (LS doesn't forward images to Google API) // Store image for MITM injection (LS doesn't forward images to Google API)
if let Some(ref img) = image { if let Some(ref img) = image {
use base64::Engine; use base64::Engine;
@@ -415,9 +421,24 @@ pub(crate) async fn handle_responses(
}) })
.await; .await;
} }
// Pre-flight: install channel BEFORE send_message so the MITM proxy
// can grab it when the LS fires its API call.
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
let mitm_rx = if has_custom_tools {
state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
let _ = state.mitm_store.take_any_function_calls().await;
let (tx, rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(tx).await;
Some(rx)
} else {
None
};
match state match state
.backend .backend
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref()) .send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
.await .await
{ {
Ok((200, _)) => { Ok((200, _)) => {
@@ -428,6 +449,7 @@ pub(crate) async fn handle_responses(
}); });
} }
Ok((status, _)) => { Ok((status, _)) => {
state.mitm_store.drop_channel().await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Antigravity returned {status}"), format!("Antigravity returned {status}"),
@@ -435,6 +457,7 @@ pub(crate) async fn handle_responses(
); );
} }
Err(e) => { Err(e) => {
state.mitm_store.drop_channel().await;
return err_response( return err_response(
StatusCode::BAD_GATEWAY, StatusCode::BAD_GATEWAY,
format!("Send message failed: {e}"), format!("Send message failed: {e}"),
@@ -472,6 +495,7 @@ pub(crate) async fn handle_responses(
cascade_id, cascade_id,
body.timeout, body.timeout,
req_params, req_params,
mitm_rx,
) )
.await .await
} else { } else {
@@ -482,6 +506,7 @@ pub(crate) async fn handle_responses(
cascade_id, cascade_id,
body.timeout, body.timeout,
req_params, req_params,
mitm_rx,
) )
.await .await
} }
@@ -603,54 +628,54 @@ async fn handle_responses_sync(
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
params: RequestParams, params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response { ) -> axum::response::Response {
let created_at = now_unix(); let created_at = now_unix();
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear stale captured response and upstream errors // Clear stale captured response and upstream errors (only if no pre-installed channel)
state.mitm_store.clear_response_async().await; if mitm_rx.is_none() {
state.mitm_store.clear_upstream_error().await; state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
}
// ── MITM bypass: poll MitmStore directly when custom tools active ── // ── MITM bypass: channel-based pipeline when custom tools active ──
if has_custom_tools { if let Some(mut rx) = mitm_rx {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
while start.elapsed().as_secs() < timeout {
// Check for upstream errors from MITM (Google API errors)
if let Some(err) = state.mitm_store.take_upstream_error().await {
return upstream_err_response(&err);
}
// Check for function calls let mut acc_text = String::new();
let captured = state.mitm_store.take_function_calls(&cascade_id).await; let mut acc_thinking: Option<String> = None;
if let Some(ref raw_calls) = captured { let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect() while let Some(event) = tokio::time::timeout(
} else { std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
raw_calls.iter().collect() rx.recv(),
}; ).await.ok().flatten() {
if !calls.is_empty() { use crate::mitm::store::MitmEvent;
match event {
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
MitmEvent::TextDelta(t) => { acc_text = t; }
MitmEvent::Usage(u) => { last_usage = Some(u); }
MitmEvent::Grounding(_) => {} // stored by proxy directly
MitmEvent::FunctionCall(raw_calls) => {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
let mut output_items: Vec<serde_json::Value> = Vec::new(); let mut output_items: Vec<serde_json::Value> = Vec::new();
for fc in &calls { for fc in &calls {
let call_id = format!( let call_id = format!(
"call_{}", "call_{}",
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string() uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
); );
state state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
.mitm_store
.register_call_id(call_id.clone(), fc.name.clone())
.await;
let arguments = serde_json::to_string(&fc.args).unwrap_or_default(); let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
output_items output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
.push(build_function_call_output(&call_id, &fc.name, &arguments));
} }
let (usage, _) = usage_from_poll( let (usage, _) = usage_from_poll(
&state.mitm_store, &state.mitm_store, &cascade_id, &None, &params.user_text, "",
&cascade_id, ).await;
&None, state.mitm_store.drop_channel().await;
&params.user_text,
"",
)
.await;
let resp = build_response_object( let resp = build_response_object(
ResponseData { ResponseData {
id: response_id, id: response_id,
@@ -666,52 +691,61 @@ async fn handle_responses_sync(
); );
return Json(resp).into_response(); return Json(resp).into_response();
} }
} MitmEvent::ResponseComplete => {
if acc_text.is_empty() && acc_thinking.is_none() {
// Empty response — continue waiting
continue;
}
if acc_text.is_empty() && acc_thinking.is_some() {
// Thinking-only — LS needs to make a follow-up request.
// Reinstall channel and unblock gate.
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
state.mitm_store.set_channel(new_tx).await;
state.mitm_store.clear_request_in_flight();
let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx;
debug!(
"Responses sync: thinking-only — new channel for follow-up, thinking_len={}",
acc_thinking.as_ref().map(|t| t.len()).unwrap_or(0)
);
continue;
}
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None, &params.user_text, &acc_text,
).await;
state.mitm_store.drop_channel().await;
// Check for completed text response let mut output_items: Vec<serde_json::Value> = Vec::new();
if state.mitm_store.is_response_complete() { if let Some(ref t) = acc_thinking {
let text = state output_items.push(build_reasoning_output(t));
.mitm_store }
.take_response_text() let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
.await output_items.push(build_message_output(&msg_id, &acc_text));
.unwrap_or_default();
let thinking = state.mitm_store.take_thinking_text().await;
let (usage, _) = usage_from_poll(
&state.mitm_store,
&cascade_id,
&None,
&params.user_text,
&text,
)
.await;
let mut output_items: Vec<serde_json::Value> = Vec::new(); let resp = build_response_object(
if let Some(ref t) = thinking { ResponseData {
output_items.push(build_reasoning_output(t)); id: response_id,
model: model_name,
status: "completed",
created_at,
completed_at: Some(now_unix()),
output: output_items,
usage: Some(usage),
thinking_signature: None,
},
&params,
);
return Json(resp).into_response();
}
MitmEvent::UpstreamError(err) => {
state.mitm_store.drop_channel().await;
return upstream_err_response(&err);
} }
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
output_items.push(build_message_output(&msg_id, &text));
let resp = build_response_object(
ResponseData {
id: response_id,
model: model_name,
status: "completed",
created_at,
completed_at: Some(now_unix()),
output: output_items,
usage: Some(usage),
thinking_signature: None,
},
&params,
);
return Json(resp).into_response();
} }
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
} }
// Timeout — return proper error, not fake incomplete response // Timeout
state.mitm_store.drop_channel().await;
return err_response( return err_response(
StatusCode::GATEWAY_TIMEOUT, StatusCode::GATEWAY_TIMEOUT,
format!("Timeout: no response from Google API after {timeout}s"), format!("Timeout: no response from Google API after {timeout}s"),
@@ -835,6 +869,7 @@ async fn handle_responses_stream(
cascade_id: String, cascade_id: String,
timeout: u64, timeout: u64,
params: RequestParams, params: RequestParams,
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
) -> axum::response::Response { ) -> axum::response::Response {
let stream = async_stream::stream! { let stream = async_stream::stream! {
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
@@ -886,50 +921,170 @@ async fn handle_responses_stream(
let mut thinking_text: Option<String> = None; let mut thinking_text: Option<String> = None;
let mut message_started = false; let mut message_started = false;
let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', "")); let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
// Clear stale captured response and upstream errors // Clear stale response (only if no pre-installed channel)
state.mitm_store.clear_response_async().await; if mitm_rx.is_none() {
state.mitm_store.clear_upstream_error().await; state.mitm_store.clear_response_async().await;
state.mitm_store.clear_upstream_error().await;
}
// ── MITM bypass mode (when custom tools are active) ── // ── MITM bypass mode (when custom tools are active) ──
// Skip LS entirely — read text, thinking, and tool calls directly from MitmStore. // Channel-based pipeline: read events directly from MITM proxy.
if has_custom_tools { // Channel is pre-installed before send_message to avoid race conditions.
if let Some(mut rx) = mitm_rx {
let mut last_thinking = String::new(); let mut last_thinking = String::new();
while start.elapsed().as_secs() < timeout { while let Some(event) = tokio::time::timeout(
// Check for upstream errors from MITM (Google API errors) std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
if let Some(err) = state.mitm_store.take_upstream_error().await { rx.recv(),
let error_msg = super::util::upstream_error_message(&err); ).await.ok().flatten() {
let error_type = super::util::upstream_error_type(&err); use crate::mitm::store::MitmEvent;
yield Ok(responses_sse_event( match event {
"response.failed", MitmEvent::ThinkingDelta(full_thinking) => {
serde_json::json!({ if !thinking_emitted && full_thinking.len() > last_thinking.len() {
"type": "response.failed", // First thinking text — emit reasoning output_item.added
"sequence_number": next_seq(), if last_thinking.is_empty() {
"response": { yield Ok(responses_sse_event(
"id": &response_id, "response.output_item.added",
"status": "failed", serde_json::json!({
"error": { "type": "response.output_item.added",
"type": error_type, "sequence_number": next_seq(),
"message": error_msg, "output_index": 0,
"code": err.status, "item": {
}, "id": &reasoning_id,
}, "type": "reasoning",
}), "summary": [],
)); },
break; }),
} ));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.added",
serde_json::json!({
"type": "response.reasoning_summary_part.added",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": "" },
}),
));
}
let delta = &full_thinking[last_thinking.len()..];
if !delta.is_empty() {
yield Ok(responses_sse_event(
"response.reasoning_summary_text.delta",
serde_json::json!({
"type": "response.reasoning_summary_text.delta",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"delta": delta,
}),
));
}
last_thinking = full_thinking;
}
}
MitmEvent::TextDelta(full_text) => {
if full_text.len() > last_text.len() {
// Finalize thinking if started but not done
if !thinking_emitted && !last_thinking.is_empty() {
thinking_emitted = true;
thinking_text = Some(last_thinking.clone());
yield Ok(responses_sse_event(
"response.reasoning_summary_text.done",
serde_json::json!({
"type": "response.reasoning_summary_text.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"text": &last_thinking,
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.done",
serde_json::json!({
"type": "response.reasoning_summary_part.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": &last_thinking },
}),
));
yield Ok(responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [{
"type": "summary_text",
"text": &last_thinking,
}],
},
}),
));
}
// Check for function calls first let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
if let Some(ref raw_calls) = captured { if !message_started {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls { message_started = true;
raw_calls.iter().take(max as usize).collect() yield Ok(responses_sse_event(
} else { "response.output_item.added",
raw_calls.iter().collect() serde_json::json!({
}; "type": "response.output_item.added",
if !calls.is_empty() { "sequence_number": next_seq(),
"output_index": msg_output_index,
"item": build_message_output_in_progress(&msg_id),
}),
));
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
}
let delta = &full_text[last_text.len()..];
if !delta.is_empty() {
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"delta": delta,
}),
));
last_text = full_text;
}
}
}
MitmEvent::FunctionCall(raw_calls) => {
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
raw_calls.iter().take(max as usize).collect()
} else {
raw_calls.iter().collect()
};
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 }; let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
for (i, fc) in calls.iter().enumerate() { for (i, fc) in calls.iter().enumerate() {
let call_id = format!( let call_id = format!(
@@ -1011,194 +1166,71 @@ async fn handle_responses_stream(
"response": response_to_json(&final_resp), "response": response_to_json(&final_resp),
}), }),
)); ));
state.mitm_store.drop_channel().await;
return; return;
} }
} MitmEvent::ResponseComplete => {
if !last_text.is_empty() {
// Stream thinking text in real-time let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
if !thinking_emitted { let (usage, _) = usage_from_poll(
if let Some(thinking) = state.mitm_store.peek_thinking_text().await { &state.mitm_store, &cascade_id, &None,
if !thinking.is_empty() && thinking != last_thinking { &params.user_text, &last_text,
// First thinking text — emit reasoning output_item.added ).await;
if last_thinking.is_empty() { let tc = thinking_text.clone();
yield Ok(responses_sse_event( for evt in completion_events(
"response.output_item.added", &response_id, &model_name, &msg_id, &reasoning_id,
serde_json::json!({ msg_idx, CONTENT_IDX, &last_text, usage,
"type": "response.output_item.added", created_at, &seq, &params, None, tc,
"sequence_number": next_seq(), ) {
"output_index": 0, yield Ok(evt);
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [],
},
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.added",
serde_json::json!({
"type": "response.reasoning_summary_part.added",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": "" },
}),
));
} }
state.mitm_store.drop_channel().await;
// Delta of new thinking text return;
let delta = if thinking.len() > last_thinking.len() } else if !last_thinking.is_empty() {
&& thinking.starts_with(&*last_thinking) // Thinking-only response — LS needs follow-up API calls.
{ // Create a new channel and unblock the gate.
thinking[last_thinking.len()..].to_string() let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
} else { state.mitm_store.set_channel(new_tx).await;
thinking.clone() state.mitm_store.clear_request_in_flight();
}; let _ = state.mitm_store.take_any_function_calls().await;
rx = new_rx;
if !delta.is_empty() { debug!(
yield Ok(responses_sse_event( "Responses stream: thinking-only — new channel for follow-up, thinking_len={}",
"response.reasoning_summary_text.delta", last_thinking.len()
serde_json::json!({ );
"type": "response.reasoning_summary_text.delta",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"delta": &delta,
}),
));
}
last_thinking = thinking;
} }
// ResponseComplete with no text and no thinking — continue waiting
} }
} MitmEvent::UpstreamError(err) => {
let error_msg = super::util::upstream_error_message(&err);
// Stream response text let error_type = super::util::upstream_error_type(&err);
if let Some(text) = state.mitm_store.peek_response_text().await { yield Ok(responses_sse_event(
if !text.is_empty() && text != last_text { "response.failed",
// Finalize thinking if started but not done serde_json::json!({
if !thinking_emitted && !last_thinking.is_empty() { "type": "response.failed",
thinking_emitted = true; "sequence_number": next_seq(),
thinking_text = Some(last_thinking.clone()); "response": {
yield Ok(responses_sse_event( "id": &response_id,
"response.reasoning_summary_text.done", "status": "failed",
serde_json::json!({ "error": {
"type": "response.reasoning_summary_text.done", "type": error_type,
"sequence_number": next_seq(), "message": error_msg,
"item_id": &reasoning_id, "code": err.status,
"output_index": 0,
"summary_index": 0,
"text": &last_thinking,
}),
));
yield Ok(responses_sse_event(
"response.reasoning_summary_part.done",
serde_json::json!({
"type": "response.reasoning_summary_part.done",
"sequence_number": next_seq(),
"item_id": &reasoning_id,
"output_index": 0,
"summary_index": 0,
"part": { "type": "summary_text", "text": &last_thinking },
}),
));
yield Ok(responses_sse_event(
"response.output_item.done",
serde_json::json!({
"type": "response.output_item.done",
"sequence_number": next_seq(),
"output_index": 0,
"item": {
"id": &reasoning_id,
"type": "reasoning",
"summary": [{
"type": "summary_text",
"text": &last_thinking,
}],
}, },
}), },
)); }),
} ));
state.mitm_store.drop_channel().await;
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
if !message_started {
message_started = true;
yield Ok(responses_sse_event(
"response.output_item.added",
serde_json::json!({
"type": "response.output_item.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"item": build_message_output_in_progress(&msg_id),
}),
));
yield Ok(responses_sse_event(
"response.content_part.added",
serde_json::json!({
"type": "response.content_part.added",
"sequence_number": next_seq(),
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"part": {
"type": "output_text",
"text": "",
"annotations": [],
}
}),
));
}
let new_content = if text.len() > last_text.len()
&& text.starts_with(&*last_text)
{
text[last_text.len()..].to_string()
} else {
text.clone()
};
if !new_content.is_empty() {
yield Ok(responses_sse_event(
"response.output_text.delta",
serde_json::json!({
"type": "response.output_text.delta",
"sequence_number": next_seq(),
"item_id": &msg_id,
"output_index": msg_output_index,
"content_index": CONTENT_IDX,
"delta": &new_content,
}),
));
last_text = text;
}
}
// Check if response is complete
if state.mitm_store.is_response_complete() && !last_text.is_empty() {
let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
let (usage, _) = usage_from_poll(
&state.mitm_store, &cascade_id, &None,
&params.user_text, &last_text,
).await;
let tc = thinking_text.clone();
for evt in completion_events(
&response_id, &model_name, &msg_id, &reasoning_id,
msg_idx, CONTENT_IDX, &last_text, usage,
created_at, &seq, &params, None, tc,
) {
yield Ok(evt);
}
return; return;
} }
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {
// Usage/grounding stored by proxy, consumed via usage_from_poll
}
} }
// Poll interval
let poll_ms: u64 = rand::thread_rng().gen_range(150..300);
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
} }
// Timeout in bypass mode — emit error, not fake incomplete // Timeout in channel mode
state.mitm_store.drop_channel().await;
yield Ok(responses_sse_event( yield Ok(responses_sse_event(
"response.failed", "response.failed",
serde_json::json!({ serde_json::json!({

View File

@@ -163,11 +163,13 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
// Set active cascade for MITM correlation // Set active cascade for MITM correlation
state.mitm_store.set_active_cascade(&cascade_id).await; state.mitm_store.set_active_cascade(&cascade_id).await;
// Store real search prompt for MITM injection — LS gets a dummy prompt
state.mitm_store.set_pending_user_text(search_prompt.clone()).await;
// Send the search message // Send the search message
if let Err(e) = state if let Err(e) = state
.backend .backend
.send_message(&cascade_id, &search_prompt, model.model_enum) .send_message(&cascade_id, ".", model.model_enum)
.await .await
{ {
state.mitm_store.clear_active_cascade().await; state.mitm_store.clear_active_cascade().await;

View File

@@ -538,9 +538,10 @@ async fn handle_http_over_tls(
} }
}; };
// Generation tracking for store write guards // Channel-based event pipeline: grab the channel sender if one exists.
let mut won_gate = false; // If the API handler set up a channel, we send events through it.
let mut conn_generation = store.current_generation(); // Otherwise, we fall back to legacy store writes (search endpoint, etc.).
let mut event_tx: Option<tokio::sync::mpsc::Sender<super::store::MitmEvent>> = None;
// Log LLM calls at info, everything else at debug // Log LLM calls at info, everything else at debug
if req_path.contains("streamGenerateContent") { if req_path.contains("streamGenerateContent") {
@@ -558,7 +559,7 @@ async fn handle_http_over_tls(
// When custom tools are active, only the FIRST request wins the // When custom tools are active, only the FIRST request wins the
// atomic compare_exchange. All others get fake STOP responses. // atomic compare_exchange. All others get fake STOP responses.
let has_tools = store.get_tools().await.is_some(); let has_tools = store.get_tools().await.is_some();
won_gate = if has_tools { if has_tools {
if !store.try_mark_request_in_flight() { if !store.try_mark_request_in_flight() {
info!("MITM: blocking LS request — another request already in-flight"); info!("MITM: blocking LS request — another request already in-flight");
let fake_response = "HTTP/1.1 200 OK\r\n\ let fake_response = "HTTP/1.1 200 OK\r\n\
@@ -575,13 +576,11 @@ async fn handle_http_over_tls(
let _ = client.flush().await; let _ = client.flush().await;
continue; continue;
} }
true // Grab the channel sender — the API handler installed it before
} else { // sending the LS message. If it's gone, we still proceed but
false // fall back to legacy store writes.
}; event_tx = store.take_channel().await;
// Snapshot the generation at gate-win time. If it changes later, }
// another completions turn started and our data is stale.
conn_generation = store.current_generation();
// ── Request modification ───────────────────────────────────── // ── Request modification ─────────────────────────────────────
// Dechunk body → check if agent request → modify → rechunk // Dechunk body → check if agent request → modify → rechunk
@@ -603,12 +602,14 @@ async fn handle_http_over_tls(
let generation_params = store.get_generation_params().await; let generation_params = store.get_generation_params().await;
let pending_image = store.take_pending_image().await; let pending_image = store.take_pending_image().await;
let tool_rounds = store.get_tool_rounds().await; let tool_rounds = store.get_tool_rounds().await;
let pending_user_text = store.take_pending_user_text().await;
let tool_ctx = if tools.is_some() let tool_ctx = if tools.is_some()
|| !pending_results.is_empty() || !pending_results.is_empty()
|| !tool_rounds.is_empty() || !tool_rounds.is_empty()
|| generation_params.is_some() || generation_params.is_some()
|| pending_image.is_some() || pending_image.is_some()
|| pending_user_text.is_some()
{ {
Some(super::modify::ToolContext { Some(super::modify::ToolContext {
tools, tools,
@@ -618,6 +619,7 @@ async fn handle_http_over_tls(
generation_params, generation_params,
pending_image, pending_image,
tool_rounds, tool_rounds,
pending_user_text,
}) })
} else { } else {
None None
@@ -794,14 +796,18 @@ async fn handle_http_over_tls(
}) })
.unwrap_or((None, None)); .unwrap_or((None, None));
store let upstream_err = super::store::UpstreamError {
.set_upstream_error(super::store::UpstreamError { status: http_status,
status: http_status, body: body_str,
body: body_str, message,
message, error_status,
error_status, };
}) // Send through channel if available, otherwise store for legacy consumers
.await; if let Some(ref tx) = event_tx {
let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await;
} else {
store.set_upstream_error(upstream_err).await;
}
} }
// Save body for usage parsing // Save body for usage parsing
@@ -812,46 +818,59 @@ async fn handle_http_over_tls(
let body = String::from_utf8_lossy(&header_buf[hdr_end..]); let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
parse_streaming_chunk(&body, &mut streaming_acc); parse_streaming_chunk(&body, &mut streaming_acc);
// Only write to store if our generation is still current. // Send events through channel if available, otherwise use legacy store
// If another completions turn started, our data is stale. if let Some(ref tx) = event_tx {
let gen_valid = !won_gate || store.current_generation() == conn_generation; // Function calls → channel event
if gen_valid {
// Store captured function calls (drain to avoid re-storing on next chunk)
if !streaming_acc.function_calls.is_empty() { if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!( store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
"MITM: stored {} function call(s) from initial body", info!("MITM: sending {} function call(s) via channel (initial body)", calls.len());
calls.len() let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
} }
// Thinking delta → channel event
if !streaming_acc.thinking_text.is_empty() { if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await; let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
} }
// Text delta → channel event
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata { if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await; store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
} }
// Response complete → channel event
if streaming_acc.is_complete { if streaming_acc.is_complete {
info!( info!(
response_text_len = streaming_acc.response_text.len(), response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(), thinking_text_len = streaming_acc.thinking_text.len(),
"MITM: response complete (initial body) — marking store" "MITM: response complete (initial body) — sending via channel"
); );
store.mark_response_complete(); let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers (search, etc.)
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from initial body", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
} }
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (initial body)");
} }
} }
@@ -890,45 +909,60 @@ async fn handle_http_over_tls(
let s = String::from_utf8_lossy(chunk); let s = String::from_utf8_lossy(chunk);
parse_streaming_chunk(&s, &mut streaming_acc); parse_streaming_chunk(&s, &mut streaming_acc);
// Only write to store if our generation is still current. // Send events through channel if available, otherwise use legacy store
let gen_valid = !won_gate || store.current_generation() == conn_generation; if let Some(ref tx) = event_tx {
if gen_valid { // Function calls → channel event
// Store captured function calls (drain to avoid re-storing on next chunk)
if !streaming_acc.function_calls.is_empty() { if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect(); let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store
.record_function_call(cascade_hint.as_deref(), fc.clone())
.await;
}
store.set_last_function_calls(calls.clone()).await; store.set_last_function_calls(calls.clone()).await;
info!( store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
"MITM: stored {} function call(s) from body chunk", info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len());
calls.len() let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
);
}
// Capture response + thinking text + grounding into MitmStore
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
} }
// Thinking delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.thinking_text.is_empty() { if !streaming_acc.thinking_text.is_empty() {
store.set_thinking_text(&streaming_acc.thinking_text).await; let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
streaming_acc.thinking_text.clone(),
)).await;
} }
// Text delta → channel event (send accumulated, handler tracks last len)
if !streaming_acc.response_text.is_empty() {
let _ = tx.send(super::store::MitmEvent::TextDelta(
streaming_acc.response_text.clone(),
)).await;
}
// Grounding → channel event
if let Some(ref gm) = streaming_acc.grounding_metadata { if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await; store.set_grounding(gm.clone()).await;
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
} }
// Response complete → channel event
if streaming_acc.is_complete { if streaming_acc.is_complete {
info!( info!(
response_text_len = streaming_acc.response_text.len(), response_text_len = streaming_acc.response_text.len(),
thinking_text_len = streaming_acc.thinking_text.len(), thinking_text_len = streaming_acc.thinking_text.len(),
function_calls = streaming_acc.function_calls.len(), function_calls = streaming_acc.function_calls.len(),
"MITM: response complete — marking store" "MITM: response complete — sending via channel"
); );
store.mark_response_complete(); let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
streaming_acc.is_complete = false; // prevent duplicate sends
}
} else {
// Legacy path: store writes for non-channel consumers
if !streaming_acc.function_calls.is_empty() {
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
for fc in &calls {
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
}
store.set_last_function_calls(calls.clone()).await;
info!("MITM: stored {} function call(s) from body chunk", calls.len());
}
if !streaming_acc.response_text.is_empty() {
store.set_response_text(&streaming_acc.response_text).await;
}
if let Some(ref gm) = streaming_acc.grounding_metadata {
store.set_grounding(gm.clone()).await;
} }
} else if streaming_acc.is_complete {
debug!("MITM: skipping store write — generation stale (body chunk)");
} }
} }
@@ -969,9 +1003,11 @@ async fn handle_http_over_tls(
store.set_grounding(gm.clone()).await; store.set_grounding(gm.clone()).await;
} }
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 { if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
// Function calls are stored immediately when detected (above),
// so no need to store them again here.
let usage = streaming_acc.into_usage(); let usage = streaming_acc.into_usage();
// Send usage through channel if available
if let Some(ref tx) = event_tx {
let _ = tx.send(super::store::MitmEvent::Usage(usage.clone())).await;
}
store.record_usage(cascade_hint.as_deref(), usage).await; store.record_usage(cascade_hint.as_deref(), usage).await;
} }
} else if !response_body_buf.is_empty() { } else if !response_body_buf.is_empty() {

View File

@@ -1,12 +1,14 @@
//! Shared store for intercepted API usage data. //! Shared store for intercepted API usage data.
//! //!
//! The MITM proxy writes usage data here; the API handlers read from it. //! The MITM proxy writes usage data here; the API handlers read from it.
//! When custom tools are active, the MITM proxy sends real-time events
//! through a channel instead of writing to shared state.
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::RwLock; use tokio::sync::{mpsc, RwLock};
use tracing::{debug, info}; use tracing::{debug, info};
/// Token usage from an intercepted API response. /// Token usage from an intercepted API response.
@@ -126,6 +128,29 @@ pub struct GenerationParams {
pub google_search: bool, pub google_search: bool,
} }
// ─── Channel-based event pipeline ────────────────────────────────────────────
/// Events sent from the MITM proxy to API handlers through a per-request channel.
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
/// instant, race-free delivery.
#[derive(Debug, Clone)]
pub enum MitmEvent {
/// Incremental thinking/reasoning text from the model.
ThinkingDelta(String),
/// Incremental response text from the model.
TextDelta(String),
/// Model requested function call(s).
FunctionCall(Vec<CapturedFunctionCall>),
/// Response streaming is complete (finishReason received).
ResponseComplete,
/// Google API returned an error.
UpstreamError(UpstreamError),
/// Grounding metadata (search results) from the response.
Grounding(serde_json::Value),
/// Token usage data from the response.
Usage(ApiUsage),
}
/// Thread-safe store for intercepted data. /// Thread-safe store for intercepted data.
/// ///
/// Keyed by a unique request ID that we can correlate with cascade operations. /// Keyed by a unique request ID that we can correlate with cascade operations.
@@ -138,20 +163,17 @@ pub struct MitmStore {
stats: Arc<RwLock<MitmStats>>, stats: Arc<RwLock<MitmStats>>,
/// Pending function calls captured from Google responses. /// Pending function calls captured from Google responses.
/// Key: cascade hint or "_latest". Value: list of function calls. /// Key: cascade hint or "_latest". Value: list of function calls.
/// Used by the non-tool LS path (normal sync responses).
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>, pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
/// Simple flag: set when a functionCall is captured, cleared when consumed.
/// Used to block follow-up requests regardless of cascade identification.
has_active_function_call: Arc<AtomicBool>,
/// Persistent flag: set when a function call is captured, cleared ONLY when
/// a tool result is submitted. Prevents the LS from making follow-up API
/// calls during the entire tool execution cycle.
awaiting_tool_result: Arc<AtomicBool>,
/// Set when the MITM forwards the first LLM request with custom tools. /// Set when the MITM forwards the first LLM request with custom tools.
/// Blocks ALL subsequent LS requests until the API handler clears it. /// Blocks ALL subsequent LS requests until the API handler clears it.
request_in_flight: Arc<AtomicBool>, request_in_flight: Arc<AtomicBool>,
/// Generation counter — incremented each time a new completions turn starts.
/// Used to discard stale data from leaked LS connections. // ── Channel-based event pipeline (replaces old polling) ──────────────
request_generation: Arc<AtomicU64>, /// Active channel sender for the current tool-path request.
/// When present, the MITM proxy sends events through this instead of
/// writing to shared state. The channel's existence = request in-flight.
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
// ── Tool call support ──────────────────────────────────────────────── // ── Tool call support ────────────────────────────────────────────────
/// Active tool definitions (Gemini format) for MITM injection. /// Active tool definitions (Gemini format) for MITM injection.
@@ -173,14 +195,9 @@ pub struct MitmStore {
/// Used by the MITM proxy to correlate intercepted traffic to cascades. /// Used by the MITM proxy to correlate intercepted traffic to cascades.
active_cascade_id: Arc<RwLock<Option<String>>>, active_cascade_id: Arc<RwLock<Option<String>>>,
// ── Direct response capture (bypasses LS) ──────────────────────────── // ── Legacy direct response capture (used by search.rs) ───────────────
/// Captured response text from MITM when custom tools are active. /// Captured response text from MITM. Used as fallback by search endpoint.
/// The completions/responses handler reads this instead of polling LS steps.
captured_response_text: Arc<RwLock<Option<String>>>, captured_response_text: Arc<RwLock<Option<String>>>,
/// Captured thinking/reasoning text from MITM (for real-time streaming).
captured_thinking_text: Arc<RwLock<Option<String>>>,
/// Whether the captured response is complete (finishReason received).
response_complete: Arc<AtomicBool>,
// ── Generation parameters for MITM injection ───────────────────────── // ── Generation parameters for MITM injection ─────────────────────────
/// Client-specified sampling parameters to inject into Google API requests. /// Client-specified sampling parameters to inject into Google API requests.
@@ -194,9 +211,14 @@ pub struct MitmStore {
/// Image to inject into the next Google API request via MITM. /// Image to inject into the next Google API request via MITM.
pending_image: Arc<RwLock<Option<PendingImage>>>, pending_image: Arc<RwLock<Option<PendingImage>>>,
// ── Upstream error capture ─────────────────────────────────────────── // ── Upstream error capture (legacy, used when no channel) ────────────
/// Error from Google's API, captured by MITM for forwarding to client. /// Error from Google's API, captured by MITM for forwarding to client.
upstream_error: Arc<RwLock<Option<UpstreamError>>>, upstream_error: Arc<RwLock<Option<UpstreamError>>>,
// ── Standard LS input: real user text for MITM injection ─────────────
/// The real user text to inject into the Google API request.
/// API handlers store this before sending a dummy prompt to the LS.
pending_user_text: Arc<RwLock<Option<String>>>,
} }
/// Aggregate statistics across all intercepted traffic. /// Aggregate statistics across all intercepted traffic.
@@ -229,10 +251,8 @@ impl MitmStore {
latest_usage: Arc::new(RwLock::new(HashMap::new())), latest_usage: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(RwLock::new(MitmStats::default())), stats: Arc::new(RwLock::new(MitmStats::default())),
pending_function_calls: Arc::new(RwLock::new(HashMap::new())), pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
has_active_function_call: Arc::new(AtomicBool::new(false)),
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
request_in_flight: Arc::new(AtomicBool::new(false)), request_in_flight: Arc::new(AtomicBool::new(false)),
request_generation: Arc::new(AtomicU64::new(0)), active_channel: Arc::new(RwLock::new(None)),
active_tools: Arc::new(RwLock::new(None)), active_tools: Arc::new(RwLock::new(None)),
active_tool_config: Arc::new(RwLock::new(None)), active_tool_config: Arc::new(RwLock::new(None)),
pending_tool_results: Arc::new(RwLock::new(Vec::new())), pending_tool_results: Arc::new(RwLock::new(Vec::new())),
@@ -241,12 +261,11 @@ impl MitmStore {
tool_rounds: Arc::new(RwLock::new(Vec::new())), tool_rounds: Arc::new(RwLock::new(Vec::new())),
active_cascade_id: Arc::new(RwLock::new(None)), active_cascade_id: Arc::new(RwLock::new(None)),
captured_response_text: Arc::new(RwLock::new(None)), captured_response_text: Arc::new(RwLock::new(None)),
captured_thinking_text: Arc::new(RwLock::new(None)),
response_complete: Arc::new(AtomicBool::new(false)),
generation_params: Arc::new(RwLock::new(None)), generation_params: Arc::new(RwLock::new(None)),
captured_grounding: Arc::new(RwLock::new(None)), captured_grounding: Arc::new(RwLock::new(None)),
pending_image: Arc::new(RwLock::new(None)), pending_image: Arc::new(RwLock::new(None)),
upstream_error: Arc::new(RwLock::new(None)), upstream_error: Arc::new(RwLock::new(None)),
pending_user_text: Arc::new(RwLock::new(None)),
} }
} }
@@ -381,30 +400,6 @@ impl MitmStore {
); );
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
pending.entry(key).or_default().push(fc); pending.entry(key).or_default().push(fc);
self.has_active_function_call.store(true, Ordering::SeqCst);
self.awaiting_tool_result.store(true, Ordering::SeqCst);
}
/// Check if there's an active (unclaimed) function call.
pub fn has_active_function_call(&self) -> bool {
self.has_active_function_call.load(Ordering::SeqCst)
}
/// Force-clear the active function call flag (used to reset stale state).
pub fn clear_active_function_call(&self) {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
/// This persists across function call consumption — only cleared when
/// actual tool results are submitted.
pub fn is_awaiting_tool_result(&self) -> bool {
self.awaiting_tool_result.load(Ordering::SeqCst)
}
/// Clear the awaiting-tool-result flag (called when tool results arrive).
pub fn clear_awaiting_tool_result(&self) {
self.awaiting_tool_result.store(false, Ordering::SeqCst);
} }
/// Take pending function calls for a specific cascade. /// Take pending function calls for a specific cascade.
@@ -417,7 +412,6 @@ impl MitmStore {
// 1. Exact cascade match // 1. Exact cascade match
if let Some(result) = pending.remove(cascade_id) { if let Some(result) = pending.remove(cascade_id) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result); return Some(result);
} }
@@ -425,7 +419,6 @@ impl MitmStore {
if let Some(active) = self.active_cascade_id.read().await.as_ref() { if let Some(active) = self.active_cascade_id.read().await.as_ref() {
if active != cascade_id { if active != cascade_id {
if let Some(result) = pending.remove(active.as_str()) { if let Some(result) = pending.remove(active.as_str()) {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result); return Some(result);
} }
} }
@@ -433,17 +426,12 @@ impl MitmStore {
// 3. Fallback to _latest // 3. Fallback to _latest
if let Some(result) = pending.remove("_latest") { if let Some(result) = pending.remove("_latest") {
self.has_active_function_call.store(false, Ordering::SeqCst);
return Some(result); return Some(result);
} }
// 4. Last resort: any key // 4. Last resort: any key
if let Some(key) = pending.keys().next().cloned() { if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key); return pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
} }
None None
@@ -455,19 +443,40 @@ impl MitmStore {
let mut pending = self.pending_function_calls.write().await; let mut pending = self.pending_function_calls.write().await;
let result = pending.remove("_latest"); let result = pending.remove("_latest");
if result.is_some() { if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
return result; return result;
} }
if let Some(key) = pending.keys().next().cloned() { if let Some(key) = pending.keys().next().cloned() {
let result = pending.remove(&key); return pending.remove(&key);
if result.is_some() {
self.has_active_function_call.store(false, Ordering::SeqCst);
}
return result;
} }
None None
} }
// ── Channel-based event pipeline ─────────────────────────────────────
/// Install a channel sender for the current tool-path request.
/// The MITM proxy will send events through this channel.
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
*self.active_channel.write().await = Some(tx);
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
// try_mark_request_in_flight() is the sole setter — setting it
// here causes compare_exchange(false,true) to always fail,
// blocking every real LS request.
}
/// Take the active channel sender (used by MITM proxy to grab it).
/// Returns None if no channel is active.
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
self.active_channel.write().await.take()
}
/// Drop the active channel and clear in-flight state.
/// Called when the API handler is done with the current request.
pub async fn drop_channel(&self) {
*self.active_channel.write().await = None;
self.request_in_flight.store(false, Ordering::SeqCst);
}
// ── Tool context methods ───────────────────────────────────────────── // ── Tool context methods ─────────────────────────────────────────────
/// Set active tool definitions (already in Gemini format). /// Set active tool definitions (already in Gemini format).
@@ -546,10 +555,10 @@ impl MitmStore {
self.tool_rounds.read().await.clone() self.tool_rounds.read().await.clone()
} }
// ── Legacy direct response capture (search.rs fallback) ──────────────
// ── Direct response capture (bypass LS) ──────────────────────────────
/// Set (replace) the captured response text. /// Set (replace) the captured response text.
/// Used by MITM proxy for non-channel path (search endpoint fallback).
pub async fn set_response_text(&self, text: &str) { pub async fn set_response_text(&self, text: &str) {
*self.captured_response_text.write().await = Some(text.to_string()); *self.captured_response_text.write().await = Some(text.to_string());
} }
@@ -559,28 +568,12 @@ impl MitmStore {
self.captured_response_text.write().await.take() self.captured_response_text.write().await.take()
} }
/// Peek at the captured response text without consuming it. /// Clear stale state between requests.
pub async fn peek_response_text(&self) -> Option<String> { /// Drops any active channel and clears in-flight flags.
self.captured_response_text.read().await.clone()
}
/// Mark the response as complete.
pub fn mark_response_complete(&self) {
self.response_complete.store(true, Ordering::SeqCst);
}
/// Check if the response is complete.
pub fn is_response_complete(&self) -> bool {
self.response_complete.load(Ordering::SeqCst)
}
/// Async version of clear_response. Bumps generation counter.
pub async fn clear_response_async(&self) { pub async fn clear_response_async(&self) {
self.response_complete.store(false, Ordering::SeqCst);
self.request_in_flight.store(false, Ordering::SeqCst); self.request_in_flight.store(false, Ordering::SeqCst);
self.request_generation.fetch_add(1, Ordering::SeqCst); *self.active_channel.write().await = None;
*self.captured_response_text.write().await = None; *self.captured_response_text.write().await = None;
*self.captured_thinking_text.write().await = None;
} }
/// Atomically try to mark request as in-flight. /// Atomically try to mark request as in-flight.
@@ -593,6 +586,7 @@ impl MitmStore {
} }
/// Check if a request is currently in-flight. /// Check if a request is currently in-flight.
#[allow(dead_code)]
pub fn is_request_in_flight(&self) -> bool { pub fn is_request_in_flight(&self) -> bool {
self.request_in_flight.load(Ordering::SeqCst) self.request_in_flight.load(Ordering::SeqCst)
} }
@@ -602,38 +596,6 @@ impl MitmStore {
self.request_in_flight.store(false, Ordering::SeqCst); self.request_in_flight.store(false, Ordering::SeqCst);
} }
/// Reset response_complete so we can wait for the next response.
pub fn clear_response_complete(&self) {
self.response_complete.store(false, Ordering::SeqCst);
}
/// Get current generation number.
pub fn current_generation(&self) -> u64 {
self.request_generation.load(Ordering::SeqCst)
}
/// Bump generation counter (invalidates all pending data from old generation).
pub fn bump_generation(&self) -> u64 {
self.request_generation.fetch_add(1, Ordering::SeqCst) + 1
}
// ── Thinking text capture ────────────────────────────────────────────
/// Set (replace) the captured thinking text.
pub async fn set_thinking_text(&self, text: &str) {
*self.captured_thinking_text.write().await = Some(text.to_string());
}
/// Peek at the captured thinking text without consuming it.
pub async fn peek_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.read().await.clone()
}
/// Take the captured thinking text (consumes it).
pub async fn take_thinking_text(&self) -> Option<String> {
self.captured_thinking_text.write().await.take()
}
// ── Cascade correlation ────────────────────────────────────────────── // ── Cascade correlation ──────────────────────────────────────────────
/// Set the active cascade ID (called by API handlers before sending a message). /// Set the active cascade ID (called by API handlers before sending a message).
@@ -718,4 +680,18 @@ impl MitmStore {
pub async fn clear_upstream_error(&self) { pub async fn clear_upstream_error(&self) {
*self.upstream_error.write().await = None; *self.upstream_error.write().await = None;
} }
// ── Pending user text for MITM injection ─────────────────────────────
/// Store the real user text for MITM injection.
/// Called by API handlers before sending a dummy prompt to the LS.
pub async fn set_pending_user_text(&self, text: String) {
*self.pending_user_text.write().await = Some(text);
}
/// Take (consume) the pending user text.
/// Called by the MITM proxy when building ToolContext.
pub async fn take_pending_user_text(&self) -> Option<String> {
self.pending_user_text.write().await.take()
}
} }