feat: completions API improvements, gemini endpoint, response types
This commit is contained in:
@@ -142,12 +142,15 @@ fn build_response_object(data: ResponseData, params: &RequestParams) -> Response
|
||||
output: data.output,
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: params.previous_response_id.clone(),
|
||||
reasoning: Reasoning::default(),
|
||||
reasoning: Reasoning {
|
||||
effort: params.reasoning_effort.clone(),
|
||||
summary: None,
|
||||
},
|
||||
store: params.store,
|
||||
temperature: params.temperature,
|
||||
text: TextFormat::default(),
|
||||
tool_choice: "auto",
|
||||
tools: vec![],
|
||||
text: params.text_format.clone(),
|
||||
tool_choice: params.tool_choice.clone(),
|
||||
tools: params.tools.clone(),
|
||||
top_p: params.top_p,
|
||||
truncation: "disabled",
|
||||
usage: data.usage,
|
||||
@@ -230,6 +233,13 @@ pub(crate) async fn handle_responses(
|
||||
}
|
||||
|
||||
// Store client tools in MitmStore for MITM injection
|
||||
// Detect web_search_preview tool (OpenAI spec) → enable Google Search grounding
|
||||
let has_web_search = body.tools.as_ref().map_or(false, |tools| {
|
||||
tools.iter().any(|t| {
|
||||
let t_type = t["type"].as_str().unwrap_or("");
|
||||
t_type == "web_search_preview" || t_type == "web_search"
|
||||
})
|
||||
});
|
||||
if let Some(ref tools) = body.tools {
|
||||
let gemini_tools = openai_tools_to_gemini(tools);
|
||||
if !gemini_tools.is_empty() {
|
||||
@@ -243,6 +253,28 @@ pub(crate) async fn handle_responses(
|
||||
}
|
||||
|
||||
// Store generation parameters for MITM injection
|
||||
// Extract text.format for structured output (json_schema)
|
||||
let (response_mime_type, response_schema, text_format) = if let Some(ref text_val) = body.text {
|
||||
let fmt_type = text_val["format"]["type"].as_str().unwrap_or("text");
|
||||
if fmt_type == "json_schema" {
|
||||
let name = text_val["format"]["name"].as_str().map(|s| s.to_string());
|
||||
let schema = text_val["format"]["schema"].as_object().map(|o| serde_json::Value::Object(o.clone()));
|
||||
let strict = text_val["format"]["strict"].as_bool();
|
||||
let tf = TextFormat {
|
||||
format: TextFormatInner {
|
||||
format_type: "json_schema".to_string(),
|
||||
name: name.clone(),
|
||||
schema: schema.clone(),
|
||||
strict,
|
||||
},
|
||||
};
|
||||
(Some("application/json".to_string()), schema, tf)
|
||||
} else {
|
||||
(None, None, TextFormat::default())
|
||||
}
|
||||
} else {
|
||||
(None, None, TextFormat::default())
|
||||
};
|
||||
{
|
||||
use crate::mitm::store::GenerationParams;
|
||||
let gp = GenerationParams {
|
||||
@@ -253,8 +285,15 @@ pub(crate) async fn handle_responses(
|
||||
stop_sequences: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
response_mime_type,
|
||||
response_schema,
|
||||
google_search: has_web_search,
|
||||
};
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some() {
|
||||
if gp.temperature.is_some() || gp.top_p.is_some() || gp.max_output_tokens.is_some()
|
||||
|| gp.reasoning_effort.is_some() || gp.response_mime_type.is_some()
|
||||
|| gp.response_schema.is_some() || gp.google_search
|
||||
{
|
||||
state.mitm_store.set_generation_params(gp).await;
|
||||
} else {
|
||||
state.mitm_store.clear_generation_params().await;
|
||||
@@ -337,6 +376,11 @@ pub(crate) async fn handle_responses(
|
||||
previous_response_id: body.previous_response_id.clone(),
|
||||
user: body.user.clone(),
|
||||
metadata: body.metadata.clone().unwrap_or(serde_json::json!({})),
|
||||
max_tool_calls: body.max_tool_calls,
|
||||
reasoning_effort: body.reasoning_effort.clone(),
|
||||
tool_choice: body.tool_choice.clone().unwrap_or(serde_json::json!("auto")),
|
||||
tools: body.tools.clone().unwrap_or_default(),
|
||||
text_format,
|
||||
};
|
||||
|
||||
if body.stream {
|
||||
@@ -365,6 +409,11 @@ struct RequestParams {
|
||||
previous_response_id: Option<String>,
|
||||
user: Option<String>,
|
||||
metadata: serde_json::Value,
|
||||
max_tool_calls: Option<u32>,
|
||||
reasoning_effort: Option<String>,
|
||||
tool_choice: serde_json::Value,
|
||||
tools: Vec<serde_json::Value>,
|
||||
text_format: TextFormat,
|
||||
}
|
||||
|
||||
/// Build Usage from the best available source, and extract thinking text from MITM:
|
||||
@@ -471,10 +520,15 @@ async fn handle_responses_sync(
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
// Check for function calls
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
if let Some(ref calls) = captured {
|
||||
if let Some(ref raw_calls) = captured {
|
||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||
raw_calls.iter().take(max as usize).collect()
|
||||
} else {
|
||||
raw_calls.iter().collect()
|
||||
};
|
||||
if !calls.is_empty() {
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
for fc in calls {
|
||||
for fc in &calls {
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
@@ -567,6 +621,14 @@ async fn handle_responses_sync(
|
||||
// Check for captured function calls from MITM (clears the active flag)
|
||||
let captured_tool_calls = state.mitm_store.take_any_function_calls().await;
|
||||
|
||||
// Enforce max_tool_calls limit
|
||||
let captured_tool_calls = captured_tool_calls.map(|mut calls| {
|
||||
if let Some(max) = params.max_tool_calls {
|
||||
calls.truncate(max as usize);
|
||||
}
|
||||
calls
|
||||
});
|
||||
|
||||
// If we have captured tool calls, return them as function_call output items
|
||||
if let Some(ref calls) = captured_tool_calls {
|
||||
info!(
|
||||
@@ -714,7 +776,12 @@ async fn handle_responses_stream(
|
||||
while start.elapsed().as_secs() < timeout {
|
||||
// Check for function calls first
|
||||
let captured = state.mitm_store.take_any_function_calls().await;
|
||||
if let Some(ref calls) = captured {
|
||||
if let Some(ref raw_calls) = captured {
|
||||
let calls: Vec<_> = if let Some(max) = params.max_tool_calls {
|
||||
raw_calls.iter().take(max as usize).collect()
|
||||
} else {
|
||||
raw_calls.iter().collect()
|
||||
};
|
||||
if !calls.is_empty() {
|
||||
let msg_output_index: u32 = if thinking_emitted { 1 } else { 0 };
|
||||
for (i, fc) in calls.iter().enumerate() {
|
||||
@@ -762,7 +829,7 @@ async fn handle_responses_stream(
|
||||
|
||||
// Build output for final response
|
||||
let mut output_items: Vec<serde_json::Value> = Vec::new();
|
||||
for fc in calls {
|
||||
for fc in &calls {
|
||||
let call_id = format!(
|
||||
"call_{}",
|
||||
uuid::Uuid::new_v4().to_string().replace('-', "")[..24].to_string()
|
||||
|
||||
Reference in New Issue
Block a user