import numpy as np
from numba import njit
from scipy import linalg

@njit
def S(p, A, gamma=0):
    ret = A.T @ (p[:, None] * A) 
    d = A.shape[1]
    return ret + gamma * np.eye(d)

@njit
def s_inv(V):
    L = np.linalg.cholesky(V)
    L_inv = np.linalg.solve(L, np.eye(L.shape[0]))
    return L_inv.T @ L_inv

@njit
def optimal_design_numba(A, p0=None):
    k, d = A.shape
    if p0 is None:
        p = np.ones(k) / k
    else:
        p = p0

    for _ in range(200):
        V = S(p, A)
        V_inv = s_inv(V)
        vs = np.empty(k)
        for i in range(k):
            a = A[i]
            vs[i] = a @ V_inv @ a 
        i_star = np.argmax(vs)
        v_star = vs[i_star]
        gamma = (v_star / d - 1) / (v_star - 1)
        p *= (1 - gamma)
        p[i_star] += gamma

    return p


@njit
def rbf_kernel_numba(X, length_scale):
    sq_dists = np.sum(X**2, axis=1)[:, None] - 2 * X @ X.T + np.sum(X**2, axis=1)[None, :]
    K = np.exp(-sq_dists / (2 * length_scale ** 2))
    return K

def kernel_matrix(A, config):
    name = config['name']
    if name == 'linear':
        return A @ A.T
    if name == 'rbf':
        return rbf_kernel_numba(A, config['length_scale'])
    else:
        raise Exception(f'Invalid kernel: {name}')

@njit
def compute_information_gain(Phi, t, gamma):
    p = optimal_design_numba(Phi)
    s = S(p, Phi) * (t / gamma) + np.eye(Phi.shape[1])
    _, logdet = np.linalg.slogdet(s)
    return logdet / 2

class InformationGain:
    def __init__(self, Phi, T, gamma):
        self.Phi = Phi
        self.T = T
        self.gamma = gamma
        self.cache = {}
        self.precomputed_powers = {}
        max_power = int(np.log2(T)) + 1
        for i in range(max_power + 1):
            t_power = 2**i
            if t_power <= T:
                self.precomputed_powers[t_power] = self.get_or_compute(t_power)

    def get_or_compute(self, t):
        if t in self.cache:
            return self.cache[t]
        information_gain = compute_information_gain(self.Phi, t, self.gamma)
        self.cache[t] = information_gain
        return information_gain

    def get(self, t):
        if t in self.cache:
            return self.cache[t]

        t_left = 2 ** int(np.log2(t))
        if t_left in self.precomputed_powers:
            dimension_left = self.precomputed_powers[t_left]
        else:
            dimension_left = self.get_or_compute(t_left)

        if t_left == t:
            return dimension_left

        t_right = min(self.T, t_left * 2)
        if t_right in self.precomputed_powers:
            dimension_right = self.precomputed_powers[t_right]
        else:
            dimension_right = self.get_or_compute(t_right)

        p = (t - t_left) / (t_right - t_left)
        result = dimension_left * (1 - p) + dimension_right * p
        self.cache[t] = result 
        return result

    def get_exact(self, t):
        if t in self.cache:
            return self.cache[t]
        return self.get_or_compute(t)
