diff --git a/pyproject.toml b/pyproject.toml index 1d3ba42..9325748 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,9 +19,18 @@ dependencies = [ "typing-extensions>=4.14.0", ] +[dependency-groups] +dev = [ + "pytest>=8.0.0", + "pytest-asyncio>=0.23.0", +] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" [tool.hatch.metadata] allow-direct-references = true + +[tool.hatch.build.targets.wheel] +only-include = ["server.py"] diff --git a/server.py b/server.py index 716ace1..c363e85 100644 --- a/server.py +++ b/server.py @@ -1,3 +1,4 @@ +import asyncio import dataclasses import io import json @@ -6,7 +7,7 @@ import os import sys import time from contextlib import contextmanager -from typing import Optional, Any, Tuple +from typing import Optional, Any, Tuple, Awaitable, Callable, TypeVar from fastmcp import FastMCP from starlette.applications import Starlette @@ -239,6 +240,86 @@ class LoginManager: login_manager = LoginManager() mcp = FastMCP("SchwabScraper") +T = TypeVar("T") + +_auth_gate_lock = asyncio.Lock() +_auth_active_task: asyncio.Task[Any] | None = None +_auth_active_operation: str | None = None +_auth_started_at: float | None = None +_auth_waiters = 0 + + +def _auth_gate_status() -> dict: + return { + "in_progress": _auth_active_task is not None and not _auth_active_task.done(), + "operation": _auth_active_operation, + "started_at": _auth_started_at, + "elapsed_seconds": int(time.time() - _auth_started_at) if _auth_started_at else None, + "waiters": _auth_waiters, + } + + +async def _decrement_auth_waiters() -> None: + global _auth_waiters + async with _auth_gate_lock: + _auth_waiters = max(0, _auth_waiters - 1) + + +async def _clear_auth_task(task: asyncio.Task[Any]) -> None: + global _auth_active_task, _auth_active_operation, _auth_started_at + async with _auth_gate_lock: + if _auth_active_task is task: + _auth_active_task = None + _auth_active_operation = None + _auth_started_at = None + + +def _schedule_auth_task_clear(task: asyncio.Task[Any]) -> None: + try: + asyncio.create_task(_clear_auth_task(task)) + except RuntimeError: + pass + + +async def _run_auth_serialized( + operation: str, + coro_factory: Callable[[], Awaitable[T]], + *, + share_same_operation: bool = False, +) -> T: + """Run auth-sensitive work without letting request cancellation cancel it.""" + global _auth_active_task, _auth_active_operation, _auth_started_at, _auth_waiters + + while True: + async with _auth_gate_lock: + active_task = _auth_active_task + active_operation = _auth_active_operation + + if active_task is None or active_task.done(): + task = asyncio.create_task(coro_factory()) + _auth_active_task = task + _auth_active_operation = operation + _auth_started_at = time.time() + task.add_done_callback(_schedule_auth_task_clear) + break + + _auth_waiters += 1 + + try: + if share_same_operation and active_operation == operation: + return await asyncio.shield(active_task) + + try: + await asyncio.shield(active_task) + except Exception: + # The current operation should still get a chance to run after a + # prior auth-sensitive task fails. + pass + finally: + await _decrement_auth_waiters() + + return await asyncio.shield(task) + def _json_default(obj: Any) -> Any: """JSON fallback handler that converts dataclasses to dicts before str().""" @@ -276,6 +357,7 @@ async def get_session_status(debug: bool = False) -> str: if result.get("success"): data = result.get("data", {}) data["login_safety"] = login_manager.get_status() + data["auth_gate"] = _auth_gate_status() return serialize(result) @@ -285,7 +367,9 @@ async def get_login_safety_status() -> str: Useful to check if a login attempt is likely to be blocked. """ - return json.dumps(login_manager.get_status()) + status = login_manager.get_status() + status["auth_gate"] = _auth_gate_status() + return json.dumps(status) @mcp.tool() @@ -320,20 +404,18 @@ async def login( config_exists = os.path.exists(config_path) mcp_logger.info(f"Config fallback: path={config_path}, exists={config_exists}") - with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf: + async def _login_impl() -> dict: mcp_logger.info("capture_logs context entered") - if debug: - mcp_logger.info("DEBUG MODE ENABLED — verbose logging active") - - # api.login does not exist in unified_api; call the underlying scraper directly from schwab_scraper.browser.auth import login_to_schwab from schwab_scraper.core.config import get_schwab_credentials, load_config - if not username or not password: + resolved_username = username + resolved_password = password + if not resolved_username or not resolved_password: config = load_config() - username, password = get_schwab_credentials(config) + resolved_username, resolved_password = get_schwab_credentials(config) - if not username or not password: + if not resolved_username or not resolved_password: result = { "success": False, "error": "Username and password are required (or set in config.json)", @@ -341,36 +423,50 @@ async def login( "retryable": False, "data": None, } - else: - try: - cookies = await login_to_schwab(username, password) - if cookies: - result = { - "success": True, - "data": {"cookies_count": len(cookies)}, - "error": None, - "error_type": None, - "retryable": False, - } - else: - result = { - "success": False, - "error": "Login failed — no cookies returned. Check credentials or 2FA status.", - "error_type": "AUTHENTICATION", - "retryable": True, - "data": None, - } - except Exception as exc: + login_manager.record_attempt(False) + return result + + try: + cookies = await login_to_schwab(resolved_username, resolved_password) + if cookies: + result = { + "success": True, + "data": {"cookies_count": len(cookies)}, + "error": None, + "error_type": None, + "retryable": False, + } + else: result = { "success": False, - "error": str(exc), - "error_type": "UNKNOWN", + "error": "Login failed — no cookies returned. Check credentials or 2FA status.", + "error_type": "AUTHENTICATION", "retryable": True, "data": None, } + except Exception as exc: + result = { + "success": False, + "error": str(exc), + "error_type": "UNKNOWN", + "retryable": True, + "data": None, + } + + login_manager.record_attempt(result.get("success", False)) + return result + + with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf: + if debug: + mcp_logger.info("DEBUG MODE ENABLED — verbose logging active") + + result = await _run_auth_serialized( + "login", + _login_impl, + share_same_operation=not username and not password, + ) success = result.get("success", False) - login_manager.record_attempt(success) mcp_logger.info(f"login completed — success={success}") result = _enrich_with_logs(result, log_buf, debug) mcp_logger.info("capture_logs context exited, returning result") @@ -385,7 +481,10 @@ async def refresh_session(debug: bool = False) -> str: debug: Enable debug logging """ with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf: - result = await api.refresh_session(debug=debug) + result = await _run_auth_serialized( + "refresh_session", + lambda: api.refresh_session(debug=debug), + ) result = _enrich_with_logs(result, log_buf, debug) return serialize(result) @@ -397,7 +496,10 @@ async def list_accounts(debug: bool = False) -> str: Args: debug: Enable debug logging """ - result = await api.list_accounts(debug=debug) + result = await _run_auth_serialized( + "list_accounts", + lambda: api.list_accounts(debug=debug), + ) return serialize(result) @@ -409,7 +511,10 @@ async def get_account_overview(account: Optional[str] = None, debug: bool = Fals account: Account summary or ID (optional) debug: Enable debug logging """ - result = await api.get_account_overview(account=account, debug=debug) + result = await _run_auth_serialized( + "get_account_overview", + lambda: api.get_account_overview(account=account, debug=debug), + ) return serialize(result) @@ -426,8 +531,11 @@ async def get_positions( include_non_equity: Whether to include non-equity positions debug: Enable debug logging """ - result = await api.get_positions( - account=account, include_non_equity=include_non_equity, debug=debug + result = await _run_auth_serialized( + "get_positions", + lambda: api.get_positions( + account=account, include_non_equity=include_non_equity, debug=debug + ), ) return serialize(result) @@ -449,12 +557,15 @@ async def get_transactions( time_period: Time period (e.g., '1D', '1M') (optional) debug: Enable debug logging """ - result = await api.get_transaction_history( - account=account, - start_date=start_date, - end_date=end_date, - time_period=time_period, - debug=debug, + result = await _run_auth_serialized( + "get_transactions", + lambda: api.get_transaction_history( + account=account, + start_date=start_date, + end_date=end_date, + time_period=time_period, + debug=debug, + ), ) return serialize(result) @@ -467,7 +578,10 @@ async def get_morningstar_data(ticker: str, debug: bool = False) -> str: ticker: Stock ticker symbol debug: Enable debug logging """ - result = await api.get_morningstar_data(ticker, debug=debug) + result = await _run_auth_serialized( + "get_morningstar_data", + lambda: api.get_morningstar_data(ticker, debug=debug), + ) # When the scraper used blob URLs (modern Schwab web components), report_url # is None even though the PDF was downloaded and parsed successfully. Point diff --git a/tests/test_auth_gate.py b/tests/test_auth_gate.py new file mode 100644 index 0000000..7eff098 --- /dev/null +++ b/tests/test_auth_gate.py @@ -0,0 +1,116 @@ +import asyncio + +import pytest +import pytest_asyncio + +import server + + +@pytest_asyncio.fixture(autouse=True) +async def reset_auth_gate(): + async with server._auth_gate_lock: + task = server._auth_active_task + server._auth_active_task = None + server._auth_active_operation = None + server._auth_started_at = None + server._auth_waiters = 0 + + if task and not task.done(): + task.cancel() + + yield + + async with server._auth_gate_lock: + task = server._auth_active_task + server._auth_active_task = None + server._auth_active_operation = None + server._auth_started_at = None + server._auth_waiters = 0 + + if task and not task.done(): + task.cancel() + + +@pytest.mark.asyncio +async def test_same_auth_operation_shares_active_task(): + calls = 0 + release = asyncio.Event() + + async def auth_work(): + nonlocal calls + calls += 1 + await release.wait() + return {"success": True} + + first = asyncio.create_task( + server._run_auth_serialized("login", auth_work, share_same_operation=True) + ) + await asyncio.sleep(0) + + second = asyncio.create_task( + server._run_auth_serialized("login", auth_work, share_same_operation=True) + ) + await asyncio.sleep(0) + + assert calls == 1 + release.set() + + assert await first == {"success": True} + assert await second == {"success": True} + + +@pytest.mark.asyncio +async def test_cancelled_waiter_does_not_cancel_auth_task(): + calls = 0 + release = asyncio.Event() + + async def auth_work(): + nonlocal calls + calls += 1 + await release.wait() + return "done" + + request = asyncio.create_task( + server._run_auth_serialized("login", auth_work, share_same_operation=True) + ) + await asyncio.sleep(0) + + request.cancel() + with pytest.raises(asyncio.CancelledError): + await request + + assert calls == 1 + assert server._auth_active_task is not None + assert not server._auth_active_task.cancelled() + + release.set() + assert await asyncio.shield(server._auth_active_task) == "done" + + +@pytest.mark.asyncio +async def test_different_operation_waits_then_runs_after_active_task(): + order = [] + release = asyncio.Event() + + async def login_work(): + order.append("login-start") + await release.wait() + order.append("login-end") + return "login" + + async def data_work(): + order.append("data") + return "data" + + login_task = asyncio.create_task(server._run_auth_serialized("login", login_work)) + await asyncio.sleep(0) + + data_task = asyncio.create_task(server._run_auth_serialized("list_accounts", data_work)) + await asyncio.sleep(0) + + assert order == ["login-start"] + release.set() + + assert await login_task == "login" + assert await data_task == "data" + assert order == ["login-start", "login-end", "data"] diff --git a/uv.lock b/uv.lock index 62c4162..7e06bb2 100644 --- a/uv.lock +++ b/uv.lock @@ -754,6 +754,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, ] +[[package]] +name = "iniconfig" +version = "2.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/34/14ca021ce8e5dfedc35312d08ba8bf51fdd999c576889fc2c24cb97f4f10/iniconfig-2.3.0.tar.gz", hash = "sha256:c76315c77db068650d49c5b56314774a7804df16fee4402c1f19d6d15d8c4730", size = 20503, upload-time = "2025-10-18T21:55:43.219Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, +] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -1196,6 +1205,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/ff/99a6f4292a90504f2927d34032a4baf6adb498dc3f7cf0f3e0e22899e310/playwright-1.54.0-py3-none-win_arm64.whl", hash = "sha256:a975815971f7b8dca505c441a4c56de1aeb56a211290f8cc214eeef5524e8d75", size = 31239119, upload-time = "2025-07-22T13:58:27.56Z" }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412, upload-time = "2025-05-15T12:30:07.975Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, +] + [[package]] name = "propcache" version = "0.4.1" @@ -1496,6 +1514,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/df/80/fc9d01d5ed37ba4c42ca2b55b4339ae6e200b456be3a1aaddf4a9fa99b8c/pyperclip-1.11.0-py3-none-any.whl", hash = "sha256:299403e9ff44581cb9ba2ffeed69c7aa96a008622ad0c46cb575ca75b5b84273", size = 11063, upload-time = "2025-09-26T14:40:36.069Z" }, ] +[[package]] +name = "pytest" +version = "9.0.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7d/0d/549bd94f1a0a402dc8cf64563a117c0f3765662e2e668477624baeec44d5/pytest-9.0.3.tar.gz", hash = "sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c", size = 1572165, upload-time = "2026-04-07T17:16:18.027Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/24/a372aaf5c9b7208e7112038812994107bc65a84cd00e0354a88c2c77a617/pytest-9.0.3-py3-none-any.whl", hash = "sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9", size = 375249, upload-time = "2026-04-07T17:16:16.13Z" }, +] + +[[package]] +name = "pytest-asyncio" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/43/7c/d36d04db312ecf4298932ef77e6e4a9e8ad017906e24e34f0b0c361a2473/pytest_asyncio-1.4.0.tar.gz", hash = "sha256:c6c0d2259945122819f171a32ecea2c349ead889ee28176caaf492143424be42", size = 58514, upload-time = "2026-05-26T09:56:04.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/e2/08a497ef684b88559c9cc5f4ad53a37e7b99e727094a86d6ea32536d5d3c/pytest_asyncio-1.4.0-py3-none-any.whl", hash = "sha256:933ca923a23075a87fb7070c0ec272a6848489824d887c85c812670932835aa1", size = 16930, upload-time = "2026-05-26T09:56:02.576Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.2" @@ -1708,7 +1755,7 @@ wheels = [ [[package]] name = "schwab-mcp-custom" -version = "0.1.0" +version = "0.2.1" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -1725,13 +1772,19 @@ dependencies = [ { name = "uvicorn" }, ] +[package.dev-dependencies] +dev = [ + { name = "pytest" }, + { name = "pytest-asyncio" }, +] + [package.metadata] requires-dist = [ { name = "aiohttp", specifier = ">=3.9.0" }, { name = "fastapi", specifier = ">=0.136.1" }, { name = "fastmcp", specifier = ">=0.4.1" }, { name = "greenlet", specifier = ">=3.2.3" }, - { name = "mcp", specifier = ">=1.2.0" }, + { name = "mcp", specifier = ">=1.27.0" }, { name = "pdfplumber", specifier = ">=0.11.4" }, { name = "playwright", specifier = ">=1.54.0" }, { name = "pyee", specifier = ">=13.0.0" }, @@ -1741,6 +1794,12 @@ requires-dist = [ { name = "uvicorn", specifier = ">=0.32.0" }, ] +[package.metadata.requires-dev] +dev = [ + { name = "pytest", specifier = ">=8.0.0" }, + { name = "pytest-asyncio", specifier = ">=0.23.0" }, +] + [[package]] name = "schwab-scraper" version = "0.6.16"