import torch, ipdb
import numpy as np
from datetime import datetime
from ..utils.datasets import taxidataset_all,Cifar5mDataset,dataset_custom,Cifar10Dataset,\
    cifar10mobilenetDataset,mnist_augment,woof,svhn,fashionmnist_augment
import pickle
from ..utils.kernel import gaussian, laplacian,ntk_1layer
from ..kernel_model import KernelModel
from ..gradient_correction import GradientCorrection
import timeit
from ..eigenpro_2.vanilla_eigenpro import vanilaep
from torch.nn.functional import one_hot
import ipdb
from fast_pytorch_kmeans import KMeans


def run(
        batch_size=50000, num_knots=5000, n_train=50000,knots_method="random",
        nystrom_subsample_size=1000, preconditioner_level=30,
        projection_cutoff=1e-5, bandwidth=5, epochs=10,
        kernel_type="gaussian", seed=0,
        wandb_run=None, learning_rate_prefactor=0.1,n_gpus = 4,dataset_name = "taxi",alpha_augment=0,train_mode="train"
):


    use_cuda = torch.cuda.is_available()
    devices = []
    for g in range(n_gpus-1):
        devices.append( torch.device(f'cuda:{g}' if use_cuda else "cpu") )

    projection_device = torch.device(f'cuda:{n_gpus - 1}' if use_cuda else "cpu")

    ################ defining kernel #################
    if kernel_type == "gaussian":
        kernel_fn = lambda x, y: gaussian(x, y, bandwidth=bandwidth)
    elif kernel_type == "laplacian":
        kernel_fn = lambda x, y: laplacian(x, y, bandwidth=bandwidth)
    elif kernel_type =="ntk_1layer":
        kernel_fn = lambda x, y: ntk_1layer(x, y)

    batch_size = batch_size
    num_knots = num_knots

    preconditioner_level = preconditioner_level
    projection_cutoff = projection_cutoff

    if dataset_name == "cifar10":
        normalize = False
        normalize_batch = False
        EP2 = True
        task = "classification"
        loss_with_std = False
        if alpha_augment>0:
            augment=1
        else:
            augment = 0
        dataset_train = Cifar10Dataset(num_knots=num_knots,mode =train_mode,
                                                device=devices,knots_method=knots_method)
        dataset_test = Cifar10Dataset(mode="test", device=devices)
        knots_x, knots_y = dataset_train.knots_x,dataset_train.knots_y

    if dataset_name == "cifar10mobilenet":
        normalize = False
        normalize_batch = False
        EP2 = True
        task = "classification"
        loss_with_std = False
        if alpha_augment > 0:
            augment = 1
        else:
            augment = 0
        dataset_train = cifar10mobilenetDataset(num_knots=num_knots,mode =train_mode,device=devices,knots_method=knots_method)
        dataset_test = cifar10mobilenetDataset(mode="test", device=devices)
        knots_x, knots_y = dataset_train.knots_x, dataset_train.knots_y

    if dataset_name == "fashionmnist":
        normalize = False
        normalize_batch = False
        # EP2 = True
        task = "classification"
        loss_with_std = False
        if alpha_augment > 0:
            augment = 1
        else:
            augment = 0
        dataset_train = fashionmnist_augment(num_knots=num_knots,mode =train_mode,device=devices,knots_method=knots_method)
        dataset_test = fashionmnist_augment(mode="test", device=devices)
        knots_x, knots_y = dataset_train.knots_x, dataset_train.knots_y

    if dataset_name == "mnist_augment":
        normalize = False
        normalize_batch = False
        # EP2 = True
        task = "classification"
        loss_with_std = False
        if alpha_augment > 0:
            augment = 1
        else:
            augment = 0
        dataset_train = mnist_augment(num_knots=num_knots,mode =train_mode,device=devices,knots_method=knots_method)
        dataset_test = mnist_augment(mode="test", device=devices)
        knots_x, knots_y = dataset_train.knots_x, dataset_train.knots_y

    if dataset_name == "woof":
            normalize = False
            normalize_batch = False
            # EP2 = True
            task = "classification"
            loss_with_std = False
            if alpha_augment > 0:
                augment = 1
            else:
                augment = 0
            dataset_train = woof(num_knots=num_knots, mode=train_mode, device=devices,
                                          knots_method=knots_method)
            dataset_test = woof(mode="test", device=devices)
            knots_x, knots_y = dataset_train.knots_x, dataset_train.knots_y
    if dataset_name == "svhn":
            normalize = False
            normalize_batch = False
            # EP2 = True
            task = "classification"
            loss_with_std = False
            if alpha_augment > 0:
                augment = 1
            else:
                augment = 0
            dataset_train = svhn(num_knots=num_knots, mode=train_mode, device=devices,
                                          knots_method=knots_method)
            dataset_test = svhn(mode="test", device=devices)
            knots_x, knots_y = dataset_train.knots_x, dataset_train.knots_y


    knots_x_all = []
    knots_y_all = []
    for ind in range(n_gpus-1 ):
        knots_x_all.append( knots_x.to(devices[ind]))
        knots_y_all.append( knots_y.to(devices[ind]))

    knots_x_all.append(knots_x.to(projection_device))
    knots_y_all.append(knots_y.to(projection_device))

    ############EP2#################
    # ipdb.set_trace()
    # result = vanilaep(
    #     kernel_fn, knots_x.cpu().numpy().astype('float32'),
    #     knots_y.cpu().numpy().astype('float32'),
    #     dataset_test.x,
    #     one_hot(dataset_test.y).cpu().numpy().astype('float32'),
    #     devices[0])
    # epoch_keys = [i for i in result.keys()]
    # acc_ep2_test = result[epoch_keys[-1]][1]['multiclass-acc']



    print('nystrom samples...')
    nystrom_ids = np.random.choice(
        range(len(dataset_train)),
        size=nystrom_subsample_size, replace=False
    )
    nystrom_samples = dataset_train.x[nystrom_ids]


    print('Data Information', len(dataset_train))

    gradient_correction_fn = GradientCorrection(
        kernel_fn=kernel_fn,
        nystrom_samples=nystrom_samples, level=preconditioner_level,
        knots=knots_x_all, device=devices, batch_size=batch_size,
        learning_rate_prefactor=learning_rate_prefactor, wandb_run=wandb_run
    )

    trainloader = torch.utils.data.DataLoader(dataset_train, batch_size=gradient_correction_fn.batch_size*len(devices),
                                              shuffle=True, num_workers=2)

    testloader = torch.utils.data.DataLoader(dataset_test, batch_size=gradient_correction_fn.batch_size*len(devices),
                                             shuffle=False, num_workers=2)
    model = KernelModel(
        kernel_fn=kernel_fn, knots=knots_x_all,
        gradient_correction_fn=gradient_correction_fn,
        projection_cutoff=projection_cutoff, projection_device=projection_device, device=devices,
        n_labels=knots_y.shape[1],
        wandb_run=wandb_run, track_time=True,task = task,normalize_batch = normalize_batch ,
        loss_with_std = loss_with_std,
        train_size=len(dataset_train),augment=augment,augment_alpha=alpha_augment
    )
    alpha, acc_valid_ep3,RMSE = model.fit(
        trainloader, epochs=epochs, valid_loader= testloader)

    return alpha, acc_valid_ep3,RMSE,nystrom_samples,knots_x


if __name__ == "__main__":
    alpha, acc_valid_ep3, acc_ep2_test = run_cifar5m(
        batch_size=8192, epochs=1, num_knots=2000,
        n_train=50000, bandwidth=5, kernel_type="laplacian"
    )

