Files
schwab-mcp-custom/tests/test_auth_gate.py
T
b3nw 4b2275fa0b
Build and Push Docker Image / build (push) Successful in 34s
fix: serialize auth-sensitive tool calls
2026-06-01 01:42:24 +00:00

117 lines
2.8 KiB
Python

import asyncio
import pytest
import pytest_asyncio
import server
@pytest_asyncio.fixture(autouse=True)
async def reset_auth_gate():
async with server._auth_gate_lock:
task = server._auth_active_task
server._auth_active_task = None
server._auth_active_operation = None
server._auth_started_at = None
server._auth_waiters = 0
if task and not task.done():
task.cancel()
yield
async with server._auth_gate_lock:
task = server._auth_active_task
server._auth_active_task = None
server._auth_active_operation = None
server._auth_started_at = None
server._auth_waiters = 0
if task and not task.done():
task.cancel()
@pytest.mark.asyncio
async def test_same_auth_operation_shares_active_task():
calls = 0
release = asyncio.Event()
async def auth_work():
nonlocal calls
calls += 1
await release.wait()
return {"success": True}
first = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
second = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
assert calls == 1
release.set()
assert await first == {"success": True}
assert await second == {"success": True}
@pytest.mark.asyncio
async def test_cancelled_waiter_does_not_cancel_auth_task():
calls = 0
release = asyncio.Event()
async def auth_work():
nonlocal calls
calls += 1
await release.wait()
return "done"
request = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
request.cancel()
with pytest.raises(asyncio.CancelledError):
await request
assert calls == 1
assert server._auth_active_task is not None
assert not server._auth_active_task.cancelled()
release.set()
assert await asyncio.shield(server._auth_active_task) == "done"
@pytest.mark.asyncio
async def test_different_operation_waits_then_runs_after_active_task():
order = []
release = asyncio.Event()
async def login_work():
order.append("login-start")
await release.wait()
order.append("login-end")
return "login"
async def data_work():
order.append("data")
return "data"
login_task = asyncio.create_task(server._run_auth_serialized("login", login_work))
await asyncio.sleep(0)
data_task = asyncio.create_task(server._run_auth_serialized("list_accounts", data_work))
await asyncio.sleep(0)
assert order == ["login-start"]
release.set()
assert await login_task == "login"
assert await data_task == "data"
assert order == ["login-start", "login-end", "data"]