117 lines
2.8 KiB
Python
117 lines
2.8 KiB
Python
import asyncio
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
import server
|
|
|
|
|
|
@pytest_asyncio.fixture(autouse=True)
|
|
async def reset_auth_gate():
|
|
async with server._auth_gate_lock:
|
|
task = server._auth_active_task
|
|
server._auth_active_task = None
|
|
server._auth_active_operation = None
|
|
server._auth_started_at = None
|
|
server._auth_waiters = 0
|
|
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
yield
|
|
|
|
async with server._auth_gate_lock:
|
|
task = server._auth_active_task
|
|
server._auth_active_task = None
|
|
server._auth_active_operation = None
|
|
server._auth_started_at = None
|
|
server._auth_waiters = 0
|
|
|
|
if task and not task.done():
|
|
task.cancel()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_same_auth_operation_shares_active_task():
|
|
calls = 0
|
|
release = asyncio.Event()
|
|
|
|
async def auth_work():
|
|
nonlocal calls
|
|
calls += 1
|
|
await release.wait()
|
|
return {"success": True}
|
|
|
|
first = asyncio.create_task(
|
|
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
second = asyncio.create_task(
|
|
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
assert calls == 1
|
|
release.set()
|
|
|
|
assert await first == {"success": True}
|
|
assert await second == {"success": True}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cancelled_waiter_does_not_cancel_auth_task():
|
|
calls = 0
|
|
release = asyncio.Event()
|
|
|
|
async def auth_work():
|
|
nonlocal calls
|
|
calls += 1
|
|
await release.wait()
|
|
return "done"
|
|
|
|
request = asyncio.create_task(
|
|
server._run_auth_serialized("login", auth_work, share_same_operation=True)
|
|
)
|
|
await asyncio.sleep(0)
|
|
|
|
request.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await request
|
|
|
|
assert calls == 1
|
|
assert server._auth_active_task is not None
|
|
assert not server._auth_active_task.cancelled()
|
|
|
|
release.set()
|
|
assert await asyncio.shield(server._auth_active_task) == "done"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_different_operation_waits_then_runs_after_active_task():
|
|
order = []
|
|
release = asyncio.Event()
|
|
|
|
async def login_work():
|
|
order.append("login-start")
|
|
await release.wait()
|
|
order.append("login-end")
|
|
return "login"
|
|
|
|
async def data_work():
|
|
order.append("data")
|
|
return "data"
|
|
|
|
login_task = asyncio.create_task(server._run_auth_serialized("login", login_work))
|
|
await asyncio.sleep(0)
|
|
|
|
data_task = asyncio.create_task(server._run_auth_serialized("list_accounts", data_work))
|
|
await asyncio.sleep(0)
|
|
|
|
assert order == ["login-start"]
|
|
release.set()
|
|
|
|
assert await login_task == "login"
|
|
assert await data_task == "data"
|
|
assert order == ["login-start", "login-end", "data"]
|