from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple, Literal, Dict
import numpy as np


AgentKernelKind = Literal[
    "laplacian_inv", "learned_mmd", "spectral_rbf", "heat", "all ones", "laplacian"  # "laplacian" kept for back-compat
]

@dataclass
class CoopKernelUCBConfig:
    # GP / UCB
    lambda_noise: float = 1e-3
    beta: float = 1.0

    # Arm kernel (on contexts X)
    base_kernel: Literal["SE", "Matern52"] = "SE"
    lengthscale: float = 1.0

    # Agent kernel selection
    agent_kernel: AgentKernelKind = "laplacian_inv"

    # laplacian_inv
    rho: float = 0.1  # Kz = (L + rho I)^{-1}

    # heat
    tau: float = 1.0  # Kz = exp(-tau L)

    # spectral_rbf
    spec_k: int = 8
    spec_sigma: float | Literal["median"] = "median"

    # learned_mmd (efficient, RFF-based by default)
    mmd_mode: Literal["rff", "exact"] = "rff"
    rff_dim: int = 256
    mmd_sigma: float | Literal["median"] = "median"
    update_every: int = 200 
    min_count: int = 5 

# ----------------------------------------
# Utilities (kernels, linear algebra)
# ----------------------------------------
def _pairwise_sq_dists(X: np.ndarray) -> np.ndarray:
    XX = np.sum(X * X, axis=1, keepdims=True)
    D2 = np.maximum(XX + XX.T - 2.0 * (X @ X.T), 0.0)
    return D2

def _kernel_matrix_X(X: np.ndarray, kind: str, ell: float) -> np.ndarray:
    if kind == "SE":
        D2 = _pairwise_sq_dists(X)
        return np.exp(-0.5 * D2 / (ell ** 2 + 1e-12))
    elif kind == "Matern52":
        D2 = _pairwise_sq_dists(X)
        R = np.sqrt(D2 + 1e-12)
        c = np.sqrt(5.0) / (ell + 1e-12)
        term = (1.0 + c * R + 5.0 * D2 / (3.0 * (ell ** 2 + 1e-12)))
        return term * np.exp(-c * R)
    else:
        raise ValueError(f"Unknown base_kernel: {kind}")

def _laplacian(W: np.ndarray) -> np.ndarray:
    d = W.sum(axis=1)
    return np.diag(d) - W

def _laplacian_inv(W: np.ndarray, rho: float) -> np.ndarray:
    L = _laplacian(W)
    A = L + float(rho) * np.eye(W.shape[0])
    A = 0.5 * (A + A.T) + 1e-10 * np.eye(A.shape[0])
    return np.linalg.inv(A)

def _heat_kernel(W: np.ndarray, tau: float) -> np.ndarray:
    L = _laplacian(W)
    L = 0.5 * (L + L.T)
    evals, evecs = np.linalg.eigh(L)
    Kz = (evecs * np.exp(-float(tau) * np.maximum(evals, 0.0))) @ evecs.T
    Kz = 0.5 * (Kz + Kz.T)
    return Kz

def _spectral_rbf(W: np.ndarray, k: int, sigma: float | Literal["median"]) -> np.ndarray:
    L = _laplacian(W)
    L = 0.5 * (L + L.T)
    evals, evecs = np.linalg.eigh(L)
    idx = np.argsort(evals)
    use = idx[1 : min(1 + k, len(idx))]
    Z = evecs[:, use]  # (n,k)
    D2 = _pairwise_sq_dists(Z)
    if sigma == "median":
        vals = D2[np.triu_indices(D2.shape[0], 1)]
        med = np.median(vals[vals > 0]) if np.any(vals > 0) else 1.0
        sigma = float(np.sqrt(max(med, 1e-12)))
    denom = 2.0 * (float(sigma) ** 2 + 1e-12)
    Kz = np.exp(-D2 / denom)
    Kz = 0.5 * (Kz + Kz.T)
    return Kz

def _all_ones(n: int) -> np.ndarray:
    return np.ones((n, n), dtype=float)

# --- RFF for mean embeddings (SE-compatible; used to *learn* Kz efficiently) ---
def _rff_params(d: int, D: int, ell: float, rng: np.random.Generator) -> Tuple[np.ndarray, np.ndarray]:
    W = rng.normal(loc=0.0, scale=1.0 / (ell + 1e-12), size=(d, D))
    b = rng.uniform(0.0, 2.0 * np.pi, size=(D,))
    return W, b

def _rff_phi(x: np.ndarray, W: np.ndarray, b: np.ndarray) -> np.ndarray:
    z = x @ W + b
    return np.sqrt(2.0 / W.shape[1]) * np.cos(z)

# Triangular solves (no SciPy dependency)
def _forward_sub(L: np.ndarray, b: np.ndarray) -> np.ndarray:
    n = L.shape[0]
    x = np.zeros_like(b, dtype=float)
    for i in range(n):
        s = b[i] - (L[i, :i] @ x[:i])
        x[i] = s / (L[i, i] + 1e-12)
    return x

def _back_sub(LT: np.ndarray, b: np.ndarray) -> np.ndarray:
    n = LT.shape[0]
    x = np.zeros_like(b, dtype=float)
    for i in range(n - 1, -1, -1):
        s = b[i] - (LT[i, i + 1 :] @ x[i + 1 :])
        x[i] = s / (LT[i, i] + 1e-12)
    return x

def _chol_append(L: np.ndarray, k: np.ndarray, kpp: float) -> np.ndarray:
    if L.size == 0:
        s2 = max(kpp, 1e-12)
        return np.array([[np.sqrt(s2)]], dtype=float)
    v = _forward_sub(L, k)
    s2 = kpp - float(v @ v)
    if s2 < 1e-12:
        s2 = 1e-12
    s = float(np.sqrt(s2))
    n = L.shape[0]
    Lnew = np.zeros((n + 1, n + 1), dtype=float)
    Lnew[:n, :n] = L
    Lnew[n, :n] = v
    Lnew[n, n] = s
    return Lnew

# ----------------------------------------
# CoopKernelUCB (with pluggable / learned Kz)
# ----------------------------------------
class CoopKernelUCB:
    """
    GP-UCB with product kernel K((u,m),(u',m')) = Kz[u,u'] * Kx[m,m'].

    - Supports fixed agent kernels: laplacian_inv, heat, spectral_rbf, all ones.
    - Supports learned_mmd: uses RFF mean embeddings per agent; updates Kz epochically
      and rebuilds the GP state from its own logged history.
    """

    def __init__(self, n_users: int, X: np.ndarray, W: np.ndarray, cfg: CoopKernelUCBConfig):
        self.n = int(n_users)
        self.X = np.asarray(X)
        self.M = self.X.shape[0]
        self.d = self.X.shape[1]
        self.W = np.asarray(W, dtype=float)
        self.cfg = cfg

        # Arm kernel (fixed)
        self.Kx = _kernel_matrix_X(self.X, cfg.base_kernel, cfg.lengthscale)
        self.kx_diag = np.diag(self.Kx).copy()

        # GP state
        self.lambda_noise = float(cfg.lambda_noise)
        self.beta = float(cfg.beta)

        # History buffers
        self.u_hist: List[int] = []
        self.m_hist: List[int] = []
        self.y_hist: List[float] = []

        self.L = np.zeros((0, 0), dtype=float)
        self.alpha = np.zeros((0,), dtype=float)
        self.t = 0

        # Agent kernel (may be dynamic) — initialize state BEFORE building Kz
        self.rng = np.random.default_rng(12345)
        self._init_agent_kernel_state()

    # ---------- Agent kernel helpers ----------
    def _init_agent_kernel_state(self):
        ak = self.cfg.agent_kernel
        if ak == "laplacian":
            ak = "laplacian_inv"
        self.agent_kernel = ak

        # Initialize learned-MMD state first, then set initial Kz
        if self.agent_kernel == "learned_mmd":
            if self.cfg.mmd_mode == "rff":
                self.W_rff, self.b_rff = _rff_params(self.d, self.cfg.rff_dim, self.cfg.lengthscale, self.rng)
                self.mu_feat = np.zeros((self.n, self.cfg.rff_dim), dtype=float)
                self.counts = np.zeros((self.n,), dtype=int)
            else:
                self.obs_idx_per_user: List[List[int]] = [[] for _ in range(self.n)]
            # Start with identity; will be updated epochically
            self.Kz = np.eye(self.n, dtype=float)
        else:
            # Fixed Kz
            self.Kz = self._build_Kz_fixed(self.agent_kernel)

    def _build_Kz_fixed(self, kind: AgentKernelKind) -> np.ndarray:
        if kind == "laplacian_inv" or kind == "laplacian":
            return _laplacian_inv(self.W, self.cfg.rho)
        if kind == "heat":
            return _heat_kernel(self.W, self.cfg.tau)
        if kind == "spectral_rbf":
            return _spectral_rbf(self.W, self.cfg.spec_k, self.cfg.spec_sigma)
        if kind == "all ones":
            return _all_ones(self.n)
        raise ValueError(f"Unknown agent kernel kind: {kind}")

    def _build_Kz_learned(self) -> np.ndarray:
        """Build learned Kz from current network-context state."""
        if getattr(self, "counts", None) is not None and self.cfg.mmd_mode == "rff":
            MU = self.mu_feat.copy()  # (n, D)
            valid = self.counts >= self.cfg.min_count
            D2 = _pairwise_sq_dists(MU)
            mask = np.outer(valid, valid)
            D2_eff = np.where(mask, D2, np.inf)

            if self.cfg.mmd_sigma == "median":
                vals = D2_eff[np.isfinite(D2_eff)]
                if vals.size == 0:
                    sigma = 1.0
                else:
                    med = np.median(vals[vals > 0]) if np.any(vals > 0) else np.median(vals)
                    sigma = float(np.sqrt(max(med, 1e-12)))
            else:
                sigma = float(self.cfg.mmd_sigma)
            denom = 2.0 * (sigma ** 2 + 1e-12)

            Kz = np.exp(-np.where(np.isfinite(D2_eff), D2_eff / denom, 1e12))
            np.fill_diagonal(Kz, 1.0)
            Kz[~valid, :] = 0.0
            Kz[:, ~valid] = 0.0
            for i in range(self.n):
                if not valid[i]:
                    Kz[i, i] = 1.0
            return 0.5 * (Kz + Kz.T)

        # exact mode
        Kz = np.eye(self.n, dtype=float)
        for i in range(self.n):
            Xi = self.obs_idx_per_user[i]
            ni = len(Xi)
            if ni < self.cfg.min_count:
                continue
            for j in range(i + 1, self.n):
                Xj = self.obs_idx_per_user[j]
                nj = len(Xj)
                if nj < self.cfg.min_count:
                    continue
                Kii = self.Kx[np.ix_(Xi, Xi)]
                Kjj = self.Kx[np.ix_(Xj, Xj)]
                Kij = self.Kx[np.ix_(Xi, Xj)]
                term = (Kii.sum() / (ni * ni + 1e-12)
                        + Kjj.sum() / (nj * nj + 1e-12)
                        - 2.0 * Kij.sum() / (ni * nj + 1e-12))
                if self.cfg.mmd_sigma == "median":
                    sigma = float(np.sqrt(max(term, 1e-12)))
                else:
                    sigma = float(self.cfg.mmd_sigma)
                val = np.exp(-term / (2.0 * (sigma ** 2 + 1e-12)))
                Kz[i, j] = Kz[j, i] = val
        return 0.5 * (Kz + Kz.T)

    # ---------- GP core ----------
    def _K_cross_candidate(self, u: int, m: int) -> np.ndarray:
        if len(self.u_hist) == 0:
            return np.zeros((0,), dtype=float)
        Ku = self.Kz[u, np.asarray(self.u_hist, dtype=int)]
        Km = self.Kx[m, np.asarray(self.m_hist, dtype=int)]
        return Ku * Km

    def _K_self(self, u: int, m: int) -> float:
        return float(self.Kz[u, u] * self.Kx[m, m])

    def _recompute_cholesky(self):
        t = len(self.u_hist)
        if t == 0:
            self.L = np.zeros((0, 0), dtype=float)
            self.alpha = np.zeros((0,), dtype=float)
            return
        U = np.asarray(self.u_hist, dtype=int)
        M = np.asarray(self.m_hist, dtype=int)
        Kt = (self.Kz[U[:, None], U[None, :]] * self.Kx[M[:, None], M[None, :]]).astype(float)
        Kt = Kt + self.lambda_noise * np.eye(t)
        jitter = 0.0
        for _ in range(3):
            try:
                L = np.linalg.cholesky(0.5 * (Kt + Kt.T) + jitter * np.eye(t))
                break
            except np.linalg.LinAlgError:
                jitter = 1e-8 if jitter == 0.0 else jitter * 10.0
        else:
            L = np.linalg.cholesky(0.5 * (Kt + Kt.T) + (jitter + 1e-6) * np.eye(t))
        self.L = L
        y = np.asarray(self.y_hist, dtype=float)
        v = _forward_sub(self.L, y)
        self.alpha = _back_sub(self.L.T, v)

    def _maybe_update_Kz_and_rebuild(self, t_now: int):
        if self.agent_kernel != "learned_mmd":
            return
        if self.cfg.update_every <= 0 or t_now == 0:
            return
        if (t_now % self.cfg.update_every) != 0:
            return
        self.Kz = self._build_Kz_learned()
        self._recompute_cholesky()

    # ---------- Public API ----------
    def select(self, t: int, u: int, cand_list: List[int]) -> int:
        self._maybe_update_Kz_and_rebuild(t)

        if len(self.u_hist) == 0:
            variances = [self._K_self(u, m) + self.lambda_noise for m in cand_list]
            j = int(np.argmax(variances))
            return int(cand_list[j])

        scores = []
        for m in cand_list:
            k = self._K_cross_candidate(u, m)
            v = _forward_sub(self.L, k)
            w = _back_sub(self.L.T, v)
            mu = float(k @ self.alpha)
            kcc = self._K_self(u, m) + self.lambda_noise
            var = max(kcc - float(k @ w), 1e-12)
            scores.append(mu + self.beta * np.sqrt(var))
        j = int(np.argmax(scores))
        return int(cand_list[j])

    def update(self, u: int, m: int, y: float):
        if self.agent_kernel == "learned_mmd":
            if self.cfg.mmd_mode == "rff":
                phi = _rff_phi(self.X[m], self.W_rff, self.b_rff)
                c = int(self.counts[u])
                if c == 0:
                    self.mu_feat[u] = phi
                else:
                    self.mu_feat[u] = (c * self.mu_feat[u] + phi) / (c + 1)
                self.counts[u] = c + 1
            else:
                # exact mode: keep indices by user
                if not hasattr(self, "obs_idx_per_user"):
                    self.obs_idx_per_user = [[] for _ in range(self.n)]
                self.obs_idx_per_user[u].append(int(m))

        k = self._K_cross_candidate(u, m)
        kpp = self._K_self(u, m) + self.lambda_noise
        self.L = _chol_append(self.L, k, kpp)

        self.u_hist.append(int(u))
        self.m_hist.append(int(m))
        self.y_hist.append(float(y))
        y_vec = np.asarray(self.y_hist, dtype=float)
        v = _forward_sub(self.L, y_vec)
        self.alpha = _back_sub(self.L.T, v)

        self.t += 1
