from __future__ import annotations

import argparse
from dataclasses import dataclass
from pathlib import Path
from typing import Dict

import numpy as np
import yaml

from .utils.seed import seed_all
from .utils.io import ensure_dir, timestamp_now, save_json, save_csv
from .metrics.risk import risk_ce_logits
from .metrics.grads import gradient_quantile_proxy
from .metrics.spearman import spearman_rho
from .metrics.mmd import MMDConfig
from .models.logistic import LogisticConfig, train_logistic
from .models.mlp import MLPConfig, train_mlp_classifier
from .data_gen.blobs_shift import BlobsShiftConfig, sample_blobs_shift
from .data_gen.moons_warp import MoonsWarpConfig, sample_moons_warp
from .eval.diagnostic import diagnostic_value_ot, diagnostic_value_mmd, model_change


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser()
    p.add_argument("--world", type=str, choices=["blobs_shift", "moons_warp"], required=True)
    p.add_argument("--config", type=str, default=None)
    p.add_argument("--grid", type=str, default=None)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--use_sinkhorn", action="store_true")
    p.add_argument("--use_mmd", action="store_true")
    p.add_argument("--use_features", action="store_true")
    p.add_argument("--sym_model_change", action="store_true")
    p.add_argument("--q", type=float, default=0.90)
    p.add_argument("--sinkhorn_eps", type=float, default=0.05)
    p.add_argument("--iters", type=int, default=300)
    p.add_argument("--device", type=str, default="cpu")
    p.add_argument("--output", type=str, default=None)
    return p.parse_args()


    


def run_blobs(cfg: dict, args: argparse.Namespace) -> Dict:
    seed_all(args.seed)
    results = []
    grid = cfg.get("grid", {})
    rotations = grid.get("rotation_deg", [0.0, 10.0, 20.0, 45.0])
    translations = grid.get("translation", [0.0, 0.25, 0.5, 1.0, 2.0])
    ns = grid.get("n", [1000, 2000, 5000])
    wds = grid.get("weight_decay", [0.0, 1e-4, 1e-3])

    for rot in rotations:
        for tr in translations:
            for n in ns:
                # Data
                ds_cfg = BlobsShiftConfig(rotation_deg=rot, translation=tr, n_source=n, n_target=n, seed=args.seed)
                Xs, ys, Xt, yt = sample_blobs_shift(ds_cfg)
                # Split val
                val_idx_s = np.arange(min(500, Xs.shape[0]))
                Xval, yval = Xs[val_idx_s], ys[val_idx_s]
                val_idx_t = np.arange(min(500, Xt.shape[0]))
                Xt_val, yt_val = Xt[val_idx_t], yt[val_idx_t]

                # Models: logistic + MLP
                logi = train_logistic(Xs, ys, LogisticConfig(max_iter=200, random_state=args.seed))
                mlp = train_mlp_classifier(Xs, ys, MLPConfig(weight_decay=wds[0], epochs=100, device="cpu"))
                logi_t = train_logistic(Xt, yt, LogisticConfig(max_iter=200, random_state=args.seed))
                mlp_t = train_mlp_classifier(Xt, yt, MLPConfig(weight_decay=wds[0], epochs=100, device="cpu"))

                def log_proba_logits(model, X):
                    proba = model.predict_proba(X)
                    return np.log(proba + 1e-12)

                # Generalization Gaps
                risk_S_Q = risk_ce_logits(log_proba_logits(logi, Xs), ys)
                risk_V_Q = risk_ce_logits(log_proba_logits(logi, Xval), yval)
                G_Q_val = abs(risk_S_Q - risk_V_Q)
                risk_St_Qt = risk_ce_logits(log_proba_logits(logi_t, Xt), yt)
                risk_Vt_Qt = risk_ce_logits(log_proba_logits(logi_t, Xt_val), yt_val)
                G_Qt_val = abs(risk_St_Qt - risk_Vt_Qt)

                # Risks on large source test: use predicted probabilities on large sample
                rng = np.random.default_rng(args.seed)
                Xtest = rng.normal(size=(100_000, 2))
                # derive pseudo labels via logistic trained on source to ensure consistency
                ytest = logi.predict(Xtest)
                # Use log-probs as logits for CE
                
                RQ = risk_ce_logits(log_proba_logits(logi, Xtest), ytest)
                RQt = risk_ce_logits(log_proba_logits(logi_t, Xtest), ytest)
                delta_R = abs(RQ - RQt)

                # Grad proxies
                Lx_Q = gradient_quantile_proxy(logi, Xval, yval, model_type="sklearn_logistic", loss="ce", q=args.q)
                Lx_Qt = gradient_quantile_proxy(logi_t, Xt_val, yt_val, model_type="sklearn_logistic", loss="ce", q=args.q)

                # Features: identity or MLP embedding
                hXs = Xs
                hXt = Xt

                term_emp_ot = diagnostic_value_ot(hXs, hXt, Lx_Q, Lx_Qt, ch=1.0, epsilon=args.sinkhorn_eps, iters=args.iters, device=args.device) if args.use_sinkhorn else 0.0
                term_emp_mmd = diagnostic_value_mmd(hXs, hXt, Lx_Q, Lx_Qt, ch=1.0, mmd_cfg=MMDConfig()) if args.use_mmd else 0.0
                term_mc = model_change(Xt, logi, logi_t, L_ell=2.0, sym_source=(Xs if args.sym_model_change else None))
                Bhat_ot = term_emp_ot + term_mc + G_Q_val + G_Qt_val
                Bhat_mmd = term_emp_mmd + term_mc + G_Q_val + G_Qt_val

                results.append({
                    "world": "blobs_shift",
                    "rot": rot,
                    "tr": tr,
                    "n": n,
                    "delta_R": float(delta_R),
                    "Bhat_ot": float(Bhat_ot),
                    "Bhat_mmd": float(Bhat_mmd),
                    "term_EmpShift_ot": float(term_emp_ot),
                    "term_EmpShift_mmd": float(term_emp_mmd),
                    "term_ModelChange": float(term_mc),
                    "term_G_Q_val": float(G_Q_val),
                    "term_G_Qt_val": float(G_Qt_val),
                })
    return {"rows": results}


def run_moons(cfg: dict, args: argparse.Namespace) -> Dict:
    seed_all(args.seed)
    results = []
    grid = cfg.get("grid", {})
    alphas = grid.get("alpha", [0.0, 0.25, 0.5, 1.0, 2.0])
    ns = grid.get("n", [1000, 2000, 5000])
    wds = grid.get("weight_decay", [0.0, 1e-4, 1e-3])

    for a in alphas:
        for n in ns:
            mw_cfg = MoonsWarpConfig(n_source=n, n_target=n, noise=cfg.get("noise", 0.1) if isinstance(cfg, dict) else 0.1, alpha=a, seed=args.seed)
            Xs, ys, Xt, yt = sample_moons_warp(mw_cfg)
            # Split val
            val_idx_s = np.arange(min(500, Xs.shape[0]))
            Xval, yval = Xs[val_idx_s], ys[val_idx_s]
            val_idx_t = np.arange(min(500, Xt.shape[0]))
            Xt_val, yt_val = Xt[val_idx_t], yt[val_idx_t]

            mlp = train_mlp_classifier(Xs, ys, MLPConfig(weight_decay=wds[0], epochs=100, device="cpu"))
            mlp_t = train_mlp_classifier(Xt, yt, MLPConfig(weight_decay=wds[0], epochs=100, device="cpu"))

            # Use logits from MLPs
            import torch

            def logits(model, X):
                model.eval()
                with torch.no_grad():
                    X_t = torch.tensor(X, dtype=torch.float64)
                    return model(X_t).cpu().numpy()

            # Generalization Gaps
            risk_S_Q = risk_ce_logits(logits(mlp, Xs), ys)
            risk_V_Q = risk_ce_logits(logits(mlp, Xval), yval)
            G_Q_val = abs(risk_S_Q - risk_V_Q)
            risk_St_Qt = risk_ce_logits(logits(mlp_t, Xt), yt)
            risk_Vt_Qt = risk_ce_logits(logits(mlp_t, Xt_val), yt_val)
            G_Qt_val = abs(risk_St_Qt - risk_Vt_Qt)

            # Risks proxy on large source test (approx via true labels of make_moons with same seed)
            from sklearn.datasets import make_moons

            Xtest, ytest = make_moons(n_samples=100_000, noise=0.1, random_state=args.seed)
            
            RQ = risk_ce_logits(logits(mlp, Xtest), ytest)
            RQt = risk_ce_logits(logits(mlp_t, Xtest), ytest)
            delta_R = abs(RQ - RQt)

            # Grad proxies via torch
            Lx_Q = gradient_quantile_proxy(mlp, Xval, yval, model_type="torch", loss="ce", q=args.q)
            Lx_Qt = gradient_quantile_proxy(mlp_t, Xt_val, yt_val, model_type="torch", loss="ce", q=args.q)

            # Features: MLP embedding
            import torch

            def embed(model, X):
                model.eval()
                with torch.no_grad():
                    X_t = torch.tensor(X, dtype=torch.float64)
                    return model.embedding(X_t).cpu().numpy()

            hXs = embed(mlp, Xs) if args.use_features else Xs
            hXt = embed(mlp, Xt) if args.use_features else Xt

            terms = {}
            if args.use_sinkhorn:
                terms["EmpShift"] = diagnostic_value_ot(hXs, hXt, Lx_Q, Lx_Qt, ch=1.0, epsilon=args.sinkhorn_eps, iters=args.iters, device=args.device)
            if args.use_mmd:
                terms["EmpShift_MMD"] = diagnostic_value_mmd(hXs, hXt, Lx_Q, Lx_Qt, ch=1.0, mmd_cfg=MMDConfig())
            terms["ModelChange"] = model_change(Xt, mlp, mlp_t, L_ell=2.0, sym_source=(Xs if args.sym_model_change else None))
            term_emp_ot = terms.get("EmpShift", 0.0)
            term_emp_mmd = terms.get("EmpShift_MMD", 0.0)
            term_mc = terms.get("ModelChange", 0.0)
            Bhat_ot = term_emp_ot + term_mc + G_Q_val + G_Qt_val
            Bhat_mmd = term_emp_mmd + term_mc + G_Q_val + G_Qt_val

            results.append({
                "world": "moons_warp",
                "alpha": a,
                "n": n,
                "delta_R": float(delta_R),
                "Bhat_ot": float(Bhat_ot),
                "Bhat_mmd": float(Bhat_mmd),
                "term_EmpShift_ot": float(term_emp_ot),
                "term_EmpShift_mmd": float(term_emp_mmd),
                "term_ModelChange": float(term_mc),
                "term_G_Q_val": float(G_Q_val),
                "term_G_Qt_val": float(G_Qt_val),
            })
    return {"rows": results}


def main():
    args = parse_args()
    seed_all(args.seed)
    cfg = {}
    if args.config is not None:
        cfg = yaml.safe_load(Path(args.config).read_text())
    if args.grid is not None:
        grid_cfg = yaml.safe_load(Path(args.grid).read_text())
        # Merge: place under cfg['grid']
        if isinstance(grid_cfg, dict):
            if 'grid' in grid_cfg:
                cfg['grid'] = grid_cfg['grid']
            else:
                cfg['grid'] = grid_cfg

    if args.world == "blobs_shift":
        out = run_blobs(cfg, args)
    else:
        out = run_moons(cfg, args)

    outdir = Path(args.output or f"trace_bench/reports/{timestamp_now()}")
    ensure_dir(outdir)
    # Save JSON and CSV
    save_json(out, outdir / "results.json")
    try:
        import pandas as pd
        df = pd.DataFrame(out["rows"]) if "rows" in out else pd.DataFrame(out)
        save_csv(df, outdir / "results.csv")
    except Exception:
        pass

    # Summary metrics (only Spearman correlations)
    D = np.array([r["delta_R"] for r in out["rows"]])
    summary: Dict[str, float] = {}

    if args.use_sinkhorn:
        B_ot = np.array([r["Bhat_ot"] for r in out["rows"]])
        summary["spearman_ot"] = spearman_rho(D, B_ot)
    if args.use_mmd:
        B_mmd = np.array([r["Bhat_mmd"] for r in out["rows"]])
        summary["spearman_mmd"] = spearman_rho(D, B_mmd)

    save_json(summary, outdir / "summary.json")
    print(f"Saved to {outdir}. Summary: {summary}")


if __name__ == "__main__":
    main()


