from dataclasses import dataclass
from typing import List
import numpy as np

from ..kernels import sqexp_kernel, matern52_kernel


KERNELS = {
    "SE": sqexp_kernel,
    "Matern52": matern52_kernel,
}


@dataclass
class GPConfig:
    lambda_noise: float = 1e-2
    base_kernel: str = "SE"
    lengthscale: float = 1.0


class GPGreedyUCB:
    def __init__(self, X: np.ndarray, config: GPConfig, beta: float = 1.0):
        self.X = X
        self.M = X.shape[0]
        self.config = config
        self.beta = beta
        self.K = KERNELS[config.base_kernel](X, X, lengthscale=config.lengthscale)
        self.hist_idx: List[int] = []
        self.hist_y: List[float] = []

    def _solve(self):
        t = len(self.hist_idx)
        if t == 0:
            return None
        Kt = self.K[np.ix_(self.hist_idx, self.hist_idx)]
        lamI = self.config.lambda_noise * np.eye(t)
        A = Kt + lamI
        y = np.array(self.hist_y, dtype=float)
        alpha = np.linalg.solve(A, y)
        return alpha

    def select(self, cand_idx: list) -> int:
        t = len(self.hist_idx)
        if t == 0:
            variances = np.diag(self.K)[cand_idx]
            return int(cand_idx[int(np.argmax(variances))])
        alpha = self._solve()
        means = []
        stds = []
        Kt = self.K[np.ix_(self.hist_idx, self.hist_idx)]
        lamI = self.config.lambda_noise * np.eye(t)
        A = Kt + lamI
        for m in cand_idx:
            kvec = self.K[self.hist_idx, m]
            mu = float(kvec @ alpha)
            v = np.linalg.solve(A, kvec)
            var = float(self.K[m, m] - kvec @ v)
            means.append(mu)
            stds.append(np.sqrt(max(var, 1e-12)))
        means = np.array(means)
        stds = np.array(stds)
        ucb = means + self.beta * stds
        j = int(np.argmax(ucb))
        return cand_idx[j]

    def update(self, m: int, y: float) -> None:
        self.hist_idx.append(m)
        self.hist_y.append(float(y))
