from dataclasses import dataclass
import numpy as np


@dataclass
class LinUCBConfig:
    alpha: float = 1.0
    lambda_reg: float = 1e-2


class PooledLinUCB:
    def __init__(self, d: int, config: LinUCBConfig):
        self.d = d
        self.config = config
        self.alpha = config.alpha
        self.A = config.lambda_reg * np.eye(d)
        self.b = np.zeros(d)

    def select(self, cand_X: np.ndarray) -> int:
        A_inv = np.linalg.inv(self.A)
        theta = A_inv @ self.b
        means = cand_X @ theta
        rad = np.sqrt(np.sum(cand_X @ A_inv * cand_X, axis=1))
        ucb = means + self.alpha * rad
        return int(np.argmax(ucb))

    def update(self, x: np.ndarray, y: float) -> None:
        self.A += np.outer(x, x)
        self.b += x * y


class PerUserLinUCB:
    def __init__(self, n_users: int, d: int, config: LinUCBConfig):
        self.n = n_users
        self.d = d
        self.config = config
        self.alpha = config.alpha
        self.A = [config.lambda_reg * np.eye(d) for _ in range(n_users)]
        self.b = [np.zeros(d) for _ in range(n_users)]

    def select(self, u: int, cand_X: np.ndarray) -> int:
        A_inv = np.linalg.inv(self.A[u])
        theta = A_inv @ self.b[u]
        means = cand_X @ theta
        rad = np.sqrt(np.sum(cand_X @ A_inv * cand_X, axis=1))
        ucb = means + self.alpha * rad
        return int(np.argmax(ucb))

    def update(self, u: int, x: np.ndarray, y: float) -> None:
        self.A[u] += np.outer(x, x)
        self.b[u] += x * y
