
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, Iterable, Optional, Sequence, Union

import numpy as np
from scipy import linalg
from scipy.special import expit
from sklearn.gaussian_process.kernels import Matern, RBF
from sklearn.metrics.pairwise import pairwise_kernels

try:
    from numba import njit

    NUMBA_AVAILABLE = True
except Exception:
    NUMBA_AVAILABLE = False

    def njit(*args, **kwargs):
        def decorator(func):
            return func

        return decorator


@njit(fastmath=True, cache=True)
def _sigmoid_nb(x: np.ndarray) -> np.ndarray:
    x = np.minimum(np.maximum(x, -500.0), 500.0)
    return 1.0 / (1.0 + np.exp(-x))


def sigmoid(x):
    x_arr = np.asarray(x, dtype=np.float64)
    if NUMBA_AVAILABLE:
        return _sigmoid_nb(x_arr)
    return expit(np.clip(x_arr, -500.0, 500.0))


@njit(fastmath=True, cache=True)
def _dsigmoid_nb(x: np.ndarray) -> np.ndarray:
    s = _sigmoid_nb(x)
    return s * (1.0 - s)


def dsigmoid(x):
    x_arr = np.asarray(x, dtype=np.float64)
    if NUMBA_AVAILABLE:
        return _dsigmoid_nb(x_arr)
    s = sigmoid(x_arr)
    return s * (1.0 - s)


@njit(fastmath=True, cache=True)
def _weighted_norm_nb(x: np.ndarray, A: np.ndarray) -> float:
    return float(np.sqrt(x @ A @ x))


def weighted_norm(x: np.ndarray, A: np.ndarray) -> float:
    x_arr = np.asarray(x, dtype=np.float64)
    A_arr = np.asarray(A, dtype=np.float64)
    if NUMBA_AVAILABLE:
        return _weighted_norm_nb(x_arr, A_arr)
    return float(np.sqrt(x_arr @ A_arr @ x_arr))


@njit(fastmath=True, cache=True)
def _gauss_sample_nb(center: np.ndarray, design: np.ndarray, radius: float) -> np.ndarray:
    L = np.linalg.cholesky(design)
    z = np.random.normal(0.0, 1.0, center.size)
    delta = np.linalg.solve(L, z) * radius
    return center + delta


def gaussian_sample_ellipsoid(center: np.ndarray, design: np.ndarray, radius: float) -> np.ndarray:
    center = np.asarray(center, dtype=np.float64)
    design = np.asarray(design, dtype=np.float64)
    if NUMBA_AVAILABLE:
        return _gauss_sample_nb(center, design, radius)
    L = np.linalg.cholesky(design)
    z = np.random.normal(size=center.size)
    delta = np.linalg.solve(L, z) * radius
    return center + delta


@njit(fastmath=True, cache=True)
def _proj_nb(x: np.ndarray, center: np.ndarray, A: np.ndarray, r: float) -> np.ndarray:
    y = x - center
    L = np.linalg.cholesky(A)
    u = np.linalg.solve(L, y)
    u2 = np.dot(u, u)
    r2 = r * r
    if u2 <= r2 + 1e-12:
        return x
    lam = 0.0
    for _ in range(50):
        f = u2 / (1.0 + lam) ** 2 - r2
        if abs(f) < 1e-12:
            break
        fp = -2.0 * u2 / (1.0 + lam) ** 3
        lam = max(lam - f / fp, 0.0)
    return center + y / (1.0 + lam)


def project_onto_ellipsoid(x: np.ndarray, center: np.ndarray, A: np.ndarray, r: float) -> np.ndarray:
    x_arr = np.asarray(x, dtype=np.float64)
    center_arr = np.asarray(center, dtype=np.float64)
    A_arr = np.asarray(A, dtype=np.float64)
    if NUMBA_AVAILABLE:
        return _proj_nb(x_arr, center_arr, A_arr, r)
    return _proj_nb.py_func(x_arr, center_arr, A_arr, r)


@njit(fastmath=True, cache=True)
def _pgd_nb(
    arm: np.ndarray,
    theta0: np.ndarray,
    V: np.ndarray,
    V_inv: np.ndarray,
    S: float,
    steps: int,
    mode: int,
    reward: float,
) -> np.ndarray:
    L = np.linalg.cholesky(V)
    L_inv = np.linalg.cholesky(V_inv)
    z0 = L @ theta0
    z = z0.copy()
    inv_z_arm = L_inv @ arm
    step = 0.5
    for _ in range(steps):
        pred = _sigmoid_nb(np.dot(z, inv_z_arm))
        coef = (pred - reward) if mode == 0 else (2.0 * pred - 1.0)
        grad = z - z0 + coef * inv_z_arm
        z -= step * grad
        z[:] = _proj_nb(z, np.zeros_like(arm), V, S)
    return np.linalg.solve(L, z)


def _online_logistic_driver(
    arm: np.ndarray,
    theta0: np.ndarray,
    V: np.ndarray,
    V_inv: np.ndarray,
    S: float,
    precision: float,
    mode: int,
    reward: float,
) -> np.ndarray:
    diam = S
    steps = int(np.ceil((9 / 4 + diam / 8) * np.log(max(diam / precision, 1.0))))
    if NUMBA_AVAILABLE:
        return _pgd_nb(arm, theta0, V, V_inv, S, steps, mode, reward)

    L = np.linalg.cholesky(V)
    L_inv = np.linalg.cholesky(V_inv)
    z0 = L @ theta0
    z = z0.copy()
    inv_z_arm = L_inv @ arm
    step = 0.5
    for _ in range(steps):
        pred = float(sigmoid(np.dot(z, inv_z_arm)))
        coef = (pred - reward) if mode == 0 else (2 * pred - 1)
        grad = z - z0 + coef * inv_z_arm
        z -= step * grad
        z = project_onto_ellipsoid(z, np.zeros_like(arm), V, S)
    return np.linalg.solve(L, z)


def fit_online_logistic_estimate(
    arm: np.ndarray,
    reward: float,
    current_estimate: np.ndarray,
    vtilde_matrix: np.ndarray,
    vtilde_inv_matrix: np.ndarray,
    constraint_set_radius: float,
    precision: float = 1e-1,
) -> np.ndarray:
    return _online_logistic_driver(
        arm,
        current_estimate,
        vtilde_matrix,
        vtilde_inv_matrix,
        constraint_set_radius,
        precision,
        mode=0,
        reward=reward,
    )


def fit_online_logistic_estimate_bar(
    arm: np.ndarray,
    current_estimate: np.ndarray,
    vtilde_matrix: np.ndarray,
    vtilde_inv_matrix: np.ndarray,
    constraint_set_radius: float,
    precision: float = 1e-1,
) -> np.ndarray:
    return _online_logistic_driver(
        arm,
        current_estimate,
        vtilde_matrix,
        vtilde_inv_matrix,
        constraint_set_radius,
        precision,
        mode=1,
        reward=0.0,
    )


def mu_dot(scalar: float) -> float:
    return float(dsigmoid(np.array([[scalar]], dtype=np.float64)))


def S(p: np.ndarray, A: np.ndarray, gamma: float = 0.0) -> np.ndarray:
    A = np.asarray(A, dtype=np.float64)
    p = np.asarray(p, dtype=np.float64)
    k, d = A.shape
    if p.shape[0] != k:
        raise ValueError("Length of p must match number of arms.")
    ret = np.zeros((d, d), dtype=np.float64)
    for i in range(k):
        a = A[[i], :]
        ret += p[i] * (a.T @ a)
    return ret + gamma * np.eye(d)


def s_inv(V: np.ndarray) -> np.ndarray:
    L = np.linalg.cholesky(V)
    return np.linalg.inv(L).T @ np.linalg.inv(L)


def optimal_design(A: np.ndarray, p0: Optional[np.ndarray] = None, *, iterations: int = 200) -> np.ndarray:
    A = np.asarray(A, dtype=np.float64)
    k, d = A.shape
    if p0 is None:
        p = np.ones(k, dtype=np.float64) / k
    else:
        p = np.asarray(p0, dtype=np.float64).copy()
    for _ in range(iterations):
        V = S(p, A)
        V_inv = s_inv(V)
        vs = np.empty(k, dtype=np.float64)
        for i in range(k):
            a = A[[i], :]
            vs[i] = (a @ V_inv @ a.T).item()
        i_star = int(np.argmax(vs))
        v_star = vs[i_star]
        gamma = (v_star / d - 1.0) / (v_star - 1.0)
        p *= (1.0 - gamma)
        p[i_star] += gamma
    return p


@njit
def optimal_design_numba(A: np.ndarray, p0: Optional[np.ndarray] = None, iterations: int = 200) -> np.ndarray:
    k, d = A.shape
    if p0 is None:
        p = np.ones(k) / k
    else:
        p = p0.copy()
    for _ in range(iterations):
        V = S(p, A)
        V_inv = s_inv(V)
        vs = np.empty(k)
        for i in range(k):
            a = A[i]
            vs[i] = a @ V_inv @ a
        i_star = int(np.argmax(vs))
        v_star = vs[i_star]
        gamma = (v_star / d - 1.0) / (v_star - 1.0)
        p *= (1.0 - gamma)
        p[i_star] += gamma
    return p


def glm_fit(X: np.ndarray, y: np.ndarray, link: str = "logistic") -> np.ndarray:
    if link != "logistic":
        raise ValueError("Only logistic link supported.")
    mean_function = lambda x: 1.0 / (1.0 + np.exp(-x))
    var_function = lambda x: np.exp(-x) / (1.0 + np.exp(-x)) ** 2
    X = np.asarray(X, dtype=np.float64)
    y = np.asarray(y, dtype=np.float64)
    n, d = X.shape
    theta = np.zeros(d, dtype=np.float64)
    S_vec = np.sum(X * y.reshape((n, 1)), axis=0)
    for _ in range(30):
        scores = X.dot(theta)
        g = S_vec - np.sum(X * mean_function(scores).reshape((n, 1)), axis=0)
        if np.allclose(g, np.zeros(d)):
            break
        H = np.zeros((d, d), dtype=np.float64)
        variances = var_function(scores)
        for i in range(n):
            x_i = X[i, :]
            H -= variances[i] * np.outer(x_i, x_i)
        H_inv = np.linalg.pinv(H)
        theta = theta - H_inv @ g
    return theta


def G(x: np.ndarray, A: np.ndarray, gamma: float) -> np.ndarray:
    return S(x, A) + gamma * np.eye(A.shape[1])


def phi(c: np.ndarray, t: float, x: np.ndarray, A: np.ndarray, gamma: float) -> float:
    g = G(x, A, gamma)
    _, logdet = np.linalg.slogdet(g)
    return float(t * (np.sum(c * x) - logdet) - np.sum(np.log(x)) - np.log(1 - np.sum(x)))


def J(c: np.ndarray, x: np.ndarray, A: np.ndarray, gamma: float) -> float:
    g = G(x, A, gamma)
    _, logdet = np.linalg.slogdet(g)
    return float(np.sum(c * x) - logdet)


def newton_direction(c: np.ndarray, t: float, x: np.ndarray, A: np.ndarray, gamma: float):
    n, d = A.shape
    V = S(x, A) + gamma * np.eye(d)
    V_inv = np.linalg.inv(V)
    vs = np.zeros(n, dtype=np.float64)
    for i in range(n):
        a = A[i, :]
        vs[i] = V_inv.dot(a).dot(a)
    d1 = t * (c - vs) - 1.0 / x + 1.0 / (1.0 - np.sum(x))
    d2 = t * np.power(A @ V_inv @ A.T, 2) + 1.0 / ((1.0 - np.sum(x)) ** 2) + np.diag(1.0 / np.power(x, 2))
    direction = -np.linalg.inv(d2) @ d1
    v = direction.reshape((n, 1))
    l = np.sqrt(v.T @ d2 @ v).item()
    return direction, l


def generalized_eigenvalues(x: np.ndarray, direction: np.ndarray, A: np.ndarray, gamma: float):
    left = S(direction, A)
    right = S(x, A) + gamma * np.eye(A.shape[1])
    eigen1 = np.real(linalg.eigvals(left, right))
    eigen2 = np.append(direction / x, -np.sum(direction) / (1 - np.sum(x)))
    return eigen1, eigen2


def line_search(c: np.ndarray, t: float, x: np.ndarray, direction: np.ndarray, A: np.ndarray, gamma: float) -> float:
    eigen1, eigen2 = generalized_eigenvalues(x, direction, A, gamma)
    h = 0.0
    for _ in range(30):
        d1 = t * np.sum(c * direction) - np.sum(t * eigen1 / (1 + h * eigen1)) - np.sum(eigen2 / (1 + h * eigen2))
        d2 = np.sum(t / ((1 / eigen1 + h) ** 2)) + np.sum(1 / ((1 / eigen2 + h) ** 2))
        increment = -d1 / d2
        if np.any(1 + (h + increment) * eigen1 <= 1e-8) or np.any(1 + (h + increment) * eigen2 <= 1e-8):
            min_eig = min(np.min(eigen1), np.min(eigen2))
            h = (h - 1 / min_eig) / 2
        else:
            h = h - d1 / d2
    return float(h)


def newton_optimize(c: np.ndarray, t: float, A: np.ndarray, gamma: float, x: Optional[np.ndarray] = None) -> np.ndarray:
    n = A.shape[0]
    if x is None:
        x = np.ones(n, dtype=np.float64) / (2 * n)
    l = 1.0
    current_phi = np.inf
    counter = 0
    while l > 1e-6 / max(gamma, 1e-12) * t:
        direction, l = newton_direction(c, t, x, A, gamma)
        step = 1.0
        if l > 0.5:
            step = line_search(c, t, x, direction, A, gamma)
        x = x + step * direction
        new_phi = phi(c, t, x, A, gamma)
        if not (new_phi < current_phi or np.allclose(new_phi, current_phi, atol=1e-3)):
            raise RuntimeError("Newton loop failed to decrease objective.")
        current_phi = new_phi
        counter += 1
        if counter > 200:
            raise RuntimeError("Newton optimisation did not converge.")
    return x


def central_path(c: np.ndarray, A: np.ndarray, gamma: float) -> np.ndarray:
    n = A.shape[0]
    x = np.ones(n, dtype=np.float64) / (2 * n)
    t = 1.0
    while t < 1e7 or (t < 1e10 and np.sum(x) < 0.8):
        x = newton_optimize(c, t, A, gamma, x=x)
        t *= 1.1
    return x


def OP(A: np.ndarray, delta: np.ndarray, beta: Union[np.ndarray, float], gamma: float) -> np.ndarray:
    n, d = A.shape
    c = delta * beta / 2
    p = central_path(c, A, gamma)
    s = S(p, A) + gamma * np.eye(d)
    s_inv = np.linalg.inv(s)
    for a in range(n):
        arm = A[[a], :]
        norm = arm @ s_inv @ arm.T
        assert norm < beta * delta[a] + 2 * n or np.allclose(norm, beta * delta[a] + 2 * n)
    return p / np.sum(p)


def kernel_matrix(A: np.ndarray, config: Dict[str, object]) -> np.ndarray:
    name = config["name"]
    params = {k: v for k, v in config.items() if k != "name"}
    if name == "matern":
        kernel = Matern(**params)
        return pairwise_kernels(A, metric=kernel)
    if name == "rbf":
        kernel = RBF(**params)
        return pairwise_kernels(A, metric=kernel)
    if name == "linear":
        return pairwise_kernels(A, metric="linear")
    raise ValueError(f"Invalid kernel: {name}")


def compute_information_gain(Phi: np.ndarray, t: int, gamma: float) -> float:
    p = optimal_design(Phi)
    s = S(p, Phi) * (t / gamma) + np.eye(Phi.shape[1])
    _, logdet = np.linalg.slogdet(s)
    return float(logdet / 2)


@dataclass
class InformationGain:
    Phi: np.ndarray
    T: int
    gamma: float

    def __post_init__(self):
        self.cache: Dict[int, float] = {}
        self.precomputed_powers: Dict[int, float] = {}
        max_power = int(np.log2(self.T)) + 1
        for i in range(max_power + 1):
            t_power = 2 ** i
            if t_power <= self.T:
                self.precomputed_powers[t_power] = self.get_or_compute(t_power)

    def get_or_compute(self, t: int) -> float:
        if t in self.cache:
            return self.cache[t]
        info = compute_information_gain(self.Phi, t, self.gamma)
        self.cache[t] = info
        return info

    def get(self, t: int) -> float:
        if t in self.cache:
            return self.cache[t]
        t_left = 2 ** int(np.log2(t))
        dim_left = self.precomputed_powers.get(t_left, self.get_or_compute(t_left))
        if t_left == t:
            self.cache[t] = dim_left
            return dim_left
        t_right = min(self.T, t_left * 2)
        dim_right = self.precomputed_powers.get(t_right, self.get_or_compute(t_right))
        p = (t - t_left) / (t_right - t_left)
        result = dim_left * (1 - p) + dim_right * p
        self.cache[t] = result
        return result

    def get_exact(self, t: int) -> float:
        if t in self.cache:
            return self.cache[t]
        return self.get_or_compute(t)


__all__ = [
    "sigmoid",
    "dsigmoid",
    "weighted_norm",
    "gaussian_sample_ellipsoid",
    "project_onto_ellipsoid",
    "fit_online_logistic_estimate",
    "fit_online_logistic_estimate_bar",
    "mu_dot",
    "S",
    "s_inv",
    "optimal_design",
    "optimal_design_numba",
    "glm_fit",
    "G",
    "phi",
    "J",
    "newton_direction",
    "line_search",
    "newton_optimize",
    "central_path",
    "OP",
    "kernel_matrix",
    "InformationGain",
    "compute_information_gain",
]
