import jax.numpy as jnp


def generalized_inv(evs, V, n):
    tol = evs.max() * V.shape[0] * jnp.finfo(V.dtype).eps
    r = jnp.sum(evs >= tol)

    if r == n:
        inv = V @ jnp.diag(1/evs) @ V.T

    else:
        evs_inv = jnp.where(evs > tol, 1/evs, jnp.zeros_like(evs))
        inv = V @ jnp.diag(evs_inv) @ V.T

    return inv, r


def leave_one_out(evs, V, r, n, y, t=None, reg=None):
    if reg is None:
        if r <= n-1:
            if t is None or t == jnp.inf:
                V_red = V[:, r:]
                A = V_red @ V_red.T
                beta = 1 / jnp.diag(A)**2

                res = 1/n * beta.T @ (A @ y)**2

            else:
                V_red = V[:, r:]
                A = V_red @ V_red.T
                beta = 1 / (jnp.diag(A) ** 2 + jnp.exp(-evs * t))

                res = 1 / n * beta.T @ (A @ y) ** 2

        else:
            if t is None or t == jnp.inf:
                A = V @ jnp.diag(1/evs) @ V.T
                beta = 1 / jnp.diag(A) ** 2

                res = 1 / n * beta.T @ (A @ y) ** 2

            else:
                A = V @ jnp.diag(jnp.exp(-evs * t)) @ V.T
                beta = 1 / jnp.diag(A) ** 2

                res = 1 / n * beta.T @ (A @ y) ** 2

    return res


def acc(targets, preds):
    return jnp.mean(jnp.sign(preds) == targets, axis=0)


def acc_multi(targets, preds):
    return jnp.mean(jnp.argmax(targets, axis=2) == jnp.argmax(preds, axis=2), axis=1)


def leave_one_out_reg(kernel_train_train, preds, y_train, reg):
    n = kernel_train_train.shape[0]
    A = kernel_train_train @ jnp.linalg.inv(kernel_train_train + reg * jnp.eye(n))
    cross = jnp.mean(((y_train - preds) / (1 - jnp.reshape(jnp.diag(A), (-1, 1)))) ** 2)

    return cross


def predict_fn(x, t, kernel_fn, data):
    kernel_train_train = kernel_fn(data.x_train, data.x_train, 'ntk')
    kernel_train_x = kernel_fn(x, data.x_train, 'ntk')

    eigh, V = jnp.linalg.eigh(kernel_train_train)
    eigh = jnp.flip(eigh)
    V = jnp.flip(V, axis=1)

    tol = eigh.max() * V.shape[0] * jnp.finfo(V.dtype).eps
    r = jnp.sum(eigh >= tol)

    if r == data.n_train:
        eigh_inv = jnp.diag(1 / eigh)

    else:
        eigh_inv = jnp.where(eigh > tol, 1 / eigh, jnp.zeros_like(eigh))

    if t is not None or t < jnp.inf:
        pred = kernel_train_x @ V @ eigh_inv @ (jnp.eye(data.n_train) - jnp.diag(jnp.exp(-eigh * t))) @ V.T @ data.y_train

    else:
        pred = kernel_train_x @ V @ eigh_inv @ V.T @ data.y_train

    return pred
