"""
Bitstamp API Interpreter.

Broker: Bitstamp (31)
Format: JSON (API sync)
Source: sync (API)
Assets: SPOT (crypto)

Data Format:
------------
Bitstamp exports transactions as JSON array with dynamic currency columns.
Each transaction has:
- id: unique transaction ID
- datetime: "YYYY-MM-DD HH:MM:SS.ffffff"
- type: "2" for trades
- fee: fee amount in quote currency
- order_id: order ID
- Dynamic currency columns: btc, ltc, eth, xrp, xlm, usdc, usdt, shib, doge, etc.
- Dynamic rate columns: btc_usd, ltc_usd, eth_btc, etc.
- usd, eur: fiat amounts

Trade Direction:
- BUY: crypto amount positive, fiat amount negative
- SELL: crypto amount negative, fiat amount positive

file_row Hash Formula (Legacy Compatibility):
---------------------------------------------
The file_row field is computed as an MD5 hash for deduplication against
the legacy TraderSync system. The formula is:

    file_row = MD5(json.dumps(pre_hash_object))

Where pre_hash_object contains:
    1. All raw API fields in original key order (id, datetime, type, fee,
       [currency columns], order_id)
    2. Pre-hash fields appended:
       - created_at = datetime value
       - created_at_formated = datetime value
       - symbol = constructed symbol (e.g., "BTCUSD")

Note: The hash uses default json.dumps separators (", " and ": "), NOT compact.

Post-hash fields NOT included (added by legacy after hash):
    date_tz, broker, action, type_option, price, archive

Match Rate: 100% (94/94 records for user 49186).
"""

import polars as pl
import json
import hashlib
from typing import ClassVar, Set, List, Dict, Any, Optional, Tuple
import logging

from pipeline.p01_normalize.base import BaseInterpreter

logger = logging.getLogger(__name__)

# Known crypto currencies on Bitstamp
CRYPTO_CURRENCIES = {
    "btc", "ltc", "eth", "xrp", "xlm", "usdc", "usdt", "shib", "doge",
    "bch", "link", "aave", "uni", "comp", "mkr", "yfi", "snx", "crv",
    "bat", "algo", "audio", "grt", "knc", "sand", "storj", "sushi",
    "uma", "zrx", "enj", "mana", "axs", "matic", "ftm", "chz", "perp",
    "1inch", "alpha", "omg", "skl", "cel", "ren", "lrc", "eur", "gbp",
}

# Known fiat currencies
FIAT_CURRENCIES = {"usd", "eur", "gbp"}

# ⚠️ Fee estimation factor for legacy compatibility
# When 'fee' field is missing in old exports, estimate from total using this factor
# Factor: 0.507038% (0.00507038 as decimal)
# WARNING: Origin unknown - may be outdated if Bitstamp commission rates changed
FEE_ESTIMATION_FACTOR = 0.00507038  # 0.507038% - Legacy Bitstamp estimation


class BitstampInterpreter(BaseInterpreter):
    """
    Interpreter for Bitstamp API JSON format.

    Handles JSON data from Bitstamp API with transaction arrays.
    Filters for type="2" (trades) only.
    """

    BROKER_ID: ClassVar[str] = "bitstamp"
    FORMAT_VERSION: ClassVar[str] = "1.0"
    SUPPORTED_ASSETS: ClassVar[Set[str]] = {"SPOT"}

    @classmethod
    def can_handle(cls, df: pl.DataFrame, metadata: dict) -> bool:
        """
        Check if this interpreter can handle the data.

        For Bitstamp, we check for specific columns from flattened JSON.
        """
        required = {"id", "datetime", "type", "fee", "order_id"}
        return required.issubset(set(df.columns))

    @classmethod
    def get_priority(cls) -> int:
        """Higher priority for Bitstamp format."""
        return 100

    @classmethod
    def _detect_trade_pair(cls, transaction: Dict[str, Any]) -> Optional[Tuple[str, str, float, float, str]]:
        """
        Detect the trading pair from a Bitstamp transaction.

        Returns:
            Tuple of (base_currency, quote_currency, quantity, price, side) or None if not a trade
        """
        # Find non-zero crypto/currency columns
        non_zero_currencies = []
        for key, value in transaction.items():
            if key in CRYPTO_CURRENCIES or key in FIAT_CURRENCIES:
                try:
                    amount = float(value) if value else 0.0
                    if amount != 0.0:
                        non_zero_currencies.append((key, amount))
                except (ValueError, TypeError):
                    continue

        if len(non_zero_currencies) < 2:
            return None

        # Priority order for quote currencies (most likely to be the quote)
        # Pure fiat comes first, then stablecoins, then major cryptos used as quote
        QUOTE_PRIORITY = {"usd": 1, "eur": 2, "gbp": 3, "usdt": 4, "usdc": 5, "btc": 6, "eth": 7}

        # Sort currencies: non-quote first, then by quote priority
        def sort_key(item):
            currency, _ = item
            if currency in QUOTE_PRIORITY:
                return (1, QUOTE_PRIORITY[currency])  # Quotes sorted by priority
            return (0, currency)  # Non-quotes first, alphabetically

        sorted_currencies = sorted(non_zero_currencies, key=sort_key)

        # First non-quote is the base, first quote is the quote
        base_currency = None
        quote_currency = None
        base_amount = 0.0

        for currency, amount in sorted_currencies:
            if currency not in QUOTE_PRIORITY:
                if base_currency is None:
                    base_currency = currency
                    base_amount = amount
            else:
                if quote_currency is None:
                    quote_currency = currency

        # Handle stablecoin pairs (e.g., USDC/USD, USDT/USD)
        # If we only found quotes, the first is the base, second is the quote
        if base_currency is None and len(non_zero_currencies) >= 2:
            # Sort quotes by priority - lower priority (higher number) is base
            quote_only = [(c, a) for c, a in non_zero_currencies if c in QUOTE_PRIORITY]
            quote_only.sort(key=lambda x: QUOTE_PRIORITY.get(x[0], 99))
            if len(quote_only) >= 2:
                # Higher priority (lower number) = quote, lower priority = base
                quote_currency = quote_only[0][0]
                base_currency = quote_only[1][0]
                base_amount = quote_only[1][1]

        if base_currency is None or quote_currency is None:
            return None

        # Determine side from base currency amount
        # Positive base = BUY (receiving crypto)
        # Negative base = SELL (selling crypto)
        side = "BUY" if base_amount > 0 else "SELL"
        quantity = abs(base_amount)

        # ❌ VALIDATION 4 (in _detect_trade_pair): Quantity must be > 0
        # Reject trades with zero quantity
        if quantity == 0:
            return None

        # Find price from rate column (e.g., btc_usd, eth_btc)
        rate_key = f"{base_currency}_{quote_currency}"
        price = float(transaction.get(rate_key, 0) or 0)

        # If no direct rate, try reverse pair
        if price == 0:
            reverse_rate_key = f"{quote_currency}_{base_currency}"
            reverse_price = float(transaction.get(reverse_rate_key, 0) or 0)
            if reverse_price > 0:
                price = 1.0 / reverse_price

        return (base_currency, quote_currency, quantity, price, side)

    @classmethod
    def parse_json_content(cls, json_content: str) -> pl.DataFrame:
        """
        Parse Bitstamp JSON content to DataFrame.

        Args:
            json_content: Raw JSON string (array of transactions)

        Returns:
            DataFrame with transaction data (only type=2 trades)
        """
        data = json.loads(json_content)

        # Data is array of transactions
        if isinstance(data, dict):
            transactions = data.get("transactions", data.get("data", []))
        elif isinstance(data, list):
            transactions = data
        else:
            logger.warning("Unexpected Bitstamp data format")
            return pl.DataFrame()

        if not transactions:
            logger.warning("No transactions found in JSON")
            return pl.DataFrame()

        # Track row_index to preserve original file order for position calculation
        records = []
        for row_idx, txn in enumerate(transactions):
            # Filter for trade transactions only (type=2)
            txn_type = str(txn.get("type", ""))
            if txn_type != "2":
                continue

            # ❌ VALIDATION 1: Status = 'FILLED' check (if field exists)
            # Only process completed/filled transactions
            if "status" in txn and txn.get("status") != "FILLED":
                continue

            # Detect trade pair and details
            trade_info = cls._detect_trade_pair(txn)
            if trade_info is None:
                continue

            base_currency, quote_currency, quantity, price, side = trade_info

            # ❌ VALIDATION 2: Price > 0 check
            # Reject trades with invalid price
            if price <= 0:
                logger.warning(f"Invalid price {price} for txn {txn.get('id')}, skipping")
                continue

            # ❌ VALIDATION 4: Quantity > 0 check
            # Reject trades with zero quantity
            if quantity <= 0:
                logger.warning(f"Invalid quantity {quantity} for txn {txn.get('id')}, skipping")
                continue

            # Build symbol (e.g., BTCUSD, LTCUSD)
            symbol = f"{base_currency}{quote_currency}".upper()

            # Get transaction details
            txn_id = str(txn.get("id", ""))
            order_id = str(txn.get("order_id", ""))
            datetime_str = str(txn.get("datetime", ""))

            # ❌ VALIDATION 3: Datetime not empty check
            # Reject trades without valid timestamp
            if not datetime_str or datetime_str == "":
                logger.warning(f"Missing datetime for txn {txn_id}, skipping")
                continue

            fee = float(txn.get("fee", 0) or 0)

            # ⚠️ VALIDATION 7 (HIGH): Fee estimation fallback for legacy exports
            # If fee is 0 and total field exists, estimate fee from total
            # Uses legacy factor 0.507038% (origin unknown - may need verification)
            if fee == 0 and "total" in txn:
                try:
                    total = abs(float(str(txn["total"]).replace('-', '')))
                    fee = total * FEE_ESTIMATION_FACTOR
                    logger.debug(f"Estimated fee {fee} from total {total} for txn {txn_id}")
                except (ValueError, TypeError, KeyError):
                    # If estimation fails, keep fee as 0
                    pass

            # Build pre-hash object for file_row computation
            # Legacy formula: MD5(json.dumps(pre_hash_object))
            # Pre-hash = raw API fields + created_at + created_at_formated + symbol
            pre_hash = dict(txn)  # Keep original key order
            pre_hash["created_at"] = datetime_str
            pre_hash["created_at_formated"] = datetime_str
            pre_hash["symbol"] = symbol

            # Compute file_row hash using default json.dumps separators
            file_row_hash = hashlib.md5(json.dumps(pre_hash).encode('utf-8')).hexdigest()

            # Build full original_file_row with all legacy fields
            # Order: raw API fields, then pre-hash fields, then post-hash fields
            full_ofr = dict(pre_hash)  # Start with pre-hash (raw + created_at/formated/symbol)
            # Add post-hash fields (except archive which is session-specific)
            full_ofr["date_tz"] = datetime_str  # Same format as datetime in legacy
            full_ofr["broker"] = "bitstamp"
            full_ofr["action"] = side
            full_ofr["type_option"] = "CRYPTO"
            full_ofr["price"] = price  # Numeric, not string

            # Determine quote currency for currency field
            currency = quote_currency.upper()
            if currency in {"USDC", "USDT"}:
                currency = "USD"  # Treat stablecoins as USD

            record = {
                "txn_id": txn_id,
                "order_id": order_id,
                "datetime": datetime_str,
                "symbol": symbol,
                "side": side,
                "quantity": quantity,
                "price": price,
                "fee": fee,
                "currency": currency,
                "base_currency": base_currency.upper(),
                "quote_currency": quote_currency.upper(),
                # Store full original_file_row with all legacy fields
                "_original_txn": json.dumps(full_ofr),
                # Store computed file_row hash
                "_file_row_hash": file_row_hash,
                # Row index for position calculation ordering (legacy compatibility)
                "_row_index": row_idx,
            }
            records.append(record)

        if not records:
            logger.warning("No trade transactions found after filtering")
            return pl.DataFrame()

        # =======================================================================
        # CRITICAL: Sort by (timestamp ASC, id DESC) for legacy compatibility
        # =======================================================================
        # Legacy system processes executions in a specific order:
        # 1. Sort by timestamp ascending (earlier trades first)
        # 2. Within same timestamp, sort by transaction id descending (higher id first)
        #
        # This is CRITICAL for position calculation because position is a running sum.
        # Different order = different intermediate positions = mismatch with legacy.
        #
        # Example at 2024-02-21 07:07:32:
        # - Legacy order: id=321107219, 321107218, 321107217, 321107216 (id DESC)
        # - Without this fix: id=321107216, 321107217, 321107218, 321107219 (id ASC - wrong!)
        # =======================================================================
        df = pl.DataFrame(records)

        # Parse datetime and id for sorting
        df = df.with_columns([
            pl.col("datetime").str.to_datetime("%Y-%m-%d %H:%M:%S%.f").alias("_datetime"),
            pl.col("txn_id").cast(pl.Int64, strict=False).alias("_txn_id_int"),
        ])

        # Sort by: datetime ASC, txn_id DESC
        df = df.sort(
            ["_datetime", "_txn_id_int"],
            descending=[False, True]  # datetime ASC, id DESC
        )

        # Re-assign row_index after sorting
        df = df.with_columns([
            pl.arange(0, df.height).alias("_row_index")
        ])

        # Drop temporary columns
        df = df.drop(["_datetime", "_txn_id_int"])

        return df

    def _build_file_row_expr_json(self) -> pl.Expr:
        """
        Build original_file_row from the stored original transaction JSON.

        The original transaction is already stored as JSON string during parsing.
        """
        return pl.col("_original_txn")

    def normalize(self, df: pl.LazyFrame, user_id: int, account_id: str = "") -> pl.LazyFrame:
        """
        Transform Bitstamp data to normalized schema for grouping.

        Args:
            df: Input data as LazyFrame (from parse_json_content)
            user_id: TraderSync user ID
            account_id: Account ID from input metadata

        Returns:
            Normalized LazyFrame matching grouping.py expected schema
        """
        return (
            df
            # Build original_file_row
            .with_columns([
                self._build_file_row_expr_json().alias("original_file_row")
            ])
            # Apply transformations
            .with_columns([
                # user_id
                pl.lit(user_id).alias("user_id"),

                # account_id - from input metadata
                pl.lit(account_id).alias("account_id"),

                # execution_id - use txn_id (unique per transaction)
                pl.col("txn_id").alias("execution_id"),

                # symbol - already constructed in parse_json_content
                pl.col("symbol").alias("symbol"),

                # side - already normalized in parse_json_content
                pl.col("side").alias("side"),

                # quantity - already extracted in parse_json_content
                pl.col("quantity").alias("quantity"),

                # price - already extracted in parse_json_content
                pl.col("price").alias("price"),

                # timestamp - parse datetime string and convert UTC to America/New_York
                # Format: "YYYY-MM-DD HH:MM:SS.ffffff"
                # Bitstamp API returns UTC times, legacy system stores in user's local timezone (EST)
                pl.col("datetime")
                .str.to_datetime("%Y-%m-%d %H:%M:%S%.f")
                .dt.replace_time_zone("UTC")
                .dt.convert_time_zone("America/New_York")
                .dt.replace_time_zone(None)  # Make naive for storage
                .alias("timestamp"),

                # commission - legacy stores 0 for Bitstamp
                pl.lit(0.0).alias("commission"),

                # fees - legacy stores 0 for Bitstamp (fee is in raw data but not stored)
                pl.lit(0.0).alias("fees"),

                # swap - not applicable
                pl.lit(0.0).alias("swap"),

                # currency - from record
                pl.col("currency").alias("currency"),

                # asset - crypto for all Bitstamp
                pl.lit("crypto").alias("asset"),

                # option_strike - not applicable for crypto
                pl.lit(None).cast(pl.Float64).alias("option_strike"),

                # option_expire - not applicable
                pl.lit(None).alias("option_expire"),

                # multiplier - 1 for crypto
                pl.lit(1.0).alias("multiplier"),

                # pip_value
                pl.lit(1.0).alias("pip_value"),

                # file_row - MD5 hash of txn_id for deduplication
                pl.col("_file_row_hash").alias("file_row"),

                # row_index - original row order for position calculation
                pl.col("_row_index").alias("row_index"),
            ])
            # ⚠️ VALIDATION 6 (HIGH): Decimal precision rounding
            # Round price and quantity to max 8 decimals for crypto
            .with_columns([
                pl.col("price").round(8).alias("price"),
                pl.col("quantity").round(8).alias("quantity"),
            ])
            # ❌ VALIDATION 5 (CRITICAL): Side must be BUY or SELL
            # Explicit validation even though _detect_trade_pair ensures this
            .filter(pl.col("side").is_in(["BUY", "SELL"]))
            # Select final columns in correct order (19 columns per schema)
            .select([
                "user_id",
                "account_id",
                "execution_id",
                "symbol",
                "side",
                "quantity",
                "price",
                "timestamp",
                "commission",
                "fees",
                "swap",
                "currency",
                "asset",
                "option_strike",
                "option_expire",
                "multiplier",
                "pip_value",
                "original_file_row",
                "file_row",
                "row_index",
            ])
        )
