import torch, ipdb
import numpy as np
from datetime import datetime
from ..utils.datasets import taxidataset_all,Cifar5mDataset,dataset_custom,Cifar5mmobilenetDataset,\
    mnist8mDataset,susy,imagenetmobilenetDataset,HIGGS
import pickle
from ..utils.kernel import gaussian, laplacian,ntk_1layer
from ..kernel_model import KernelModel
from ..gradient_correction import GradientCorrection
import timeit
from ..eigenpro_2.vanilla_eigenpro import vanilaep
from torch.nn.functional import one_hot
import ipdb
#from fast_pytorch_kmeans import KMeans
from ..utils.cifar10 import Cifar10DataLoader,Cifar10Dataset


def run(
        batch_size=50000, num_knots=5000, n_train=50000,
        nystrom_subsample_size=1000, preconditioner_level=30,
        projection_cutoff=1e-5, bandwidth=5, epochs=10,
        kernel_type="gaussian", seed=0,
        wandb_run=None, learning_rate_prefactor=0.1,n_gpus_outer = 3,n_gpus_inner = 1,
        dataset_name = "taxi",load_checkpoint = None,knot_include=0,EP2=0
):
    torch.manual_seed(seed)
    np.random.seed(seed)

    use_cuda = torch.cuda.is_available()
    devices = []
    for g in range(n_gpus_outer):
        devices.append( torch.device(f'cuda:{g}' if use_cuda else "cpu") )

    projection_device = []
    for g in range(n_gpus_inner):
        projection_device.append(torch.device(f'cuda:{n_gpus_outer+g}' if use_cuda else "cpu"))

    ################ defining kernel #################
    if kernel_type == "gaussian":
        kernel_fn = lambda x, y: gaussian(x, y, bandwidth=bandwidth)
    elif kernel_type == "laplacian":
        kernel_fn = lambda x, y: laplacian(x, y, bandwidth=bandwidth)
    elif kernel_type =="ntk_1layer":
        kernel_fn = lambda x, y: ntk_1layer(x, y)

    batch_size = batch_size
    num_knots = num_knots

    preconditioner_level = preconditioner_level
    projection_cutoff = projection_cutoff


    if dataset_name=='taxi':
        dataset = taxidataset_all( train_size=n_train,n_test=100_000, num_knots=num_knots,
                                    device=devices)
        knots_x, knots_y = dataset.knots_x, dataset.knots_y
        normalize = False
        normalize_batch = False
        EP2 = False
        loss_with_std = True
        task = "regression"
    elif dataset_name == "cifar5m":
        normalize = True
        normalize_batch = True
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = Cifar5mDataset(subsample= n_train,num_knots=num_knots,device=devices)
        knots_x, knots_y = dataset.knots_x.float()/255, one_hot(dataset.knots_y)
    elif dataset_name == "imagenetmobilenetDataset":
        normalize = False
        normalize_batch = False
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = imagenetmobilenetDataset(subsample= n_train,num_knots=num_knots,
                                          device=devices,knot_include=knot_include)
        knots_x, knots_y = dataset.knots_x, one_hot(dataset.knots_y)

    elif dataset_name == "Cifar5mmobilenet":
        normalize = False
        normalize_batch = False
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = Cifar5mmobilenetDataset(subsample= n_train,num_knots=num_knots,
                                          device=devices,knot_include=knot_include)
        knots_x, knots_y = dataset.knots_x, one_hot(dataset.knots_y)

    elif dataset_name == "mnist8m":
        normalize = True
        normalize_batch = True
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = mnist8mDataset(subsample= n_train,num_knots=num_knots,
                                          device=devices,knot_include=knot_include)
        knots_x, knots_y = dataset.knots_x.float()/255, one_hot(dataset.knots_y)

    elif dataset_name == "susy":
        normalize = False
        normalize_batch = False
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = susy(subsample=n_train, num_knots=num_knots,
                                          device=devices, knot_include=knot_include)
        knots_x, knots_y = dataset.knots_x, one_hot(dataset.knots_y)
    elif dataset_name == "HIGGS":
        normalize = False
        normalize_batch = False
        EP2 = EP2
        loss_with_std = False
        task = "classification"
        dataset = HIGGS(subsample=n_train, num_knots=num_knots,
                       device=devices, knot_include=knot_include)
        knots_x, knots_y = dataset.knots_x, one_hot(dataset.knots_y)

    traindatasets = []
    for ind in range(n_gpus_outer):
        traindatasets.append( dataset_custom(dataset .X_train_all[ind],dataset .y_train_all[ind],dataset=dataset) )

    testdatasets =  dataset_custom(dataset.X_test,dataset.y_test,dataset=dataset)



    if load_checkpoint is not None:
        with open( load_checkpoint, 'rb') as fp:
            checkpoint = pickle.load(fp)

    print('nystrom samples...')
    
    if load_checkpoint is not None:
        print("load from checkpoint")
        nystrom_samples = checkpoint['nystrom_samples']
        knots_x_all = checkpoint['knots_x']
        knots_y = checkpoint['knots_y']
        del checkpoint
    else:
        knots_x_all = []
        knots_y_all = []

        for ind in range(n_gpus_outer):
            knots_x_all.append( knots_x.to(devices[ind]))
            knots_y_all.append( knots_y.to(devices[ind]))

        nystrom_ids = np.random.choice(
            range(dataset.X_train.shape[0]),
            size=nystrom_subsample_size, replace=False
        )
        nystrom_samples = dataset.X_train[nystrom_ids]

        if normalize:
            nystrom_samples = nystrom_samples.to(devices[0]).float() / 255




    ################ EP2 ######################
    # print("eigenpro2...")
    acc_ep2_test = []

    if EP2:
        if normalize:
            test_forEP2 = testdatasets.X.cpu().numpy().astype('float32')/255
        else:
            test_forEP2 = testdatasets.X.cpu().numpy().astype('float32')
        result = vanilaep(
            kernel_fn, knots_x.cpu().numpy().astype('float32'),
            knots_y.cpu().numpy().astype('float32'),
            test_forEP2,
            one_hot(testdatasets.y).cpu().numpy().astype('float32'),
            devices[0])
        epoch_keys = [i for i in result.keys()]
        acc_ep2_test = result[epoch_keys[-1]][1]['multiclass-acc']


    print('Data Information', dataset.X_train.shape[0])

    gradient_correction_fn = GradientCorrection(
        kernel_fn=kernel_fn,
        nystrom_samples=nystrom_samples, level=preconditioner_level,
        knots=knots_x_all, device=devices, batch_size=batch_size,
        learning_rate_prefactor=learning_rate_prefactor, wandb_run=wandb_run
    )

    model = KernelModel(
        kernel_fn=kernel_fn, knots=knots_x_all,knots_y=knots_y,
        gradient_correction_fn=gradient_correction_fn,
        projection_cutoff=projection_cutoff, projection_device=projection_device, device=devices,
        n_labels=knots_y.shape[1],
        wandb_run=wandb_run, track_time=True,task = task,normalize_batch = normalize_batch ,
        loss_with_std = loss_with_std,checkpoint = load_checkpoint,
        train_size=dataset.X_train.shape[0],dataset_name=dataset_name,kernel_type=kernel_type
    )
    alpha, acc_valid_ep3= model.fit(
        traindatasets, epochs=epochs, valid_loader=testdatasets)

    return alpha, acc_valid_ep3,nystrom_samples,knots_x,acc_ep2_test


if __name__ == "__main__":
    alpha, acc_valid_ep3, acc_ep2_test = run_cifar5m(
        batch_size=8192, epochs=1, num_knots=2000,
        n_train=50000, bandwidth=5, kernel_type="laplacian"
    )


