- Decompose modify_request() into 7 single-responsibility helpers - Decompose handle_http_over_tls(): extract read_full_request, dispatch_stream_events - Promote connect_upstream/resolve_upstream to module-level functions - Split standalone.rs (1238 lines) into 4 submodules: standalone/mod.rs, spawn.rs, discovery.rs, stub.rs - Extract proto wire primitives into proto/wire.rs - Remove 6 dead MitmStore methods - Remove dead SessionResult, DEFAULT_SESSION, get_or_create - Remove dead decode_varint_at, extract_conversation_id - Clean all unused imports across 10 files - Suppress structural dead_code warnings on deserialization fields Warnings: 20 -> 0. All 43 tests pass.
1114 lines
49 KiB
Rust
1114 lines
49 KiB
Rust
//! OpenAI Chat Completions API (/v1/chat/completions) handler.
|
|
|
|
use axum::{
|
|
extract::State,
|
|
http::StatusCode,
|
|
response::{sse::Event, IntoResponse, Json, Sse},
|
|
};
|
|
use rand::Rng;
|
|
use std::sync::Arc;
|
|
use tracing::{debug, info, warn};
|
|
|
|
use super::models::{lookup_model, DEFAULT_MODEL, MODELS};
|
|
use super::polling::{
|
|
extract_response_text, extract_thinking_content, is_response_done, poll_for_response,
|
|
};
|
|
use super::types::*;
|
|
use super::util::{err_response, now_unix, upstream_err_response};
|
|
use super::AppState;
|
|
use crate::mitm::store::{CapturedFunctionCall, PendingToolResult, ToolRound};
|
|
|
|
|
|
|
|
|
|
|
|
/// System fingerprint for completions responses (derived from crate version at compile time).
|
|
fn system_fingerprint() -> String {
|
|
format!("fp_{}", env!("CARGO_PKG_VERSION").replace('.', ""))
|
|
}
|
|
|
|
/// Build a streaming chunk JSON with all required OpenAI fields.
|
|
/// Includes system_fingerprint, service_tier, and logprobs:null in choices.
|
|
fn chunk_json(
|
|
id: &str,
|
|
model: &str,
|
|
choices: serde_json::Value,
|
|
usage: Option<serde_json::Value>,
|
|
) -> String {
|
|
let mut obj = serde_json::json!({
|
|
"id": id,
|
|
"object": "chat.completion.chunk",
|
|
"created": now_unix(),
|
|
"model": model,
|
|
"system_fingerprint": system_fingerprint(),
|
|
"service_tier": "default",
|
|
"choices": choices,
|
|
});
|
|
if let Some(u) = usage {
|
|
obj["usage"] = u;
|
|
}
|
|
serde_json::to_string(&obj).unwrap_or_default()
|
|
}
|
|
|
|
/// Build a single choice for a streaming chunk (delta + finish_reason + logprobs).
|
|
fn chunk_choice(
|
|
index: u32,
|
|
delta: serde_json::Value,
|
|
finish_reason: Option<&str>,
|
|
) -> serde_json::Value {
|
|
serde_json::json!({
|
|
"index": index,
|
|
"delta": delta,
|
|
"logprobs": serde_json::Value::Null,
|
|
"finish_reason": finish_reason,
|
|
})
|
|
}
|
|
|
|
// ─── Finish reason mapping ───────────────────────────────────────────────────
|
|
|
|
/// Map Google's finishReason → OpenAI's finish_reason.
|
|
/// Google: STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER, BLOCKLIST, PROHIBITED_CONTENT
|
|
/// OpenAI: stop, length, content_filter, tool_calls (handled separately)
|
|
fn google_to_openai_finish_reason(stop_reason: Option<&str>) -> &'static str {
|
|
match stop_reason {
|
|
Some("MAX_TOKENS") => "length",
|
|
Some("SAFETY") | Some("RECITATION") | Some("BLOCKLIST") | Some("PROHIBITED_CONTENT") => {
|
|
"content_filter"
|
|
}
|
|
_ => "stop",
|
|
}
|
|
}
|
|
|
|
// ─── Input extraction ────────────────────────────────────────────────────────
|
|
|
|
/// Extract user text from Chat Completions messages array.
|
|
///
|
|
/// Builds the full conversation context including all messages (system, user,
|
|
/// assistant, tool) so the model has complete history — matching how OpenAI
|
|
/// sends the entire messages array to the model.
|
|
fn extract_chat_input(messages: &[CompletionMessage]) -> (String, Option<crate::proto::ImageData>) {
|
|
// Extract image from last user message content array
|
|
let image = messages
|
|
.iter()
|
|
.rev()
|
|
.find(|m| m.role == "user")
|
|
.and_then(|m| super::util::extract_first_image(&m.content));
|
|
// Always build the full conversation
|
|
(build_conversation_with_tools(messages), image)
|
|
}
|
|
|
|
/// Extract text content from a message's content field (string or array).
|
|
fn extract_message_text(content: &serde_json::Value) -> String {
|
|
match content {
|
|
serde_json::Value::String(s) => s.clone(),
|
|
serde_json::Value::Array(arr) => arr
|
|
.iter()
|
|
.filter_map(|item| item["text"].as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n"),
|
|
_ => String::new(),
|
|
}
|
|
}
|
|
|
|
/// Build conversation text that includes tool call results.
|
|
///
|
|
/// Format:
|
|
/// [system prompt]
|
|
/// [user message]
|
|
/// [assistant called tool X with args Y]
|
|
/// [tool result: Z]
|
|
/// [user followup if any]
|
|
fn build_conversation_with_tools(messages: &[CompletionMessage]) -> String {
|
|
let mut parts = Vec::new();
|
|
|
|
for msg in messages {
|
|
match msg.role.as_str() {
|
|
"system" | "developer" => {
|
|
let text = extract_message_text(&msg.content);
|
|
if !text.is_empty() {
|
|
parts.push(text);
|
|
}
|
|
}
|
|
"user" => {
|
|
let text = extract_message_text(&msg.content);
|
|
if !text.is_empty() {
|
|
parts.push(text);
|
|
}
|
|
}
|
|
"assistant" => {
|
|
// Include assistant text if any
|
|
let text = extract_message_text(&msg.content);
|
|
if !text.is_empty() {
|
|
parts.push(text);
|
|
}
|
|
// Include tool calls as context
|
|
if let Some(ref tool_calls) = msg.tool_calls {
|
|
for tc in tool_calls {
|
|
if let Some(func) = tc.get("function") {
|
|
let name = func["name"].as_str().unwrap_or("unknown");
|
|
let args = func["arguments"].as_str().unwrap_or("{}");
|
|
parts.push(format!("[Tool call: {}({})]", name, args));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
"tool" => {
|
|
let text = extract_message_text(&msg.content);
|
|
let tool_id = msg.tool_call_id.as_deref().unwrap_or("unknown");
|
|
if !text.is_empty() {
|
|
parts.push(format!("[Tool result ({})]:\n{}", tool_id, text));
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
parts.join("\n\n")
|
|
}
|
|
|
|
// ─── Handler ─────────────────────────────────────────────────────────────────
|
|
|
|
/// POST /v1/chat/completions — OpenAI Chat Completions API compatibility shim.
|
|
/// Accepts standard messages format, reuses the same backend cascade, and
|
|
/// outputs in the Chat Completions streaming/sync format.
|
|
pub(crate) async fn handle_completions(
|
|
State(state): State<Arc<AppState>>,
|
|
Json(body): Json<CompletionRequest>,
|
|
) -> axum::response::Response {
|
|
let model_name = body.model.as_deref().unwrap_or(DEFAULT_MODEL);
|
|
info!(
|
|
"POST /v1/chat/completions model={} stream={}",
|
|
model_name, body.stream
|
|
);
|
|
|
|
|
|
|
|
let model = match lookup_model(model_name) {
|
|
Some(m) => m,
|
|
None => {
|
|
let names: Vec<&str> = MODELS.iter().map(|m| m.name).collect();
|
|
return err_response(
|
|
StatusCode::BAD_REQUEST,
|
|
format!("Unknown model: {model_name}. Available: {names:?}"),
|
|
"invalid_request_error",
|
|
);
|
|
}
|
|
};
|
|
|
|
// ── Build per-request state locally ──────────────────────────────────
|
|
|
|
// Convert OpenAI tools to Gemini format
|
|
let tools = body.tools.as_ref().and_then(|t| {
|
|
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(t);
|
|
if gemini_tools.is_empty() { None } else {
|
|
info!(count = t.len(), "Completions: client tools for MITM injection");
|
|
Some(gemini_tools)
|
|
}
|
|
});
|
|
let tool_config = body.tools.as_ref().and_then(|_| {
|
|
body.tool_choice.as_ref().map(|choice| {
|
|
crate::mitm::modify::openai_tool_choice_to_gemini(choice)
|
|
})
|
|
});
|
|
|
|
// ── Extract tool results from messages for MITM injection ──────────
|
|
// Build ToolRounds from message history: each round pairs assistant tool_calls
|
|
// with subsequent tool result messages. Local call_id_to_name mapping.
|
|
let mut tool_rounds: Vec<ToolRound> = Vec::new();
|
|
let mut call_id_to_name: std::collections::HashMap<String, String> = std::collections::HashMap::new();
|
|
{
|
|
let mut current_round: Option<ToolRound> = None;
|
|
|
|
for msg in &body.messages {
|
|
match msg.role.as_str() {
|
|
"assistant" => {
|
|
// Finalize any open round
|
|
if let Some(round) = current_round.take() {
|
|
if !round.calls.is_empty() {
|
|
tool_rounds.push(round);
|
|
}
|
|
}
|
|
// Start new round if this assistant has tool_calls
|
|
if let Some(ref tool_calls) = msg.tool_calls {
|
|
let mut calls = Vec::new();
|
|
for tc in tool_calls {
|
|
if let Some(func) = tc.get("function") {
|
|
let name = func["name"].as_str().unwrap_or("unknown").to_string();
|
|
let args_str = func["arguments"].as_str().unwrap_or("{}");
|
|
let args = serde_json::from_str::<serde_json::Value>(args_str)
|
|
.unwrap_or(serde_json::json!({}));
|
|
let call_id = tc["id"].as_str().unwrap_or("").to_string();
|
|
|
|
// Register call_id → name locally
|
|
if !call_id.is_empty() {
|
|
call_id_to_name.insert(call_id, name.clone());
|
|
}
|
|
|
|
calls.push(CapturedFunctionCall {
|
|
name,
|
|
args,
|
|
thought_signature: None,
|
|
captured_at: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
});
|
|
}
|
|
}
|
|
if !calls.is_empty() {
|
|
current_round = Some(ToolRound {
|
|
calls,
|
|
results: Vec::new(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
"tool" => {
|
|
let text = extract_message_text(&msg.content);
|
|
if let Some(ref call_id) = msg.tool_call_id {
|
|
let result_index = current_round
|
|
.as_ref()
|
|
.map(|r| r.results.len())
|
|
.unwrap_or(0);
|
|
let name = call_id_to_name
|
|
.get(call_id.as_str())
|
|
.cloned()
|
|
.unwrap_or_else(|| {
|
|
current_round
|
|
.as_ref()
|
|
.and_then(|r| r.calls.get(result_index))
|
|
.map(|fc| fc.name.clone())
|
|
.unwrap_or_else(|| "unknown_function".to_string())
|
|
});
|
|
|
|
let result_value = serde_json::from_str::<serde_json::Value>(&text)
|
|
.unwrap_or_else(|_| serde_json::json!({"result": text}));
|
|
|
|
if let Some(ref mut round) = current_round {
|
|
round.results.push(PendingToolResult {
|
|
name,
|
|
result: result_value,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
_ => {
|
|
// Any other role (user, system) finalizes the current round
|
|
if let Some(round) = current_round.take() {
|
|
if !round.calls.is_empty() {
|
|
tool_rounds.push(round);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Finalize last round
|
|
if let Some(round) = current_round.take() {
|
|
if !round.calls.is_empty() {
|
|
tool_rounds.push(round);
|
|
}
|
|
}
|
|
|
|
if !tool_rounds.is_empty() {
|
|
info!(
|
|
round_count = tool_rounds.len(),
|
|
calls = ?tool_rounds.iter().map(|r| r.calls.iter().map(|c| &c.name).collect::<Vec<_>>()).collect::<Vec<_>>(),
|
|
"Completions: {} tool round(s) for MITM history rewrite",
|
|
tool_rounds.len(),
|
|
);
|
|
|
|
// Merge thought_signatures from MITM-captured function calls.
|
|
// OpenAI format doesn't carry thought signatures, but Google requires
|
|
// them when injecting functionCall parts back into history.
|
|
let sigs = state.mitm_store.peek_thought_signatures().await;
|
|
if !sigs.is_empty() {
|
|
let mut merged = 0usize;
|
|
for round in &mut tool_rounds {
|
|
for fc in &mut round.calls {
|
|
if fc.thought_signature.is_none() {
|
|
if let Some(sig) = sigs.get(&fc.name) {
|
|
fc.thought_signature = Some(sig.clone());
|
|
merged += 1;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
if merged > 0 {
|
|
info!(
|
|
merged_count = merged,
|
|
"Completions: merged {} thought_signature(s) from MITM capture",
|
|
merged,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Build generation parameters locally
|
|
use crate::mitm::store::GenerationParams;
|
|
let (response_mime_type, response_schema) = match body.response_format.as_ref() {
|
|
Some(rf) => match rf.format_type.as_str() {
|
|
"json_object" | "json" => (Some("application/json".to_string()), None),
|
|
"json_schema" => {
|
|
let schema = rf.json_schema.as_ref().and_then(|js| js.schema.clone());
|
|
(Some("application/json".to_string()), schema)
|
|
}
|
|
_ => (None, None),
|
|
},
|
|
None => (None, None),
|
|
};
|
|
let gp = GenerationParams {
|
|
temperature: body.temperature,
|
|
top_p: body.top_p,
|
|
top_k: None,
|
|
max_output_tokens: body.max_tokens.or(body.max_completion_tokens),
|
|
stop_sequences: body.stop.clone().map(|s| s.into_vec()),
|
|
frequency_penalty: body.frequency_penalty,
|
|
presence_penalty: body.presence_penalty,
|
|
reasoning_effort: body.reasoning_effort.clone(),
|
|
response_mime_type,
|
|
response_schema,
|
|
google_search: body.web_search,
|
|
};
|
|
let generation_params = if gp.temperature.is_some()
|
|
|| gp.top_p.is_some()
|
|
|| gp.max_output_tokens.is_some()
|
|
|| gp.frequency_penalty.is_some()
|
|
|| gp.presence_penalty.is_some()
|
|
|| gp.reasoning_effort.is_some()
|
|
|| gp.stop_sequences.is_some()
|
|
|| gp.response_mime_type.is_some()
|
|
|| gp.response_schema.is_some()
|
|
|| gp.google_search
|
|
{
|
|
Some(gp)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let token = state.backend.oauth_token().await;
|
|
if token.is_empty() {
|
|
return err_response(
|
|
StatusCode::UNAUTHORIZED,
|
|
"No OAuth token. POST to /v1/token or set ANTIGRAVITY_OAUTH_TOKEN env var.".into(),
|
|
"authentication_error",
|
|
);
|
|
}
|
|
|
|
let (user_text, image) = extract_chat_input(&body.messages);
|
|
if user_text.is_empty() {
|
|
return err_response(
|
|
StatusCode::BAD_REQUEST,
|
|
"No user message found".to_string(),
|
|
"invalid_request_error",
|
|
);
|
|
}
|
|
|
|
let n = (body.n.max(1)).min(5); // Cap at 5 to prevent abuse
|
|
if n > 1 && body.stream {
|
|
warn!("n={n} requested with streaming — streaming only supports n=1, ignoring n");
|
|
}
|
|
|
|
// Always create a new cascade for every request
|
|
let cascade_id = match state.backend.create_cascade().await {
|
|
Ok(cid) => cid,
|
|
Err(e) => {
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("StartCascade failed: {e}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
};
|
|
|
|
// Image for MITM injection
|
|
let pending_image = image.as_ref().map(|img| {
|
|
use base64::Engine;
|
|
crate::mitm::store::PendingImage {
|
|
base64_data: base64::engine::general_purpose::STANDARD.encode(&img.data),
|
|
mime_type: img.mime_type.clone(),
|
|
}
|
|
});
|
|
|
|
// Get last calls from the latest tool round (if any) for proxy recording compat
|
|
let last_function_calls = tool_rounds.last()
|
|
.map(|r| r.calls.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Build event channel for streaming
|
|
let has_custom_tools = tools.is_some();
|
|
let (mitm_rx, event_tx) = if has_custom_tools && body.stream {
|
|
let (tx, rx) = tokio::sync::mpsc::channel(64);
|
|
(Some(rx), Some(tx))
|
|
} else {
|
|
(None, None)
|
|
};
|
|
|
|
// Build pending tool results from latest round
|
|
let pending_tool_results = tool_rounds.last()
|
|
.map(|r| r.results.clone())
|
|
.unwrap_or_default();
|
|
|
|
// Register all per-request state atomically
|
|
state.mitm_store.register_request(crate::mitm::store::RequestContext {
|
|
cascade_id: cascade_id.clone(),
|
|
pending_user_text: user_text.clone(),
|
|
event_channel: event_tx,
|
|
generation_params,
|
|
pending_image,
|
|
tools,
|
|
tool_config,
|
|
pending_tool_results,
|
|
tool_rounds,
|
|
last_function_calls,
|
|
call_id_to_name,
|
|
created_at: std::time::Instant::now(),
|
|
}).await;
|
|
|
|
// Send REAL user text to LS
|
|
match state
|
|
.backend
|
|
.send_message_with_image(&cascade_id, &format!(".<cid:{}>", cascade_id), model.model_enum, image.as_ref())
|
|
.await
|
|
{
|
|
Ok((200, _)) => {
|
|
let bg = Arc::clone(&state.backend);
|
|
let cid = cascade_id.clone();
|
|
tokio::spawn(async move {
|
|
let _ = bg.update_annotations(&cid).await;
|
|
});
|
|
}
|
|
Ok((status, _)) => {
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("Backend returned {status}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
Err(e) => {
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return err_response(
|
|
StatusCode::BAD_GATEWAY,
|
|
format!("Send failed: {e}"),
|
|
"server_error",
|
|
);
|
|
}
|
|
}
|
|
|
|
let completion_id = format!(
|
|
"chatcmpl-{}",
|
|
uuid::Uuid::new_v4().to_string().replace('-', "")
|
|
);
|
|
|
|
let include_usage = body
|
|
.stream_options
|
|
.as_ref()
|
|
.map_or(false, |o| o.include_usage);
|
|
|
|
if body.stream {
|
|
chat_completions_stream(
|
|
state,
|
|
completion_id,
|
|
model_name.to_string(),
|
|
cascade_id,
|
|
body.timeout,
|
|
include_usage,
|
|
mitm_rx,
|
|
)
|
|
.await
|
|
} else if n <= 1 {
|
|
chat_completions_sync(
|
|
state,
|
|
completion_id,
|
|
model_name.to_string(),
|
|
cascade_id,
|
|
body.timeout,
|
|
)
|
|
.await
|
|
} else {
|
|
// n > 1: fire additional (n-1) parallel cascades
|
|
let mut extra_cascade_ids = Vec::with_capacity((n - 1) as usize);
|
|
for _ in 1..n {
|
|
match state.backend.create_cascade().await {
|
|
Ok(cid) => {
|
|
// Send the same message on each extra cascade
|
|
match state
|
|
.backend
|
|
.send_message_with_image(&cid, &format!(".<cid:{}>", cid), model.model_enum, image.as_ref())
|
|
.await
|
|
{
|
|
Ok((200, _)) => {
|
|
let bg = Arc::clone(&state.backend);
|
|
let cid2 = cid.clone();
|
|
tokio::spawn(async move {
|
|
let _ = bg.update_annotations(&cid2).await;
|
|
});
|
|
extra_cascade_ids.push(cid);
|
|
}
|
|
_ => {} // Skip failed cascades
|
|
}
|
|
}
|
|
Err(_) => {} // Skip failed cascade creation
|
|
}
|
|
}
|
|
|
|
// Poll all cascades in parallel
|
|
let mut handles = Vec::with_capacity(n as usize);
|
|
let all_cascade_ids: Vec<String> = std::iter::once(cascade_id.clone())
|
|
.chain(extra_cascade_ids)
|
|
.collect();
|
|
|
|
for cid in &all_cascade_ids {
|
|
let st = Arc::clone(&state);
|
|
let cid = cid.clone();
|
|
let timeout = body.timeout;
|
|
handles.push(tokio::spawn(async move {
|
|
let result = poll_for_response(&st, &cid, timeout).await;
|
|
let mitm = match st.mitm_store.take_usage(&cid).await {
|
|
Some(u) => Some(u),
|
|
None => st.mitm_store.take_usage("_latest").await,
|
|
};
|
|
(result, mitm)
|
|
}));
|
|
}
|
|
|
|
let mut choices = Vec::with_capacity(n as usize);
|
|
let mut total_prompt = 0u64;
|
|
let mut total_completion = 0u64;
|
|
let mut total_cached = 0u64;
|
|
let mut total_thinking = 0u64;
|
|
|
|
for (i, handle) in handles.into_iter().enumerate() {
|
|
if let Ok((result, mitm)) = handle.await {
|
|
let finish_reason = google_to_openai_finish_reason(
|
|
mitm.as_ref().and_then(|u| u.stop_reason.as_deref()),
|
|
);
|
|
let (pt, ct, cached, thinking) = if let Some(ref mu) = mitm {
|
|
(
|
|
mu.input_tokens,
|
|
mu.output_tokens,
|
|
mu.cache_read_input_tokens,
|
|
mu.thinking_output_tokens,
|
|
)
|
|
} else if let Some(u) = &result.usage {
|
|
(u.input_tokens, u.output_tokens, 0, 0)
|
|
} else {
|
|
(0, 0, 0, 0)
|
|
};
|
|
total_prompt += pt;
|
|
total_completion += ct;
|
|
total_cached += cached;
|
|
total_thinking += thinking;
|
|
|
|
let mut message = serde_json::json!({
|
|
"role": "assistant",
|
|
"content": result.text,
|
|
});
|
|
if let Some(ref thinking_text) = result.thinking {
|
|
message["reasoning_content"] = serde_json::json!(thinking_text);
|
|
}
|
|
|
|
choices.push(serde_json::json!({
|
|
"index": i,
|
|
"message": message,
|
|
"logprobs": serde_json::Value::Null,
|
|
"finish_reason": finish_reason,
|
|
}));
|
|
}
|
|
}
|
|
|
|
Json(serde_json::json!({
|
|
"id": completion_id,
|
|
"object": "chat.completion",
|
|
"created": now_unix(),
|
|
"model": model_name,
|
|
"system_fingerprint": system_fingerprint(),
|
|
"service_tier": "default",
|
|
"choices": choices,
|
|
"usage": {
|
|
"prompt_tokens": total_prompt,
|
|
"completion_tokens": total_completion,
|
|
"total_tokens": total_prompt + total_completion,
|
|
"prompt_tokens_details": {
|
|
"cached_tokens": total_cached,
|
|
},
|
|
"completion_tokens_details": {
|
|
"reasoning_tokens": total_thinking,
|
|
},
|
|
},
|
|
}))
|
|
.into_response()
|
|
}
|
|
}
|
|
|
|
// ─── Streaming ───────────────────────────────────────────────────────────────
|
|
|
|
/// Streaming output in Chat Completions format.
|
|
async fn chat_completions_stream(
|
|
state: Arc<AppState>,
|
|
completion_id: String,
|
|
model_name: String,
|
|
cascade_id: String,
|
|
timeout: u64,
|
|
include_usage: bool,
|
|
mitm_rx: Option<tokio::sync::mpsc::Receiver<crate::mitm::store::MitmEvent>>,
|
|
) -> axum::response::Response {
|
|
let stream = async_stream::stream! {
|
|
let start = std::time::Instant::now();
|
|
let mut last_text = String::new();
|
|
let has_custom_tools = mitm_rx.is_some();
|
|
|
|
if !has_custom_tools {
|
|
state.mitm_store.clear_response_async().await;
|
|
state.mitm_store.clear_upstream_error().await;
|
|
let _ = state.mitm_store.take_any_function_calls().await;
|
|
}
|
|
|
|
// Initial role chunk
|
|
yield Ok::<_, std::convert::Infallible>(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"role": "assistant", "content": ""}), None)]),
|
|
None,
|
|
)));
|
|
|
|
let mut keepalive_counter: u64 = 0;
|
|
let mut last_thinking_len: usize = 0;
|
|
let mut complete_polls: u32 = 0;
|
|
let mut did_unblock_ls = false; // Prevents infinite unblock loops
|
|
|
|
// Helper: build usage JSON from MITM tokens
|
|
let build_usage = |pt: u64, ct: u64, crt: u64, tt: u64| -> serde_json::Value {
|
|
serde_json::json!({
|
|
"prompt_tokens": pt,
|
|
"completion_tokens": ct,
|
|
"total_tokens": pt + ct,
|
|
"prompt_tokens_details": { "cached_tokens": crt },
|
|
"completion_tokens_details": { "reasoning_tokens": tt },
|
|
})
|
|
};
|
|
|
|
// Take ownership of the pre-installed channel receiver
|
|
let mut rx_opt = mitm_rx;
|
|
|
|
while start.elapsed().as_secs() < timeout {
|
|
if let Some(ref mut rx) = rx_opt {
|
|
// ── Channel-based MITM pipeline ──
|
|
|
|
// Track accumulated text for delta computation
|
|
let mut acc_text = String::new();
|
|
let mut acc_thinking = String::new();
|
|
let mut last_usage: Option<crate::mitm::store::ApiUsage> = None;
|
|
|
|
'channel_loop: while let Some(event) = tokio::time::timeout(
|
|
std::time::Duration::from_secs(timeout.saturating_sub(start.elapsed().as_secs())),
|
|
rx.recv(),
|
|
).await.ok().flatten() {
|
|
use crate::mitm::store::MitmEvent;
|
|
match event {
|
|
MitmEvent::ThinkingDelta(full_thinking) => {
|
|
if full_thinking.len() > acc_thinking.len() {
|
|
let delta = full_thinking[acc_thinking.len()..].to_string();
|
|
acc_thinking = full_thinking;
|
|
last_thinking_len = acc_thinking.len();
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
|
|
None,
|
|
)));
|
|
}
|
|
}
|
|
MitmEvent::TextDelta(full_text) => {
|
|
if full_text.len() > acc_text.len() {
|
|
let delta = full_text[acc_text.len()..].to_string();
|
|
acc_text = full_text;
|
|
last_text = acc_text.clone();
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
|
|
None,
|
|
)));
|
|
}
|
|
}
|
|
MitmEvent::FunctionCall(calls) => {
|
|
let mut tool_calls = Vec::new();
|
|
for (i, fc) in calls.iter().enumerate() {
|
|
let call_id = format!(
|
|
"call_{}",
|
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
|
);
|
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
|
tool_calls.push(serde_json::json!({
|
|
"index": i,
|
|
"id": call_id,
|
|
"type": "function",
|
|
"function": {
|
|
"name": fc.name,
|
|
"arguments": arguments,
|
|
},
|
|
}));
|
|
}
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"tool_calls": tool_calls}), None)]),
|
|
None,
|
|
)));
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("tool_calls"))]),
|
|
None,
|
|
)));
|
|
if include_usage {
|
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
.or(state.mitm_store.take_usage("_latest").await);
|
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else if let Some(ref u) = last_usage {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else { (0, 0, 0, 0) };
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([]),
|
|
Some(build_usage(pt, ct, crt, tt)),
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return;
|
|
}
|
|
MitmEvent::ResponseComplete => {
|
|
if !acc_text.is_empty() {
|
|
// Have response text — done
|
|
debug!("Completions: channel response complete, text_len={}, thinking_len={}",
|
|
acc_text.len(), acc_thinking.len());
|
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
.or(state.mitm_store.take_usage("_latest").await)
|
|
.or(last_usage.take());
|
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
None,
|
|
)));
|
|
if include_usage {
|
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else { (0, 0, 0, 0) };
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([]),
|
|
Some(build_usage(pt, ct, crt, tt)),
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return;
|
|
} else if !acc_thinking.is_empty() && !did_unblock_ls {
|
|
// Thinking-only response — LS needs follow-up API calls.
|
|
// Create a new channel and unblock the gate.
|
|
did_unblock_ls = true;
|
|
let (new_tx, new_rx) = tokio::sync::mpsc::channel(64);
|
|
state.mitm_store.set_channel(&cascade_id, new_tx).await;
|
|
|
|
let _ = state.mitm_store.take_any_function_calls().await;
|
|
*rx = new_rx;
|
|
debug!(
|
|
"Completions: thinking-only — new channel for follow-up, thinking_len={}",
|
|
acc_thinking.len()
|
|
);
|
|
continue 'channel_loop;
|
|
} else if !acc_thinking.is_empty() && did_unblock_ls {
|
|
// Already unblocked once, still thinking-only.
|
|
// Wait a bit for potential follow-up events.
|
|
complete_polls += 1;
|
|
if complete_polls >= 25 {
|
|
info!("Completions: thinking-only timeout, thinking_len={}", acc_thinking.len());
|
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
.or(state.mitm_store.take_usage("_latest").await)
|
|
.or(last_usage.take());
|
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
None,
|
|
)));
|
|
if include_usage {
|
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else { (0, 0, 0, 0) };
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([]),
|
|
Some(build_usage(pt, ct, crt, tt)),
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return;
|
|
}
|
|
// Don't break — wait for more channel events
|
|
continue 'channel_loop;
|
|
} else {
|
|
// Empty response (no text, no thinking, no tools)
|
|
complete_polls += 1;
|
|
if complete_polls >= 4 {
|
|
info!("Completions: channel response complete but empty, ending stream");
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
|
|
None,
|
|
)));
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return;
|
|
}
|
|
continue 'channel_loop;
|
|
}
|
|
}
|
|
MitmEvent::UpstreamError(err) => {
|
|
let error_msg = super::util::upstream_error_message(&err);
|
|
let error_type = super::util::upstream_error_type(&err);
|
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
|
"error": {
|
|
"message": error_msg,
|
|
"type": error_type,
|
|
"code": err.status,
|
|
}
|
|
})).unwrap()));
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
return;
|
|
}
|
|
MitmEvent::Usage(u) => {
|
|
last_usage = Some(u);
|
|
}
|
|
MitmEvent::Grounding(_) => {
|
|
// Grounding metadata handled by store directly
|
|
}
|
|
}
|
|
}
|
|
|
|
// Channel closed or timeout — clean up
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
|
|
// If we got here from timeout with content, emit what we have
|
|
if !last_text.is_empty() || last_thinking_len > 0 {
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some("stop"))]),
|
|
None,
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
return;
|
|
} else {
|
|
// ── Fallback: LS steps (no MITM capture active) ──
|
|
if let Ok((status, data)) = state.backend.get_steps(&cascade_id).await {
|
|
if status == 200 {
|
|
if let Some(steps) = data["steps"].as_array() {
|
|
// Stream thinking deltas (reasoning_content)
|
|
if let Some(tc) = extract_thinking_content(steps) {
|
|
if tc.len() > last_thinking_len {
|
|
let delta = &tc[last_thinking_len..];
|
|
last_thinking_len = tc.len();
|
|
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"reasoning_content": delta}), None)]),
|
|
None,
|
|
)));
|
|
}
|
|
}
|
|
|
|
let text = extract_response_text(steps);
|
|
|
|
if !text.is_empty() && text != last_text {
|
|
let delta = if text.len() > last_text.len() && text.starts_with(&*last_text) {
|
|
&text[last_text.len()..]
|
|
} else {
|
|
&text
|
|
};
|
|
|
|
if !delta.is_empty() {
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({"content": delta}), None)]),
|
|
None,
|
|
)));
|
|
last_text = text.to_string();
|
|
}
|
|
}
|
|
|
|
// Done check
|
|
let has_content = !last_text.is_empty() || last_thinking_len > 0;
|
|
if is_response_done(steps) && has_content {
|
|
debug!("Completions stream done, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
|
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
.or(state.mitm_store.take_usage("_latest").await);
|
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
None,
|
|
)));
|
|
if include_usage {
|
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else { (0, 0, 0, 0) };
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([]),
|
|
Some(build_usage(pt, ct, crt, tt)),
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
return;
|
|
}
|
|
|
|
// IDLE fallback
|
|
let step_count = steps.len();
|
|
if step_count > 4 && step_count % 5 == 0 {
|
|
if let Ok((ts, td)) = state.backend.get_trajectory(&cascade_id).await {
|
|
if ts == 200 {
|
|
let run_status = td["status"].as_str().unwrap_or("");
|
|
let has_content_idle = !last_text.is_empty() || last_thinking_len > 0;
|
|
if run_status.contains("IDLE") && has_content_idle {
|
|
debug!("Completions IDLE, text length={}, thinking_len={}", last_text.len(), last_thinking_len);
|
|
let mitm = state.mitm_store.take_usage(&cascade_id).await
|
|
.or(state.mitm_store.take_usage("_latest").await);
|
|
let fr = google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([chunk_choice(0, serde_json::json!({}), Some(fr))]),
|
|
None,
|
|
)));
|
|
if include_usage {
|
|
let (pt, ct, crt, tt) = if let Some(ref u) = mitm {
|
|
(u.input_tokens, u.output_tokens, u.cache_read_input_tokens, u.thinking_output_tokens)
|
|
} else { (0, 0, 0, 0) };
|
|
yield Ok(Event::default().data(chunk_json(
|
|
&completion_id, &model_name,
|
|
serde_json::json!([]),
|
|
Some(build_usage(pt, ct, crt, tt)),
|
|
)));
|
|
}
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Keep-alive comment every ~5 iterations
|
|
keepalive_counter += 1;
|
|
if keepalive_counter % 5 == 0 {
|
|
yield Ok(Event::default().comment("keepalive"));
|
|
}
|
|
|
|
// Fast poll — 300ms so we pick up MITM captures quickly
|
|
let poll_ms: u64 = rand::thread_rng().gen_range(250..400);
|
|
tokio::time::sleep(tokio::time::Duration::from_millis(poll_ms)).await;
|
|
}
|
|
|
|
// Timeout — emit error, not placeholder content
|
|
warn!("Completions stream timeout after {}s", timeout);
|
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
|
"error": {
|
|
"message": format!("Timeout: no response from Google API after {timeout}s"),
|
|
"type": "upstream_error",
|
|
"code": 504,
|
|
}
|
|
})).unwrap()));
|
|
// Always clear in-flight flag when stream ends
|
|
state.mitm_store.remove_request(&cascade_id).await;
|
|
yield Ok(Event::default().data("[DONE]"));
|
|
};
|
|
|
|
Sse::new(stream)
|
|
.keep_alive(
|
|
axum::response::sse::KeepAlive::new()
|
|
.interval(std::time::Duration::from_secs(15))
|
|
.text(""),
|
|
)
|
|
.into_response()
|
|
}
|
|
|
|
// ─── Sync ────────────────────────────────────────────────────────────────────
|
|
|
|
/// Sync output in Chat Completions format.
|
|
async fn chat_completions_sync(
|
|
state: Arc<AppState>,
|
|
completion_id: String,
|
|
model_name: String,
|
|
cascade_id: String,
|
|
timeout: u64,
|
|
) -> axum::response::Response {
|
|
let result = poll_for_response(&state, &cascade_id, timeout).await;
|
|
if let Some(ref err) = result.upstream_error {
|
|
return upstream_err_response(err);
|
|
}
|
|
|
|
// Check MITM store first for real intercepted usage (fallback to _latest)
|
|
let mitm = match state.mitm_store.take_usage(&cascade_id).await {
|
|
Some(u) => Some(u),
|
|
None => state.mitm_store.take_usage("_latest").await,
|
|
};
|
|
|
|
let finish_reason =
|
|
google_to_openai_finish_reason(mitm.as_ref().and_then(|u| u.stop_reason.as_deref()));
|
|
|
|
let (prompt_tokens, completion_tokens, cached_tokens, thinking_tokens) =
|
|
if let Some(ref mitm_usage) = mitm {
|
|
(
|
|
mitm_usage.input_tokens,
|
|
mitm_usage.output_tokens,
|
|
mitm_usage.cache_read_input_tokens,
|
|
mitm_usage.thinking_output_tokens,
|
|
)
|
|
} else if let Some(u) = &result.usage {
|
|
(u.input_tokens, u.output_tokens, 0, 0)
|
|
} else {
|
|
(0, 0, 0, 0)
|
|
};
|
|
|
|
// Build message object, including reasoning_content if thinking is present
|
|
let mut message = serde_json::json!({
|
|
"role": "assistant",
|
|
"content": result.text,
|
|
});
|
|
if let Some(ref thinking) = result.thinking {
|
|
message["reasoning_content"] = serde_json::json!(thinking);
|
|
}
|
|
|
|
Json(serde_json::json!({
|
|
"id": completion_id,
|
|
"object": "chat.completion",
|
|
"created": now_unix(),
|
|
"model": model_name,
|
|
"system_fingerprint": system_fingerprint(),
|
|
"service_tier": "default",
|
|
"choices": [{
|
|
"index": 0,
|
|
"message": message,
|
|
"logprobs": serde_json::Value::Null,
|
|
"finish_reason": finish_reason,
|
|
}],
|
|
"usage": {
|
|
"prompt_tokens": prompt_tokens,
|
|
"completion_tokens": completion_tokens,
|
|
"total_tokens": prompt_tokens + completion_tokens,
|
|
"prompt_tokens_details": {
|
|
"cached_tokens": cached_tokens,
|
|
},
|
|
"completion_tokens_details": {
|
|
"reasoning_tokens": thinking_tokens,
|
|
},
|
|
},
|
|
}))
|
|
.into_response()
|
|
}
|