import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision
import torchvision.transforms as transforms
from torch.nn.functional import one_hot
import os, ipdb
from os.path import join as pjoin
import csv
import pandas as pd
# from sklearn.cluster import KMeans
from fast_pytorch_kmeans import KMeans
import torch
import timm

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

class svhn(Dataset):
    def __init__(self,num_knots=50_000,mode = "train",knots_method="random",
                 DATADIR='',
                 **kwargs):
        print("SVHN")
        if mode=="train":

            self.x = torch.load(DATADIR+"SVHN_x_train.pt").cpu()
            # self.x = self.x.reshape(-1,28*28)
            self.y = torch.load(DATADIR+"SVHN_y_train.pt").cpu()
            train_size = self.x.shape[0]
            if knots_method=="random":
                # self.knot_index = np.random.choice(train_size, num_knots)
                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x.to('cuda:0'), one_hot(self.y)
            elif knots_method=="kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  # .\
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids
                # self.knots_y = one_hot(torch.tensor(self.y.long()))
                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')
            # ipdb.set_trace()

        elif mode=="test":
            self.x = torch.load(DATADIR + "SVHN_x_test.pt")
            # self.x = self.x.reshape(-1, 28 * 28)
            self.y  = torch.load(DATADIR+"SVHN_y_test.pt")
            self.y = self.y.long()
        elif mode=="augment":

            print("dataset:augmentation")
            knots_x = torch.load(DATADIR+"SVHN_x_train.pt")
            self.knots_x = knots_x
            self.knots_y = one_hot(torch.load(DATADIR+"SVHN_y_train.pt"))
            ipdb.set_trace()
            knot_size = knots_x.shape[0]

            seeds = []

            self.x = []
            self.y = []
            for ind,s in enumerate(seeds):
                print(f'part {ind}')
                x_tmp = torch.load(DATADIR+f'SVHN_x_train_seed={seeds}.pt').cpu()

                self.x.append( x_tmp )
                self.y.append(torch.load(DATADIR + f'SVHN_y_train_seed={seeds}.pt').cpu())
            self.x = torch.cat(self.x)
            self.y = torch.cat(self.y)


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )

class woof(Dataset):
    def __init__(self,num_knots=50_000,mode = "train",knots_method="random",
                 DATADIR='',
                 **kwargs):
        print("woof_augment")
        if mode=="train":

            self.x = torch.load(DATADIR+"imagenet_mobilenet_xtrain.pt").cpu()

            self.y = torch.load(DATADIR+"imagenet_mobilenet_ytrain.pt").cpu()
            train_size = self.x.shape[0]
            if knots_method=="random":

                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x.to('cuda:0'), one_hot(self.y)
            elif knots_method=="kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  # .\
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids

                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')


        elif mode=="test":
            self.x = torch.load(DATADIR + "imagenet_mobilenet_xtest.pt")

            self.y  = torch.load(DATADIR+"imagenet_mobilenet_ytest.pt")
            self.y = self.y.long()
        elif mode=="augment":

            print("dataset:augmentation")
            knots_x = torch.load(DATADIR+"imagenet_mobilenet_xtrain.pt")
            self.knots_x = knots_x
            self.knots_y = one_hot(torch.load(DATADIR+"imagenet_mobilenet_ytrain.pt"))
            knot_size = knots_x.shape[0]

            seeds = []

            self.x = []
            self.y = []
            for ind,s in enumerate(seeds):
                print(f'part {ind}')
                x_tmp = torch.load(DATADIR+f'imagenet_mobilenet_xaugment.pt').cpu()

                self.x.append( x_tmp )
                self.y.append(torch.load(DATADIR + f'imagenet_mobilenet_yaugment.pt').cpu())
            self.x = torch.cat(self.x)
            self.y = torch.cat(self.y)


    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )


class fashionmnist_augment(Dataset):
    def __init__(self,num_knots=50_000,mode = "train",knots_method="random",
                 DATADIR='',
                 **kwargs):
        print("fashionmnist_augment")
        if mode=="train":

            self.x = torch.load(DATADIR+"FashionMNIST_x_train.pt")
            self.x = self.x.reshape(-1,28*28)
            self.y = torch.load(DATADIR+"FashionMNIST_y_train.pt")
            train_size = self.x.shape[0]
            if knots_method=="random":
                # self.knot_index = np.random.choice(train_size, num_knots)
                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x.to('cuda:0'), one_hot(self.y)
            elif knots_method=="kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  # .\
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids
                # self.knots_y = one_hot(torch.tensor(self.y.long()))
                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')

        elif mode=="test":
            self.x = torch.load(DATADIR + "FashionMNIST_x_test.pt")
            self.x = self.x.reshape(-1, 28 * 28)
            self.y  = torch.load(DATADIR+"FashionMNIST_y_test.pt")
            self.y = self.y.long()
        elif mode=="augment":

            print("dataset:augmentation")
            knots_x = torch.load(DATADIR+"FashionMNIST_x_train.pt")
            self.knots_x = knots_x.reshape(-1,28*28)
            self.knots_y = one_hot(torch.load(DATADIR+"FashionMNIST_y_train.pt"))
            knot_size = knots_x.shape[0]


            seeds = []

            self.x = []
            self.y = []
            for ind,s in enumerate(seeds):
                print(f'part {ind}')
                x_tmp = torch.load(DATADIR+f'FashionMNIST_x_train_seed.pt')
                x_tmp = x_tmp.reshape(-1,28*28)
                self.x.append( x_tmp )
                self.y.append(torch.load(DATADIR + f'FashionMNIST_y_train_seed.pt'))
            self.x = torch.cat(self.x)
            self.y = torch.cat(self.y)
            # ipdb.set_trace()

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )
class mnist_augment(Dataset):
    def __init__(self,num_knots=50_000,mode = "train",knots_method="random",
                 DATADIR='',
                 **kwargs):
        print("MNIST_augment")
        if mode=="train":

            self.x = torch.load(DATADIR+"MNIST_x_train.pt")
            self.x = self.x.reshape(-1,28*28)
            self.y = torch.load(DATADIR+"MNIST_y_train.pt")
            train_size = self.x.shape[0]
            if knots_method=="random":
                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x.to('cuda:0'), one_hot(self.y)
            elif knots_method=="kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids

                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')

        elif mode=="test":
            self.x = torch.load(DATADIR + "MNIST_x_test.pt")
            self.x = self.x.reshape(-1, 28 * 28)
            self.y  = torch.load(DATADIR+"MNIST_y_test.pt")
            self.y = self.y.long()
        elif mode=="augment":

            print("dataset:augmentation")
            knots_x = torch.load(DATADIR+"MNIST_x_train.pt")
            self.knots_x = knots_x.reshape(-1,28*28)
            self.knots_y = one_hot(torch.load(DATADIR+"MNIST_y_train.pt"))
            knot_size = knots_x.shape[0]


            seeds = []

            self.x = []
            self.y = []
            for ind,s in enumerate(seeds):
                print(f'part {ind}')
                x_tmp = torch.load(DATADIR+f'MNIST_x_train_seed.pt')
                x_tmp = x_tmp.reshape(-1,28*28)
                self.x.append( x_tmp )
                self.y.append(torch.load(DATADIR + f'MNIST_y_train_seed.pt'))
            self.x = torch.cat(self.x)
            self.y = torch.cat(self.y)
            # ipdb.set_trace()

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )


class Cifar10Dataset(Dataset):

    def __init__(self,DATADIR='',
                 num_knots=50_000,mode = "train",knots_method="random",
                 device=torch.device('cpu'),validation=False,id_offset=0,
                 **kwargs):
        super().__init__(**kwargs)
        parts=4
        self.device = device
        self.x = []
        self.y = []
        self.id_offset= id_offset
        if mode=="train":
            for i in range(0, min(6, parts+1)):
                data_dict = unpickle(
                    os.path.join(DATADIR,f'data_batch_{i+1}'))
                self.x.append(data_dict[b'data'])
                self.y.append(np.array(data_dict[b'labels']))
            self.x = np.concatenate(self.x)
            self.y = np.concatenate(self.y)
            self.x = torch.from_numpy(self.x)/255.0
            self.y = torch.from_numpy(np.array(self.y))

            train_size = self.x.shape[0]
            if knots_method == "random":
                self.knot_index = np.random.choice(train_size, num_knots)
                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x[self.knot_index].to('cuda:0'), one_hot(self.y)
            elif knots_method == "kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  # .\
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids
                # self.knots_y = one_hot(torch.tensor(self.y.long()))
                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')

        elif mode=="test":
            data_dict = unpickle(
                os.path.join(DATADIR,f'test_batch'))
            self.x.append(data_dict[b'data'])
            self.y.append(np.array(data_dict[b'labels']))
            self.x = np.concatenate(self.x)
            self.y = np.concatenate(self.y)
            self.x = torch.from_numpy(self.x)/255.0
            self.y = torch.from_numpy(np.array(self.y))



    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )

class cifar10mobilenetDataset(Dataset):
    def __init__(self,num_knots=50_000,mode = "train",knots_method="random",
                 DATADIR='',
                 **kwargs):
        print("cifar10mobilenetDataset")
        if mode=="train":

            self.x = torch.load(DATADIR+"ciar10_mobilenetv2_100_feature_train.pt")

            self.y = torch.load(DATADIR+"ciar10_mobilenetv2_100_y_train.pt")

            train_size = self.x.shape[0]
            if knots_method=="random":

                self.y = self.y.long()
                self.knots_x, self.knots_y = self.x.to('cuda:0'), one_hot(self.y)
            elif knots_method=="kmeans":
                print('kmeans ...')
                kmeans = KMeans(n_clusters=num_knots, mode='euclidean', max_iter=100, verbose=1)  # .\
                kmeans.fit(self.x.to('cuda:0'))
                self.knots_x = kmeans.centroids
                # self.knots_y = one_hot(torch.tensor(self.y.long()))
                self.knots_y = one_hot(torch.tensor(self.y.long()))
                print('kmeans done.')

        elif mode=="test":
            self.x = torch.load(DATADIR + "ciar10_mobilenetv2_100_feature_test.pt")
            self.y  = torch.load(DATADIR+"ciar10_mobilenetv2_100_y_test.pt")
            self.y = self.y.long()
        elif mode=="augment":

            print("dataset:augmentation")
            knots_x = torch.load(DATADIR+"ciar10_mobilenetv2_100_feature_train.pt")
            knots_y = torch.load(DATADIR+"ciar10_mobilenetv2_100_y_train.pt")
            knot_size = knots_x.shape[0]

            # self.knot_index = np.random.choice(knot_size, num_knots)
            # ipdb.set_trace()
            knots_y = knots_y.long()
            self.knots_x, self.knots_y = knots_x, one_hot(knots_y)
            seeds = []

            self.x = []
            self.y = []
            for ind,s in enumerate(seeds):
                print(f'part {ind}')
                self.x.append( torch.load(DATADIR+f'ciar10_mobilenetv2_100_feature_train_seed={s}.pt') )
                self.y.append(torch.load(DATADIR + f'ciar10_mobilenetv2_100_y_train_seed={s}.pt'))
            self.x = torch.cat(self.x)
            self.y = torch.cat(self.y)
        # ipdb.set_trace()

    def __len__(self):
        return self.y.shape[0]

    def __getitem__(self, idx):
        return (
            self.x[idx],
            self.y[idx]
            )






class Cifar5mDataset(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=1,
                 device=torch.device('cpu'), subsample =None,
                 n_test=10000,num_knots= None,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device
        self.X_train = []
        self.y_train = []
        self.X_test = []
        self.y_test = []

        print('Loading cifar5m train set...')
        for ind in range(parts+1):
            print(f'part={ind}')
            # z = np.load(pjoin(DATADIR, f'part{i}.npz'))
            self.X_train.append( torch.load(pjoin(DATADIR,f'part{ind}_X.pt'), torch.device('cpu')) )
            self.y_train.append(torch.load(pjoin(DATADIR,f'part{ind}_y.pt'), torch.device('cpu')))
            # print(f'Loaded part {i + 1}/6')
        print("Loading cifar5m test set...")
        # z = np.load(pjoin(DATADIR, 'part5.npz'))
        self.X_test.append(torch.load(pjoin(DATADIR,f'part5_X.pt'), torch.device(device[0]))[:n_test])
        self.y_test.append(torch.load(pjoin(DATADIR,f'part5_y.pt'), torch.device(device[0]))[:n_test])

        self.X_train = torch.cat(self.X_train)
        self.y_train = torch.cat(self.y_train)
        self.X_test = torch.cat(self.X_test)
        self.y_test = torch.cat(self.y_test)


        if num_knots is not None:

            self.knots_x = self.X_train
            self.knots_y = self.y_train


        if subsample is not None:
            if num_knots==None:
                diff_set = set(range(self.y_train.shape[0]))
            else:
                diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
            diff_set = np.array(list(diff_set))
            randomind = np.random.choice(
                diff_set, size=subsample, replace=False)
            self.X_train = self.X_train[randomind]
            self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))


class taxidataset_all(Dataset):

    print('Dataset all...')

    def __init__(self,
                 DATADIR='',
                 parts=9,
                 device=[torch.device('cpu')], train_size =1000,
                 n_test=500000,num_knots= None,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device
        self.X_train = []
        self.y_train = []

        print('loding training and testing for taxi data set:')
        for i in range(parts):
            print(f'part:{i}')
            self.X_train.append( torch.load(os.path.join(DATADIR,f'x_part{i}_taxi.pt')) )
            self.y_train.append( torch.load(os.path.join(DATADIR, f'y_part{i}_taxi.pt')).double())


        self.X_train = torch.cat(self.X_train)
        self.y_train = torch.cat(self.y_train)

        self.X_test = torch.load(os.path.join(DATADIR,f'x_test_taxi.pt'))
        self.y_test = torch.load(os.path.join(DATADIR,f'y_test_taxi.pt'))



        total_data = self.X_train.shape[0]

        indices = np.arange(total_data)
        np.random.shuffle(indices)

        self.X_train =  self.X_train[indices]
        self.y_train =  self.y_train[indices]


        indices_count_holder = 0
        if num_knots is not None:
            self.knots_x = self.X_train[0:num_knots]
            self.knots_y = self.y_train[0:num_knots]
            indices_count_holder += num_knots

        print('indices...')



        self.X_train = self.X_train[indices_count_holder:indices_count_holder+train_size]
        self.y_train = self.y_train[indices_count_holder:indices_count_holder+train_size]

        print('mean and std calculation...')
        self.mean_x_tr = torch.mean(self.X_train,dim=0)
        self.std_x_tr  =  torch.std(self.X_train,dim=0)
        self.mean_y_tr = torch.mean(self.y_train, dim=0)
        self.std_y_tr = torch.std(self.y_train, dim=0)

        print(f'std of y_train:{self.std_y_tr}')

        if num_knots is not None:
            self.knots_x = (self.knots_x -self.mean_x_tr)/self.std_x_tr
            self.knots_y = (self.knots_y -self.mean_y_tr)/self.std_y_tr
            self.knots_x = self.knots_x.float().to(device[0])
            self.knots_y = self.knots_y.float().to(device[0])


        self.X_train = ( self.X_train -self.mean_x_tr )/self.std_x_tr
        self.y_train = (self.y_train - self.mean_y_tr) / self.std_y_tr
        self.X_test  = ( self.X_test -self.mean_x_tr )/self.std_x_tr
        self.y_test  = (self.y_test - self.mean_y_tr) / self.std_y_tr

        batches_in_1gpu = self.X_train.shape[0]//len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].float().to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].float().to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].float().to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].float().to(g))

        self.X_train = self.X_train.float()
        self.y_train = self.y_train.float()



        self.X_test  = self.X_test.float()
        self.y_test  = self.y_test.float()

        self.X_test  = self.X_test[:self.X_test.shape[0]//1]
        self.y_test  = self.y_test[:self.y_test.shape[0]//1]

        self.mean_x_tr = self.mean_x_tr.float()
        self.std_x_tr  =  self.std_x_tr.float()
        self.mean_y_tr = self.mean_y_tr.float()
        self.std_y_tr = self.std_y_tr.float()


        print(f'number of training set:{self.X_train.shape[0]}')
        print(f'number of testing set:{self.X_test.shape[0]}')




class HIGGS(Dataset):

    def __init__(self,
                 DATADIR='',
                 device=[torch.device('cpu')],
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device

        df = pd.read_csv(DATADIR)
        df_numpy = df.to_numpy()

        x = df_numpy[:,1:]
        y = df_numpy[:,1]



class dataset_custom(Dataset):
    def __init__(self, X,Y,dataset,
                 **kwargs):
        super().__init__(**kwargs)
        self.X = X
        self.y = Y
        self.knots_x = dataset.knots_x
        self.knots_y = dataset.knots_y

        if hasattr(dataset, 'transform_train'):
            self.transform_train = dataset.transform_train
            self.transform_test = dataset.transform_test

        if hasattr(dataset, 'mean_x_tr'):
            self.mean_x_tr = dataset.mean_x_tr
            self.std_x_tr  = dataset.std_x_tr
            self.mean_y_tr = dataset.mean_y_tr
            self.std_y_tr  = dataset.std_y_tr