import numpy as np
import ipdb
import torch
import torchvision
import torchvision.transforms as transforms
from torch.nn.functional import one_hot
from kernel import laplacian, gaussian
import unitnorm_svd
import matplotlib.pyplot as plt
import scipy.linalg as linalg
import falkon
from hilbert_projection import HilbertProjection
from fast_pytorch_kmeans import KMeans
import pickle
import time

class cifar10():
    def __init__(self, n_subsamples,device = torch.device('cpu') ):
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                download=True, transform=transform)
        train_ind = np.random.choice(range(trainset.data.shape[0]), n_subsamples,replace=False)
        self.train_x = torch.from_numpy(trainset.data[train_ind]).reshape(-1, 3 * 32 * 32) / 255.0

        self.train_x = self.train_x.to(device)
        self.train_y = one_hot(torch.tensor(trainset.targets)[train_ind]).to(device)

        testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                               download=True, transform=transform)

        self.test_x = torch.from_numpy(testset.data).reshape(-1, 3 * 32 * 32) / 255.0
        self.test_x = self.test_x.to(device)
        self.test_y = one_hot(torch.tensor(testset.targets)).to(device)


class MNIST():
    def __init__(self, n_subsamples=60_000,device = torch.device('cpu') ):
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                                download=True, transform=transform)
        train_ind = np.random.choice(range(trainset.data.shape[0]), n_subsamples,replace=False)
        self.train_x = trainset.data[train_ind].reshape(-1, 28*28) / 255.0

        self.train_x = self.train_x.to(device)
        self.train_y = one_hot(torch.tensor(trainset.targets)[train_ind]).to(device)

        testset = torchvision.datasets.MNIST(root='./data', train=False,
                                               download=True, transform=transform)

        self.test_x = testset.data.reshape(-1, 28*28) / 255.0
        self.test_x = self.test_x.to(device)
        self.test_y = one_hot(torch.tensor(testset.targets)).to(device)

class Fashion():
    def __init__(self, n_subsamples=60_000,device = torch.device('cpu') ):
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        trainset = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                                download=True, transform=transform)
        train_ind = np.random.choice(range(trainset.data.shape[0]), n_subsamples,replace=False)
        self.train_x = trainset.data[train_ind].reshape(-1, 28*28) / 255.0

        self.train_x = self.train_x.to(device)
        self.train_y = one_hot(torch.tensor(trainset.targets)[train_ind]).to(device)

        testset = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                               download=True, transform=transform)

        self.test_x = testset.data.reshape(-1, 28*28) / 255.0
        self.test_x = self.test_x.to(device)
        self.test_y = one_hot(torch.tensor(testset.targets)).to(device)

class cifar10mobilenetDataset():
    def __init__(self,num_knots=50_000,
                 DATADIR='',
                 **kwargs):
        print("cifar10mobilenetDataset")

        self.train_x = torch.load(DATADIR+"ciar10_mobilenetv2_100_feature_train.pt")
        self.train_y = one_hot(torch.load(DATADIR+"ciar10_mobilenetv2_100_y_train.pt").long())

        self.test_x = torch.load(DATADIR + "ciar10_mobilenetv2_100_feature_test.pt")
        self.test_y  = torch.load(DATADIR+"ciar10_mobilenetv2_100_y_test.pt")
        self.test_y = one_hot(self.test_y.long())



def mse(y, y_hat):
    loss_mse = torch.nn.MSELoss(reduction='mean')
    return loss_mse(y, y_hat).item()


def GD(cifar10_dataset, kernel_fn, knots, epochs=100, lr=0.01):

    x_train = cifar10_dataset.train_x
    n_train = x_train.shape[0]
    y_train = cifar10_dataset.train_y

    x_test = cifar10_dataset.test_x
    y_test = cifar10_dataset.test_y

    n_labels = y_train.shape[1]
    p = knots.shape[0]
    alpha = torch.zeros((p, n_labels))
    loss = []
    accu=[]
    k_xz = kernel_fn(x_train, knots)
    kzxkxz = k_xz.T @ k_xz


    vals, vecs = linalg.eigh(kzxkxz,
                             eigvals=(p-1, p - 1))
    lr = 0.99 / vals.item()


    for t in range(epochs):
        loss.append(mse(k_xz @ alpha, y_train))
        g = k_xz @ alpha - y_train
        alpha = alpha - lr * k_xz.T @ (g)

        K_testz = kernel_fn(x_test, knots)
        test_pred = K_testz @ alpha
        y_hat = np.argmax(test_pred,axis=1)
        y_true = np.argmax(y_test,axis=1)
        accu.append(100*torch.sum(y_hat==y_true)/y_hat.shape[0])

    return loss,accu


def nystrom(nystrom_samples, knots, kernel_fn, s=1000, q=10):
    nystrom_size = int(s)
    level = int(q)

    Lam_x, E_x = unitnorm_svd.nystrom_kernel_svd(
        nystrom_samples,
        kernel_fn, level + 1
    )

    tail_eig_x = Lam_x[level]
    Lam_x = Lam_x[:level]
    E_x = E_x[:, :level]
    D_x = (1 - tail_eig_x / Lam_x) / Lam_x / nystrom_size
    Kmat_xs_z = kernel_fn(nystrom_samples, knots)
    preconditioner_matrix = Kmat_xs_z.T @ (D_x * E_x)
    return preconditioner_matrix, E_x, 1 / (2 * tail_eig_x)


def EP3(cifar10_dataset, kernel_fn, knots, epochs=100, s=1000, q=10, lr=0.05):

    x_train = cifar10_dataset.train_x
    y_train = cifar10_dataset.train_y
    x_test = cifar10_dataset.test_x
    y_test = cifar10_dataset.test_y
    n_labels = y_train.shape[1]
    p = knots.shape[0]
    alpha = torch.zeros((p, n_labels))
    loss = []
    accu = []
    k_xz = kernel_fn(x_train, knots)

    nystrom_ind = np.random.choice(x_train.shape[0], s)
    nystrom_samples = x_train[nystrom_ind]
    K_xsx = kernel_fn(nystrom_samples, x_train)
    preconditioner_matrix, E_x, lr = nystrom(nystrom_samples, knots, kernel_fn, s=s, q=q)
    lr /= x_train.shape[0]

    h_projection = HilbertProjection(kernel_fn,
                      knots.to(torch.device('cuda:0')), n_labels, device=torch.device('cuda:0'))

    for t in range(epochs):
        print(f'number of epoch={t}')
        loss.append(mse(k_xz @ alpha, y_train))
        g = k_xz @ alpha - y_train
        gz_projection = k_xz.T @ g - preconditioner_matrix @ ((E_x.T @ K_xsx) @ g)
        theta, _ = h_projection.fit_hilbert_projection(
        knots.to(torch.device('cuda:0')),
        gz_projection.to(torch.device('cuda:0')), mem_gb=33,
        return_log=False)
        alpha = alpha - lr *theta.cpu()

        K_testz = kernel_fn(x_test, knots)
        test_pred = K_testz @ alpha
        y_hat = np.argmax(test_pred,axis=1)
        y_true = np.argmax(y_test,axis=1)
        accu.append(100*torch.sum(y_hat==y_true)/y_hat.shape[0])
    return loss,accu

def falkon_run(cifar10_dataset, kernel_fn,options, p=100,epochs=[100],lambda_reg=5e-3,bw=5.0):

    x_train = cifar10_dataset.train_x
    y_train = cifar10_dataset.train_y.float()
    x_test = cifar10_dataset.test_x
    y_test = cifar10_dataset.test_y



    loss = []
    accu = []

    for t in [epochs]:
        print(f't is:{t}')
        flk = falkon.Falkon(kernel=kernel_fn, penalty=lambda_reg, M=p, options=options,maxiter=t)

        flk.fit(x_train, y_train)
        train_pred = flk.predict(x_train)
        loss.append(mse(train_pred, y_train))

        test_pred = flk.predict(x_test)
        y_hat = np.argmax(test_pred,axis=1)
        y_true = np.argmax(y_test,axis=1)
        # ipdb.set_trace()
        accu.append(100*torch.sum(y_hat==y_true)/y_hat.shape[0])

    return loss,accu


if __name__ == "__main__":
    # torch.manual_seed(0)
    # np.random.seed(0)

    n_train = 60000
    dataset  =  MNIST(n_train)
    name = 'mnist'
    bw = 5.0
    kernel_fn = lambda x, y: laplacian(x, y, bandwidth=bw)
    p = 1000
    s = int(n_train/10)
    q = int(s / 10)
    epochs = 100
    runs = 1

    ###### regulizer ######
    lambda_reg = 0#1e-5
    print(f'reg. is: {lambda_reg }')
    print(f'bandwidth is: {bw}')

    

    falkon_acc = []
    falkon_time = []
    for r in range(runs):
        print(f'run {r}')
        t_start = time.time()
        options = falkon.FalkonOptions(keops_active="yes", use_cpu="False")
        kernel_fn_flk = falkon.kernels.LaplacianKernel(sigma=bw, opt=options)
        Falkon_loss,accu_falkon = falkon_run(dataset,kernel_fn_flk,options, p=p, epochs=epochs, lambda_reg=lambda_reg,bw=bw)
        t_end = time.time()
        falkon_acc.append(accu_falkon[-1])
        falkon_time.append(t_end -t_start)
        
    print("Falkon done.")


    EP3_accu_random =[]
    EP3_accu_kmeans = []

    for r in range(runs):
        print(f"run = {r}")


        knots_ind = np.random.choice(n_train, p, replace=False)
        knots_random = dataset.train_x[knots_ind]
        GD_loss_random, accu_gd_random = GD(dataset, kernel_fn, knots_random, epochs)
        EP3_loss_random,accu_ep3_random = EP3(dataset, kernel_fn, knots_random, s=s, q=q, epochs=epochs)


        kmeans = KMeans(n_clusters=p, mode='euclidean', verbose=1)
        labels = kmeans.fit_predict(dataset.train_x)

        knots_kmeans = kmeans.centroids
        GD_loss_kmeans, accu_gd_kmeans = GD(dataset, kernel_fn, knots_kmeans, epochs)
        EP3_loss_kmeans, accu_ep3_kmeans = EP3(dataset, kernel_fn, knots_kmeans, s=s, q=q, epochs=epochs)
        EP3_accu_random.append(accu_ep3_random[-1])
        EP3_accu_kmeans.append(accu_ep3_kmeans[-1])

    print(f'FALKON accuracy(p={p}--n={n_train}): mean ={np.mean(falkon_acc)}--std={np.std(falkon_acc)}')
    print(f'EP3_random accuracy(p={p}--n={n_train}): mean ={np.mean(EP3_accu_random)}--std={np.std(EP3_accu_random)}')
    print(f'EP3_kmeans accuracy(p={p}--n={n_train}): mean ={np.mean(EP3_accu_kmeans)}--std={np.std(EP3_accu_kmeans)}')
    print(f'EP3_time accuracy(p={p}--n={n_train}): mean ={np.mean(EP3_time)}--std={np.std(EP3_time)}')
    print(f'Falkon_time accuracy(p={p}--n={n_train}): mean ={np.mean(falkon_time)}--std={np.std(falkon_time)}')





    results = {"accu_gd_random":accu_gd_random,"accu_gd_kmeans":accu_gd_kmeans}
    with open(f'{name}-accu-gd.pickle', 'wb') as f_pkl:
        pickle.dump(results, f_pkl)



    plot_size = 10
    fig, ax = plt.subplots(1, 1, figsize=(plot_size, plot_size))


    ax.plot(range(epochs), accu_gd_random, label=f'GD--random')
    ax.plot(range(epochs), accu_ep3_random, label=f'EP3--random')
    ax.plot(range(epochs), accu_gd_kmeans, label=f'GD--kmeans')
    ax.plot(range(epochs), accu_ep3_kmeans, label=f'EP3--kmeans')
    ax.plot(range(epochs), accu_falkon, label=f'Falkon')


    ax.set_xlabel("Number of epochs")
    ax.set_ylabel("Train loss (MSE)")
    plt.grid()

    plt.legend()

    fig.savefig(f'{name}-baseline-gd.png', format='png', bbox_inches="tight")









