fix: serialize auth-sensitive tool calls
Build and Push Docker Image / build (push) Successful in 34s

This commit is contained in:
2026-06-01 01:42:24 +00:00
parent b06fc47d29
commit 4b2275fa0b
4 changed files with 345 additions and 47 deletions
+9
View File
@@ -19,9 +19,18 @@ dependencies = [
"typing-extensions>=4.14.0",
]
[dependency-groups]
dev = [
"pytest>=8.0.0",
"pytest-asyncio>=0.23.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
only-include = ["server.py"]
+133 -19
View File
@@ -1,3 +1,4 @@
import asyncio
import dataclasses
import io
import json
@@ -6,7 +7,7 @@ import os
import sys
import time
from contextlib import contextmanager
from typing import Optional, Any, Tuple
from typing import Optional, Any, Tuple, Awaitable, Callable, TypeVar
from fastmcp import FastMCP
from starlette.applications import Starlette
@@ -239,6 +240,86 @@ class LoginManager:
login_manager = LoginManager()
mcp = FastMCP("SchwabScraper")
T = TypeVar("T")
_auth_gate_lock = asyncio.Lock()
_auth_active_task: asyncio.Task[Any] | None = None
_auth_active_operation: str | None = None
_auth_started_at: float | None = None
_auth_waiters = 0
def _auth_gate_status() -> dict:
return {
"in_progress": _auth_active_task is not None and not _auth_active_task.done(),
"operation": _auth_active_operation,
"started_at": _auth_started_at,
"elapsed_seconds": int(time.time() - _auth_started_at) if _auth_started_at else None,
"waiters": _auth_waiters,
}
async def _decrement_auth_waiters() -> None:
global _auth_waiters
async with _auth_gate_lock:
_auth_waiters = max(0, _auth_waiters - 1)
async def _clear_auth_task(task: asyncio.Task[Any]) -> None:
global _auth_active_task, _auth_active_operation, _auth_started_at
async with _auth_gate_lock:
if _auth_active_task is task:
_auth_active_task = None
_auth_active_operation = None
_auth_started_at = None
def _schedule_auth_task_clear(task: asyncio.Task[Any]) -> None:
try:
asyncio.create_task(_clear_auth_task(task))
except RuntimeError:
pass
async def _run_auth_serialized(
operation: str,
coro_factory: Callable[[], Awaitable[T]],
*,
share_same_operation: bool = False,
) -> T:
"""Run auth-sensitive work without letting request cancellation cancel it."""
global _auth_active_task, _auth_active_operation, _auth_started_at, _auth_waiters
while True:
async with _auth_gate_lock:
active_task = _auth_active_task
active_operation = _auth_active_operation
if active_task is None or active_task.done():
task = asyncio.create_task(coro_factory())
_auth_active_task = task
_auth_active_operation = operation
_auth_started_at = time.time()
task.add_done_callback(_schedule_auth_task_clear)
break
_auth_waiters += 1
try:
if share_same_operation and active_operation == operation:
return await asyncio.shield(active_task)
try:
await asyncio.shield(active_task)
except Exception:
# The current operation should still get a chance to run after a
# prior auth-sensitive task fails.
pass
finally:
await _decrement_auth_waiters()
return await asyncio.shield(task)
def _json_default(obj: Any) -> Any:
"""JSON fallback handler that converts dataclasses to dicts before str()."""
@@ -276,6 +357,7 @@ async def get_session_status(debug: bool = False) -> str:
if result.get("success"):
data = result.get("data", {})
data["login_safety"] = login_manager.get_status()
data["auth_gate"] = _auth_gate_status()
return serialize(result)
@@ -285,7 +367,9 @@ async def get_login_safety_status() -> str:
Useful to check if a login attempt is likely to be blocked.
"""
return json.dumps(login_manager.get_status())
status = login_manager.get_status()
status["auth_gate"] = _auth_gate_status()
return json.dumps(status)
@mcp.tool()
@@ -320,20 +404,18 @@ async def login(
config_exists = os.path.exists(config_path)
mcp_logger.info(f"Config fallback: path={config_path}, exists={config_exists}")
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
async def _login_impl() -> dict:
mcp_logger.info("capture_logs context entered")
if debug:
mcp_logger.info("DEBUG MODE ENABLED — verbose logging active")
# api.login does not exist in unified_api; call the underlying scraper directly
from schwab_scraper.browser.auth import login_to_schwab
from schwab_scraper.core.config import get_schwab_credentials, load_config
if not username or not password:
resolved_username = username
resolved_password = password
if not resolved_username or not resolved_password:
config = load_config()
username, password = get_schwab_credentials(config)
resolved_username, resolved_password = get_schwab_credentials(config)
if not username or not password:
if not resolved_username or not resolved_password:
result = {
"success": False,
"error": "Username and password are required (or set in config.json)",
@@ -341,9 +423,11 @@ async def login(
"retryable": False,
"data": None,
}
else:
login_manager.record_attempt(False)
return result
try:
cookies = await login_to_schwab(username, password)
cookies = await login_to_schwab(resolved_username, resolved_password)
if cookies:
result = {
"success": True,
@@ -369,8 +453,20 @@ async def login(
"data": None,
}
login_manager.record_attempt(result.get("success", False))
return result
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
if debug:
mcp_logger.info("DEBUG MODE ENABLED — verbose logging active")
result = await _run_auth_serialized(
"login",
_login_impl,
share_same_operation=not username and not password,
)
success = result.get("success", False)
login_manager.record_attempt(success)
mcp_logger.info(f"login completed — success={success}")
result = _enrich_with_logs(result, log_buf, debug)
mcp_logger.info("capture_logs context exited, returning result")
@@ -385,7 +481,10 @@ async def refresh_session(debug: bool = False) -> str:
debug: Enable debug logging
"""
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
result = await api.refresh_session(debug=debug)
result = await _run_auth_serialized(
"refresh_session",
lambda: api.refresh_session(debug=debug),
)
result = _enrich_with_logs(result, log_buf, debug)
return serialize(result)
@@ -397,7 +496,10 @@ async def list_accounts(debug: bool = False) -> str:
Args:
debug: Enable debug logging
"""
result = await api.list_accounts(debug=debug)
result = await _run_auth_serialized(
"list_accounts",
lambda: api.list_accounts(debug=debug),
)
return serialize(result)
@@ -409,7 +511,10 @@ async def get_account_overview(account: Optional[str] = None, debug: bool = Fals
account: Account summary or ID (optional)
debug: Enable debug logging
"""
result = await api.get_account_overview(account=account, debug=debug)
result = await _run_auth_serialized(
"get_account_overview",
lambda: api.get_account_overview(account=account, debug=debug),
)
return serialize(result)
@@ -426,8 +531,11 @@ async def get_positions(
include_non_equity: Whether to include non-equity positions
debug: Enable debug logging
"""
result = await api.get_positions(
result = await _run_auth_serialized(
"get_positions",
lambda: api.get_positions(
account=account, include_non_equity=include_non_equity, debug=debug
),
)
return serialize(result)
@@ -449,12 +557,15 @@ async def get_transactions(
time_period: Time period (e.g., '1D', '1M') (optional)
debug: Enable debug logging
"""
result = await api.get_transaction_history(
result = await _run_auth_serialized(
"get_transactions",
lambda: api.get_transaction_history(
account=account,
start_date=start_date,
end_date=end_date,
time_period=time_period,
debug=debug,
),
)
return serialize(result)
@@ -467,7 +578,10 @@ async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
ticker: Stock ticker symbol
debug: Enable debug logging
"""
result = await api.get_morningstar_data(ticker, debug=debug)
result = await _run_auth_serialized(
"get_morningstar_data",
lambda: api.get_morningstar_data(ticker, debug=debug),
)
# When the scraper used blob URLs (modern Schwab web components), report_url
# is None even though the PDF was downloaded and parsed successfully. Point
+116
View File
@@ -0,0 +1,116 @@
import asyncio
import pytest
import pytest_asyncio
import server
@pytest_asyncio.fixture(autouse=True)
async def reset_auth_gate():
async with server._auth_gate_lock:
task = server._auth_active_task
server._auth_active_task = None
server._auth_active_operation = None
server._auth_started_at = None
server._auth_waiters = 0
if task and not task.done():
task.cancel()
yield
async with server._auth_gate_lock:
task = server._auth_active_task
server._auth_active_task = None
server._auth_active_operation = None
server._auth_started_at = None
server._auth_waiters = 0
if task and not task.done():
task.cancel()
@pytest.mark.asyncio
async def test_same_auth_operation_shares_active_task():
calls = 0
release = asyncio.Event()
async def auth_work():
nonlocal calls
calls += 1
await release.wait()
return {"success": True}
first = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
second = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
assert calls == 1
release.set()
assert await first == {"success": True}
assert await second == {"success": True}
@pytest.mark.asyncio
async def test_cancelled_waiter_does_not_cancel_auth_task():
calls = 0
release = asyncio.Event()
async def auth_work():
nonlocal calls
calls += 1
await release.wait()
return "done"
request = asyncio.create_task(
server._run_auth_serialized("login", auth_work, share_same_operation=True)
)
await asyncio.sleep(0)
request.cancel()
with pytest.raises(asyncio.CancelledError):
await request
assert calls == 1
assert server._auth_active_task is not None
assert not server._auth_active_task.cancelled()
release.set()
assert await asyncio.shield(server._auth_active_task) == "done"
@pytest.mark.asyncio
async def test_different_operation_waits_then_runs_after_active_task():
order = []
release = asyncio.Event()
async def login_work():
order.append("login-start")
await release.wait()
order.append("login-end")
return "login"
async def data_work():
order.append("data")
return "data"
login_task = asyncio.create_task(server._run_auth_serialized("login", login_work))
await asyncio.sleep(0)
data_task = asyncio.create_task(server._run_auth_serialized("list_accounts", data_work))
await asyncio.sleep(0)
assert order == ["login-start"]
release.set()
assert await login_task == "login"
assert await data_task == "data"
assert order == ["login-start", "login-end", "data"]
Generated
+61 -2
View File
@@ -754,6 +754,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" },
]
[[package]]
name = "iniconfig"
version = "2.3.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" },
]
[[package]]
name = "jaraco-classes"
version = "3.4.0"
@@ -1196,6 +1205,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/33/ff/99a6f4292a90504f2927d34032a4baf6adb498dc3f7cf0f3e0e22899e310/playwright-1.54.0-py3-none-win_arm64.whl", hash = "sha256:a975815971f7b8dca505c441a4c56de1aeb56a211290f8cc214eeef5524e8d75", size = 31239119, upload-time = "2025-07-22T13:58:27.56Z" },
]
[[package]]
name = "pluggy"
version = "1.6.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" },
]
[[package]]
name = "propcache"
version = "0.4.1"
@@ -1496,6 +1514,35 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" },
]
[[package]]
name = "pytest"
version = "9.0.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "iniconfig" },
{ name = "packaging" },
{ name = "pluggy" },
{ name = "pygments" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
{ name = "typing-extensions", marker = "python_full_version < '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/43/7c/d36d04db312ecf4298932ef77e6e4a9e8ad017906e24e34f0b0c361a2473/pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42", size = 58514, upload-time = "2026-05-26T09:56:04.083Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/03/e2/08a497ef684b88559c9cc5f4ad53a37e7b99e727094a86d6ea32536d5d3c/pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1", size = 16930, upload-time = "2026-05-26T09:56:02.576Z" },
]
[[package]]
name = "python-dotenv"
version = "1.2.2"
@@ -1708,7 +1755,7 @@ wheels = [
[[package]]
name = "schwab-mcp-custom"
version = "0.1.0"
version = "0.2.1"
source = { editable = "." }
dependencies = [
{ name = "aiohttp" },
@@ -1725,13 +1772,19 @@ dependencies = [
{ name = "uvicorn" },
]
[package.dev-dependencies]
dev = [
{ name = "pytest" },
{ name = "pytest-asyncio" },
]
[package.metadata]
requires-dist = [
{ name = "aiohttp", specifier = ">=3.9.0" },
{ name = "fastapi", specifier = ">=0.136.1" },
{ name = "fastmcp", specifier = ">=0.4.1" },
{ name = "greenlet", specifier = ">=3.2.3" },
{ name = "mcp", specifier = ">=1.2.0" },
{ name = "mcp", specifier = ">=1.27.0" },
{ name = "pdfplumber", specifier = ">=0.11.4" },
{ name = "playwright", specifier = ">=1.54.0" },
{ name = "pyee", specifier = ">=13.0.0" },
@@ -1741,6 +1794,12 @@ requires-dist = [
{ name = "uvicorn", specifier = ">=0.32.0" },
]
[package.metadata.requires-dev]
dev = [
{ name = "pytest", specifier = ">=8.0.0" },
{ name = "pytest-asyncio", specifier = ">=0.23.0" },
]
[[package]]
name = "schwab-scraper"
version = "0.6.16"