from dataclasses import dataclass
import numpy as np

from ..graph import laplacian_with_ridge, inv_sqrt_psd


@dataclass
class GraphUCBConfig:
    alpha: float = 1.0
    lambda_reg: float = 1e-2
    rho_lap: float = 0.1


class GraphUCB:
    def __init__(self, X: np.ndarray, W: np.ndarray, config: GraphUCBConfig):
        self.X = np.asarray(X, dtype=float)
        self.n = W.shape[0]
        self.d = self.X.shape[1]
        self.M = self.X.shape[0]
        self.D = self.n * self.d

        self.config = config

        L_rho = laplacian_with_ridge(W, rho=config.rho_lap)
        B, _ = inv_sqrt_psd(L_rho, jitter=1e-8, r_trunc=None)
        self.B = B

        self.A_inv = (1.0 / config.lambda_reg) * np.eye(self.D)
        self.b = np.zeros(self.D)
        self.theta = np.zeros(self.D)

    def _z(self, u: int, m: int) -> np.ndarray:
        return np.kron(self.B[:, u], self.X[m])

    def select(self, u: int, cand_idx: list) -> int:
        best = None
        best_val = -np.inf
        for m in cand_idx:
            z = self._z(u, m)
            mu = float(z @ self.theta)
            rad = float(np.sqrt(max(z @ (self.A_inv @ z), 1e-12)))
            ucb = mu + self.config.alpha * rad
            if ucb > best_val:
                best_val = ucb
                best = m
        return int(best)

    def update(self, u: int, m: int, y: float) -> None:
        z = self._z(u, m)
        Az = self.A_inv @ z
        denom = 1.0 + z @ Az
        self.A_inv = self.A_inv - np.outer(Az, Az) / denom
        self.b += z * y
        self.theta = self.A_inv @ self.b
