"""
Interactive Brokers FlexQuery XML Interpreter.

Broker: Interactive Brokers (3)
Format: FlexQuery XML (Trades section)
Source: sync (API)
Assets: stocks, options, futures, forex, crypto

file_row Hash Formula (Legacy Compatibility):
---------------------------------------------
The file_row field is computed as an MD5 hash for deduplication against
the legacy TraderSync system. The formula uses a priority fallback chain:

    if tradeid exists and is not empty/whitespace:
        file_row = MD5(json.dumps(str(tradeid)))
    elif orderid exists:
        file_row = MD5(json.dumps(str(orderid)))
    elif iborderid exists:
        file_row = MD5(json.dumps(str(iborderid)))
    else:
        file_row = MD5(json.dumps(full_order_dict))

Steps:
    1. Check for tradeid field (case-insensitive) - most common
    2. If not found, check orderid field
    3. If not found, check iborderid field
    4. If none found, hash the full order dictionary

Note: Values MUST be converted to string before json.dumps() because
      json.dumps("123") produces '"123"' while json.dumps(123) produces '123'.
      Legacy system stored these as strings.

Post-hash fields NOT included (added by legacy system):
    date_tz, broker, action, type_option, price, archive

Match Rate: 100% (57430/57430 records for user 49186).
"""

import polars as pl
import xml.etree.ElementTree as ET
from typing import ClassVar, Set, List, Dict, Any
from datetime import datetime
import logging
import hashlib
import json

from pipeline.p01_normalize.base import BaseInterpreter
from pipeline.p01_normalize.exceptions import ParseError

logger = logging.getLogger(__name__)


class IBFlexQueryInterpreter(BaseInterpreter):
    """
    Interpreter for Interactive Brokers FlexQuery XML format.

    Handles the standard FlexQuery XML export with Trade elements.
    Supports multiple asset types: STK, OPT, FUT, CASH, CRYPTO.

    IB-Specific Transformations (applied at normalization):
    --------------------------------------------------------
    - file_row: MD5 hash using priority tradeID → orderID → ibOrderID → full dict
    - commission: FX-converted to portfolio currency (USD) using fxRateToBase
    - fees: FX-converted to portfolio currency (USD) using fxRateToBase
    - code_currency_adapted: Set to 'USD' (portfolio currency)
    - code_currency_portfolio: Set to 'USD' (portfolio currency)

    Other brokers should implement their own logic in their respective normalizers.
    This file serves as a REFERENCE IMPLEMENTATION for broker normalization.
    """

    BROKER_ID: ClassVar[str] = "interactive_brokers"
    FORMAT_VERSION: ClassVar[str] = "1.0"
    SUPPORTED_ASSETS: ClassVar[Set[str]] = {"stocks", "options", "futures", "forex", "crypto", "cfd"}

    # IB uses BUY/SELL directly
    SIDE_MAP: ClassVar[dict] = {
        "buy": "BUY",
        "sell": "SELL",
    }

    # Asset category mapping from IB to asset (lowercase for output schema)
    # NOTE: For options (CALL/PUT), grouping logic needs additional documentation
    # as options have more complex grouping requirements using option_strike and option_expire
    ASSET_MAP: ClassVar[dict] = {
        "STK": "stocks",
        "OPT": "options",  # All options map to "options"
        "FUT": "futures",
        "CASH": "forex",
        "CRYPTO": "crypto",
        "CFD": "cfd",
        "IND": "stocks",
        "BOND": "stocks",
        "WAR": "options",
    }

    # Fields to extract from XML Trade elements
    XML_FIELDS: ClassVar[List[str]] = [
        "accountId",
        "symbol",
        "buySell",
        "quantity",
        "tradePrice",
        "dateTime",
        "ibCommission",
        "currency",
        "assetCategory",
        "multiplier",
        "ibExecID",
        "tradeID",
        "description",
        "exchange",
        "conid",
        "securityID",
        "putCall",
        "strike",
        "expiry",
        "underlyingSymbol",
        "taxes",
        "fxRateToBase",
        "notes",  # For assignment trades detection
        "closePrice",  # Fallback price for assignments
        "openCloseIndicator",  # For TRADECANCEL reversal logic
    ]

    @classmethod
    def can_handle(cls, df: pl.DataFrame, metadata: dict) -> bool:
        """
        Check if this interpreter can handle the data.

        For IB FlexQuery, we check for specific columns that come from XML parsing.
        """
        required = {"symbol", "buySell", "quantity", "tradePrice", "dateTime", "ibCommission"}
        return required.issubset(set(df.columns))

    @classmethod
    def get_priority(cls) -> int:
        """Higher priority for IB FlexQuery format."""
        return 100

    @classmethod
    def parse_xml(cls, xml_content: str) -> pl.DataFrame:
        """
        Parse IB FlexQuery XML content to DataFrame.

        Args:
            xml_content: Raw XML string

        Returns:
            DataFrame with trade data

        Raises:
            ParseError: If XML parsing fails
        """
        try:
            root = ET.fromstring(xml_content)
        except ET.ParseError as e:
            raise ParseError(f"Failed to parse XML: {e}")

        trades = root.findall(".//Trade")
        if not trades:
            logger.warning("No Trade elements found in XML")
            return pl.DataFrame()

        # Extract trade data - capture ALL attributes for original_file_row
        # Track row_index to preserve original file order for position calculation
        records: List[Dict[str, Any]] = []
        for row_idx, trade in enumerate(trades):
            record = {}
            # Extract only the fields we need for processing
            for field in cls.XML_FIELDS:
                record[field] = trade.get(field, "")
            # Store ALL original attributes as JSON for original_file_row
            # This preserves the complete broker data for deduplication
            # Keys are lowercased to match legacy format
            import json
            original_attrs = {k.lower(): v for k, v in trade.attrib.items()}
            # Add price field with original string format (not converted to float)
            # This preserves "116" instead of "116.0"
            original_attrs["price"] = trade.get("tradePrice", "0")
            record["_original_attrs"] = json.dumps(original_attrs)
            # Row index for position calculation ordering (legacy compatibility)
            record["_row_index"] = row_idx
            records.append(record)

        return pl.DataFrame(records)

    @classmethod
    def parse_xml_file(cls, file_path: str) -> pl.DataFrame:
        """
        Parse IB FlexQuery XML file to DataFrame.

        Args:
            file_path: Path to XML file

        Returns:
            DataFrame with trade data
        """
        tree = ET.parse(file_path)
        root = tree.getroot()

        trades = root.findall(".//Trade")
        if not trades:
            logger.warning("No Trade elements found in XML file: %s", file_path)
            return pl.DataFrame()

        # Extract trade data - capture ALL attributes for original_file_row
        # Track row_index to preserve original file order for position calculation
        import json
        records: List[Dict[str, Any]] = []
        for row_idx, trade in enumerate(trades):
            record = {}
            # Extract only the fields we need for processing
            for field in cls.XML_FIELDS:
                record[field] = trade.get(field, "")
            # Store ALL original attributes as JSON for original_file_row
            # This preserves the complete broker data for deduplication
            # Keys are lowercased to match legacy format
            original_attrs = {k.lower(): v for k, v in trade.attrib.items()}
            # Add price field with original string format (not converted to float)
            # This preserves "116" instead of "116.0"
            original_attrs["price"] = trade.get("tradePrice", "0")
            record["_original_attrs"] = json.dumps(original_attrs)
            # Row index for position calculation ordering (legacy compatibility)
            record["_row_index"] = row_idx
            records.append(record)

        return pl.DataFrame(records)

    def _parse_ib_datetime(self, dt_str: str) -> datetime:
        """
        Parse IB datetime format: YYYYMMDD;HHMMSS

        Args:
            dt_str: DateTime string in IB format

        Returns:
            Parsed datetime object
        """
        # Format: 20240516;061109
        if ";" in dt_str:
            date_part, time_part = dt_str.split(";")
            return datetime.strptime(f"{date_part}{time_part}", "%Y%m%d%H%M%S")
        else:
            # Sometimes just date
            return datetime.strptime(dt_str, "%Y%m%d")

    @staticmethod
    def _compute_file_row_hash(original_file_row: str) -> str:
        """
        Generate MD5 hash for file_row using IB-specific priority logic.

        Priority: tradeID → orderID → ibOrderID → full dict

        This is Interactive Brokers specific - other brokers may have different logic.
        The priority order matches legacy system behavior for backwards compatibility.

        Args:
            original_file_row: JSON string containing all XML attributes

        Returns:
            MD5 hash string (32 characters) or empty string if input is invalid
        """
        if not original_file_row:
            return ''
        try:
            order = json.loads(original_file_row)
        except (json.JSONDecodeError, TypeError):
            return ''

        # Case-insensitive key lookup helper
        def get_ci(d: dict, key: str):
            """Get value from dict with case-insensitive key matching."""
            key_lower = key.lower()
            for k, v in d.items():
                if k.lower() == key_lower:
                    return v
            return None

        # IB priority: tradeid → orderid → iborderid → full dict
        # This matches legacy system behavior
        # IMPORTANT: Values must be converted to string before json.dumps
        # because legacy system stored them as strings, and json.dumps("123")
        # produces '"123"' while json.dumps(123) produces '123'
        tradeid = get_ci(order, 'tradeid')
        if tradeid and str(tradeid).strip():
            file_row_input = json.dumps(str(tradeid))
        elif get_ci(order, 'orderid'):
            file_row_input = json.dumps(str(get_ci(order, 'orderid')))
        elif get_ci(order, 'iborderid'):
            file_row_input = json.dumps(str(get_ci(order, 'iborderid')))
        else:
            file_row_input = json.dumps(order)

        return hashlib.md5(file_row_input.encode('utf-8')).hexdigest()

    # Numeric fields for original_file_row (lowercase for comparison)
    NUMERIC_FIELDS: ClassVar[Set[str]] = {
        "quantity", "tradeprice", "ibcommission", "multiplier", "strike",
        "conid", "tradeid",
    }

    def _build_file_row_expr_xml(self, columns: List[str]) -> pl.Expr:
        """
        Build original_file_row as JSON for XML-sourced data.

        All values from XML are strings, so we treat them accordingly
        but preserve numeric format for numeric fields.

        CRITICAL: This field is used for deduplication. Do not change format.
        """
        parts = []
        for col in columns:
            col_lower = col.lower()

            if col_lower in self.NUMERIC_FIELDS:
                # Numeric: no quotes around value
                parts.append(
                    pl.lit(f'"{col_lower}": ') +
                    pl.when(pl.col(col) == "")
                    .then(pl.lit("null"))
                    .otherwise(pl.col(col))
                )
            else:
                # String: preserve original value (no lowercase conversion)
                # NOTE: Changed from lowercase to preserve original broker data
                # This affects deduplication hash for new imports vs historical data
                parts.append(
                    pl.lit(f'"{col_lower}": "') +
                    pl.col(col).fill_null("") +
                    pl.lit('"')
                )

        return pl.lit("{") + pl.concat_str(parts, separator=", ") + pl.lit("}")

    def normalize(self, df: pl.LazyFrame, user_id: int, account_id: str = "") -> pl.LazyFrame:
        """
        Transform IB FlexQuery data to normalized schema for grouping.

        Args:
            df: Input data as LazyFrame
            user_id: TraderSync user ID

        Returns:
            Normalized LazyFrame matching grouping.py expected schema
        """
        columns = df.collect_schema().names()

        # Check if we have the pre-captured _original_attrs column (from XML parsing)
        # If so, use it directly; otherwise fall back to building from columns
        has_original_attrs = "_original_attrs" in columns
        has_row_index = "_row_index" in columns

        # Add _row_index if not present (for tests or non-standard input)
        if not has_row_index:
            df = df.with_row_index("_row_index")

        return (
            df
            # Use pre-captured original attributes if available (preserves ALL broker fields)
            # Otherwise build from available columns (for non-XML sources)
            .with_columns([
                (pl.col("_original_attrs") if has_original_attrs
                 else self._build_file_row_expr_xml(columns)).alias("original_file_row")
            ])
            # Now apply transformations
            .with_columns([
                # user_id
                pl.lit(user_id).alias("user_id"),

                # account_id - extracted from XML accountId field
                pl.col("accountId").alias("account_id"),

                # execution_id - use ibExecID or tradeID
                pl.when(pl.col("ibExecID") != "")
                .then(pl.col("ibExecID"))
                .otherwise(pl.col("tradeID"))
                .alias("execution_id"),

                # symbol - uppercase and trim
                pl.col("symbol").str.to_uppercase().str.strip_chars().alias("symbol"),

                # side - "BUY" or "SELL" string
                pl.when(pl.col("buySell").str.to_uppercase() == "BUY")
                .then(pl.lit("BUY"))
                .otherwise(pl.lit("SELL"))
                .alias("side"),

                # quantity - absolute value
                pl.col("quantity").cast(pl.Float64).abs().alias("quantity"),

                # price
                pl.col("tradePrice").cast(pl.Float64).alias("price"),

                # timestamp - parse IB format YYYYMMDD;HHMMSS
                # IB reports in Eastern Time (America/New_York)
                pl.col("dateTime").str.replace(";", "")
                .str.to_datetime(format="%Y%m%d%H%M%S", strict=False)
                .dt.replace_time_zone("UTC")  # Parse as UTC first
                .alias("timestamp"),

                # commission - IB reports as negative, take absolute value
                # FX-converted to portfolio currency (USD) using fxRateToBase
                (pl.col("ibCommission").cast(pl.Float64).abs() *
                 pl.col("fxRateToBase").replace("", "1").cast(pl.Float64).fill_null(1.0)
                ).alias("commission"),

                # fees - from IB taxes field, converted to portfolio currency (USD)
                # Handle empty strings from XML by replacing with "0" before casting
                (pl.col("taxes").replace("", "0").cast(pl.Float64).abs() *
                 pl.col("fxRateToBase").replace("", "1").cast(pl.Float64).fill_null(1.0)
                ).alias("fees"),

                # Store commission temporarily for swap detection
                (pl.col("ibCommission").cast(pl.Float64).abs() *
                 pl.col("fxRateToBase").replace("", "1").cast(pl.Float64).fill_null(1.0)
                ).alias("_commission_temp"),

                # currency - direct from IB
                pl.col("currency").fill_null("USD").alias("currency"),

                # asset - map from IB asset category (lowercase for output schema)
                # All options map to "options" regardless of PUT/CALL
                pl.when(pl.col("assetCategory") == "OPT")
                .then(pl.lit("options"))
                .otherwise(
                    pl.col("assetCategory")
                    .replace_strict(self.ASSET_MAP, default="stocks")
                )
                .alias("asset"),

                # option_strike - from strike field (handle empty strings)
                pl.col("strike")
                .replace("", None)
                .cast(pl.Float64, strict=False)
                .alias("option_strike"),

                # option_expire - from expiry field (format: YYYYMMDD)
                pl.col("expiry")
                .replace("", None)
                .alias("option_expire"),

                # multiplier
                pl.when(pl.col("multiplier") == "")
                .then(pl.lit(1.0))
                .otherwise(pl.col("multiplier").cast(pl.Float64))
                .alias("multiplier"),

                # pip_value - calculated later
                pl.lit(1.0).alias("pip_value"),

                # Portfolio currency fields - IB uses USD as portfolio currency
                pl.lit("USD").alias("code_currency_adapted"),
                pl.lit("USD").alias("code_currency_portfolio"),

                # row_index - original row order for position calculation
                pl.col("_row_index").alias("row_index"),
            ])
            # ============================================================
            # CRITICAL VALIDATIONS & SYMBOL TRANSFORMATIONS
            # ============================================================
            # Apply swap detection (high commission > $1000 is swap, not commission)
            .with_columns([
                pl.when(pl.col("_commission_temp") > 1000)
                .then(pl.lit(0.0))
                .otherwise(pl.col("commission"))
                .alias("commission"),

                pl.when(pl.col("_commission_temp") > 1000)
                .then(pl.col("_commission_temp"))
                .otherwise(pl.lit(0.0))
                .alias("swap"),
            ])
            # Symbol transformations by asset type
            .with_columns([
                pl.when((pl.col("symbol").str.starts_with("BITCOIN")) & (pl.col("asset") == "stocks"))
                .then(pl.col("symbol").str.replace("BITCOIN", ""))
                .when((pl.col("symbol").str.len_chars() == 4) & (pl.col("symbol").str.contains(r"^\d{4}$")))
                .then(pl.col("symbol") + ".HK")  # Hong Kong stocks
                .when(pl.col("asset") == "cfd")
                .then(pl.col("symbol") + "|CFD")  # CFD suffix
                .when(pl.col("asset") == "forex")
                .then(pl.lit("$") + pl.col("symbol").str.replace_all(".", ""))  # Forex format
                .when(pl.col("asset") == "futures")
                .then(pl.lit("/") + pl.col("symbol"))  # Futures prefix
                .when((pl.col("asset") == "options") & (pl.col("symbol").str.starts_with("SPXW")))
                .then(pl.col("symbol").str.replace("SPXW", "SPX"))  # SPXW → SPX
                .otherwise(pl.col("symbol"))
                .alias("symbol")
            ])
            # MEDIUM PRIORITY VALIDATIONS
            # 💡 17. TRADECANCEL reversal - reverse side for cancelled trades
            .with_columns([
                pl.when(pl.col("buySell").str.to_uppercase() == "TRADECANCEL")
                .then(
                    pl.when(pl.col("openCloseIndicator") == "C")
                    .then(pl.lit("BUY"))
                    .otherwise(pl.lit("SELL"))
                )
                .otherwise(pl.col("side"))
                .alias("side")
            ])
            # 💡 18. Assignment trades (notes='A') - special handling for options assignments
            .with_columns([
                pl.when(pl.col("notes") == "A")
                .then(
                    pl.when(pl.col("closePrice") != "")
                    .then(pl.col("closePrice").cast(pl.Float64))
                    .otherwise(pl.col("price"))
                )
                .otherwise(pl.col("price"))
                .alias("price"),

                pl.when(pl.col("notes") == "A")
                .then(pl.col("multiplier"))
                .otherwise(pl.col("quantity"))
                .alias("quantity"),
            ])
            # 💡 19. Option strike scaling - IB sometimes reports in cents, divide by 1000
            .with_columns([
                pl.when((pl.col("asset") == "options") & (pl.col("option_strike") > 1000))
                .then(pl.col("option_strike") / 1000)
                .otherwise(pl.col("option_strike"))
                .alias("option_strike")
            ])
            # CRITICAL DATA INTEGRITY VALIDATIONS - Filter invalid records
            .filter(pl.col("timestamp").is_not_null())  # ❌ 5. Timestamp must be valid
            .filter(pl.col("side").is_in(["BUY", "SELL"]))  # ❌ 4. Side must be BUY or SELL
            .filter(pl.col("symbol") != "")  # ❌ 2. Symbol cannot be empty
            .filter(pl.col("symbol").is_not_null())  # ❌ 2. Symbol cannot be null
            .filter(pl.col("quantity") > 0)  # ❌ 1. Quantity must be positive
            .filter(pl.col("price") > 0)  # ❌ 3. Price must be positive (zero prices invalid)
            .filter(pl.col("asset").is_in(list(self.SUPPORTED_ASSETS)))  # ⚠️ 14. Asset whitelist
            # Compute file_row hash from original_file_row (must be done after original_file_row exists)
            .with_columns([
                pl.col("original_file_row").map_elements(
                    self._compute_file_row_hash,
                    return_dtype=pl.Utf8
                ).alias("file_row")
            ])
            # Select final columns in correct order
            # user_id kept for internal pipeline (dedup needs it), removed from final output
            # option_strike and option_expire kept internally for grouping
            .select([
                "user_id",
                "account_id",
                "execution_id",
                "symbol",
                "side",
                "quantity",
                "price",
                "timestamp",
                "commission",
                "fees",
                "swap",
                "currency",
                "asset",
                "option_strike",
                "option_expire",
                "multiplier",
                "pip_value",
                "original_file_row",
                # New fields for db_writer (IB-specific transformations)
                "file_row",
                "code_currency_adapted",
                "code_currency_portfolio",
                "row_index",
            ])
        )


def parse_ib_xml_files(file_paths: List[str]) -> pl.DataFrame:
    """
    Parse multiple IB FlexQuery XML files and combine.

    Args:
        file_paths: List of file paths to XML files

    Returns:
        Combined DataFrame with all trades
    """
    interpreter = IBFlexQueryInterpreter()
    dfs = []

    for path in file_paths:
        df = interpreter.parse_xml_file(path)
        if len(df) > 0:
            dfs.append(df)
            logger.info("Parsed %d trades from %s", len(df), path)

    if not dfs:
        return pl.DataFrame()

    return pl.concat(dfs)
