fix: serialize auth-sensitive tool calls
Build and Push Docker Image / build (push) Successful in 34s
Build and Push Docker Image / build (push) Successful in 34s
This commit is contained in:
@@ -0,0 +1,116 @@
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def reset_auth_gate():
|
||||
async with server._auth_gate_lock:
|
||||
task = server._auth_active_task
|
||||
server._auth_active_task = None
|
||||
server._auth_active_operation = None
|
||||
server._auth_started_at = None
|
||||
server._auth_waiters = 0
|
||||
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
yield
|
||||
|
||||
async with server._auth_gate_lock:
|
||||
task = server._auth_active_task
|
||||
server._auth_active_task = None
|
||||
server._auth_active_operation = None
|
||||
server._auth_started_at = None
|
||||
server._auth_waiters = 0
|
||||
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_auth_operation_shares_active_task():
|
||||
calls = 0
|
||||
release = asyncio.Event()
|
||||
|
||||
async def auth_work():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
await release.wait()
|
||||
return {"success": True}
|
||||
|
||||
first = asyncio.create_task(
|
||||
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
second = asyncio.create_task(
|
||||
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert calls == 1
|
||||
release.set()
|
||||
|
||||
assert await first == {"success": True}
|
||||
assert await second == {"success": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_waiter_does_not_cancel_auth_task():
|
||||
calls = 0
|
||||
release = asyncio.Event()
|
||||
|
||||
async def auth_work():
|
||||
nonlocal calls
|
||||
calls += 1
|
||||
await release.wait()
|
||||
return "done"
|
||||
|
||||
request = asyncio.create_task(
|
||||
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
request.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await request
|
||||
|
||||
assert calls == 1
|
||||
assert server._auth_active_task is not None
|
||||
assert not server._auth_active_task.cancelled()
|
||||
|
||||
release.set()
|
||||
assert await asyncio.shield(server._auth_active_task) == "done"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_operation_waits_then_runs_after_active_task():
|
||||
order = []
|
||||
release = asyncio.Event()
|
||||
|
||||
async def login_work():
|
||||
order.append("login-start")
|
||||
await release.wait()
|
||||
order.append("login-end")
|
||||
return "login"
|
||||
|
||||
async def data_work():
|
||||
order.append("data")
|
||||
return "data"
|
||||
|
||||
login_task = asyncio.create_task(server._run_auth_serialized("login", login_work))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
data_task = asyncio.create_task(server._run_auth_serialized("list_accounts", data_work))
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert order == ["login-start"]
|
||||
release.set()
|
||||
|
||||
assert await login_task == "login"
|
||||
assert await data_task == "data"
|
||||
assert order == ["login-start", "login-end", "data"]
|
||||
Reference in New Issue
Block a user