from typing import Optional
import numpy as np

from .graph import laplacian, laplacian_with_ridge, inv_sqrt_psd
from .contexts import sample_context_pool
from .kernels import sqexp_kernel, matern52_kernel


KERNELS = {
    "SE": sqexp_kernel,
    "Matern52": matern52_kernel,
}


def regime1_linear_gob(n: int, d: int, M: int, eta: float, graph_W: np.ndarray, seed: int = 0):
    rng = np.random.default_rng(seed)
    X = sample_context_pool(M, d, seed=seed)
    L = laplacian(graph_W)
    Theta0 = rng.normal(size=(n, d))
    A = np.eye(n) + eta * L
    Theta = np.linalg.solve(A, Theta0)

    def f_eval(u_idx: int, x_vec: np.ndarray) -> float:
        return float(x_vec @ Theta[u_idx])

    return X, Theta, f_eval


def regime2_kernel_prepare(n: int, M: int, X: np.ndarray, graph_W: np.ndarray, rho: float, base_kernel: str, lengthscale: float, r_trunc: Optional[int] = None):
    L_rho = laplacian_with_ridge(graph_W, rho=rho)
    K_user, _ = inv_sqrt_psd(L_rho, jitter=1e-8, r_trunc=r_trunc)
    ker = KERNELS[base_kernel]
    K_arm = ker(X, X, lengthscale=lengthscale)
    return K_user, K_arm


def regime2A_gp_draw(n: int, M: int, K_user: np.ndarray, K_arm: np.ndarray, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    wu, Uu = np.linalg.eigh(K_user)
    wa, Ua = np.linalg.eigh(K_arm)
    wu = np.maximum(wu, 0.0)
    wa = np.maximum(wa, 0.0)
    Z = rng.normal(size=(n, M))
    F = Uu @ (np.sqrt(wu)[:, None] * (Uu.T @ Z))
    F = (F @ Ua) * np.sqrt(wa)[None, :]
    return F


def regime2B_representer_draw(n: int, M: int, K_user: np.ndarray, K_arm: np.ndarray, tau: float = 1.0, seed: int = 0) -> np.ndarray:
    rng = np.random.default_rng(seed)
    A = rng.normal(scale=tau / np.sqrt(n * M), size=(n, M))
    F = K_user @ A @ K_arm
    return F
