feat: add tool call support to chat completions endpoint
- Accept tools and tool_choice fields in CompletionRequest - Convert OpenAI tools to Gemini format and store in MitmStore - Detect MITM-captured function calls in streaming poll loop - Emit tool_calls delta chunks in OpenAI streaming format - Finish with 'tool_calls' reason instead of 'stop' when tools used - Only clear tools when request has none (prevents stale state leak)
This commit is contained in:
@@ -78,10 +78,22 @@ pub(crate) async fn handle_completions(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Clear any stale tool definitions from other endpoints (e.g. /v1/responses)
|
// Store client tools from this request (or clear stale ones from other endpoints)
|
||||||
// to prevent them leaking into completions requests. The completions endpoint
|
if let Some(ref tools) = body.tools {
|
||||||
// does not support our custom tool call flow, so tools must never be injected.
|
let gemini_tools = crate::mitm::modify::openai_tools_to_gemini(tools);
|
||||||
state.mitm_store.clear_tools().await;
|
if !gemini_tools.is_empty() {
|
||||||
|
state.mitm_store.set_tools(gemini_tools).await;
|
||||||
|
if let Some(ref choice) = body.tool_choice {
|
||||||
|
let gemini_config = crate::mitm::modify::openai_tool_choice_to_gemini(choice);
|
||||||
|
state.mitm_store.set_tool_config(gemini_config).await;
|
||||||
|
}
|
||||||
|
info!(count = tools.len(), "Completions: stored client tools for MITM injection");
|
||||||
|
} else {
|
||||||
|
state.mitm_store.clear_tools().await;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
state.mitm_store.clear_tools().await;
|
||||||
|
}
|
||||||
state.mitm_store.clear_active_function_call();
|
state.mitm_store.clear_active_function_call();
|
||||||
|
|
||||||
let token = state.backend.oauth_token().await;
|
let token = state.backend.oauth_token().await;
|
||||||
@@ -225,6 +237,57 @@ async fn chat_completions_stream(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check for MITM-captured function calls (tool use)
|
||||||
|
let captured = state.mitm_store.take_any_function_calls().await;
|
||||||
|
if let Some(ref calls) = captured {
|
||||||
|
if !calls.is_empty() {
|
||||||
|
// Emit tool_calls in OpenAI streaming format
|
||||||
|
let mut tool_calls = Vec::new();
|
||||||
|
for (i, fc) in calls.iter().enumerate() {
|
||||||
|
let call_id = format!(
|
||||||
|
"call_{}",
|
||||||
|
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||||
|
);
|
||||||
|
let arguments = serde_json::to_string(&fc.args).unwrap_or_default();
|
||||||
|
tool_calls.push(serde_json::json!({
|
||||||
|
"index": i,
|
||||||
|
"id": call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": fc.name,
|
||||||
|
"arguments": arguments,
|
||||||
|
},
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"tool_calls": tool_calls},
|
||||||
|
"finish_reason": serde_json::Value::Null,
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
|
||||||
|
// Finish with tool_calls reason
|
||||||
|
yield Ok(Event::default().data(serde_json::to_string(&serde_json::json!({
|
||||||
|
"id": completion_id,
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": now_unix(),
|
||||||
|
"model": model_name,
|
||||||
|
"choices": [{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {},
|
||||||
|
"finish_reason": "tool_calls",
|
||||||
|
}],
|
||||||
|
})).unwrap_or_default()));
|
||||||
|
yield Ok(Event::default().data("[DONE]"));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Done check: need DONE status AND non-empty text
|
// Done check: need DONE status AND non-empty text
|
||||||
if is_response_done(steps) && !last_text.is_empty() {
|
if is_response_done(steps) && !last_text.is_empty() {
|
||||||
debug!("Completions stream done, text length={}", last_text.len());
|
debug!("Completions stream done, text length={}", last_text.len());
|
||||||
|
|||||||
@@ -49,6 +49,10 @@ pub(crate) struct CompletionRequest {
|
|||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
#[serde(default = "default_timeout")]
|
#[serde(default = "default_timeout")]
|
||||||
pub timeout: u64,
|
pub timeout: u64,
|
||||||
|
/// OpenAI-format tool definitions
|
||||||
|
pub tools: Option<Vec<serde_json::Value>>,
|
||||||
|
/// Tool choice: "auto", "none", "required", or {"type":"function","function":{"name":"..."}}
|
||||||
|
pub tool_choice: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
|||||||
Reference in New Issue
Block a user