from dataclasses import dataclass
from typing import Optional, List
import numpy as np

from ..regimes import regime2_kernel_prepare


@dataclass
class LKGPTSConfig:
    lambda_noise: float = 1e-2
    nu: float = 1.0
    base_kernel: str = "SE"
    lengthscale: float = 1.0
    rho: float = 0.1
    r_trunc: Optional[int] = None
    t_star: Optional[int] = None


class LKGPTS:
    def __init__(self, n: int, X: np.ndarray, graph_W: np.ndarray, config: LKGPTSConfig, BK0: Optional[np.ndarray] = None, seed: int = 0):
        self.n = n
        self.M = X.shape[0]
        self.N = self.n * self.M
        self.config = config
        self.rng = np.random.default_rng(seed)

        self.K_user, self.K_arm = regime2_kernel_prepare(n, self.M, X, graph_W, rho=config.rho, base_kernel=config.base_kernel, lengthscale=config.lengthscale, r_trunc=config.r_trunc)
        self.BK0 = BK0 if BK0 is not None else np.kron(self.K_user, self.K_arm)

        self.lambda_noise = float(config.lambda_noise)
        self.t_star = int(np.ceil((self.N) ** (1.0 / 3.0))) if config.t_star is None else int(config.t_star)

        self.hist_idx: List[int] = []
        self.hist_y: List[float] = []
        self.L: Optional[np.ndarray] = None

        self.mode = "chol"
        self.Q: Optional[np.ndarray] = None
        self.mu: Optional[np.ndarray] = None
        self.var: Optional[np.ndarray] = None

        self.t = 0

    def _grid_index(self, u: int, m: int) -> int:
        return u * self.M + m

    def _kvec_to_history(self, u: int, m: int) -> np.ndarray:
        if not self.hist_idx:
            return np.zeros(0, dtype=float)
        u_hist = np.array([idx // self.M for idx in self.hist_idx], dtype=int)
        m_hist = np.array([idx % self.M for idx in self.hist_idx], dtype=int)
        ku = self.K_user[u, u_hist]
        ka = self.K_arm[m, m_hist]
        return ku * ka

    def _kdiag(self, u: int, m: int) -> float:
        return float(self.K_user[u, u] * self.K_arm[m, m])

    def _rebuild_cholesky(self):
        t = len(self.hist_idx)
        if t == 0:
            self.L = None
            return
        u_hist = np.array([idx // self.M for idx in self.hist_idx], dtype=int)
        m_hist = np.array([idx % self.M for idx in self.hist_idx], dtype=int)
        Ku = self.K_user[np.ix_(u_hist, u_hist)]
        Ka = self.K_arm[np.ix_(m_hist, m_hist)]
        Kt = Ku * Ka
        A = Kt + self.lambda_noise * np.eye(t)
        self.L = np.linalg.cholesky(A)

    def _chol_solve(self, y: np.ndarray) -> np.ndarray:
        v = np.linalg.solve(self.L, y)
        alpha = np.linalg.solve(self.L.T, v)
        return alpha

    def _switch_to_recursive(self):
        self.Q = self.BK0.astype(np.float64, copy=True)
        self.mu = np.zeros(self.N, dtype=np.float64)
        self.var = np.clip(np.diag(self.Q).astype(np.float64), 1e-12, None)
        for z, y in zip(self.hist_idx, self.hist_y):
            c = self.Q[:, z].copy()
            denom = self.lambda_noise + self.var[z]
            resid = y - self.mu[z]
            self.mu += (c / denom) * resid
            self.var = np.clip(self.var - (c * c) / denom, 1e-12, None)
            row = self.Q[z, :].copy()
            self.Q -= np.outer(c, row) / denom
        self.mode = "rec"

    def select(self, t: int, u_t: int, cand_idx: list) -> int:
        if self.mode == "chol":
            self._rebuild_cholesky()
            if not self.hist_idx:
                vars_ = [self._kdiag(u_t, m) for m in cand_idx]
                return int(cand_idx[int(np.argmax(vars_))])
            y = np.array(self.hist_y, dtype=float)
            alpha = self._chol_solve(y)

            best = None
            best_val = -np.inf
            for m in cand_idx:
                kvec = self._kvec_to_history(u_t, m)
                mu = float(kvec @ alpha)
                v = np.linalg.solve(self.L, kvec)
                var = max(self._kdiag(u_t, m) - float(v @ v), 1e-12)
                sample = self.rng.normal(loc=mu, scale=self.config.nu * np.sqrt(var))
                if sample > best_val:
                    best_val = sample
                    best = m
            return int(best)

        idxs = [self._grid_index(u_t, m) for m in cand_idx]
        means = self.mu[idxs]
        stds = np.sqrt(np.clip(self.var[idxs], 1e-12, None))
        samples = self.rng.normal(loc=means, scale=self.config.nu * stds)
        j = int(np.argmax(samples))
        return cand_idx[j]

    def update(self, u_t: int, m_t: int, y_t: float) -> None:
        self.t += 1
        if self.mode == "chol" and self.t >= self.t_star:
            self._switch_to_recursive()

        if self.mode == "chol":
            z = self._grid_index(u_t, m_t)
            self.hist_idx.append(z)
            self.hist_y.append(float(y_t))
            return

        z = self._grid_index(u_t, m_t)
        c = self.Q[:, z].copy()
        denom = self.lambda_noise + self.var[z]
        resid = y_t - self.mu[z]
        self.mu += (c / denom) * resid
        self.var = np.clip(self.var - (c * c) / denom, 1e-12, None)
        row = self.Q[z, :].copy()
        self.Q -= np.outer(c, row) / denom
