from __future__ import annotations from decimal import Decimal, InvalidOperation from typing import Iterable, Optional from ...core import AccountSummary, Envelope, ErrorType, PortfolioSnapshot, Position, fail, ok from .positions_scraper import get_positions def _aggregate_positions(positions: Iterable[Position]) -> tuple[list[Position], Optional[Decimal]]: aggregated: dict[str, Position] = {} total_value = Decimal("0") has_value = False for position in positions: if position.market_value is not None: total_value += position.market_value has_value = True key = position.symbol.upper() if position.symbol else "UNKNOWN" if key not in aggregated: aggregated[key] = Position( symbol=position.symbol, description=position.description, asset_type=position.asset_type, quantity=position.quantity, market_price=position.market_price, market_value=position.market_value, cost_basis_total=position.cost_basis_total, unrealized_gain=position.unrealized_gain, unrealized_gain_pct=position.unrealized_gain_pct, lots=list(position.lots), ) continue existing = aggregated[key] if position.quantity is not None: if existing.quantity is None: existing.quantity = position.quantity else: existing.quantity += position.quantity if position.market_value is not None: if existing.market_value is None: existing.market_value = position.market_value else: existing.market_value += position.market_value if position.cost_basis_total is not None: if existing.cost_basis_total is None: existing.cost_basis_total = position.cost_basis_total else: existing.cost_basis_total += position.cost_basis_total if position.unrealized_gain is not None: if existing.unrealized_gain is None: existing.unrealized_gain = position.unrealized_gain else: existing.unrealized_gain += position.unrealized_gain if position.market_price is not None: existing.market_price = position.market_price if position.unrealized_gain_pct is not None: existing.unrealized_gain_pct = position.unrealized_gain_pct if position.description and not existing.description: existing.description = position.description if position.asset_type: existing.asset_type = position.asset_type if position.lots: existing.lots.extend(position.lots) for item in aggregated.values(): if item.unrealized_gain is not None and item.cost_basis_total not in (None, Decimal("0")): try: item.unrealized_gain_pct = float((item.unrealized_gain / item.cost_basis_total) * 100) except (InvalidOperation, ZeroDivisionError): item.unrealized_gain_pct = None total_value_out = total_value if has_value else None return list(aggregated.values()), total_value_out async def get_portfolio_snapshot( account: AccountSummary | str | None = None, *, aggregate_by_symbol: bool = True, include_non_equity: bool = False, debug: bool = False, ) -> Envelope[PortfolioSnapshot]: positions_envelope = await get_positions( account=account, include_non_equity=include_non_equity, debug=debug, ) if not positions_envelope["success"]: return fail( positions_envelope.get("error") or "Failed to retrieve positions.", positions_envelope.get("error_type") or ErrorType.UNKNOWN, positions_envelope.get("retryable", True), ) positions = positions_envelope["data"] or [] if aggregate_by_symbol: aggregated_positions, total_value = _aggregate_positions(positions) count = len(aggregated_positions) snapshot = PortfolioSnapshot( equities=aggregated_positions, total_value=total_value, count=count, ) return ok(snapshot) total_value = Decimal("0") has_value = False for position in positions: if position.market_value is not None: total_value += position.market_value has_value = True total_value_out = total_value if has_value else None snapshot = PortfolioSnapshot( equities=positions, total_value=total_value_out, count=len(positions), ) return ok(snapshot)