refactor: endpoint parity and proxy improvements
Mixed changes from recent sessions: endpoint feature parity improvements, proxy bug fixes, and store cleanup.
This commit is contained in:
@@ -223,7 +223,6 @@ pub(crate) async fn handle_completions(
|
|||||||
} else {
|
} else {
|
||||||
state.mitm_store.clear_tools().await;
|
state.mitm_store.clear_tools().await;
|
||||||
}
|
}
|
||||||
state.mitm_store.clear_active_function_call();
|
|
||||||
|
|
||||||
// ── Extract tool results from messages for MITM injection ──────────
|
// ── Extract tool results from messages for MITM injection ──────────
|
||||||
// When OpenCode sends back tool results, the messages array contains:
|
// When OpenCode sends back tool results, the messages array contains:
|
||||||
@@ -340,7 +339,6 @@ pub(crate) async fn handle_completions(
|
|||||||
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
|
state.mitm_store.set_last_function_calls(last_round.calls.clone()).await;
|
||||||
}
|
}
|
||||||
state.mitm_store.set_tool_rounds(rounds).await;
|
state.mitm_store.set_tool_rounds(rounds).await;
|
||||||
state.mitm_store.clear_awaiting_tool_result();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -441,6 +439,8 @@ pub(crate) async fn handle_completions(
|
|||||||
|
|
||||||
// Send message on primary cascade
|
// Send message on primary cascade
|
||||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||||
|
// Store real user text for MITM injection — LS gets a dummy prompt
|
||||||
|
state.mitm_store.set_pending_user_text(user_text.clone()).await;
|
||||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||||
if let Some(ref img) = image {
|
if let Some(ref img) = image {
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
@@ -452,9 +452,25 @@ pub(crate) async fn handle_completions(
|
|||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-flight: install channel BEFORE send_message so the MITM proxy
|
||||||
|
// can grab it when the LS fires its API call.
|
||||||
|
// Only for streaming — sync paths use poll_for_response (legacy store).
|
||||||
|
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||||
|
let mitm_rx = if has_custom_tools && body.stream {
|
||||||
|
state.mitm_store.clear_response_async().await;
|
||||||
|
state.mitm_store.clear_upstream_error().await;
|
||||||
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||||
|
state.mitm_store.set_channel(tx).await;
|
||||||
|
Some(rx)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref())
|
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok((200, _)) => {
|
Ok((200, _)) => {
|
||||||
@@ -465,6 +481,7 @@ pub(crate) async fn handle_completions(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
Ok((status, _)) => {
|
Ok((status, _)) => {
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return err_response(
|
return err_response(
|
||||||
StatusCode::BAD_GATEWAY,
|
StatusCode::BAD_GATEWAY,
|
||||||
format!("Backend returned {status}"),
|
format!("Backend returned {status}"),
|
||||||
@@ -472,6 +489,7 @@ pub(crate) async fn handle_completions(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return err_response(
|
return err_response(
|
||||||
StatusCode::BAD_GATEWAY,
|
StatusCode::BAD_GATEWAY,
|
||||||
format!("Send failed: {e}"),
|
format!("Send failed: {e}"),
|
||||||
@@ -498,6 +516,7 @@ pub(crate) async fn handle_completions(
|
|||||||
cascade_id,
|
cascade_id,
|
||||||
body.timeout,
|
body.timeout,
|
||||||
include_usage,
|
include_usage,
|
||||||
|
mitm_rx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
} else if n <= 1 {
|
} else if n <= 1 {
|
||||||
@@ -518,7 +537,7 @@ pub(crate) async fn handle_completions(
|
|||||||
// Send the same message on each extra cascade
|
// Send the same message on each extra cascade
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
.send_message_with_image(&cid, &user_text, model.model_enum, image.as_ref())
|
.send_message_with_image(&cid, ".", model.model_enum, image.as_ref())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok((200, _)) => {
|
Ok((200, _)) => {
|
||||||
@@ -635,18 +654,18 @@ async fn chat_completions_stream(
|
|||||||
cascade_id: String,
|
cascade_id: String,
|
||||||
timeout: u64,
|
timeout: u64,
|
||||||
include_usage: bool,
|
include_usage: bool,
|
||||||
|
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let mut last_text = String::new();
|
let mut last_text = String::new();
|
||||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
let has_custom_tools = mitm_rx.is_some();
|
||||||
|
|
||||||
// Clear ALL stale state from previous requests
|
if !has_custom_tools {
|
||||||
state.mitm_store.clear_response_async().await;
|
state.mitm_store.clear_response_async().await;
|
||||||
state.mitm_store.clear_upstream_error().await;
|
state.mitm_store.clear_upstream_error().await;
|
||||||
state.mitm_store.clear_active_function_call();
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
// Drain any stale function calls from previous requests
|
}
|
||||||
let _ = state.mitm_store.take_any_function_calls().await;
|
|
||||||
|
|
||||||
// Initial role chunk
|
// Initial role chunk
|
||||||
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
|
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
|
||||||
@@ -659,7 +678,6 @@ async fn chat_completions_stream(
|
|||||||
let mut last_thinking_len: usize = 0;
|
let mut last_thinking_len: usize = 0;
|
||||||
let mut complete_polls: u32 = 0;
|
let mut complete_polls: u32 = 0;
|
||||||
let mut did_unblock_ls = false; // Prevents infinite unblock loops
|
let mut did_unblock_ls = false; // Prevents infinite unblock loops
|
||||||
let mut my_generation = state.mitm_store.current_generation();
|
|
||||||
|
|
||||||
// Helper: build usage JSON from MITM tokens
|
// Helper: build usage JSON from MITM tokens
|
||||||
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
|
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
|
||||||
@@ -672,212 +690,219 @@ async fn chat_completions_stream(
|
|||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
while start.elapsed().as_secs() < timeout {
|
// Take ownership of the pre-installed channel receiver
|
||||||
// Check for upstream errors from MITM (Google API errors)
|
let mut rx_opt = mitm_rx;
|
||||||
if let Some(err) = state.mitm_store.take_upstream_error().await {
|
|
||||||
let error_msg = super::util::upstream_error_message(&err);
|
|
||||||
let error_type = super::util::upstream_error_type(&err);
|
|
||||||
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
|
||||||
"error": {
|
|
||||||
"message": error_msg,
|
|
||||||
"type": error_type,
|
|
||||||
"code": err.status,
|
|
||||||
}
|
|
||||||
})).unwrap()));
|
|
||||||
yield Ok(Event::default().data("[DONE]".to_string()));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bail if another completions handler has superseded us
|
while start.elapsed().as_secs() < timeout {
|
||||||
if state.mitm_store.current_generation() != my_generation {
|
if let Some(ref mut rx) = rx_opt {
|
||||||
debug!("Completions: generation changed (superseded), ending stream");
|
// ── Channel-based MITM pipeline ──
|
||||||
|
|
||||||
|
// Track accumulated text for delta computation
|
||||||
|
let mut acc_text = String::new();
|
||||||
|
let mut acc_thinking = String::new();
|
||||||
|
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
|
||||||
|
|
||||||
|
'channel_loop: while let Some(event) = tokio::time::timeout(
|
||||||
|
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
|
||||||
|
rx.recv(),
|
||||||
|
).await.ok().flatten() {
|
||||||
|
use crate::mitm::store::MitmEvent;
|
||||||
|
match event {
|
||||||
|
MitmEvent::ThinkingDelta(full_thinking) => {
|
||||||
|
if full_thinking.len() > acc_thinking.len() {
|
||||||
|
let delta = full_thinking[acc_thinking.len()..].to_string();
|
||||||
|
acc_thinking = full_thinking;
|
||||||
|
last_thinking_len = acc_thinking.len();
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MitmEvent::TextDelta(full_text) => {
|
||||||
|
if full_text.len() > acc_text.len() {
|
||||||
|
let delta = full_text[acc_text.len()..].to_string();
|
||||||
|
acc_text = full_text;
|
||||||
|
last_text = acc_text.clone();
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MitmEvent::FunctionCall(calls) => {
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
for (i, fc) in calls.iter().enumerate() {
|
||||||
|
let call_id = format!(
|
||||||
|
"call_{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
|
);
|
||||||
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
|
tool_calls.push(serde_json::json!({
|
||||||
|
"index": i,
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": fc.name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
if include_usage {
|
||||||
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
||||||
|
.or(state.mitm_store.take_usage("_latest").await);
|
||||||
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
||||||
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
||||||
|
} else if let Some(ref u) = last_usage {
|
||||||
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
||||||
|
} else { (0, 0, 0, 0) };
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([]),
|
||||||
|
Some(build_usage(pt, ct, crt, tt)),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MitmEvent::ResponseComplete => {
|
||||||
|
if !acc_text.is_empty() {
|
||||||
|
// Have response text — done
|
||||||
|
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
|
||||||
|
acc_text.len(), acc_thinking.len());
|
||||||
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
||||||
|
.or(state.mitm_store.take_usage("_latest").await)
|
||||||
|
.or(last_usage.take());
|
||||||
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
if include_usage {
|
||||||
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
||||||
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
||||||
|
} else { (0, 0, 0, 0) };
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([]),
|
||||||
|
Some(build_usage(pt, ct, crt, tt)),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return;
|
||||||
|
} else if !acc_thinking.is_empty() && !did_unblock_ls {
|
||||||
|
// Thinking-only response — LS needs follow-up API calls.
|
||||||
|
// Create a new channel and unblock the gate.
|
||||||
|
did_unblock_ls = true;
|
||||||
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||||
|
state.mitm_store.set_channel(new_tx).await;
|
||||||
|
state.mitm_store.clear_request_in_flight();
|
||||||
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
|
*rx = new_rx;
|
||||||
|
debug!(
|
||||||
|
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
|
||||||
|
acc_thinking.len()
|
||||||
|
);
|
||||||
|
continue 'channel_loop;
|
||||||
|
} else if !acc_thinking.is_empty() && did_unblock_ls {
|
||||||
|
// Already unblocked once, still thinking-only.
|
||||||
|
// Wait a bit for potential follow-up events.
|
||||||
|
complete_polls += 1;
|
||||||
|
if complete_polls >= 25 {
|
||||||
|
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
|
||||||
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
||||||
|
.or(state.mitm_store.take_usage("_latest").await)
|
||||||
|
.or(last_usage.take());
|
||||||
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
if include_usage {
|
||||||
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
||||||
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
||||||
|
} else { (0, 0, 0, 0) };
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([]),
|
||||||
|
Some(build_usage(pt, ct, crt, tt)),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Don't break — wait for more channel events
|
||||||
|
continue 'channel_loop;
|
||||||
|
} else {
|
||||||
|
// Empty response (no text, no thinking, no tools)
|
||||||
|
complete_polls += 1;
|
||||||
|
if complete_polls >= 4 {
|
||||||
|
info!("Completions: channel response complete but empty, ending stream");
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
continue 'channel_loop;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MitmEvent::UpstreamError(err) => {
|
||||||
|
let error_msg = super::util::upstream_error_message(&err);
|
||||||
|
let error_type = super::util::upstream_error_type(&err);
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"error": {
|
||||||
|
"message": error_msg,
|
||||||
|
"type": error_type,
|
||||||
|
"code": err.status,
|
||||||
|
}
|
||||||
|
})).unwrap()));
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MitmEvent::Usage(u) => {
|
||||||
|
last_usage = Some(u);
|
||||||
|
}
|
||||||
|
MitmEvent::Grounding(_) => {
|
||||||
|
// Grounding metadata handled by store directly
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Channel closed or timeout — clean up
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
|
||||||
|
// If we got here from timeout with content, emit what we have
|
||||||
|
if !last_text.is_empty() || last_thinking_len > 0 {
|
||||||
|
yield Ok(Event::default().data(chunk_json(
|
||||||
|
&completion_id, &model_name,
|
||||||
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
|
||||||
|
None,
|
||||||
|
)));
|
||||||
|
}
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
return;
|
return;
|
||||||
}
|
|
||||||
|
|
||||||
// ── Check for MITM-captured function calls FIRST ──
|
|
||||||
// This runs independently of LS steps — the MITM captures tool calls
|
|
||||||
// at the proxy layer, so we don't need to wait for LS processing.
|
|
||||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
|
||||||
if let Some(ref calls) = captured {
|
|
||||||
if !calls.is_empty() {
|
|
||||||
let mut tool_calls = Vec::new();
|
|
||||||
for (i, fc) in calls.iter().enumerate() {
|
|
||||||
let call_id = format!(
|
|
||||||
"call_{}",
|
|
||||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
|
||||||
);
|
|
||||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
|
||||||
tool_calls.push(serde_json::json!({
|
|
||||||
"index": i,
|
|
||||||
"id": call_id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": fc.name,
|
|
||||||
"arguments": arguments,
|
|
||||||
},
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
if include_usage {
|
|
||||||
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
||||||
.or(state.mitm_store.take_usage("_latest").await);
|
|
||||||
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
||||||
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
||||||
} else { (0, 0, 0, 0) };
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([]),
|
|
||||||
Some(build_usage(pt, ct, crt, tt)),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
|
||||||
// Clear in-flight flag so the next turn's requests can get through
|
|
||||||
state.mitm_store.clear_response_async().await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Primary: MITM-captured response (when custom tools are active) ──
|
|
||||||
// The MITM intercepts the real Google SSE stream and captures text,
|
|
||||||
// thinking, and function calls. This is the authoritative data source.
|
|
||||||
// The LS only gets rewritten responses (function calls → text placeholders)
|
|
||||||
// so it doesn't provide useful streaming data when MITM is active.
|
|
||||||
if has_custom_tools {
|
|
||||||
// Stream thinking text as reasoning_content deltas
|
|
||||||
if let Some(tc) = state.mitm_store.peek_thinking_text().await {
|
|
||||||
if tc.len() > last_thinking_len {
|
|
||||||
let delta = &tc[last_thinking_len..];
|
|
||||||
last_thinking_len = tc.len();
|
|
||||||
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stream response text as content deltas
|
|
||||||
if let Some(text) = state.mitm_store.peek_response_text().await {
|
|
||||||
if !text.is_empty() && text != last_text {
|
|
||||||
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
|
|
||||||
text[last_text.len()..].to_string()
|
|
||||||
} else {
|
|
||||||
text.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
if !delta.is_empty() {
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
last_text = text;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if MITM response is complete
|
|
||||||
if state.mitm_store.is_response_complete() {
|
|
||||||
if !last_text.is_empty() {
|
|
||||||
// Have actual response text — done
|
|
||||||
complete_polls += 1;
|
|
||||||
if complete_polls >= 2 {
|
|
||||||
debug!("Completions: MITM response complete, text_len={}, thinking_len={}", last_text.len(), last_thinking_len);
|
|
||||||
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
||||||
.or(state.mitm_store.take_usage("_latest").await);
|
|
||||||
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
if include_usage {
|
|
||||||
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
||||||
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
||||||
} else { (0, 0, 0, 0) };
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([]),
|
|
||||||
Some(build_usage(pt, ct, crt, tt)),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else if last_thinking_len > 0 && !did_unblock_ls {
|
|
||||||
// Thinking-only response. The LS needs follow-up API calls
|
|
||||||
// to get actual function calls or text. Unblock once.
|
|
||||||
did_unblock_ls = true;
|
|
||||||
complete_polls = 0;
|
|
||||||
// Bump generation FIRST — invalidates old MITM connection's store writes
|
|
||||||
my_generation = state.mitm_store.bump_generation();
|
|
||||||
state.mitm_store.clear_request_in_flight();
|
|
||||||
state.mitm_store.clear_response_complete();
|
|
||||||
// Drain store so leaked connections can't produce stale content
|
|
||||||
state.mitm_store.set_response_text("").await;
|
|
||||||
state.mitm_store.set_thinking_text("").await;
|
|
||||||
let _ = state.mitm_store.take_any_function_calls().await;
|
|
||||||
debug!(
|
|
||||||
"Completions: thinking-only — unblocking LS for follow-up, thinking_len={}, new_gen={}",
|
|
||||||
last_thinking_len, my_generation
|
|
||||||
);
|
|
||||||
} else if last_thinking_len > 0 && did_unblock_ls {
|
|
||||||
// Already unblocked once. Still only thinking after follow-up.
|
|
||||||
complete_polls += 1;
|
|
||||||
if complete_polls >= 25 {
|
|
||||||
info!("Completions: thinking-only timeout after ~10s, thinking_len={}", last_thinking_len);
|
|
||||||
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
||||||
.or(state.mitm_store.take_usage("_latest").await);
|
|
||||||
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
if include_usage {
|
|
||||||
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
||||||
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
||||||
} else { (0, 0, 0, 0) };
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([]),
|
|
||||||
Some(build_usage(pt, ct, crt, tt)),
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// response_complete but no text AND no thinking — might be
|
|
||||||
// a function-call-only response that was already consumed,
|
|
||||||
// or empty response. Wait a bit then give up.
|
|
||||||
complete_polls += 1;
|
|
||||||
if complete_polls >= 4 {
|
|
||||||
info!("Completions: MITM response complete but no content (text/thinking/tools all empty), ending stream");
|
|
||||||
yield Ok(Event::default().data(chunk_json(
|
|
||||||
&completion_id, &model_name,
|
|
||||||
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
|
|
||||||
None,
|
|
||||||
)));
|
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
complete_polls = 0; // Reset — not complete yet
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// ── Fallback: LS steps (no MITM capture active) ──
|
// ── Fallback: LS steps (no MITM capture active) ──
|
||||||
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
||||||
@@ -1001,7 +1026,7 @@ async fn chat_completions_stream(
|
|||||||
}
|
}
|
||||||
})).unwrap()));
|
})).unwrap()));
|
||||||
// Always clear in-flight flag when stream ends
|
// Always clear in-flight flag when stream ends
|
||||||
state.mitm_store.clear_response_async().await;
|
state.mitm_store.drop_channel().await;
|
||||||
yield Ok(Event::default().data("[DONE]"));
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -309,7 +309,11 @@ pub(crate) async fn handle_responses(
|
|||||||
count = tools.len(),
|
count = tools.len(),
|
||||||
"Stored client tools for MITM injection"
|
"Stored client tools for MITM injection"
|
||||||
);
|
);
|
||||||
|
} else {
|
||||||
|
state.mitm_store.clear_tools().await;
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
state.mitm_store.clear_tools().await;
|
||||||
}
|
}
|
||||||
if let Some(ref choice) = body.tool_choice {
|
if let Some(ref choice) = body.tool_choice {
|
||||||
let gemini_config = openai_tool_choice_to_gemini(choice);
|
let gemini_config = openai_tool_choice_to_gemini(choice);
|
||||||
@@ -404,6 +408,8 @@ pub(crate) async fn handle_responses(
|
|||||||
|
|
||||||
// Send message
|
// Send message
|
||||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||||
|
// Store real user text for MITM injection — LS gets a dummy prompt
|
||||||
|
state.mitm_store.set_pending_user_text(user_text.clone()).await;
|
||||||
// Store image for MITM injection (LS doesn't forward images to Google API)
|
// Store image for MITM injection (LS doesn't forward images to Google API)
|
||||||
if let Some(ref img) = image {
|
if let Some(ref img) = image {
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
@@ -415,9 +421,24 @@ pub(crate) async fn handle_responses(
|
|||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Pre-flight: install channel BEFORE send_message so the MITM proxy
|
||||||
|
// can grab it when the LS fires its API call.
|
||||||
|
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
||||||
|
let mitm_rx = if has_custom_tools {
|
||||||
|
state.mitm_store.clear_response_async().await;
|
||||||
|
state.mitm_store.clear_upstream_error().await;
|
||||||
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
||||||
|
state.mitm_store.set_channel(tx).await;
|
||||||
|
Some(rx)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
match state
|
match state
|
||||||
.backend
|
.backend
|
||||||
.send_message_with_image(&cascade_id, &user_text, model.model_enum, image.as_ref())
|
.send_message_with_image(&cascade_id, ".", model.model_enum, image.as_ref())
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok((200, _)) => {
|
Ok((200, _)) => {
|
||||||
@@ -428,6 +449,7 @@ pub(crate) async fn handle_responses(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
Ok((status, _)) => {
|
Ok((status, _)) => {
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return err_response(
|
return err_response(
|
||||||
StatusCode::BAD_GATEWAY,
|
StatusCode::BAD_GATEWAY,
|
||||||
format!("Antigravity returned {status}"),
|
format!("Antigravity returned {status}"),
|
||||||
@@ -435,6 +457,7 @@ pub(crate) async fn handle_responses(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return err_response(
|
return err_response(
|
||||||
StatusCode::BAD_GATEWAY,
|
StatusCode::BAD_GATEWAY,
|
||||||
format!("Send message failed: {e}"),
|
format!("Send message failed: {e}"),
|
||||||
@@ -472,6 +495,7 @@ pub(crate) async fn handle_responses(
|
|||||||
cascade_id,
|
cascade_id,
|
||||||
body.timeout,
|
body.timeout,
|
||||||
req_params,
|
req_params,
|
||||||
|
mitm_rx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
@@ -482,6 +506,7 @@ pub(crate) async fn handle_responses(
|
|||||||
cascade_id,
|
cascade_id,
|
||||||
body.timeout,
|
body.timeout,
|
||||||
req_params,
|
req_params,
|
||||||
|
mitm_rx,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
@@ -603,54 +628,54 @@ async fn handle_responses_sync(
|
|||||||
cascade_id: String,
|
cascade_id: String,
|
||||||
timeout: u64,
|
timeout: u64,
|
||||||
params: RequestParams,
|
params: RequestParams,
|
||||||
|
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
let created_at = now_unix();
|
let created_at = now_unix();
|
||||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
|
||||||
|
|
||||||
// Clear stale captured response and upstream errors
|
// Clear stale captured response and upstream errors (only if no pre-installed channel)
|
||||||
state.mitm_store.clear_response_async().await;
|
if mitm_rx.is_none() {
|
||||||
state.mitm_store.clear_upstream_error().await;
|
state.mitm_store.clear_response_async().await;
|
||||||
|
state.mitm_store.clear_upstream_error().await;
|
||||||
|
}
|
||||||
|
|
||||||
// ── MITM bypass: poll MitmStore directly when custom tools active ──
|
// ── MITM bypass: channel-based pipeline when custom tools active ──
|
||||||
if has_custom_tools {
|
if let Some(mut rx) = mitm_rx {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
while start.elapsed().as_secs() < timeout {
|
|
||||||
// Check for upstream errors from MITM (Google API errors)
|
|
||||||
if let Some(err) = state.mitm_store.take_upstream_error().await {
|
|
||||||
return upstream_err_response(&err);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for function calls
|
let mut acc_text = String::new();
|
||||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
let mut acc_thinking: Option<String> = None;
|
||||||
if let Some(ref raw_calls) = captured {
|
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
|
||||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
|
||||||
raw_calls.iter().take(max as usize).collect()
|
while let Some(event) = tokio::time::timeout(
|
||||||
} else {
|
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
|
||||||
raw_calls.iter().collect()
|
rx.recv(),
|
||||||
};
|
).await.ok().flatten() {
|
||||||
if !calls.is_empty() {
|
use crate::mitm::store::MitmEvent;
|
||||||
|
match event {
|
||||||
|
MitmEvent::ThinkingDelta(t) => { acc_thinking = Some(t); }
|
||||||
|
MitmEvent::TextDelta(t) => { acc_text = t; }
|
||||||
|
MitmEvent::Usage(u) => { last_usage = Some(u); }
|
||||||
|
MitmEvent::Grounding(_) => {} // stored by proxy directly
|
||||||
|
MitmEvent::FunctionCall(raw_calls) => {
|
||||||
|
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||||
|
raw_calls.iter().take(max as usize).collect()
|
||||||
|
} else {
|
||||||
|
raw_calls.iter().collect()
|
||||||
|
};
|
||||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||||
for fc in &calls {
|
for fc in &calls {
|
||||||
let call_id = format!(
|
let call_id = format!(
|
||||||
"call_{}",
|
"call_{}",
|
||||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
);
|
);
|
||||||
state
|
state.mitm_store.register_call_id(call_id.clone(), fc.name.clone()).await;
|
||||||
.mitm_store
|
|
||||||
.register_call_id(call_id.clone(), fc.name.clone())
|
|
||||||
.await;
|
|
||||||
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
output_items
|
output_items.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
||||||
.push(build_function_call_output(&call_id, &fc.name, &arguments));
|
|
||||||
}
|
}
|
||||||
let (usage, _) = usage_from_poll(
|
let (usage, _) = usage_from_poll(
|
||||||
&state.mitm_store,
|
&state.mitm_store, &cascade_id, &None, ¶ms.user_text, "",
|
||||||
&cascade_id,
|
).await;
|
||||||
&None,
|
state.mitm_store.drop_channel().await;
|
||||||
¶ms.user_text,
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
let resp = build_response_object(
|
let resp = build_response_object(
|
||||||
ResponseData {
|
ResponseData {
|
||||||
id: response_id,
|
id: response_id,
|
||||||
@@ -666,52 +691,61 @@ async fn handle_responses_sync(
|
|||||||
);
|
);
|
||||||
return Json(resp).into_response();
|
return Json(resp).into_response();
|
||||||
}
|
}
|
||||||
}
|
MitmEvent::ResponseComplete => {
|
||||||
|
if acc_text.is_empty() && acc_thinking.is_none() {
|
||||||
|
// Empty response — continue waiting
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if acc_text.is_empty() && acc_thinking.is_some() {
|
||||||
|
// Thinking-only — LS needs to make a follow-up request.
|
||||||
|
// Reinstall channel and unblock gate.
|
||||||
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||||
|
state.mitm_store.set_channel(new_tx).await;
|
||||||
|
state.mitm_store.clear_request_in_flight();
|
||||||
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
|
rx = new_rx;
|
||||||
|
debug!(
|
||||||
|
"Responses sync: thinking-only — new channel for follow-up, thinking_len={}",
|
||||||
|
acc_thinking.as_ref().map(|t| t.len()).unwrap_or(0)
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let (usage, _) = usage_from_poll(
|
||||||
|
&state.mitm_store, &cascade_id, &None, ¶ms.user_text, &acc_text,
|
||||||
|
).await;
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
|
||||||
// Check for completed text response
|
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||||
if state.mitm_store.is_response_complete() {
|
if let Some(ref t) = acc_thinking {
|
||||||
let text = state
|
output_items.push(build_reasoning_output(t));
|
||||||
.mitm_store
|
}
|
||||||
.take_response_text()
|
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
.await
|
output_items.push(build_message_output(&msg_id, &acc_text));
|
||||||
.unwrap_or_default();
|
|
||||||
let thinking = state.mitm_store.take_thinking_text().await;
|
|
||||||
let (usage, _) = usage_from_poll(
|
|
||||||
&state.mitm_store,
|
|
||||||
&cascade_id,
|
|
||||||
&None,
|
|
||||||
¶ms.user_text,
|
|
||||||
&text,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
let resp = build_response_object(
|
||||||
if let Some(ref t) = thinking {
|
ResponseData {
|
||||||
output_items.push(build_reasoning_output(t));
|
id: response_id,
|
||||||
|
model: model_name,
|
||||||
|
status: "completed",
|
||||||
|
created_at,
|
||||||
|
completed_at: Some(now_unix()),
|
||||||
|
output: output_items,
|
||||||
|
usage: Some(usage),
|
||||||
|
thinking_signature: None,
|
||||||
|
},
|
||||||
|
¶ms,
|
||||||
|
);
|
||||||
|
return Json(resp).into_response();
|
||||||
|
}
|
||||||
|
MitmEvent::UpstreamError(err) => {
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
|
return upstream_err_response(&err);
|
||||||
}
|
}
|
||||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
|
||||||
output_items.push(build_message_output(&msg_id, &text));
|
|
||||||
|
|
||||||
let resp = build_response_object(
|
|
||||||
ResponseData {
|
|
||||||
id: response_id,
|
|
||||||
model: model_name,
|
|
||||||
status: "completed",
|
|
||||||
created_at,
|
|
||||||
completed_at: Some(now_unix()),
|
|
||||||
output: output_items,
|
|
||||||
usage: Some(usage),
|
|
||||||
thinking_signature: None,
|
|
||||||
},
|
|
||||||
¶ms,
|
|
||||||
);
|
|
||||||
return Json(resp).into_response();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout — return proper error, not fake incomplete response
|
// Timeout
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return err_response(
|
return err_response(
|
||||||
StatusCode::GATEWAY_TIMEOUT,
|
StatusCode::GATEWAY_TIMEOUT,
|
||||||
format!("Timeout: no response from Google API after {timeout}s"),
|
format!("Timeout: no response from Google API after {timeout}s"),
|
||||||
@@ -835,6 +869,7 @@ async fn handle_responses_stream(
|
|||||||
cascade_id: String,
|
cascade_id: String,
|
||||||
timeout: u64,
|
timeout: u64,
|
||||||
params: RequestParams,
|
params: RequestParams,
|
||||||
|
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
let stream = async_stream::stream! {
|
let stream = async_stream::stream! {
|
||||||
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
let msg_id = format!("msg_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
@@ -886,50 +921,170 @@ async fn handle_responses_stream(
|
|||||||
let mut thinking_text: Option<String> = None;
|
let mut thinking_text: Option<String> = None;
|
||||||
let mut message_started = false;
|
let mut message_started = false;
|
||||||
let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
let reasoning_id = format!("rs_{}", uuid::Uuid::new_v4().to_string().replace('-', ""));
|
||||||
let has_custom_tools = state.mitm_store.get_tools().await.is_some();
|
|
||||||
|
|
||||||
// Clear stale captured response and upstream errors
|
// Clear stale response (only if no pre-installed channel)
|
||||||
state.mitm_store.clear_response_async().await;
|
if mitm_rx.is_none() {
|
||||||
state.mitm_store.clear_upstream_error().await;
|
state.mitm_store.clear_response_async().await;
|
||||||
|
state.mitm_store.clear_upstream_error().await;
|
||||||
|
}
|
||||||
|
|
||||||
// ── MITM bypass mode (when custom tools are active) ──
|
// ── MITM bypass mode (when custom tools are active) ──
|
||||||
// Skip LS entirely — read text, thinking, and tool calls directly from MitmStore.
|
// Channel-based pipeline: read events directly from MITM proxy.
|
||||||
if has_custom_tools {
|
// Channel is pre-installed before send_message to avoid race conditions.
|
||||||
|
if let Some(mut rx) = mitm_rx {
|
||||||
let mut last_thinking = String::new();
|
let mut last_thinking = String::new();
|
||||||
|
|
||||||
while start.elapsed().as_secs() < timeout {
|
while let Some(event) = tokio::time::timeout(
|
||||||
// Check for upstream errors from MITM (Google API errors)
|
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
|
||||||
if let Some(err) = state.mitm_store.take_upstream_error().await {
|
rx.recv(),
|
||||||
let error_msg = super::util::upstream_error_message(&err);
|
).await.ok().flatten() {
|
||||||
let error_type = super::util::upstream_error_type(&err);
|
use crate::mitm::store::MitmEvent;
|
||||||
yield Ok(responses_sse_event(
|
match event {
|
||||||
"response.failed",
|
MitmEvent::ThinkingDelta(full_thinking) => {
|
||||||
serde_json::json!({
|
if !thinking_emitted && full_thinking.len() > last_thinking.len() {
|
||||||
"type": "response.failed",
|
// First thinking text — emit reasoning output_item.added
|
||||||
"sequence_number": next_seq(),
|
if last_thinking.is_empty() {
|
||||||
"response": {
|
yield Ok(responses_sse_event(
|
||||||
"id": &response_id,
|
"response.output_item.added",
|
||||||
"status": "failed",
|
serde_json::json!({
|
||||||
"error": {
|
"type": "response.output_item.added",
|
||||||
"type": error_type,
|
"sequence_number": next_seq(),
|
||||||
"message": error_msg,
|
"output_index": 0,
|
||||||
"code": err.status,
|
"item": {
|
||||||
},
|
"id": &reasoning_id,
|
||||||
},
|
"type": "reasoning",
|
||||||
}),
|
"summary": [],
|
||||||
));
|
},
|
||||||
break;
|
}),
|
||||||
}
|
));
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.reasoning_summary_part.added",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.reasoning_summary_part.added",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &reasoning_id,
|
||||||
|
"output_index": 0,
|
||||||
|
"summary_index": 0,
|
||||||
|
"part": { "type": "summary_text", "text": "" },
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
let delta = &full_thinking[last_thinking.len()..];
|
||||||
|
if !delta.is_empty() {
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.reasoning_summary_text.delta",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.reasoning_summary_text.delta",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &reasoning_id,
|
||||||
|
"output_index": 0,
|
||||||
|
"summary_index": 0,
|
||||||
|
"delta": delta,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
last_thinking = full_thinking;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MitmEvent::TextDelta(full_text) => {
|
||||||
|
if full_text.len() > last_text.len() {
|
||||||
|
// Finalize thinking if started but not done
|
||||||
|
if !thinking_emitted && !last_thinking.is_empty() {
|
||||||
|
thinking_emitted = true;
|
||||||
|
thinking_text = Some(last_thinking.clone());
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.reasoning_summary_text.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.reasoning_summary_text.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &reasoning_id,
|
||||||
|
"output_index": 0,
|
||||||
|
"summary_index": 0,
|
||||||
|
"text": &last_thinking,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.reasoning_summary_part.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.reasoning_summary_part.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &reasoning_id,
|
||||||
|
"output_index": 0,
|
||||||
|
"summary_index": 0,
|
||||||
|
"part": { "type": "summary_text", "text": &last_thinking },
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.output_item.done",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_item.done",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": 0,
|
||||||
|
"item": {
|
||||||
|
"id": &reasoning_id,
|
||||||
|
"type": "reasoning",
|
||||||
|
"summary": [{
|
||||||
|
"type": "summary_text",
|
||||||
|
"text": &last_thinking,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// Check for function calls first
|
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
||||||
let captured = state.mitm_store.take_function_calls(&cascade_id).await;
|
|
||||||
if let Some(ref raw_calls) = captured {
|
if !message_started {
|
||||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
message_started = true;
|
||||||
raw_calls.iter().take(max as usize).collect()
|
yield Ok(responses_sse_event(
|
||||||
} else {
|
"response.output_item.added",
|
||||||
raw_calls.iter().collect()
|
serde_json::json!({
|
||||||
};
|
"type": "response.output_item.added",
|
||||||
if !calls.is_empty() {
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": msg_output_index,
|
||||||
|
"item": build_message_output_in_progress(&msg_id),
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.content_part.added",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.content_part.added",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"output_index": msg_output_index,
|
||||||
|
"content_index": CONTENT_IDX,
|
||||||
|
"part": {
|
||||||
|
"type": "output_text",
|
||||||
|
"text": "",
|
||||||
|
"annotations": [],
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let delta = &full_text[last_text.len()..];
|
||||||
|
if !delta.is_empty() {
|
||||||
|
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
||||||
|
yield Ok(responses_sse_event(
|
||||||
|
"response.output_text.delta",
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "response.output_text.delta",
|
||||||
|
"sequence_number": next_seq(),
|
||||||
|
"item_id": &msg_id,
|
||||||
|
"output_index": msg_output_index,
|
||||||
|
"content_index": CONTENT_IDX,
|
||||||
|
"delta": delta,
|
||||||
|
}),
|
||||||
|
));
|
||||||
|
last_text = full_text;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MitmEvent::FunctionCall(raw_calls) => {
|
||||||
|
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||||
|
raw_calls.iter().take(max as usize).collect()
|
||||||
|
} else {
|
||||||
|
raw_calls.iter().collect()
|
||||||
|
};
|
||||||
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
||||||
for (i, fc) in calls.iter().enumerate() {
|
for (i, fc) in calls.iter().enumerate() {
|
||||||
let call_id = format!(
|
let call_id = format!(
|
||||||
@@ -1011,194 +1166,71 @@ async fn handle_responses_stream(
|
|||||||
"response": response_to_json(&final_resp),
|
"response": response_to_json(&final_resp),
|
||||||
}),
|
}),
|
||||||
));
|
));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
}
|
MitmEvent::ResponseComplete => {
|
||||||
|
if !last_text.is_empty() {
|
||||||
// Stream thinking text in real-time
|
let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
|
||||||
if !thinking_emitted {
|
let (usage, _) = usage_from_poll(
|
||||||
if let Some(thinking) = state.mitm_store.peek_thinking_text().await {
|
&state.mitm_store, &cascade_id, &None,
|
||||||
if !thinking.is_empty() && thinking != last_thinking {
|
¶ms.user_text, &last_text,
|
||||||
// First thinking text — emit reasoning output_item.added
|
).await;
|
||||||
if last_thinking.is_empty() {
|
let tc = thinking_text.clone();
|
||||||
yield Ok(responses_sse_event(
|
for evt in completion_events(
|
||||||
"response.output_item.added",
|
&response_id, &model_name, &msg_id, &reasoning_id,
|
||||||
serde_json::json!({
|
msg_idx, CONTENT_IDX, &last_text, usage,
|
||||||
"type": "response.output_item.added",
|
created_at, &seq, ¶ms, None, tc,
|
||||||
"sequence_number": next_seq(),
|
) {
|
||||||
"output_index": 0,
|
yield Ok(evt);
|
||||||
"item": {
|
|
||||||
"id": &reasoning_id,
|
|
||||||
"type": "reasoning",
|
|
||||||
"summary": [],
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.reasoning_summary_part.added",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.reasoning_summary_part.added",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"item_id": &reasoning_id,
|
|
||||||
"output_index": 0,
|
|
||||||
"summary_index": 0,
|
|
||||||
"part": { "type": "summary_text", "text": "" },
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
// Delta of new thinking text
|
return;
|
||||||
let delta = if thinking.len() > last_thinking.len()
|
} else if !last_thinking.is_empty() {
|
||||||
&& thinking.starts_with(&*last_thinking)
|
// Thinking-only response — LS needs follow-up API calls.
|
||||||
{
|
// Create a new channel and unblock the gate.
|
||||||
thinking[last_thinking.len()..].to_string()
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
||||||
} else {
|
state.mitm_store.set_channel(new_tx).await;
|
||||||
thinking.clone()
|
state.mitm_store.clear_request_in_flight();
|
||||||
};
|
let _ = state.mitm_store.take_any_function_calls().await;
|
||||||
|
rx = new_rx;
|
||||||
if !delta.is_empty() {
|
debug!(
|
||||||
yield Ok(responses_sse_event(
|
"Responses stream: thinking-only — new channel for follow-up, thinking_len={}",
|
||||||
"response.reasoning_summary_text.delta",
|
last_thinking.len()
|
||||||
serde_json::json!({
|
);
|
||||||
"type": "response.reasoning_summary_text.delta",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"item_id": &reasoning_id,
|
|
||||||
"output_index": 0,
|
|
||||||
"summary_index": 0,
|
|
||||||
"delta": &delta,
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
last_thinking = thinking;
|
|
||||||
}
|
}
|
||||||
|
// ResponseComplete with no text and no thinking — continue waiting
|
||||||
}
|
}
|
||||||
}
|
MitmEvent::UpstreamError(err) => {
|
||||||
|
let error_msg = super::util::upstream_error_message(&err);
|
||||||
// Stream response text
|
let error_type = super::util::upstream_error_type(&err);
|
||||||
if let Some(text) = state.mitm_store.peek_response_text().await {
|
yield Ok(responses_sse_event(
|
||||||
if !text.is_empty() && text != last_text {
|
"response.failed",
|
||||||
// Finalize thinking if started but not done
|
serde_json::json!({
|
||||||
if !thinking_emitted && !last_thinking.is_empty() {
|
"type": "response.failed",
|
||||||
thinking_emitted = true;
|
"sequence_number": next_seq(),
|
||||||
thinking_text = Some(last_thinking.clone());
|
"response": {
|
||||||
yield Ok(responses_sse_event(
|
"id": &response_id,
|
||||||
"response.reasoning_summary_text.done",
|
"status": "failed",
|
||||||
serde_json::json!({
|
"error": {
|
||||||
"type": "response.reasoning_summary_text.done",
|
"type": error_type,
|
||||||
"sequence_number": next_seq(),
|
"message": error_msg,
|
||||||
"item_id": &reasoning_id,
|
"code": err.status,
|
||||||
"output_index": 0,
|
|
||||||
"summary_index": 0,
|
|
||||||
"text": &last_thinking,
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.reasoning_summary_part.done",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.reasoning_summary_part.done",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"item_id": &reasoning_id,
|
|
||||||
"output_index": 0,
|
|
||||||
"summary_index": 0,
|
|
||||||
"part": { "type": "summary_text", "text": &last_thinking },
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.output_item.done",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.output_item.done",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"output_index": 0,
|
|
||||||
"item": {
|
|
||||||
"id": &reasoning_id,
|
|
||||||
"type": "reasoning",
|
|
||||||
"summary": [{
|
|
||||||
"type": "summary_text",
|
|
||||||
"text": &last_thinking,
|
|
||||||
}],
|
|
||||||
},
|
},
|
||||||
}),
|
},
|
||||||
));
|
}),
|
||||||
}
|
));
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
|
||||||
|
|
||||||
if !message_started {
|
|
||||||
message_started = true;
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.output_item.added",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.output_item.added",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"output_index": msg_output_index,
|
|
||||||
"item": build_message_output_in_progress(&msg_id),
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.content_part.added",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.content_part.added",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"output_index": msg_output_index,
|
|
||||||
"content_index": CONTENT_IDX,
|
|
||||||
"part": {
|
|
||||||
"type": "output_text",
|
|
||||||
"text": "",
|
|
||||||
"annotations": [],
|
|
||||||
}
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
let new_content = if text.len() > last_text.len()
|
|
||||||
&& text.starts_with(&*last_text)
|
|
||||||
{
|
|
||||||
text[last_text.len()..].to_string()
|
|
||||||
} else {
|
|
||||||
text.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
if !new_content.is_empty() {
|
|
||||||
yield Ok(responses_sse_event(
|
|
||||||
"response.output_text.delta",
|
|
||||||
serde_json::json!({
|
|
||||||
"type": "response.output_text.delta",
|
|
||||||
"sequence_number": next_seq(),
|
|
||||||
"item_id": &msg_id,
|
|
||||||
"output_index": msg_output_index,
|
|
||||||
"content_index": CONTENT_IDX,
|
|
||||||
"delta": &new_content,
|
|
||||||
}),
|
|
||||||
));
|
|
||||||
last_text = text;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if response is complete
|
|
||||||
if state.mitm_store.is_response_complete() && !last_text.is_empty() {
|
|
||||||
let msg_idx: u32 = if thinking_emitted { 1 } else { 0 };
|
|
||||||
let (usage, _) = usage_from_poll(
|
|
||||||
&state.mitm_store, &cascade_id, &None,
|
|
||||||
¶ms.user_text, &last_text,
|
|
||||||
).await;
|
|
||||||
let tc = thinking_text.clone();
|
|
||||||
for evt in completion_events(
|
|
||||||
&response_id, &model_name, &msg_id, &reasoning_id,
|
|
||||||
msg_idx, CONTENT_IDX, &last_text, usage,
|
|
||||||
created_at, &seq, ¶ms, None, tc,
|
|
||||||
) {
|
|
||||||
yield Ok(evt);
|
|
||||||
}
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
MitmEvent::Usage(_) | MitmEvent::Grounding(_) => {
|
||||||
|
// Usage/grounding stored by proxy, consumed via usage_from_poll
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Poll interval
|
|
||||||
let poll_ms: u64 = rand::thread_rng().gen_range(150..300);
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout in bypass mode — emit error, not fake incomplete
|
// Timeout in channel mode
|
||||||
|
state.mitm_store.drop_channel().await;
|
||||||
yield Ok(responses_sse_event(
|
yield Ok(responses_sse_event(
|
||||||
"response.failed",
|
"response.failed",
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
|
|||||||
@@ -163,11 +163,13 @@ async fn do_search(state: Arc<AppState>, body: SearchRequest) -> axum::response:
|
|||||||
|
|
||||||
// Set active cascade for MITM correlation
|
// Set active cascade for MITM correlation
|
||||||
state.mitm_store.set_active_cascade(&cascade_id).await;
|
state.mitm_store.set_active_cascade(&cascade_id).await;
|
||||||
|
// Store real search prompt for MITM injection — LS gets a dummy prompt
|
||||||
|
state.mitm_store.set_pending_user_text(search_prompt.clone()).await;
|
||||||
|
|
||||||
// Send the search message
|
// Send the search message
|
||||||
if let Err(e) = state
|
if let Err(e) = state
|
||||||
.backend
|
.backend
|
||||||
.send_message(&cascade_id, &search_prompt, model.model_enum)
|
.send_message(&cascade_id, ".", model.model_enum)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
state.mitm_store.clear_active_cascade().await;
|
state.mitm_store.clear_active_cascade().await;
|
||||||
|
|||||||
@@ -538,9 +538,10 @@ async fn handle_http_over_tls(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generation tracking for store write guards
|
// Channel-based event pipeline: grab the channel sender if one exists.
|
||||||
let mut won_gate = false;
|
// If the API handler set up a channel, we send events through it.
|
||||||
let mut conn_generation = store.current_generation();
|
// Otherwise, we fall back to legacy store writes (search endpoint, etc.).
|
||||||
|
let mut event_tx: Option<tokio::sync::mpsc::Sender<super::store::MitmEvent>> = None;
|
||||||
|
|
||||||
// Log LLM calls at info, everything else at debug
|
// Log LLM calls at info, everything else at debug
|
||||||
if req_path.contains("streamGenerateContent") {
|
if req_path.contains("streamGenerateContent") {
|
||||||
@@ -558,7 +559,7 @@ async fn handle_http_over_tls(
|
|||||||
// When custom tools are active, only the FIRST request wins the
|
// When custom tools are active, only the FIRST request wins the
|
||||||
// atomic compare_exchange. All others get fake STOP responses.
|
// atomic compare_exchange. All others get fake STOP responses.
|
||||||
let has_tools = store.get_tools().await.is_some();
|
let has_tools = store.get_tools().await.is_some();
|
||||||
won_gate = if has_tools {
|
if has_tools {
|
||||||
if !store.try_mark_request_in_flight() {
|
if !store.try_mark_request_in_flight() {
|
||||||
info!("MITM: blocking LS request — another request already in-flight");
|
info!("MITM: blocking LS request — another request already in-flight");
|
||||||
let fake_response = "HTTP/1.1 200 OK\r\n\
|
let fake_response = "HTTP/1.1 200 OK\r\n\
|
||||||
@@ -575,13 +576,11 @@ async fn handle_http_over_tls(
|
|||||||
let _ = client.flush().await;
|
let _ = client.flush().await;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
true
|
// Grab the channel sender — the API handler installed it before
|
||||||
} else {
|
// sending the LS message. If it's gone, we still proceed but
|
||||||
false
|
// fall back to legacy store writes.
|
||||||
};
|
event_tx = store.take_channel().await;
|
||||||
// Snapshot the generation at gate-win time. If it changes later,
|
}
|
||||||
// another completions turn started and our data is stale.
|
|
||||||
conn_generation = store.current_generation();
|
|
||||||
|
|
||||||
// ── Request modification ─────────────────────────────────────
|
// ── Request modification ─────────────────────────────────────
|
||||||
// Dechunk body → check if agent request → modify → rechunk
|
// Dechunk body → check if agent request → modify → rechunk
|
||||||
@@ -603,12 +602,14 @@ async fn handle_http_over_tls(
|
|||||||
let generation_params = store.get_generation_params().await;
|
let generation_params = store.get_generation_params().await;
|
||||||
let pending_image = store.take_pending_image().await;
|
let pending_image = store.take_pending_image().await;
|
||||||
let tool_rounds = store.get_tool_rounds().await;
|
let tool_rounds = store.get_tool_rounds().await;
|
||||||
|
let pending_user_text = store.take_pending_user_text().await;
|
||||||
|
|
||||||
let tool_ctx = if tools.is_some()
|
let tool_ctx = if tools.is_some()
|
||||||
|| !pending_results.is_empty()
|
|| !pending_results.is_empty()
|
||||||
|| !tool_rounds.is_empty()
|
|| !tool_rounds.is_empty()
|
||||||
|| generation_params.is_some()
|
|| generation_params.is_some()
|
||||||
|| pending_image.is_some()
|
|| pending_image.is_some()
|
||||||
|
|| pending_user_text.is_some()
|
||||||
{
|
{
|
||||||
Some(super::modify::ToolContext {
|
Some(super::modify::ToolContext {
|
||||||
tools,
|
tools,
|
||||||
@@ -618,6 +619,7 @@ async fn handle_http_over_tls(
|
|||||||
generation_params,
|
generation_params,
|
||||||
pending_image,
|
pending_image,
|
||||||
tool_rounds,
|
tool_rounds,
|
||||||
|
pending_user_text,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -794,14 +796,18 @@ async fn handle_http_over_tls(
|
|||||||
})
|
})
|
||||||
.unwrap_or((None, None));
|
.unwrap_or((None, None));
|
||||||
|
|
||||||
store
|
let upstream_err = super::store::UpstreamError {
|
||||||
.set_upstream_error(super::store::UpstreamError {
|
status: http_status,
|
||||||
status: http_status,
|
body: body_str,
|
||||||
body: body_str,
|
message,
|
||||||
message,
|
error_status,
|
||||||
error_status,
|
};
|
||||||
})
|
// Send through channel if available, otherwise store for legacy consumers
|
||||||
.await;
|
if let Some(ref tx) = event_tx {
|
||||||
|
let _ = tx.send(super::store::MitmEvent::UpstreamError(upstream_err)).await;
|
||||||
|
} else {
|
||||||
|
store.set_upstream_error(upstream_err).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save body for usage parsing
|
// Save body for usage parsing
|
||||||
@@ -812,46 +818,59 @@ async fn handle_http_over_tls(
|
|||||||
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
let body = String::from_utf8_lossy(&header_buf[hdr_end..]);
|
||||||
parse_streaming_chunk(&body, &mut streaming_acc);
|
parse_streaming_chunk(&body, &mut streaming_acc);
|
||||||
|
|
||||||
// Only write to store if our generation is still current.
|
// Send events through channel if available, otherwise use legacy store
|
||||||
// If another completions turn started, our data is stale.
|
if let Some(ref tx) = event_tx {
|
||||||
let gen_valid = !won_gate || store.current_generation() == conn_generation;
|
// Function calls → channel event
|
||||||
if gen_valid {
|
|
||||||
// Store captured function calls (drain to avoid re-storing on next chunk)
|
|
||||||
if !streaming_acc.function_calls.is_empty() {
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
let calls: Vec<_> =
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
streaming_acc.function_calls.drain(..).collect();
|
|
||||||
for fc in &calls {
|
|
||||||
store
|
|
||||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
store.set_last_function_calls(calls.clone()).await;
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
info!(
|
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
|
||||||
"MITM: stored {} function call(s) from initial body",
|
info!("MITM: sending {} function call(s) via channel (initial body)", calls.len());
|
||||||
calls.len()
|
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Capture response + thinking text + grounding into MitmStore
|
|
||||||
if !streaming_acc.response_text.is_empty() {
|
|
||||||
store.set_response_text(&streaming_acc.response_text).await;
|
|
||||||
}
|
}
|
||||||
|
// Thinking delta → channel event
|
||||||
if !streaming_acc.thinking_text.is_empty() {
|
if !streaming_acc.thinking_text.is_empty() {
|
||||||
store.set_thinking_text(&streaming_acc.thinking_text).await;
|
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
|
||||||
|
streaming_acc.thinking_text.clone(),
|
||||||
|
)).await;
|
||||||
}
|
}
|
||||||
|
// Text delta → channel event
|
||||||
|
if !streaming_acc.response_text.is_empty() {
|
||||||
|
let _ = tx.send(super::store::MitmEvent::TextDelta(
|
||||||
|
streaming_acc.response_text.clone(),
|
||||||
|
)).await;
|
||||||
|
}
|
||||||
|
// Grounding → channel event
|
||||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||||
store.set_grounding(gm.clone()).await;
|
store.set_grounding(gm.clone()).await;
|
||||||
|
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
|
||||||
}
|
}
|
||||||
|
// Response complete → channel event
|
||||||
if streaming_acc.is_complete {
|
if streaming_acc.is_complete {
|
||||||
info!(
|
info!(
|
||||||
response_text_len = streaming_acc.response_text.len(),
|
response_text_len = streaming_acc.response_text.len(),
|
||||||
thinking_text_len = streaming_acc.thinking_text.len(),
|
thinking_text_len = streaming_acc.thinking_text.len(),
|
||||||
"MITM: response complete (initial body) — marking store"
|
"MITM: response complete (initial body) — sending via channel"
|
||||||
);
|
);
|
||||||
store.mark_response_complete();
|
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
|
||||||
|
streaming_acc.is_complete = false; // prevent duplicate sends
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy path: store writes for non-channel consumers (search, etc.)
|
||||||
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
|
for fc in &calls {
|
||||||
|
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||||
|
}
|
||||||
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
|
info!("MITM: stored {} function call(s) from initial body", calls.len());
|
||||||
|
}
|
||||||
|
if !streaming_acc.response_text.is_empty() {
|
||||||
|
store.set_response_text(&streaming_acc.response_text).await;
|
||||||
|
}
|
||||||
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||||
|
store.set_grounding(gm.clone()).await;
|
||||||
}
|
}
|
||||||
} else if streaming_acc.is_complete {
|
|
||||||
debug!("MITM: skipping store write — generation stale (initial body)");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -890,45 +909,60 @@ async fn handle_http_over_tls(
|
|||||||
let s = String::from_utf8_lossy(chunk);
|
let s = String::from_utf8_lossy(chunk);
|
||||||
parse_streaming_chunk(&s, &mut streaming_acc);
|
parse_streaming_chunk(&s, &mut streaming_acc);
|
||||||
|
|
||||||
// Only write to store if our generation is still current.
|
// Send events through channel if available, otherwise use legacy store
|
||||||
let gen_valid = !won_gate || store.current_generation() == conn_generation;
|
if let Some(ref tx) = event_tx {
|
||||||
if gen_valid {
|
// Function calls → channel event
|
||||||
// Store captured function calls (drain to avoid re-storing on next chunk)
|
|
||||||
if !streaming_acc.function_calls.is_empty() {
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
for fc in &calls {
|
|
||||||
store
|
|
||||||
.record_function_call(cascade_hint.as_deref(), fc.clone())
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
store.set_last_function_calls(calls.clone()).await;
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
info!(
|
store.record_function_call(cascade_hint.as_deref(), calls[0].clone()).await;
|
||||||
"MITM: stored {} function call(s) from body chunk",
|
info!("MITM: sending {} function call(s) via channel (body chunk)", calls.len());
|
||||||
calls.len()
|
let _ = tx.send(super::store::MitmEvent::FunctionCall(calls)).await;
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Capture response + thinking text + grounding into MitmStore
|
|
||||||
if !streaming_acc.response_text.is_empty() {
|
|
||||||
store.set_response_text(&streaming_acc.response_text).await;
|
|
||||||
}
|
}
|
||||||
|
// Thinking delta → channel event (send accumulated, handler tracks last len)
|
||||||
if !streaming_acc.thinking_text.is_empty() {
|
if !streaming_acc.thinking_text.is_empty() {
|
||||||
store.set_thinking_text(&streaming_acc.thinking_text).await;
|
let _ = tx.send(super::store::MitmEvent::ThinkingDelta(
|
||||||
|
streaming_acc.thinking_text.clone(),
|
||||||
|
)).await;
|
||||||
}
|
}
|
||||||
|
// Text delta → channel event (send accumulated, handler tracks last len)
|
||||||
|
if !streaming_acc.response_text.is_empty() {
|
||||||
|
let _ = tx.send(super::store::MitmEvent::TextDelta(
|
||||||
|
streaming_acc.response_text.clone(),
|
||||||
|
)).await;
|
||||||
|
}
|
||||||
|
// Grounding → channel event
|
||||||
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||||
store.set_grounding(gm.clone()).await;
|
store.set_grounding(gm.clone()).await;
|
||||||
|
let _ = tx.send(super::store::MitmEvent::Grounding(gm.clone())).await;
|
||||||
}
|
}
|
||||||
|
// Response complete → channel event
|
||||||
if streaming_acc.is_complete {
|
if streaming_acc.is_complete {
|
||||||
info!(
|
info!(
|
||||||
response_text_len = streaming_acc.response_text.len(),
|
response_text_len = streaming_acc.response_text.len(),
|
||||||
thinking_text_len = streaming_acc.thinking_text.len(),
|
thinking_text_len = streaming_acc.thinking_text.len(),
|
||||||
function_calls = streaming_acc.function_calls.len(),
|
function_calls = streaming_acc.function_calls.len(),
|
||||||
"MITM: response complete — marking store"
|
"MITM: response complete — sending via channel"
|
||||||
);
|
);
|
||||||
store.mark_response_complete();
|
let _ = tx.send(super::store::MitmEvent::ResponseComplete).await;
|
||||||
|
streaming_acc.is_complete = false; // prevent duplicate sends
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Legacy path: store writes for non-channel consumers
|
||||||
|
if !streaming_acc.function_calls.is_empty() {
|
||||||
|
let calls: Vec<_> = streaming_acc.function_calls.drain(..).collect();
|
||||||
|
for fc in &calls {
|
||||||
|
store.record_function_call(cascade_hint.as_deref(), fc.clone()).await;
|
||||||
|
}
|
||||||
|
store.set_last_function_calls(calls.clone()).await;
|
||||||
|
info!("MITM: stored {} function call(s) from body chunk", calls.len());
|
||||||
|
}
|
||||||
|
if !streaming_acc.response_text.is_empty() {
|
||||||
|
store.set_response_text(&streaming_acc.response_text).await;
|
||||||
|
}
|
||||||
|
if let Some(ref gm) = streaming_acc.grounding_metadata {
|
||||||
|
store.set_grounding(gm.clone()).await;
|
||||||
}
|
}
|
||||||
} else if streaming_acc.is_complete {
|
|
||||||
debug!("MITM: skipping store write — generation stale (body chunk)");
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -969,9 +1003,11 @@ async fn handle_http_over_tls(
|
|||||||
store.set_grounding(gm.clone()).await;
|
store.set_grounding(gm.clone()).await;
|
||||||
}
|
}
|
||||||
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
if streaming_acc.is_complete || streaming_acc.output_tokens > 0 {
|
||||||
// Function calls are stored immediately when detected (above),
|
|
||||||
// so no need to store them again here.
|
|
||||||
let usage = streaming_acc.into_usage();
|
let usage = streaming_acc.into_usage();
|
||||||
|
// Send usage through channel if available
|
||||||
|
if let Some(ref tx) = event_tx {
|
||||||
|
let _ = tx.send(super::store::MitmEvent::Usage(usage.clone())).await;
|
||||||
|
}
|
||||||
store.record_usage(cascade_hint.as_deref(), usage).await;
|
store.record_usage(cascade_hint.as_deref(), usage).await;
|
||||||
}
|
}
|
||||||
} else if !response_body_buf.is_empty() {
|
} else if !response_body_buf.is_empty() {
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
//! Shared store for intercepted API usage data.
|
//! Shared store for intercepted API usage data.
|
||||||
//!
|
//!
|
||||||
//! The MITM proxy writes usage data here; the API handlers read from it.
|
//! The MITM proxy writes usage data here; the API handlers read from it.
|
||||||
|
//! When custom tools are active, the MITM proxy sends real-time events
|
||||||
|
//! through a channel instead of writing to shared state.
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::{mpsc, RwLock};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
/// Token usage from an intercepted API response.
|
/// Token usage from an intercepted API response.
|
||||||
@@ -126,6 +128,29 @@ pub struct GenerationParams {
|
|||||||
pub google_search: bool,
|
pub google_search: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Channel-based event pipeline ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Events sent from the MITM proxy to API handlers through a per-request channel.
|
||||||
|
/// Replaces the old polling-based approach (shared atomics + RwLocks) with
|
||||||
|
/// instant, race-free delivery.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum MitmEvent {
|
||||||
|
/// Incremental thinking/reasoning text from the model.
|
||||||
|
ThinkingDelta(String),
|
||||||
|
/// Incremental response text from the model.
|
||||||
|
TextDelta(String),
|
||||||
|
/// Model requested function call(s).
|
||||||
|
FunctionCall(Vec<CapturedFunctionCall>),
|
||||||
|
/// Response streaming is complete (finishReason received).
|
||||||
|
ResponseComplete,
|
||||||
|
/// Google API returned an error.
|
||||||
|
UpstreamError(UpstreamError),
|
||||||
|
/// Grounding metadata (search results) from the response.
|
||||||
|
Grounding(serde_json::Value),
|
||||||
|
/// Token usage data from the response.
|
||||||
|
Usage(ApiUsage),
|
||||||
|
}
|
||||||
|
|
||||||
/// Thread-safe store for intercepted data.
|
/// Thread-safe store for intercepted data.
|
||||||
///
|
///
|
||||||
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
/// Keyed by a unique request ID that we can correlate with cascade operations.
|
||||||
@@ -138,20 +163,17 @@ pub struct MitmStore {
|
|||||||
stats: Arc<RwLock<MitmStats>>,
|
stats: Arc<RwLock<MitmStats>>,
|
||||||
/// Pending function calls captured from Google responses.
|
/// Pending function calls captured from Google responses.
|
||||||
/// Key: cascade hint or "_latest". Value: list of function calls.
|
/// Key: cascade hint or "_latest". Value: list of function calls.
|
||||||
|
/// Used by the non-tool LS path (normal sync responses).
|
||||||
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
pending_function_calls: Arc<RwLock<HashMap<String, Vec<CapturedFunctionCall>>>>,
|
||||||
/// Simple flag: set when a functionCall is captured, cleared when consumed.
|
|
||||||
/// Used to block follow-up requests regardless of cascade identification.
|
|
||||||
has_active_function_call: Arc<AtomicBool>,
|
|
||||||
/// Persistent flag: set when a function call is captured, cleared ONLY when
|
|
||||||
/// a tool result is submitted. Prevents the LS from making follow-up API
|
|
||||||
/// calls during the entire tool execution cycle.
|
|
||||||
awaiting_tool_result: Arc<AtomicBool>,
|
|
||||||
/// Set when the MITM forwards the first LLM request with custom tools.
|
/// Set when the MITM forwards the first LLM request with custom tools.
|
||||||
/// Blocks ALL subsequent LS requests until the API handler clears it.
|
/// Blocks ALL subsequent LS requests until the API handler clears it.
|
||||||
request_in_flight: Arc<AtomicBool>,
|
request_in_flight: Arc<AtomicBool>,
|
||||||
/// Generation counter — incremented each time a new completions turn starts.
|
|
||||||
/// Used to discard stale data from leaked LS connections.
|
// ── Channel-based event pipeline (replaces old polling) ──────────────
|
||||||
request_generation: Arc<AtomicU64>,
|
/// Active channel sender for the current tool-path request.
|
||||||
|
/// When present, the MITM proxy sends events through this instead of
|
||||||
|
/// writing to shared state. The channel's existence = request in-flight.
|
||||||
|
active_channel: Arc<RwLock<Option<mpsc::Sender<MitmEvent>>>>,
|
||||||
|
|
||||||
// ── Tool call support ────────────────────────────────────────────────
|
// ── Tool call support ────────────────────────────────────────────────
|
||||||
/// Active tool definitions (Gemini format) for MITM injection.
|
/// Active tool definitions (Gemini format) for MITM injection.
|
||||||
@@ -173,14 +195,9 @@ pub struct MitmStore {
|
|||||||
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
|
/// Used by the MITM proxy to correlate intercepted traffic to cascades.
|
||||||
active_cascade_id: Arc<RwLock<Option<String>>>,
|
active_cascade_id: Arc<RwLock<Option<String>>>,
|
||||||
|
|
||||||
// ── Direct response capture (bypasses LS) ────────────────────────────
|
// ── Legacy direct response capture (used by search.rs) ───────────────
|
||||||
/// Captured response text from MITM when custom tools are active.
|
/// Captured response text from MITM. Used as fallback by search endpoint.
|
||||||
/// The completions/responses handler reads this instead of polling LS steps.
|
|
||||||
captured_response_text: Arc<RwLock<Option<String>>>,
|
captured_response_text: Arc<RwLock<Option<String>>>,
|
||||||
/// Captured thinking/reasoning text from MITM (for real-time streaming).
|
|
||||||
captured_thinking_text: Arc<RwLock<Option<String>>>,
|
|
||||||
/// Whether the captured response is complete (finishReason received).
|
|
||||||
response_complete: Arc<AtomicBool>,
|
|
||||||
|
|
||||||
// ── Generation parameters for MITM injection ─────────────────────────
|
// ── Generation parameters for MITM injection ─────────────────────────
|
||||||
/// Client-specified sampling parameters to inject into Google API requests.
|
/// Client-specified sampling parameters to inject into Google API requests.
|
||||||
@@ -194,9 +211,14 @@ pub struct MitmStore {
|
|||||||
/// Image to inject into the next Google API request via MITM.
|
/// Image to inject into the next Google API request via MITM.
|
||||||
pending_image: Arc<RwLock<Option<PendingImage>>>,
|
pending_image: Arc<RwLock<Option<PendingImage>>>,
|
||||||
|
|
||||||
// ── Upstream error capture ───────────────────────────────────────────
|
// ── Upstream error capture (legacy, used when no channel) ────────────
|
||||||
/// Error from Google's API, captured by MITM for forwarding to client.
|
/// Error from Google's API, captured by MITM for forwarding to client.
|
||||||
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
|
upstream_error: Arc<RwLock<Option<UpstreamError>>>,
|
||||||
|
|
||||||
|
// ── Standard LS input: real user text for MITM injection ─────────────
|
||||||
|
/// The real user text to inject into the Google API request.
|
||||||
|
/// API handlers store this before sending a dummy prompt to the LS.
|
||||||
|
pending_user_text: Arc<RwLock<Option<String>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Aggregate statistics across all intercepted traffic.
|
/// Aggregate statistics across all intercepted traffic.
|
||||||
@@ -229,10 +251,8 @@ impl MitmStore {
|
|||||||
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
latest_usage: Arc::new(RwLock::new(HashMap::new())),
|
||||||
stats: Arc::new(RwLock::new(MitmStats::default())),
|
stats: Arc::new(RwLock::new(MitmStats::default())),
|
||||||
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
pending_function_calls: Arc::new(RwLock::new(HashMap::new())),
|
||||||
has_active_function_call: Arc::new(AtomicBool::new(false)),
|
|
||||||
awaiting_tool_result: Arc::new(AtomicBool::new(false)),
|
|
||||||
request_in_flight: Arc::new(AtomicBool::new(false)),
|
request_in_flight: Arc::new(AtomicBool::new(false)),
|
||||||
request_generation: Arc::new(AtomicU64::new(0)),
|
active_channel: Arc::new(RwLock::new(None)),
|
||||||
active_tools: Arc::new(RwLock::new(None)),
|
active_tools: Arc::new(RwLock::new(None)),
|
||||||
active_tool_config: Arc::new(RwLock::new(None)),
|
active_tool_config: Arc::new(RwLock::new(None)),
|
||||||
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
pending_tool_results: Arc::new(RwLock::new(Vec::new())),
|
||||||
@@ -241,12 +261,11 @@ impl MitmStore {
|
|||||||
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
tool_rounds: Arc::new(RwLock::new(Vec::new())),
|
||||||
active_cascade_id: Arc::new(RwLock::new(None)),
|
active_cascade_id: Arc::new(RwLock::new(None)),
|
||||||
captured_response_text: Arc::new(RwLock::new(None)),
|
captured_response_text: Arc::new(RwLock::new(None)),
|
||||||
captured_thinking_text: Arc::new(RwLock::new(None)),
|
|
||||||
response_complete: Arc::new(AtomicBool::new(false)),
|
|
||||||
generation_params: Arc::new(RwLock::new(None)),
|
generation_params: Arc::new(RwLock::new(None)),
|
||||||
captured_grounding: Arc::new(RwLock::new(None)),
|
captured_grounding: Arc::new(RwLock::new(None)),
|
||||||
pending_image: Arc::new(RwLock::new(None)),
|
pending_image: Arc::new(RwLock::new(None)),
|
||||||
upstream_error: Arc::new(RwLock::new(None)),
|
upstream_error: Arc::new(RwLock::new(None)),
|
||||||
|
pending_user_text: Arc::new(RwLock::new(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -381,30 +400,6 @@ impl MitmStore {
|
|||||||
);
|
);
|
||||||
let mut pending = self.pending_function_calls.write().await;
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
pending.entry(key).or_default().push(fc);
|
pending.entry(key).or_default().push(fc);
|
||||||
self.has_active_function_call.store(true, Ordering::SeqCst);
|
|
||||||
self.awaiting_tool_result.store(true, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if there's an active (unclaimed) function call.
|
|
||||||
pub fn has_active_function_call(&self) -> bool {
|
|
||||||
self.has_active_function_call.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Force-clear the active function call flag (used to reset stale state).
|
|
||||||
pub fn clear_active_function_call(&self) {
|
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if we're awaiting a tool result (blocks LS follow-up requests).
|
|
||||||
/// This persists across function call consumption — only cleared when
|
|
||||||
/// actual tool results are submitted.
|
|
||||||
pub fn is_awaiting_tool_result(&self) -> bool {
|
|
||||||
self.awaiting_tool_result.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Clear the awaiting-tool-result flag (called when tool results arrive).
|
|
||||||
pub fn clear_awaiting_tool_result(&self) {
|
|
||||||
self.awaiting_tool_result.store(false, Ordering::SeqCst);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Take pending function calls for a specific cascade.
|
/// Take pending function calls for a specific cascade.
|
||||||
@@ -417,7 +412,6 @@ impl MitmStore {
|
|||||||
|
|
||||||
// 1. Exact cascade match
|
// 1. Exact cascade match
|
||||||
if let Some(result) = pending.remove(cascade_id) {
|
if let Some(result) = pending.remove(cascade_id) {
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
return Some(result);
|
return Some(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +419,6 @@ impl MitmStore {
|
|||||||
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
if let Some(active) = self.active_cascade_id.read().await.as_ref() {
|
||||||
if active != cascade_id {
|
if active != cascade_id {
|
||||||
if let Some(result) = pending.remove(active.as_str()) {
|
if let Some(result) = pending.remove(active.as_str()) {
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
return Some(result);
|
return Some(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -433,17 +426,12 @@ impl MitmStore {
|
|||||||
|
|
||||||
// 3. Fallback to _latest
|
// 3. Fallback to _latest
|
||||||
if let Some(result) = pending.remove("_latest") {
|
if let Some(result) = pending.remove("_latest") {
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
return Some(result);
|
return Some(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Last resort: any key
|
// 4. Last resort: any key
|
||||||
if let Some(key) = pending.keys().next().cloned() {
|
if let Some(key) = pending.keys().next().cloned() {
|
||||||
let result = pending.remove(&key);
|
return pending.remove(&key);
|
||||||
if result.is_some() {
|
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
None
|
None
|
||||||
@@ -455,19 +443,40 @@ impl MitmStore {
|
|||||||
let mut pending = self.pending_function_calls.write().await;
|
let mut pending = self.pending_function_calls.write().await;
|
||||||
let result = pending.remove("_latest");
|
let result = pending.remove("_latest");
|
||||||
if result.is_some() {
|
if result.is_some() {
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
if let Some(key) = pending.keys().next().cloned() {
|
if let Some(key) = pending.keys().next().cloned() {
|
||||||
let result = pending.remove(&key);
|
return pending.remove(&key);
|
||||||
if result.is_some() {
|
|
||||||
self.has_active_function_call.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Channel-based event pipeline ─────────────────────────────────────
|
||||||
|
|
||||||
|
/// Install a channel sender for the current tool-path request.
|
||||||
|
/// The MITM proxy will send events through this channel.
|
||||||
|
pub async fn set_channel(&self, tx: mpsc::Sender<MitmEvent>) {
|
||||||
|
*self.active_channel.write().await = Some(tx);
|
||||||
|
// NOTE: Do NOT set request_in_flight here. The MITM proxy's
|
||||||
|
// try_mark_request_in_flight() is the sole setter — setting it
|
||||||
|
// here causes compare_exchange(false,true) to always fail,
|
||||||
|
// blocking every real LS request.
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take the active channel sender (used by MITM proxy to grab it).
|
||||||
|
/// Returns None if no channel is active.
|
||||||
|
pub async fn take_channel(&self) -> Option<mpsc::Sender<MitmEvent>> {
|
||||||
|
self.active_channel.write().await.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/// Drop the active channel and clear in-flight state.
|
||||||
|
/// Called when the API handler is done with the current request.
|
||||||
|
pub async fn drop_channel(&self) {
|
||||||
|
*self.active_channel.write().await = None;
|
||||||
|
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||||
|
}
|
||||||
|
|
||||||
// ── Tool context methods ─────────────────────────────────────────────
|
// ── Tool context methods ─────────────────────────────────────────────
|
||||||
|
|
||||||
/// Set active tool definitions (already in Gemini format).
|
/// Set active tool definitions (already in Gemini format).
|
||||||
@@ -546,10 +555,10 @@ impl MitmStore {
|
|||||||
self.tool_rounds.read().await.clone()
|
self.tool_rounds.read().await.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Legacy direct response capture (search.rs fallback) ──────────────
|
||||||
// ── Direct response capture (bypass LS) ──────────────────────────────
|
|
||||||
|
|
||||||
/// Set (replace) the captured response text.
|
/// Set (replace) the captured response text.
|
||||||
|
/// Used by MITM proxy for non-channel path (search endpoint fallback).
|
||||||
pub async fn set_response_text(&self, text: &str) {
|
pub async fn set_response_text(&self, text: &str) {
|
||||||
*self.captured_response_text.write().await = Some(text.to_string());
|
*self.captured_response_text.write().await = Some(text.to_string());
|
||||||
}
|
}
|
||||||
@@ -559,28 +568,12 @@ impl MitmStore {
|
|||||||
self.captured_response_text.write().await.take()
|
self.captured_response_text.write().await.take()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Peek at the captured response text without consuming it.
|
/// Clear stale state between requests.
|
||||||
pub async fn peek_response_text(&self) -> Option<String> {
|
/// Drops any active channel and clears in-flight flags.
|
||||||
self.captured_response_text.read().await.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Mark the response as complete.
|
|
||||||
pub fn mark_response_complete(&self) {
|
|
||||||
self.response_complete.store(true, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the response is complete.
|
|
||||||
pub fn is_response_complete(&self) -> bool {
|
|
||||||
self.response_complete.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Async version of clear_response. Bumps generation counter.
|
|
||||||
pub async fn clear_response_async(&self) {
|
pub async fn clear_response_async(&self) {
|
||||||
self.response_complete.store(false, Ordering::SeqCst);
|
|
||||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||||
self.request_generation.fetch_add(1, Ordering::SeqCst);
|
*self.active_channel.write().await = None;
|
||||||
*self.captured_response_text.write().await = None;
|
*self.captured_response_text.write().await = None;
|
||||||
*self.captured_thinking_text.write().await = None;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Atomically try to mark request as in-flight.
|
/// Atomically try to mark request as in-flight.
|
||||||
@@ -593,6 +586,7 @@ impl MitmStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Check if a request is currently in-flight.
|
/// Check if a request is currently in-flight.
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn is_request_in_flight(&self) -> bool {
|
pub fn is_request_in_flight(&self) -> bool {
|
||||||
self.request_in_flight.load(Ordering::SeqCst)
|
self.request_in_flight.load(Ordering::SeqCst)
|
||||||
}
|
}
|
||||||
@@ -602,38 +596,6 @@ impl MitmStore {
|
|||||||
self.request_in_flight.store(false, Ordering::SeqCst);
|
self.request_in_flight.store(false, Ordering::SeqCst);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Reset response_complete so we can wait for the next response.
|
|
||||||
pub fn clear_response_complete(&self) {
|
|
||||||
self.response_complete.store(false, Ordering::SeqCst);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get current generation number.
|
|
||||||
pub fn current_generation(&self) -> u64 {
|
|
||||||
self.request_generation.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Bump generation counter (invalidates all pending data from old generation).
|
|
||||||
pub fn bump_generation(&self) -> u64 {
|
|
||||||
self.request_generation.fetch_add(1, Ordering::SeqCst) + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Thinking text capture ────────────────────────────────────────────
|
|
||||||
|
|
||||||
/// Set (replace) the captured thinking text.
|
|
||||||
pub async fn set_thinking_text(&self, text: &str) {
|
|
||||||
*self.captured_thinking_text.write().await = Some(text.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Peek at the captured thinking text without consuming it.
|
|
||||||
pub async fn peek_thinking_text(&self) -> Option<String> {
|
|
||||||
self.captured_thinking_text.read().await.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Take the captured thinking text (consumes it).
|
|
||||||
pub async fn take_thinking_text(&self) -> Option<String> {
|
|
||||||
self.captured_thinking_text.write().await.take()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Cascade correlation ──────────────────────────────────────────────
|
// ── Cascade correlation ──────────────────────────────────────────────
|
||||||
|
|
||||||
/// Set the active cascade ID (called by API handlers before sending a message).
|
/// Set the active cascade ID (called by API handlers before sending a message).
|
||||||
@@ -718,4 +680,18 @@ impl MitmStore {
|
|||||||
pub async fn clear_upstream_error(&self) {
|
pub async fn clear_upstream_error(&self) {
|
||||||
*self.upstream_error.write().await = None;
|
*self.upstream_error.write().await = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Pending user text for MITM injection ─────────────────────────────
|
||||||
|
|
||||||
|
/// Store the real user text for MITM injection.
|
||||||
|
/// Called by API handlers before sending a dummy prompt to the LS.
|
||||||
|
pub async fn set_pending_user_text(&self, text: String) {
|
||||||
|
*self.pending_user_text.write().await = Some(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Take (consume) the pending user text.
|
||||||
|
/// Called by the MITM proxy when building ToolContext.
|
||||||
|
pub async fn take_pending_user_text(&self) -> Option<String> {
|
||||||
|
self.pending_user_text.write().await.take()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user