from __future__ import annotations
import json
import os
from copy import deepcopy
from dataclasses import asdict, dataclass
from typing import Dict, List, Literal, Optional, Tuple

import numpy as np
from tqdm import tqdm

from .algorithms.gp_baselines import GPConfig, GPGreedyUCB
from .algorithms.linucb import LinUCBConfig, PerUserLinUCB, PooledLinUCB
from .algorithms.graphucb import GraphUCB, GraphUCBConfig
from .algorithms.lk_gp_ucb import LKGPUCB, LKGPUCBConfig
from .algorithms.lk_gp_ts import LKGPTS, LKGPTSConfig
from .algorithms.coop_kernel_ucb import CoopKernelUCB, CoopKernelUCBConfig
from .contexts import sample_context_pool
from .graph import er_graph, rbf_graph_from_latents, sbm_graph, spectral_ratio
from .regimes import (
    regime1_linear_gob,
    regime2A_gp_draw,
    regime2B_representer_draw,
    regime2_kernel_prepare,
)
from .utils import set_all_seeds


@dataclass
class BenchmarkConfig:
    exp_name: str = "benchmark"
    seed: int = 0

    regime: Literal["linear_gob", "kernel_optionA", "kernel_optionB"] = "linear_gob"
    n_users: int = 20
    d: int = 5
    M: int = 20
    T: int = 1000
    m: int = 5

    graph_type: Literal["ER", "RBF", "SBM"] = "ER"

    # ER params
    p: float = 0.2

    # RBF params
    rho_graph: float = 1.0
    q_latent: int = 4
    threshold: float = 0.0

    # SBM params (used only when graph_type == "SBM")
    sbm_n_blocks: int = 4
    sbm_p_in: float = 0.4
    sbm_p_out: float = 0.05
    sbm_weight: float = 1.0

    eta: float = 1.0

    base_kernel: Literal["SE", "Matern52"] = "SE"
    lengthscale: float = 1.0
    rho_lap: float = 0.1

    sigma: float = 0.1

    algos: Optional[List[str]] = None

    beta: float = 1.0
    nu: float = 1.0
    alpha_lin: float = 1.0
    lambda_lin: float = 1e-2
    lambda_gp: float = 1e-3  # interpreted as λ_base for GP-style algos

    R: int = 20

    algo_params: dict | None = None


def build_graph(cfg: BenchmarkConfig, seed: int) -> np.ndarray:
    if cfg.graph_type == "ER":
        return er_graph(cfg.n_users, cfg.p, weight=1.0, seed=seed)

    if cfg.graph_type == "RBF":
        return rbf_graph_from_latents(
            cfg.n_users,
            cfg.q_latent,
            rho=cfg.rho_graph,
            threshold=cfg.threshold,
            seed=seed,
        )

    if cfg.graph_type == "SBM":
        return sbm_graph(
            cfg.n_users,
            cfg.sbm_n_blocks,
            cfg.sbm_p_in,
            cfg.sbm_p_out,
            weight=cfg.sbm_weight,
            seed=seed,
        )

    raise ValueError(f"Unknown graph_type: {cfg.graph_type}")



def make_outdir(exp_name: str) -> str:
    outdir = os.path.join("out", exp_name)
    os.makedirs(outdir, exist_ok=True)
    return outdir


def instantiate_algo(
    name: str,
    cfg: BenchmarkConfig,
    X: np.ndarray,
    W: np.ndarray,
    BK0: Optional[np.ndarray],
    run_seed: int,
):
    overrides = (cfg.algo_params or {}).get(name, {})

    if name == "LK-GP-UCB":
        return LKGPUCB(
            cfg.n_users,
            X,
            W,
            LKGPUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                rho=cfg.rho_lap,
            ),
            BK0=BK0,
        )

    if name == "LK-GP-TS":
        return LKGPTS(
            cfg.n_users,
            X,
            W,
            LKGPTSConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                nu=overrides.get("nu", cfg.nu),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                rho=cfg.rho_lap,
            ),
            BK0=BK0,
            seed=run_seed + 99,
        )

    if name in ("GraphUCB", "GOB.Lin"):
        rho = overrides.get("rho_lap", 1.0 if name == "GOB.Lin" else cfg.rho_lap)
        return GraphUCB(
            X,
            W,
            GraphUCBConfig(
                alpha=overrides.get("alpha_lin", cfg.alpha_lin),
                lambda_reg=cfg.lambda_lin,
                rho_lap=rho,
            ),
        )

    if name == "GP-UCB":
        return GPGreedyUCB(
            X,
            GPConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
            ),
            beta=overrides.get("beta", cfg.beta),
        )

    if name == "Pooled-LinUCB":
        return PooledLinUCB(
            cfg.d,
            LinUCBConfig(
                alpha=overrides.get("alpha_lin", cfg.alpha_lin),
                lambda_reg=cfg.lambda_lin,
            ),
        )

    if name == "PerUser-LinUCB":
        return PerUserLinUCB(
            cfg.n_users,
            cfg.d,
            LinUCBConfig(
                alpha=overrides.get("alpha_lin", cfg.alpha_lin),
                lambda_reg=cfg.lambda_lin,
            ),
        )

    # --- Coop-KernelUCB main entry (pluggable) ---
    if name == "Coop-KernelUCB":
        agent_kernel = overrides.get("agent_kernel", "learned_mmd")
        return CoopKernelUCB(
            cfg.n_users,
            X,
            W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel=agent_kernel,
                # laplacian_inv
                rho=overrides.get("rho", cfg.rho_lap),
                # heat
                tau=overrides.get("tau", 1.0),
                # spectral_rbf
                spec_k=overrides.get("spec_k", 8),
                spec_sigma=overrides.get("spec_sigma", "median"),
                # learned_mmd
                mmd_mode=overrides.get("mmd_mode", "rff"),
                rff_dim=overrides.get("rff_dim", 10),
                mmd_sigma=overrides.get("mmd_sigma", "median"),
                update_every=overrides.get("update_every", 200),
                min_count=overrides.get("min_count", 5),
            ),
        )

    # --- Coop-KernelUCB aliases so variants can coexist in one run ---
    if name == "Coop-KernelUCB (laplacian_inv)":
        return CoopKernelUCB(
            cfg.n_users, X, W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel="laplacian_inv",
                rho=overrides.get("rho", cfg.rho_lap),
            ),
        )

    if name == "Coop-KernelUCB (learned_mmd)":
        return CoopKernelUCB(
            cfg.n_users, X, W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel="learned_mmd",
                mmd_mode=overrides.get("mmd_mode", "rff"),
                rff_dim=overrides.get("rff_dim", 10),
                mmd_sigma=overrides.get("mmd_sigma", "median"),
                update_every=overrides.get("update_every", 200),
                min_count=overrides.get("min_count", 5),
            ),
        )

    if name == "Coop-KernelUCB (spectral_rbf)":
        return CoopKernelUCB(
            cfg.n_users, X, W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel="spectral_rbf",
                spec_k=overrides.get("spec_k", 8),
                spec_sigma=overrides.get("spec_sigma", "median"),
            ),
        )

    if name == "Coop-KernelUCB (heat)":
        return CoopKernelUCB(
            cfg.n_users, X, W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel="heat",
                tau=overrides.get("tau", 1.0),
            ),
        )

    if name == "Coop-KernelUCB (all ones)":
        return CoopKernelUCB(
            cfg.n_users, X, W,
            CoopKernelUCBConfig(
                lambda_noise=overrides.get("lambda_gp", cfg.lambda_gp),
                beta=overrides.get("beta", cfg.beta),
                base_kernel=cfg.base_kernel,
                lengthscale=cfg.lengthscale,
                agent_kernel="all ones",
            ),
        )

    raise ValueError(f"Unknown algo: {name}")


def _sanitize_key(name: str) -> str:
    return (
        name.lower()
        .replace("-", "_")
        .replace(" ", "_")
        .replace("(", "")
        .replace(")", "")
        .replace("{", "")
        .replace("}", "")
        .replace("^", "")
    )


# -------- λ scheduling utilities --------
def _doubling_schedule(T: int, t0_hint: int = 200) -> List[int]:
    """Epoch boundaries: t0, 2*t0, 4*t0, ... <= T-1."""
    t0 = max(50, min(t0_hint, max(1, T // 10)))
    bounds = []
    k = t0
    while k < T:
        bounds.append(k)
        k *= 2
    return bounds


def _compute_lambda(
    base: float,
    S_spec: float,
    T: int,
    t: int,
    lam_min: float = 1e-6,
    lam_max: float = 1e-1,
) -> float:
    """λ_t = base * S_spec * (T / (T + t)), clipped."""
    lam = float(base) * float(S_spec) * (float(T) / float(T + max(0, t)))
    return float(min(lam_max, max(lam_min, lam)))


def _cfg_with_lambda(cfg: BenchmarkConfig, lam_val: float) -> BenchmarkConfig:
    cfg2 = deepcopy(cfg)
    cfg2.lambda_gp = float(lam_val)
    ap = deepcopy(cfg2.algo_params or {})
    for nm in [
        "LK-GP-UCB",
        "LK-GP-TS",
        "Coop-KernelUCB",
        "Coop-KernelUCB (laplacian_inv)",
        "Coop-KernelUCB (learned_mmd)",
        "Coop-KernelUCB (spectral_rbf)",
        "Coop-KernelUCB (heat)",
        "Coop-KernelUCB (all ones)",
        "GP-UCB",
    ]:
        ap[nm] = {**ap.get(nm, {}), "lambda_gp": float(lam_val)}
    cfg2.algo_params = ap
    return cfg2


def _rebuild_algo(
    name: str,
    lam_val: float,
    cfg: BenchmarkConfig,
    X: np.ndarray,
    W: np.ndarray,
    BK0: Optional[np.ndarray],
    run_seed: int,
    buffer: List[Dict[str, object]],
):
    """Recreate algo with λ=lam_val and replay its own history."""
    cfg2 = _cfg_with_lambda(cfg, lam_val)
    obj = instantiate_algo(name, cfg2, X, W, BK0, run_seed)
    for rec in buffer:
        u = rec["u"]
        m = rec["m"]
        y = rec["y"]
        x = rec["x"]
        if name == "Pooled-LinUCB":
            obj.update(x, y)
        elif name == "PerUser-LinUCB":
            obj.update(u, x, y)
        elif name in (
            "LK-GP-UCB", "LK-GP-TS",
            "Coop-KernelUCB",
            "Coop-KernelUCB (laplacian_inv)",
            "Coop-KernelUCB (learned_mmd)",
            "Coop-KernelUCB (spectral_rbf)",
            "Coop-KernelUCB (heat)",
            "Coop-KernelUCB (all ones)",
        ):
            obj.update(u, m, y)
        elif name == "GP-UCB":
            obj.update(m, y)
        elif name in ("GraphUCB", "GOB.Lin"):
            obj.update(u, m, y)
        else:
            raise ValueError(f"Unknown algo during replay: {name}")
    return obj


# --------------------------
# Core: one environment run
# --------------------------
def run_one_environment(cfg: BenchmarkConfig, env_seed: int) -> Tuple[Dict[str, list], List[Dict[str, float]]]:
    # Graph + contexts
    W = build_graph(cfg, env_seed)
    X = sample_context_pool(cfg.M, cfg.d, seed=env_seed + 19)

    # Product kernels (for LK/Coop/GP baselines)
    K_user, K_arm = regime2_kernel_prepare(
        cfg.n_users, cfg.M, X, W, rho=cfg.rho_lap, base_kernel=cfg.base_kernel, lengthscale=cfg.lengthscale
    )
    BK0 = np.kron(K_user, K_arm)

    # Reward regimes
    if cfg.regime == "linear_gob":
        Xpool, Theta, f_eval = regime1_linear_gob(cfg.n_users, cfg.d, cfg.M, cfg.eta, W, seed=env_seed + 23)

        def f_on(u, m):
            return f_eval(u, Xpool[m])

        sigma = cfg.sigma
    else:
        if cfg.regime == "kernel_optionA":
            F = regime2A_gp_draw(cfg.n_users, cfg.M, K_user, K_arm, seed=env_seed + 37)
        else:
            F = regime2B_representer_draw(cfg.n_users, cfg.M, K_user, K_arm, tau=1.0, seed=env_seed + 41)

        def f_on(u, m):
            return float(F[u, m])

        frange = float(F.max() - F.min())
        sigma = max(1e-6, 0.01 * (frange if frange > 0 else 1.0))

    # Sequences
    rng = np.random.default_rng(env_seed + 777)
    u_seq = rng.integers(low=0, high=cfg.n_users, size=cfg.T)
    cand_seq = [rng.choice(cfg.M, size=cfg.m, replace=False) for _ in range(cfg.T)]
    eps_seq = rng.normal(0.0, sigma, size=cfg.T)

    # Algorithms to run
    algos = cfg.algos or [
        "LK-GP-UCB",
        "LK-GP-TS",
        "Coop-KernelUCB",
        "Coop-KernelUCB (heat)",
        "GOB.Lin",
        "GraphUCB",
        "GP-UCB",
        "Pooled-LinUCB",
        "PerUser-LinUCB",
    ]

    L = np.diag(W.sum(axis=1)) - W
    S_spec = spectral_ratio(L)
    if S_spec <= 0.0:
        S_spec = 1e-3

    T = cfg.T
    lambda_base = float(cfg.lambda_gp)
    lam0 = _compute_lambda(lambda_base, S_spec, T, t=0)
    curr_lambda = lam0
    lambda_trace = [{"t": 0, "lambda": float(lam0), "S_spec": float(S_spec), "lambda_base": float(lambda_base)}]
    epoch_bounds = _doubling_schedule(T, t0_hint=200)
    lambda_sensitive = {
        "LK-GP-UCB",
        "LK-GP-TS",
        "Coop-KernelUCB",
        "Coop-KernelUCB (laplacian_inv)",
        "Coop-KernelUCB (learned_mmd)",
        "Coop-KernelUCB (spectral_rbf)",
        "Coop-KernelUCB (heat)",
        "Coop-KernelUCB (all ones)",
        "GP-UCB",
    }

    cfg0 = _cfg_with_lambda(cfg, lam0)
    algo_objs = {}
    for name in algos:
        cfg_for_name = cfg0 if name in lambda_sensitive else cfg
        algo_objs[name] = instantiate_algo(name, cfg_for_name, X, W, BK0, env_seed)

    regrets = {name: [] for name in algos}
    # Per-algo replay buffer for rebuilds
    algo_buffers: Dict[str, List[Dict[str, object]]] = {name: [] for name in algos}

    next_epoch_idx = 0

    # -------- main interaction loop --------
    for t in range(T):
        if next_epoch_idx < len(epoch_bounds) and t == epoch_bounds[next_epoch_idx]:
            new_lambda = _compute_lambda(lambda_base, S_spec, T, t)
            lambda_trace.append(
                {"t": int(t), "lambda": float(new_lambda), "S_spec": float(S_spec), "lambda_base": float(lambda_base)}
            )
            if abs(new_lambda / curr_lambda - 1.0) >= 0.20:
                for name in algos:
                    if name in lambda_sensitive:
                        algo_objs[name] = _rebuild_algo(
                            name, new_lambda, cfg, X, W, BK0, env_seed, algo_buffers[name]
                        )
                curr_lambda = new_lambda
            next_epoch_idx += 1

        # Round t
        u_t = int(u_seq[t])
        cands = cand_seq[t]
        vals = [f_on(u_t, m) for m in cands]
        best_val = max(vals)

        for name, obj in algo_objs.items():
            if name == "Pooled-LinUCB":
                j = obj.select(np.asarray([X[m] for m in cands]))
                m_sel = int(cands[j])
                y = f_on(u_t, m_sel) + eps_seq[t]
                obj.update(X[m_sel], y)
                algo_buffers[name].append({"u": u_t, "m": m_sel, "y": y, "x": X[m_sel]})

            elif name == "PerUser-LinUCB":
                j = obj.select(u_t, np.asarray([X[m] for m in cands]))
                m_sel = int(cands[j])
                y = f_on(u_t, m_sel) + eps_seq[t]
                obj.update(u_t, X[m_sel], y)
                algo_buffers[name].append({"u": u_t, "m": m_sel, "y": y, "x": X[m_sel]})

            elif name in (
                "LK-GP-UCB", "LK-GP-TS",
                "Coop-KernelUCB",
                "Coop-KernelUCB (laplacian_inv)",
                "Coop-KernelUCB (learned_mmd)",
                "Coop-KernelUCB (spectral_rbf)",
                "Coop-KernelUCB (heat)",
                "Coop-KernelUCB (all ones)",
            ):
                m_sel = obj.select(t + 1, u_t, cands.tolist())
                y = f_on(u_t, m_sel) + eps_seq[t]
                obj.update(u_t, m_sel, y)
                algo_buffers[name].append({"u": u_t, "m": m_sel, "y": y, "x": X[m_sel]})

            elif name == "GP-UCB":
                m_sel = obj.select(cands.tolist())
                y = f_on(u_t, m_sel) + eps_seq[t]
                obj.update(m_sel, y)
                algo_buffers[name].append({"u": u_t, "m": m_sel, "y": y, "x": X[m_sel]})

            elif name in ("GraphUCB", "GOB.Lin"):
                m_sel = obj.select(u_t, cands.tolist())
                y = f_on(u_t, m_sel) + eps_seq[t]
                obj.update(u_t, m_sel, y)
                algo_buffers[name].append({"u": u_t, "m": m_sel, "y": y, "x": X[m_sel]})

            else:
                raise ValueError(name)

            regrets[name].append(best_val - f_on(u_t, m_sel))

    return regrets, lambda_trace


def run_benchmark(cfg: BenchmarkConfig):
    outdir = make_outdir(cfg.exp_name)
    algos = cfg.algos or [
        "LK-GP-UCB",
        "LK-GP-TS",
        "Coop-KernelUCB",
        "Coop-KernelUCB (heat)",
        "GOB.Lin",
        "GraphUCB",
        "GP-UCB",
        "Pooled-LinUCB",
        "PerUser-LinUCB",
    ]
    all_runs = {a: [] for a in algos}
    all_lambda_traces: List[List[Dict[str, float]]] = []

    print(f"[INFO] Starting benchmark '{cfg.exp_name}' with R={cfg.R}, T={cfg.T}, algos={algos}")
    for r in tqdm(range(cfg.R), desc="Repetitions", ncols=80):
        regrets, lam_trace = run_one_environment(cfg, env_seed=cfg.seed + 1000 * r)
        all_lambda_traces.append(lam_trace)
        for a in algos:
            all_runs[a].append(np.cumsum(regrets[a]))

        with open(os.path.join(outdir, f"lambda_trace_rep{r}.json"), "w") as f:
            json.dump(lam_trace, f, indent=2)

    keymap = {a: _sanitize_key(a) for a in algos}
    save_dict = {keymap[a]: np.stack(all_runs[a], axis=0) for a in algos}
    np.savez_compressed(os.path.join(outdir, "runs.npz"), **save_dict)

    with open(os.path.join(outdir, "algo_keymap.json"), "w") as f:
        json.dump(keymap, f, indent=2)

    meta = {
        "exp_name": cfg.exp_name,
        "R": cfg.R,
        "T": cfg.T,
        "algos": algos,
        "lambda_base": float(cfg.lambda_gp),
        "note": "lambda_t = lambda_base * S_spec * T/(T+t); scheduled with epochic rebuilds for GP-style algos",
    }
    with open(os.path.join(outdir, "metadata.json"), "w") as f:
        json.dump(meta, f, indent=2)

    with open(os.path.join(outdir, "config.json"), "w") as f:
        json.dump(asdict(cfg), f, indent=2)

    print(f"[INFO] Benchmark finished. Data saved in {outdir}")
    print(f"[INFO] Next: python -m lkb.scripts.plot_results --exp_dir {outdir}")
    return outdir
