fix: serialize auth-sensitive tool calls
Build and Push Docker Image / build (push) Successful in 34s
Build and Push Docker Image / build (push) Successful in 34s
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
@@ -6,7 +7,7 @@ import os
|
||||
import sys
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Any, Tuple
|
||||
from typing import Optional, Any, Tuple, Awaitable, Callable, TypeVar
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from starlette.applications import Starlette
|
||||
@@ -239,6 +240,86 @@ class LoginManager:
|
||||
login_manager = LoginManager()
|
||||
mcp = FastMCP("SchwabScraper")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_auth_gate_lock = asyncio.Lock()
|
||||
_auth_active_task: asyncio.Task[Any] | None = None
|
||||
_auth_active_operation: str | None = None
|
||||
_auth_started_at: float | None = None
|
||||
_auth_waiters = 0
|
||||
|
||||
|
||||
def _auth_gate_status() -> dict:
|
||||
return {
|
||||
"in_progress": _auth_active_task is not None and not _auth_active_task.done(),
|
||||
"operation": _auth_active_operation,
|
||||
"started_at": _auth_started_at,
|
||||
"elapsed_seconds": int(time.time() - _auth_started_at) if _auth_started_at else None,
|
||||
"waiters": _auth_waiters,
|
||||
}
|
||||
|
||||
|
||||
async def _decrement_auth_waiters() -> None:
|
||||
global _auth_waiters
|
||||
async with _auth_gate_lock:
|
||||
_auth_waiters = max(0, _auth_waiters - 1)
|
||||
|
||||
|
||||
async def _clear_auth_task(task: asyncio.Task[Any]) -> None:
|
||||
global _auth_active_task, _auth_active_operation, _auth_started_at
|
||||
async with _auth_gate_lock:
|
||||
if _auth_active_task is task:
|
||||
_auth_active_task = None
|
||||
_auth_active_operation = None
|
||||
_auth_started_at = None
|
||||
|
||||
|
||||
def _schedule_auth_task_clear(task: asyncio.Task[Any]) -> None:
|
||||
try:
|
||||
asyncio.create_task(_clear_auth_task(task))
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
async def _run_auth_serialized(
|
||||
operation: str,
|
||||
coro_factory: Callable[[], Awaitable[T]],
|
||||
*,
|
||||
share_same_operation: bool = False,
|
||||
) -> T:
|
||||
"""Run auth-sensitive work without letting request cancellation cancel it."""
|
||||
global _auth_active_task, _auth_active_operation, _auth_started_at, _auth_waiters
|
||||
|
||||
while True:
|
||||
async with _auth_gate_lock:
|
||||
active_task = _auth_active_task
|
||||
active_operation = _auth_active_operation
|
||||
|
||||
if active_task is None or active_task.done():
|
||||
task = asyncio.create_task(coro_factory())
|
||||
_auth_active_task = task
|
||||
_auth_active_operation = operation
|
||||
_auth_started_at = time.time()
|
||||
task.add_done_callback(_schedule_auth_task_clear)
|
||||
break
|
||||
|
||||
_auth_waiters += 1
|
||||
|
||||
try:
|
||||
if share_same_operation and active_operation == operation:
|
||||
return await asyncio.shield(active_task)
|
||||
|
||||
try:
|
||||
await asyncio.shield(active_task)
|
||||
except Exception:
|
||||
# The current operation should still get a chance to run after a
|
||||
# prior auth-sensitive task fails.
|
||||
pass
|
||||
finally:
|
||||
await _decrement_auth_waiters()
|
||||
|
||||
return await asyncio.shield(task)
|
||||
|
||||
|
||||
def _json_default(obj: Any) -> Any:
|
||||
"""JSON fallback handler that converts dataclasses to dicts before str()."""
|
||||
@@ -276,6 +357,7 @@ async def get_session_status(debug: bool = False) -> str:
|
||||
if result.get("success"):
|
||||
data = result.get("data", {})
|
||||
data["login_safety"] = login_manager.get_status()
|
||||
data["auth_gate"] = _auth_gate_status()
|
||||
return serialize(result)
|
||||
|
||||
|
||||
@@ -285,7 +367,9 @@ async def get_login_safety_status() -> str:
|
||||
|
||||
Useful to check if a login attempt is likely to be blocked.
|
||||
"""
|
||||
return json.dumps(login_manager.get_status())
|
||||
status = login_manager.get_status()
|
||||
status["auth_gate"] = _auth_gate_status()
|
||||
return json.dumps(status)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
@@ -320,20 +404,18 @@ async def login(
|
||||
config_exists = os.path.exists(config_path)
|
||||
mcp_logger.info(f"Config fallback: path={config_path}, exists={config_exists}")
|
||||
|
||||
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
|
||||
async def _login_impl() -> dict:
|
||||
mcp_logger.info("capture_logs context entered")
|
||||
if debug:
|
||||
mcp_logger.info("DEBUG MODE ENABLED — verbose logging active")
|
||||
|
||||
# api.login does not exist in unified_api; call the underlying scraper directly
|
||||
from schwab_scraper.browser.auth import login_to_schwab
|
||||
from schwab_scraper.core.config import get_schwab_credentials, load_config
|
||||
|
||||
if not username or not password:
|
||||
resolved_username = username
|
||||
resolved_password = password
|
||||
if not resolved_username or not resolved_password:
|
||||
config = load_config()
|
||||
username, password = get_schwab_credentials(config)
|
||||
resolved_username, resolved_password = get_schwab_credentials(config)
|
||||
|
||||
if not username or not password:
|
||||
if not resolved_username or not resolved_password:
|
||||
result = {
|
||||
"success": False,
|
||||
"error": "Username and password are required (or set in config.json)",
|
||||
@@ -341,36 +423,50 @@ async def login(
|
||||
"retryable": False,
|
||||
"data": None,
|
||||
}
|
||||
else:
|
||||
try:
|
||||
cookies = await login_to_schwab(username, password)
|
||||
if cookies:
|
||||
result = {
|
||||
"success": True,
|
||||
"data": {"cookies_count": len(cookies)},
|
||||
"error": None,
|
||||
"error_type": None,
|
||||
"retryable": False,
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"success": False,
|
||||
"error": "Login failed — no cookies returned. Check credentials or 2FA status.",
|
||||
"error_type": "AUTHENTICATION",
|
||||
"retryable": True,
|
||||
"data": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
login_manager.record_attempt(False)
|
||||
return result
|
||||
|
||||
try:
|
||||
cookies = await login_to_schwab(resolved_username, resolved_password)
|
||||
if cookies:
|
||||
result = {
|
||||
"success": True,
|
||||
"data": {"cookies_count": len(cookies)},
|
||||
"error": None,
|
||||
"error_type": None,
|
||||
"retryable": False,
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"success": False,
|
||||
"error": str(exc),
|
||||
"error_type": "UNKNOWN",
|
||||
"error": "Login failed — no cookies returned. Check credentials or 2FA status.",
|
||||
"error_type": "AUTHENTICATION",
|
||||
"retryable": True,
|
||||
"data": None,
|
||||
}
|
||||
except Exception as exc:
|
||||
result = {
|
||||
"success": False,
|
||||
"error": str(exc),
|
||||
"error_type": "UNKNOWN",
|
||||
"retryable": True,
|
||||
"data": None,
|
||||
}
|
||||
|
||||
login_manager.record_attempt(result.get("success", False))
|
||||
return result
|
||||
|
||||
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
|
||||
if debug:
|
||||
mcp_logger.info("DEBUG MODE ENABLED — verbose logging active")
|
||||
|
||||
result = await _run_auth_serialized(
|
||||
"login",
|
||||
_login_impl,
|
||||
share_same_operation=not username and not password,
|
||||
)
|
||||
|
||||
success = result.get("success", False)
|
||||
login_manager.record_attempt(success)
|
||||
mcp_logger.info(f"login completed — success={success}")
|
||||
result = _enrich_with_logs(result, log_buf, debug)
|
||||
mcp_logger.info("capture_logs context exited, returning result")
|
||||
@@ -385,7 +481,10 @@ async def refresh_session(debug: bool = False) -> str:
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
with capture_logs(level=logging.DEBUG if debug else logging.INFO) as log_buf:
|
||||
result = await api.refresh_session(debug=debug)
|
||||
result = await _run_auth_serialized(
|
||||
"refresh_session",
|
||||
lambda: api.refresh_session(debug=debug),
|
||||
)
|
||||
result = _enrich_with_logs(result, log_buf, debug)
|
||||
return serialize(result)
|
||||
|
||||
@@ -397,7 +496,10 @@ async def list_accounts(debug: bool = False) -> str:
|
||||
Args:
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
result = await api.list_accounts(debug=debug)
|
||||
result = await _run_auth_serialized(
|
||||
"list_accounts",
|
||||
lambda: api.list_accounts(debug=debug),
|
||||
)
|
||||
return serialize(result)
|
||||
|
||||
|
||||
@@ -409,7 +511,10 @@ async def get_account_overview(account: Optional[str] = None, debug: bool = Fals
|
||||
account: Account summary or ID (optional)
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
result = await api.get_account_overview(account=account, debug=debug)
|
||||
result = await _run_auth_serialized(
|
||||
"get_account_overview",
|
||||
lambda: api.get_account_overview(account=account, debug=debug),
|
||||
)
|
||||
return serialize(result)
|
||||
|
||||
|
||||
@@ -426,8 +531,11 @@ async def get_positions(
|
||||
include_non_equity: Whether to include non-equity positions
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
result = await api.get_positions(
|
||||
account=account, include_non_equity=include_non_equity, debug=debug
|
||||
result = await _run_auth_serialized(
|
||||
"get_positions",
|
||||
lambda: api.get_positions(
|
||||
account=account, include_non_equity=include_non_equity, debug=debug
|
||||
),
|
||||
)
|
||||
return serialize(result)
|
||||
|
||||
@@ -449,12 +557,15 @@ async def get_transactions(
|
||||
time_period: Time period (e.g., '1D', '1M') (optional)
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
result = await api.get_transaction_history(
|
||||
account=account,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
time_period=time_period,
|
||||
debug=debug,
|
||||
result = await _run_auth_serialized(
|
||||
"get_transactions",
|
||||
lambda: api.get_transaction_history(
|
||||
account=account,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
time_period=time_period,
|
||||
debug=debug,
|
||||
),
|
||||
)
|
||||
return serialize(result)
|
||||
|
||||
@@ -467,7 +578,10 @@ async def get_morningstar_data(ticker: str, debug: bool = False) -> str:
|
||||
ticker: Stock ticker symbol
|
||||
debug: Enable debug logging
|
||||
"""
|
||||
result = await api.get_morningstar_data(ticker, debug=debug)
|
||||
result = await _run_auth_serialized(
|
||||
"get_morningstar_data",
|
||||
lambda: api.get_morningstar_data(ticker, debug=debug),
|
||||
)
|
||||
|
||||
# When the scraper used blob URLs (modern Schwab web components), report_url
|
||||
# is None even though the PDF was downloaded and parsed successfully. Point
|
||||
|
||||
Reference in New Issue
Block a user