from __future__ import annotations import re from decimal import Decimal, InvalidOperation from typing import Any, Optional, Sequence from ...browser.auth import ensure_cookies from ...browser.client import connect, new_context, new_page from ...browser.navigation import goto_with_auth_check from ...core import AccountSummary, Envelope, ErrorType, Lot, Position, fail, ok from ...core.config import get_playwright_url, load_config POSITIONS_URL = "https://client.schwab.com/app/accounts/positions/#/" def _parse_decimal(value: str | None) -> Optional[Decimal]: if not value: return None cleaned = value.strip() if not cleaned or cleaned in {"-", "--"}: return None negative = False if cleaned.startswith("(") and cleaned.endswith(")"): negative = True cleaned = ( cleaned.replace("$", "") .replace(",", "") .replace("(", "") .replace(")", "") .replace("−", "-") .replace("%", "") .strip() ) if not cleaned: return None try: parsed = Decimal(cleaned) if negative or parsed < 0: parsed = -abs(parsed) return parsed except InvalidOperation: return None def _parse_float(value: str | None) -> Optional[float]: decimal_value = _parse_decimal(value) if decimal_value is None: return None try: return float(decimal_value) except (ValueError, InvalidOperation): return None def _normalize_account_label(label: str) -> AccountSummary: normalized = re.sub(r"\s+", " ", label).strip() last4_match = re.search(r"(\d{3,4})\b", normalized.replace(" ", "")) last4 = last4_match.group(1)[-4:] if last4_match else None type_match = re.search(r"^[A-Za-z&'\- ]+", normalized) account_type = re.sub(r"\s+", "_", type_match.group(0).strip()) if type_match else "Account" account_id = f"{account_type}-{last4}" if last4 else account_type return AccountSummary( id=account_id, label=normalized, type=account_type, last4=last4, is_margin="margin" in normalized.lower(), ) def _match_account(candidate: AccountSummary, requested: AccountSummary | str | None) -> bool: if requested is None: return True if isinstance(requested, AccountSummary): requested_values = { requested.id.lower(), requested.label.lower(), } if requested.last4: requested_values.add(requested.last4.lower()) else: lookup = requested.strip().lower() requested_values = {lookup} candidate_values = {candidate.id.lower(), candidate.label.lower()} if candidate.last4: candidate_values.add(candidate.last4.lower()) return bool(candidate_values & requested_values) def classify_asset(symbol: str | None, description: str | None) -> str: if symbol: sym = symbol.strip().upper() else: sym = "" desc = (description or "").strip().upper() if sym and re.fullmatch(r"[A-Z]{1,5}", sym): if "ETF" in desc: return "ETF" if any(kw in desc for kw in ["FUND", "MUTUAL"]): return "MUTUAL_FUND" return "EQUITY" if sym and re.search(r"\d", sym) and len(sym) > 5: return "OPTION" if any(kw in desc for kw in ["BOND", "CD", "TREASURY"]): return "BOND" if sym in {"CASH", "MMDA", "SWEEP"} or "CASH" in desc: return "CASH" if "ETF" in desc: return "ETF" if "FUND" in desc: return "MUTUAL_FUND" return "OTHER" async def _evaluate_table(page) -> dict[str, Any] | None: return await page.evaluate( """ () => { const table = document.querySelector('#positionsDetails'); if (!table) { return null; } const headers = Array.from(table.querySelectorAll('thead tr th')).map((th) => (th.innerText || th.textContent || '').trim() ); const rowElements = Array.from(table.querySelectorAll('tbody tr')); const rows = []; let current = null; let currentAccount = null; const isLotRow = (row) => { const klass = (row.className || '').toLowerCase(); if (klass.includes('lot') || klass.includes('sub') || klass.includes('child')) { return true; } const dataRole = (row.getAttribute('data-row-type') || '').toLowerCase(); return dataRole.includes('lot'); }; const isPositionRow = (row) => { const klass = (row.className || '').toLowerCase(); return klass.includes('position-row'); }; const isAccountHeader = (row) => { const klass = (row.className || '').toLowerCase(); const text = (row.textContent || '').trim(); return !klass.includes('position-row') && (klass.includes('highlight-row') || klass.includes('border-top-dark')) && text.includes('account panel'); }; for (const row of rowElements) { // Check if this is an account header row if (isAccountHeader(row)) { const text = row.textContent.trim(); // Extract account name from account panel text const match = text.match(/account panel[\\s\\n]+([^\\n]+)/); if (match) { currentAccount = match[1].trim(); } continue; } const cells = Array.from(row.querySelectorAll('td')).map((cell) => (cell.innerText || cell.textContent || '').trim() ); if (!cells.length) { continue; } if (isLotRow(row)) { if (current) { current.lots.push(cells); } } else if (isPositionRow(row)) { // Extract symbol from data-symbol attribute const symbol = row.getAttribute('data-symbol') || ''; current = { type: 'position', cells: cells, lots: [], symbol: symbol, account: currentAccount }; rows.push(current); } } return { headers, rows }; } """ ) def _map_row(headers: Sequence[str], cells: Sequence[str]) -> dict[str, str]: result: dict[str, str] = {} # Special handling: The table has columns in headers that don't correspond to cells # Headers: ['', 'Symbol', 'Description', 'Qty', 'Price', ...] # Cells: ['VANGUARD...', '192.5', '$328.17', ...] # The first two headers (empty checkbox and Symbol) have no corresponding cells # So: Cell 0 → 'Description', Cell 1 → 'Qty', Cell 2 → 'Price', etc. # Find the symbol header index to know where the offset starts symbol_header_idx = None for idx, header in enumerate(headers): key = header.strip().lower() if 'symbol' in key and 'description' not in key: symbol_header_idx = idx break # Calculate offset - typically 2 (empty column + symbol column) offset = symbol_header_idx + 1 if symbol_header_idx is not None else 0 for idx, header in enumerate(headers): # Normalize header: take first line, strip, lowercase # Headers often have format "Label\nsort\nfieldname" header_parts = header.strip().split('\n') key = header_parts[0].strip().lower() if header_parts else "" if not key: key = f"column_{idx}" # Map header to cell with offset if idx < offset: # These headers (empty, symbol) have no corresponding cells value = "" else: cell_idx = idx - offset value = cells[cell_idx].strip() if cell_idx < len(cells) else "" result[key] = value return result def _parse_lots(lot_rows: Sequence[Sequence[str]]) -> list[Lot]: lots: list[Lot] = [] for cells in lot_rows: if not cells: continue acquired_date = cells[0].strip() if len(cells) > 0 else None quantity = _parse_float(cells[1] if len(cells) > 1 else None) cost_basis = _parse_decimal(cells[2] if len(cells) > 2 else None) lot_id = cells[3].strip() if len(cells) > 3 else None lots.append( Lot( acquired_date=acquired_date or None, quantity=quantity, cost_basis=cost_basis, lot_id=lot_id or None, ) ) return lots def _row_to_position(row_map: dict[str, str], lots_rows: Sequence[Sequence[str]], symbol: str = "") -> Position: # Symbol is now passed from data-symbol attribute on row # Description is in the first visible cell description = row_map.get('description') or row_map.get('name') or row_map.get('column_1') or "" # Price is typically in column labeled 'price' or similar market_price = _parse_decimal( row_map.get('price') or row_map.get('market price') or row_map.get('last price') ) # Quantity - now in different column due to layout change quantity = _parse_float(row_map.get('quantity') or row_map.get('qty')) market_value = _parse_decimal(row_map.get('market value') or row_map.get('mkt val')) cost_basis_total = _parse_decimal(row_map.get('cost basis') or row_map.get('total cost')) unrealized_gain = _parse_decimal( row_map.get('gain/loss $') or row_map.get('unrealized gain') or row_map.get('gain/loss') ) unrealized_gain_pct = _parse_float( row_map.get('gain/loss %') or row_map.get('unrealized gain %') ) asset_type = classify_asset(symbol, description) lots = _parse_lots(lots_rows) return Position( symbol=symbol or "", description=description or None, asset_type=asset_type, quantity=quantity, market_price=market_price, market_value=market_value, cost_basis_total=cost_basis_total, unrealized_gain=unrealized_gain, unrealized_gain_pct=unrealized_gain_pct, lots=lots, ) async def get_positions( account: AccountSummary | str | None = None, *, include_non_equity: bool = False, debug: bool = False, ) -> Envelope[list[Position]]: cookies = await ensure_cookies() if not cookies: return fail("Unable to establish Schwab session.", ErrorType.AUTHENTICATION, retryable=False) config = load_config() playwright_url = get_playwright_url(config) playwright = browser = context = page = None try: playwright, browser = await connect(playwright_url) context = await new_context(browser, cookies=cookies) page = await new_page(context) if not await goto_with_auth_check(page, context, POSITIONS_URL, debug=debug): return fail("Failed to load Schwab positions page.", ErrorType.AUTHENTICATION, retryable=True) await page.wait_for_selector('#positionsDetails', timeout=45000) await page.wait_for_timeout(1000) await page.evaluate('window.scrollTo(0, document.body.scrollHeight)') await page.wait_for_timeout(1500) table_data = await _evaluate_table(page) if not table_data: return fail("Unable to locate positions table.", ErrorType.PARSING, retryable=True) headers = [header.strip().lower() for header in table_data.get('headers') or []] if not headers: return fail("Positions table headers not found.", ErrorType.PARSING, retryable=True) positions: list[Position] = [] for row in table_data.get('rows', []): if row.get('type') != 'position': continue cells = row.get('cells') or [] symbol = row.get('symbol') or "" account_label = row.get('account') or "" row_map = _map_row(headers, cells) position = _row_to_position(row_map, row.get('lots') or [], symbol=symbol) # Filter by account if requested if account is not None and account_label: # Normalize the account label from the row account_summary = _normalize_account_label(account_label) if not _match_account(account_summary, account): continue elif account is not None and not account_label: # If filtering by account but row has no account, skip it continue if not include_non_equity and position.asset_type not in {"EQUITY", "ETF"}: continue positions.append(position) if not positions: return fail("No positions matched the requested criteria.", ErrorType.VALIDATION, retryable=False) return ok(positions) except Exception as exc: return fail(str(exc), ErrorType.UNKNOWN, retryable=True) finally: await _safe_close_page(page) await _safe_close_context(context) await _safe_close_browser(browser) await _safe_stop_playwright(playwright) async def _safe_close_page(page) -> None: if page is None: return try: await page.close() except Exception: pass async def _safe_close_context(context) -> None: if context is None: return try: await context.close() except Exception: pass async def _safe_close_browser(browser) -> None: if browser is None: return try: await browser.close() except Exception: pass async def _safe_stop_playwright(playwright) -> None: if playwright is None: return try: await playwright.stop() except Exception: pass