import jax.numpy as jnp
from jax import random, jit
from neural_tangents import stax
import neural_tangents as nt
import os


save_path = os.path.dirname(os.path.realpath(__file__)) + '/store'


class RandomFeatures:
    def __init__(self, m, dim, key, x_train, x_test, reg=None, Z=None):
        self.m = m
        self.dim = dim
        self.key = key
        self.x_train = x_train
        self.x_test = x_test
        self.reg = reg
        self.Z = Z

        self.n_train = self.x_train.shape[0]
        self.s = 100000
        self.min = jnp.minimum(self.n_train, self.m)

        # Initialize random features
        if self.Z is None:
            self.Z = random.normal(key=self.key, shape=(self.dim, self.m))

        if self.m <= self.s:
            # Compute the feature map once on the training dataset
            self.feat_train = self.feature_map(self.x_train)

            # Compute the feature map once on the test dataset
            self.feat_test = self.feature_map(self.x_test)

            # Compute the SVD once
            self.V, self.svs, self.U = jnp.linalg.svd(self.feat_train)
        else:
            self.kernel_train = self.kernel_fn(self.x_train, self.x_train)
            self.kernel_train_test = self.kernel_fn(self.x_test, self.x_train)
            self.evs, self.V = jnp.linalg.eigh(self.kernel_train)
            self.evs = jnp.flip(self.evs)
            self.V = jnp.flip(self.V, axis=1)

        if self.m <= self.n_train:
            self.svs = jnp.concatenate([jnp.diag(self.svs)] + [jnp.zeros(shape=(1, self.m)) for _ in range(self.n_train - self.min)], axis=0)

        else:
            self.svs = jnp.concatenate(
                [jnp.diag(self.svs)] + [jnp.zeros(shape=(self.n_train, 1)) for _ in range(self.m - self.n_train)], axis=1)

        # Invert the associated eigenvalues
        if self.reg is None:
            if self.m <= self.s:
                self.evs = self.svs @ self.svs.T
            tol = self.evs.max() * self.U.shape[0] * jnp.finfo(self.V.dtype).eps
            self.r = jnp.sum(self.evs >= tol)
            self.evs_inv = jnp.where(self.evs > tol, 1 / self.evs, jnp.zeros_like(self.evs))

        else:
            self.r = self.n_train
            if self.m <= self.s:
                self.evs_no = self.svs @ self.svs.T
            self.evs = self.evs + self.reg * jnp.eye(self.n_train)
            self.evs_inv = jnp.diag(1 / jnp.diag(self.evs))
        print(self.r)
    def feature_map(self, x):
        return 1/jnp.sqrt(self.m) * jnp.abs(x @ self.Z)

    def kernel_fn(self, x, z):
        x_feat = self.feature_map(x)
        z_feat = self.feature_map(z)

        return x_feat @ z_feat.T

    def predict_fn(self, x, labels, t=None, train=False, test=False):
        if train:
            if self.m <= self.s:
                feat_x = self.feat_train
            else:
                kernel_x = self.kernel_train
        elif test:
            if self.m <= self.s:
                feat_x = self.feat_test
            else:
                kernel_x = self.kernel_train_test
        else:
            if self.m <= self.s:
                feat_x = self.feature_map(x)
            else:
                kernel_x = self.kernel_fn(x, self.x_train)

        if t == jnp.inf:
            if self.m <= self.s:
                A = self.U.T @ self.svs.T @ self.evs_inv @ self.V.T
                return feat_x @ A @ labels
            else:
                A_inf = self.V @ jnp.diag(self.evs_inv) @ self.V.T

                return kernel_x @ A_inf @ labels

        else:
            exp_t = jnp.diag(jnp.exp(-t * jnp.diag(self.evs)))
            A_t = self.U.T @ self.svs.T @ self.evs_inv @ (jnp.eye(self.n_train) - exp_t) @ self.V.T

            return feat_x @ A_t @ labels

    def leave_one_out(self, labels, pred_labels=None):
        if pred_labels is None:
            pred_labels = labels

        V_r_n = self.V[:, self.r:]
        V_0_r = self.V[:, :self.r]
        A_r_n = V_r_n @ V_r_n.T

        if self.r <= self.n_train-1:
            if self.reg is None:
                beta = jnp.expand_dims(1 / jnp.diag(A_r_n), 1)
                alpha = beta * (A_r_n @ labels)  # + pred_labels - labels

                res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=(1, 2))
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=(1, 2))

            else:
                evs_0_r_no = jnp.diag(self.evs_no)[:self.r]
                A_0_r = V_0_r @ jnp.diag(self.reg/(evs_0_r_no + self.reg)) @ V_0_r.T
                beta = 1 / (jnp.diag(A_r_n) + jnp.diag(A_0_r)) ** 2
                alpha = beta * ((A_r_n + A_0_r) @ labels)

                res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=(1, 2))
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=(1, 2))
        else:
            A_0_r = self.V @ jnp.diag(1/jnp.diag(self.evs)) @ self.V.T
            beta = jnp.expand_dims(1 / jnp.diag(A_0_r), axis=1)
            alpha = beta * (A_0_r @ labels)

            res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=(1, 2))
            acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=(1, 2))

        return res, acc


class NeuralTangent:
    def __init__(self, dim, x_train, x_test, depth, reg=None, bs=None, classes=2):
        self.dim = dim
        self.x_train = x_train
        self.x_test = x_test
        self.depth = depth
        self.reg = reg
        self.bs = bs
        self.classes = classes

        self.n_train = self.x_train.shape[0]

        # Initialize kernel
        architecture = []
        for i in range(depth-1):
            architecture += [stax.Dense(512, W_std=1., b_std=0.0), stax.Relu()]
        architecture += [stax.Dense(self.classes, W_std=1., b_std=0.0)]

        init_fn, apply_fn, self.kernel_fn_get = stax.serial(*architecture)

        if self.bs is not None:
            self.kernel_fn_batched = nt.batch(self.kernel_fn_get, device_count=-1, batch_size=self.bs)
        else:
            self.kernel_fn_batched = self.kernel_fn_get

        # Compute the kernel once on the training dataset
        self.kernel_train = self.kernel_fn_get(x_train, x_train, 'ntk')

        # Compute the kernel once on the test set
        self.kernel_train_test = self.kernel_fn_get(x_test, x_train, 'ntk')

        # Compute the SVD once
        self.evs, self.V = jnp.linalg.eigh(self.kernel_train)
        self.evs = jnp.flip(self.evs)
        self.V = jnp.flip(self.V, axis=1)

        tol = self.evs.max() * self.V.shape[0] * jnp.finfo(self.V.dtype).eps
        self.r = jnp.sum(self.evs >= tol)

        if self.r < self.n_train:
            if self.reg is None:
                self.evs_inv = jnp.where(self.evs > tol, 1 / self.evs, jnp.zeros_like(self.evs))
            else:
                self.evs_inv = jnp.where(self.evs > tol, 1 / (self.evs + self.reg), jnp.zeros_like(self.evs))
        else:
            if self.reg is None:
                self.evs_inv = 1 / self.evs
            else:
                self.evs_inv = 1 / (self.evs + self.reg)

    def kernel_fn(self, x, z):
        return self.kernel_fn_get(x, z, 'ntk')

    def predict_fn(self, x, labels, t=None, train=False, test=False):
        if train:
            kernel_x = self.kernel_train
        elif test:
            kernel_x = self.kernel_train_test
        else:
            kernel_x = self.kernel_fn(x, self.x_train)

        if t == jnp.inf:
            A_inf = self.V @ jnp.diag(self.evs_inv) @ self.V.T

            return kernel_x @ A_inf @ labels

        else:
            exp_t = jnp.exp(-t * self.evs)
            A_t = self.V  @ jnp.diag(self.evs_inv * (jnp.eye(self.n_train) - exp_t)) @ self.V.T

            return kernel_x @ A_t @ labels

    def leave_one_out(self, labels, pred_labels=None):
        if pred_labels is None:
            pred_labels = labels

        evs_0_r = jnp.diag(self.evs)[:self.r]
        V_r_n = self.V[:, self.r:]
        V_0_r = self.V[:, :self.r]
        A_r_n = V_r_n @ V_r_n.T

        if self.r <= self.n_train-1:
            if self.reg is None:
                beta = jnp.expand_dims(1 / jnp.diag(A_r_n), 1)
                alpha = beta * (A_r_n @ labels)  # + pred_labels - labels

                res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=0)

                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=0)

            else:
                A_0_r = V_0_r @ jnp.diag(self.reg/(evs_0_r)) @ V_0_r.T
                beta = 1 / (jnp.diag(A_r_n) + jnp.diag(A_0_r)) ** 2

                res = 1 / self.n_train * beta.T @ ((A_r_n + A_0_r) @ labels) ** 2

        else:
            if self.reg is None:
                A_0_r = self.V @ jnp.diag(1/self.evs) @ self.V.T
            else:
                A_0_r = self.V @ jnp.diag(1 / (self.evs + self.reg)) @ self.V.T
            beta = jnp.expand_dims(1 / jnp.diag(A_0_r), 1)
            alpha = beta * (A_0_r @ labels) #+ pred_labels - labels

            res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels)**2, axis=(1, 2))
            if self.classes == 2:
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=0)
            else:
                preds = jnp.argmax(labels - alpha, axis=2)
                targets = jnp.argmax(pred_labels, axis=2)
                acc = jnp.mean(preds == targets, axis=1)

        return res, acc

    def save_kernel(self, mode):
        n_train = self.x_train.shape[0]
        name = '/NTK_Fully' + str(self.depth) + 'num_samples_' + str(n_train) + mode
        if mode == 'train':
            jnp.save(save_path + name, self.kernel_train)
        else:
            jnp.save(save_path + name, self.kernel_train_test)


class NeuralTangentConv:
    def __init__(self, dim, widths, filter_shapes, strides, x_train, x_test, depth, reg=None, bs=None, classes=2):
        self.dim = dim
        self.widths = widths
        self.filter_shapes = filter_shapes
        self.strides = strides
        self.x_train = x_train
        self.x_test = x_test
        self.depth = depth
        self.reg = reg
        self.bs = bs
        self.classes = classes

        self.n_train = self.x_train.shape[0]

        # Initialize kernel
        architecture = []
        for i in range(self.depth - 1):
            architecture += [stax.Conv(widths[i], filter_shapes[i], strides[i]), stax.Relu()]

        architecture += [stax.Flatten(), stax.Dense(self.classes)]

        init_fn, apply_fn, self.kernel_fn_get = stax.serial(
            *architecture
        )

        if self.bs is not None:
            self.kernel_fn_batched = nt.batch(self.kernel_fn_get, device_count=-1, batch_size=self.bs)
        else:
            self.kernel_fn_batched = self.kernel_fn_get

        # Compute the kernel once on the training dataset
        self.kernel_train = self.kernel_fn_get(x_train, x_train, 'ntk')

        # Compute the kernel once on the test set
        self.kernel_train_test = self.kernel_fn_get(x_test, x_train, 'ntk')

        # Compute the SVD once
        self.evs, self.V = jnp.linalg.eigh(self.kernel_train)
        self.evs = jnp.flip(self.evs)
        self.V = jnp.flip(self.V, axis=1)

        tol = self.evs.max() * self.V.shape[0] * jnp.finfo(self.V.dtype).eps
        self.r = jnp.sum(self.evs >= tol)
        print(self.r)

        if self.r < self.n_train:
            self.evs_inv = jnp.where(self.evs > tol, 1 / self.evs, jnp.zeros_like(self.evs))
        else:
            self.evs_inv = 1 / self.evs

    def kernel_fn(self, x, z):
        return self.kernel_fn_get(x, z, 'ntk')

    def predict_fn(self, x, labels, t=None, train=False, test=False):
        if train:
            kernel_x = self.kernel_train
        elif test:
            kernel_x = self.kernel_train_test
        else:
            kernel_x = self.kernel_fn(x, self.x_train)

        if t == jnp.inf:
            A_inf = self.V @ jnp.diag(self.evs_inv) @ self.V.T

            return kernel_x @ A_inf @ labels

        else:
            exp_t = jnp.exp(-t * self.evs)
            A_t = self.V  @ jnp.diag(self.evs_inv * (jnp.eye(self.n_train) - exp_t)) @ self.V.T

            return kernel_x @ A_t @ labels

    def leave_one_out(self, labels, pred_labels=None):
        if pred_labels is None:
            pred_labels = labels

        evs_0_r = jnp.diag(self.evs)[:self.r]
        V_r_n = self.V[:, self.r:]
        V_0_r = self.V[:, :self.r]
        A_r_n = V_r_n @ V_r_n.T

        if self.r <= self.n_train-1:
            if self.reg is None:
                beta = jnp.expand_dims(1 / jnp.diag(A_r_n), 1)
                alpha = beta * (A_r_n @ labels)  # + pred_labels - labels

                res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=0)

                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=0)

            else:
                A_0_r = V_0_r @ jnp.diag(self.reg/(evs_0_r)) @ V_0_r.T
                beta = 1 / (jnp.diag(A_r_n) + jnp.diag(A_0_r)) ** 2

                res = 1 / self.n_train * beta.T @ ((A_r_n + A_0_r) @ labels) ** 2

        else:
            if self.reg is None:
                A_0_r = self.V @ jnp.diag(1/self.evs) @ self.V.T
            else:
                A_0_r = self.V @ jnp.diag(1 / (self.evs + self.reg)) @ self.V.T
            beta = jnp.expand_dims(1 / jnp.diag(A_0_r), 1)
            alpha = beta * (A_0_r @ labels) #+ pred_labels - labels

            res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels)**2, axis=(1, 2))
            if self.classes == 2:
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=0)
            else:
                preds = jnp.argmax(labels - alpha, axis=2)
                targets = jnp.argmax(pred_labels, axis=2)
                acc = jnp.mean(preds == targets, axis=1)

        return res, acc

    def save_kernel(self, mode):
        n_train = self.x_train.shape[0]
        name = '/NTK_Conv' + str(self.depth) + 'num_samples_' + str(n_train) + mode
        if mode == 'train':
            jnp.save(save_path + name, self.kernel_train)
        else:
            jnp.save(save_path + name, self.kernel_train_test)


class FeatureMapKernel:
    def __init__(self, feature_map, x_train, x_test, reg=None, classes=2):
        self.feature_map = feature_map
        self.x_train = x_train
        self.x_test = x_test
        self.reg = reg
        self.classes = classes

        self.n_train = self.x_train.shape[0]

        # Compute the kernel once on the training dataset
        self.kernel_train = self.kernel_fn(x_train, x_train)

        # Compute the kernel once on the test set
        self.kernel_train_test = self.kernel_fn(x_test, x_train)

        # Compute the SVD once
        self.evs, self.V = jnp.linalg.eigh(self.kernel_train)
        self.evs = jnp.flip(self.evs)
        self.V = jnp.flip(self.V, axis=1)

        tol = self.evs.max() * self.V.shape[0] * jnp.finfo(self.V.dtype).eps
        self.r = jnp.sum(self.evs >= tol)

        if self.r < self.n_train and self.reg is None:
            self.evs_inv = jnp.where(self.evs > tol, 1 / self.evs, jnp.zeros_like(self.evs))
        else:
            if self.reg is None:
                self.evs_inv = 1 / self.evs
            else:
                self.evs_inv = 1 / (self.evs + self.reg)

    def kernel_fn(self, x, z):
        x_feat = self.feature_map(x)
        z_feat = self.feature_map(z)

        return x_feat @ z_feat.T

    def predict_fn(self, x, labels, t=None, train=False, test=False):
        if train:
            kernel_x = self.kernel_train
        elif test:
            kernel_x = self.kernel_train_test
        else:
            kernel_x = self.kernel_fn(x, self.x_train)

        if t == jnp.inf:
            A_inf = self.V @ jnp.diag(self.evs_inv) @ self.V.T

            return kernel_x @ A_inf @ labels

        else:
            exp_t = jnp.exp(-t * self.evs)
            A_t = self.V  @ jnp.diag(self.evs_inv * (jnp.eye(self.n_train) - exp_t)) @ self.V.T

            return kernel_x @ A_t @ labels

    def leave_one_out(self, labels, pred_labels=None):
        if pred_labels is None:
            pred_labels = labels

        evs_0_r = jnp.diag(self.evs)[:self.r]
        V_r_n = self.V[:, self.r:]
        V_0_r = self.V[:, :self.r]
        A_r_n = V_r_n @ V_r_n.T

        if self.r <= self.n_train-1:
            if self.reg is None:
                beta = jnp.expand_dims(1 / jnp.diag(A_r_n), 1)
                alpha = beta * (A_r_n @ labels)  # + pred_labels - labels

                if self.classes == 2:
                    res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=1)
                    acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=0)
                else:
                    res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=(1, 2))
                    preds = jnp.argmax(labels - alpha, axis=2)
                    targets = jnp.argmax(pred_labels, axis=2)
                    acc = jnp.mean(preds == targets, axis=1)

            else:
                A_0_r = V_0_r @ jnp.diag(self.reg/(self.reg + evs_0_r)) @ V_0_r.T
                beta = 1 / (jnp.diag(A_r_n) + jnp.diag(A_0_r)) ** 2

                res = 1 / self.n_train * beta.T @ ((A_r_n + A_0_r) @ labels) ** 2

        else:
            if self.reg is None:
                A_0_r = self.V @ jnp.diag(1/self.evs) @ self.V.T
            else:
                A_0_r = self.V @ jnp.diag(1 / (self.evs + self.reg)) @ self.V.T
            beta = jnp.expand_dims(1 / jnp.diag(A_0_r), 1)
            alpha = beta * (A_0_r @ labels)

            res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels)**2, axis=(1, 2))
            if self.classes == 2:
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=(1, 2))
            else:
                preds = jnp.argmax(labels - alpha, axis=2)
                targets = jnp.argmax(pred_labels, axis=2)
                acc = jnp.mean(preds == targets, axis=1)

        return res, acc


class EmpiricalNeuralTangent:
    def __init__(self, model, params, x_train, x_test, reg=None, bs=None, classes=2):
        self.model = model
        self.params = params
        self.x_train = x_train
        self.x_test = x_test
        self.reg = reg
        self.bs = bs
        self.classes = classes

        self.n_train = self.x_train.shape[0]

        self.kernel_fn_get = nt.empirical_kernel_fn(self.model, trace_axes=(-1,), implementation=2, vmap_axes=0)

        if self.bs is not None:
            self.kernel_fn = nt.batch(self.kernel_fn_get, device_count=-1, batch_size=self.bs)
        else:
            self.kernel_fn = self.kernel_fn_get

        # Compute the kernel once on the training dataset
        self.kernel_train = self.kernel_fn(x_train, x_train, 'ntk', params)

        # Compute the kernel once on the test set
        self.kernel_train_test = self.kernel_fn(x_test, x_train, 'ntk', params)

        # Compute the SVD once
        self.evs, self.V = jnp.linalg.eigh(self.kernel_train)
        self.evs = jnp.flip(self.evs)
        self.V = jnp.flip(self.V, axis=1)

        tol = self.evs.max() * self.V.shape[0] * jnp.finfo(self.V.dtype).eps
        self.r = jnp.sum(self.evs >= tol)
        print(self.r)
        if self.r < self.n_train and self.reg is None:
            self.evs_inv = jnp.where(self.evs > tol, 1 / self.evs, jnp.zeros_like(self.evs))
        else:
            if self.reg is None:
                self.evs_inv = 1 / self.evs
            else:
                self.evs_inv = 1 / (self.evs + self.reg)

    def predict_fn(self, x, labels, t=None, train=False, test=False):
        if train:
            kernel_x = self.kernel_train
        elif test:
            kernel_x = self.kernel_train_test
        else:
            kernel_x = self.kernel_fn(x, self.x_train)

        if t == jnp.inf:
            A_inf = self.V @ jnp.diag(self.evs_inv) @ self.V.T

            return kernel_x @ A_inf @ labels

        else:
            exp_t = jnp.exp(-t * self.evs)
            A_t = self.V  @ jnp.diag(self.evs_inv * (jnp.eye(self.n_train) - exp_t)) @ self.V.T

            return kernel_x @ A_t @ labels

    def leave_one_out(self, labels, pred_labels=None):
        if pred_labels is None:
            pred_labels = labels

        evs_0_r = jnp.diag(self.evs)[:self.r]
        V_r_n = self.V[:, self.r:]
        V_0_r = self.V[:, :self.r]
        A_r_n = V_r_n @ V_r_n.T

        if self.r <= self.n_train-1:
            if self.reg is None:
                beta = jnp.expand_dims(1 / jnp.diag(A_r_n), 1)
                alpha = beta * (A_r_n @ labels)  # + pred_labels - labels

                res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels) ** 2, axis=1)

                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=1)

            else:
                A_0_r = V_0_r @ jnp.diag(self.reg/(self.reg + evs_0_r)) @ V_0_r.T
                beta = 1 / (jnp.diag(A_r_n) + jnp.diag(A_0_r)) ** 2

                res = 1 / self.n_train * beta.T @ ((A_r_n + A_0_r) @ labels) ** 2

        else:
            if self.reg is None:
                A_0_r = self.V @ jnp.diag(1/self.evs) @ self.V.T
            else:
                A_0_r = self.V @ jnp.diag(1 / (self.evs + self.reg)) @ self.V.T
            beta = jnp.expand_dims(1 / jnp.diag(A_0_r), 1)
            alpha = beta * (A_0_r @ labels)

            res = 1 / self.n_train * jnp.sum((alpha + pred_labels - labels)**2, axis=(1, 2))
            if self.classes == 2:
                acc = jnp.mean((pred_labels * alpha) < pred_labels * labels, axis=(1, 2))
            else:
                preds = jnp.argmax(labels - alpha, axis=2)
                targets = jnp.argmax(pred_labels, axis=2)
                acc = jnp.mean(preds == targets, axis=1)

        return res, acc

    def save_kernel(self, mode):
        n_train = self.x_train.shape[0]
        name = '/NTK_Fully' + str(self.depth) + 'num_samples_' + str(n_train) + mode
        if mode == 'train':
            jnp.save(save_path + name, self.kernel_train)
        else:
            jnp.save(save_path + name, self.kernel_train_test)