from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Dict, Literal, Optional, Tuple

import numpy as np

from ..metrics.mmd import MMDConfig, mmd2_rbf
from ..metrics.sinkhorn import sinkhorn_divergence
from ..metrics.grads import gradient_quantile_proxy


@dataclass
class DiagnosticConfig:
    use_sinkhorn: bool = True
    use_mmd: bool = False
    epsilon: float = 0.05
    iters: int = 300
    q: float = 0.90
    L_ell: float = 2.0  # CE + l2
    use_features: bool = False
    feature_scale: float = 1.0
    sym_model_change: bool = False
    device: str = "cpu"


def build_features(X: np.ndarray, model=None, use_features: bool = False) -> np.ndarray:
    if not use_features or model is None:
        return X
    # Torch MLP with embedding method
    import torch

    model.eval()
    with torch.no_grad():
        X_t = torch.tensor(X, dtype=torch.float64, device=next(model.parameters()).device)
        h = model.embedding(X_t).cpu().numpy()
    return h


def empirical_shift(
    Xs: np.ndarray,
    Xt: np.ndarray,
    model_Q,
    model_Qt,
    q: float,
    use_features: bool,
    feature_scale: float,
    sinkhorn_eps: float,
    sinkhorn_iters: int,
    device: str = "cpu",
) -> float:
    # Compute Lx(q) for both models using held-out data (source validation)
    # Here we approximate using Xs and dummy labels if needed; callers should pass appropriate splits
    n_s = Xs.shape[0]
    # Heuristics for labels: for regression we need y; for classifier, labels must be provided by caller.
    # In this benchmark, we pass y alongside when calling gradient_quantile_proxy elsewhere for exactness.
    # For EmpShift scaling, we approximate with features only, so set Lx from models with zero labels via torch CE not possible.
    # Instead, callers should pass precomputed Lx and we just multiply by transport. For simplicity, compute here requires y.
    raise NotImplementedError("empirical_shift should be called via Diagnostic.run which supplies Lx and features.")


def model_change(
    X_eval: np.ndarray,
    model_Q,
    model_Qt,
    L_ell: float = 2.0,
    sym_source: Optional[np.ndarray] = None,
) -> float:
    # Compute average squared difference of model outputs on X_eval
    # For classifiers, compare logits; for regressors, outputs
    def predict_array(m, X):
        if hasattr(m, "predict") and not hasattr(m, "weight_decay"):
            # sklearn
            try:
                # classifiers
                proba = m.predict_proba(X)
                logits = np.log(proba + 1e-12)
                return logits
            except Exception:
                # regressors
                return m.predict(X).reshape(-1, 1)
        else:
            import torch

            m.eval()
            with torch.no_grad():
                X_t = torch.tensor(X, dtype=torch.float64, device=next(m.parameters()).device)
                logits = m(X_t).cpu().numpy()
            return logits

    fQ = predict_array(model_Q, X_eval)
    fQt = predict_array(model_Qt, X_eval)
    diff2 = np.mean(np.sum((fQ - fQt) ** 2, axis=1))
    val = L_ell * diff2

    if sym_source is not None:
        fQ_s = predict_array(model_Q, sym_source)
        fQt_s = predict_array(model_Qt, sym_source)
        diff2_s = np.mean(np.sum((fQ_s - fQt_s) ** 2, axis=1))
        val = 0.5 * L_ell * (diff2 + diff2_s)
    return float(val)


def diagnostic_value_ot(
    hXs: np.ndarray,
    hXt: np.ndarray,
    Lx_q_Q: float,
    Lx_q_Qt: float,
    ch: float,
    epsilon: float,
    iters: int,
    device: str = "cpu",
) -> float:
    ot = sinkhorn_divergence(hXs, hXt, epsilon=epsilon, n_iters=iters, debiased=True, device=device)
    return float((Lx_q_Q + Lx_q_Qt) * ch * ot)


def diagnostic_value_mmd(
    hXs: np.ndarray,
    hXt: np.ndarray,
    Lx_q_Q: float,
    Lx_q_Qt: float,
    ch: float,
    mmd_cfg: Optional[MMDConfig] = None,
) -> float:
    # MMD consumes raw feature vectors; do not reuse OT normalization here.
    mmd2, _ = mmd2_rbf(hXs, hXt, config=mmd_cfg)
    mmd = float(np.sqrt(max(mmd2, 0.0)))
    return float((Lx_q_Q + Lx_q_Qt) * ch * mmd)


def build_Bhat_terms(
    Xs: np.ndarray,
    ys: np.ndarray,
    Xt: np.ndarray,
    yt: np.ndarray,
    model_Q,
    model_Qt,
    cfg: DiagnosticConfig,
    model_type: str,
) -> dict:
    # Gradient quantiles on validation (use source val here)
    if model_type == "ridge":
        Lx_Q = gradient_quantile_proxy(model_Q, Xs, ys, model_type="sklearn_ridge", loss="mse", q=cfg.q)
        Lx_Qt = gradient_quantile_proxy(model_Qt, Xs, ys, model_type="sklearn_ridge", loss="mse", q=cfg.q)
    elif model_type == "logistic":
        Lx_Q = gradient_quantile_proxy(model_Q, Xs, ys, model_type="sklearn_logistic", loss="ce", q=cfg.q)
        Lx_Qt = gradient_quantile_proxy(model_Qt, Xs, ys, model_type="sklearn_logistic", loss="ce", q=cfg.q)
    else:
        Lx_Q = gradient_quantile_proxy(model_Q, Xs, ys, model_type="torch", loss="ce", q=cfg.q)
        Lx_Qt = gradient_quantile_proxy(model_Qt, Xs, ys, model_type="torch", loss="ce", q=cfg.q)

    # Features
    hXs = build_features(Xs, model_Q if cfg.use_features else None, use_features=cfg.use_features)
    hXt = build_features(Xt, model_Q if cfg.use_features else None, use_features=cfg.use_features)

    terms = {}
    if cfg.use_sinkhorn:
        terms["EmpShift"] = diagnostic_value_ot(hXs, hXt, Lx_Q, Lx_Qt, ch=cfg.feature_scale, epsilon=cfg.epsilon, iters=cfg.iters)
    if cfg.use_mmd:
        terms["EmpShift_MMD"] = diagnostic_value_mmd(hXs, hXt, Lx_Q, Lx_Qt, ch=cfg.feature_scale, mmd_cfg=MMDConfig())

    terms["ModelChange"] = model_change(Xt, model_Q, model_Qt, L_ell=cfg.L_ell, sym_source=(Xs if cfg.sym_model_change else None))
    terms["G_Q_val"] = 0.0
    terms["G_Qt_val"] = 0.0
    return terms


