diff --git a/server.py b/server.py index f6f0797..c790f52 100644 --- a/server.py +++ b/server.py @@ -1,8 +1,18 @@ import json import logging import os +import sys from typing import Optional, Any +# Ensure local vendor and package modules are in path +project_root = os.path.dirname(os.path.abspath(__file__)) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +vendor_path = os.path.join(project_root, "vendor", "schwab-scraper") +if vendor_path not in sys.path: + sys.path.insert(0, vendor_path) + from fastmcp import FastMCP from starlette.applications import Starlette from starlette.responses import JSONResponse @@ -74,6 +84,43 @@ async def get_session_status(debug: bool = False) -> str: result = await api.get_session_status(debug=debug) return serialize(result) +@mcp.tool() +async def get_login_safety_status() -> str: + """Get the current login safety status, including any active backoffs or limits. + + Useful to check if a login attempt is likely to be blocked. + """ + try: + result_str = await api.get_session_status() + result = json.loads(result_str) + if result.get("success") and "login_safety" in result.get("data", {}): + return json.dumps(result["data"]["login_safety"]) + return json.dumps({"status": "unknown", "message": "Login safety info not available"}) + except Exception as e: + return json.dumps({"error": str(e)}) + +@mcp.tool() +async def login(username: Optional[str] = None, password: Optional[str] = None, debug: bool = False) -> str: + """Perform an automated login to Schwab to establish a new session. + + Args: + username: Schwab username (optional, will use env if omitted) + password: Schwab password (optional, will use env if omitted) + debug: Enable debug logging + """ + result = await api.login(username=username, password=password, debug=debug) + return serialize(result) + +@mcp.tool() +async def refresh_session(debug: bool = False) -> str: + """Refresh the current Schwab session to prevent expiration. + + Args: + debug: Enable debug logging + """ + result = await api.refresh_session(debug=debug) + return serialize(result) + @mcp.tool() async def list_accounts(debug: bool = False) -> str: """List all Schwab accounts.