import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
import torchvision
import torchvision.transforms as transforms
from torch.nn.functional import one_hot
import os, ipdb
from os.path import join as pjoin
import gzip
import pandas as pd


class imagenetmobilenetDataset(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=4,
                 device=torch.device('cpu'), subsample=None,
                 n_test=100000, num_knots=None, knot_include=0,
                 **kwargs):
        super().__init__(**kwargs)


        print("imagenet mobilenet...")
        self.device = device
        self.X_train = torch.load(DATADIR+"imagenet_mobilenet_xtrain.pt")
        self.y_train = torch.load(DATADIR+"imagenet_mobilenet_ytrain.pt")
        self.X_test= torch.load(DATADIR+"imagenet_mobilenet_xtest.pt").to(torch.device(device[0]))
        self.y_test = torch.load(DATADIR+"imagenet_mobilenet_ytest.pt").to(torch.device(device[0]))



        if knot_include:
            if subsample is not None:
                randomind = np.random.choice(
                    range(self.y_train.shape[0]), size=subsample, replace=False)
                self.X_train = self.X_train[randomind]
                self.y_train = self.y_train[randomind]
                randomind_knots = np.random.choice(
                    range(self.y_train.shape[0]), size=num_knots, replace=False)
                self.knots_x = self.X_train[randomind_knots]
                self.knots_y = self.y_train[randomind_knots]

        else:
            if num_knots is not None:
                randomind_knots = np.random.choice(
                    range(self.y_train.shape[0]), size=num_knots, replace=False)
                self.knots_x = self.X_train[randomind_knots]
                self.knots_y = self.y_train[randomind_knots]


            if subsample is not None:
                if num_knots == None:
                    diff_set = set(range(self.y_train.shape[0]))
                else:
                    diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
                diff_set = np.array(list(diff_set))
                randomind = np.random.choice(
                    diff_set, size=subsample, replace=False)
                self.X_train = self.X_train[randomind]
                self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind, g in enumerate(device):
            if ind < len(device) - 1:
                self.X_train_all.append(self.X_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind * batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))



class HIGGS(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=4,
                 device=torch.device('cpu'), subsample =None,
                 n_test=100_000,num_knots= None,knot_include =0,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device

        print("HIGGS...")
        data = pd.read_csv(DATADIR)
        data = data.to_numpy()
        all_data = data.shape[0]
        X = data[:,1:]
        Y = data[:,0]
        self.X_train = torch.from_numpy(X[0:all_data-500_000,:]).to(torch.float)
        mean = torch.mean(self.X_train,dim=0)
        std = torch.std(self.X_train,dim=0)
        self.X_train = (self.X_train -mean)/  std

        self.y_train = torch.from_numpy(Y[0:all_data - 500_000]).long()
        self.X_test= torch.from_numpy(X[all_data-500_000:,:]).to(torch.float)
        self.X_test = (self.X_test-mean)/std
        self.X_test = self.X_test.to(torch.device(device[0]))
        self.y_test = torch.from_numpy(Y[all_data-500_000:]).long().to(torch.device(device[0]))

        if num_knots is not None:
            randomind_knots = np.random.choice(
                range(self.y_train.shape[0]), size=num_knots, replace=False)
            self.knots_x = self.X_train[randomind_knots]
            self.knots_y = self.y_train[randomind_knots]


        if subsample is not None:
            if num_knots==None:
                diff_set = set(range(self.y_train.shape[0]))
            else:
                diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
            diff_set = np.array(list(diff_set))
            randomind = np.random.choice(
                diff_set, size=subsample, replace=False)
            self.X_train = self.X_train[randomind]
            self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))

class mnist8mDataset(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=4,
                 device=torch.device('cpu'), subsample =None,
                 n_test=100_000,num_knots= None,knot_include =0,
                 **kwargs):
        super().__init__(**kwargs)

        print("mnist8M--4Msubset...")
        self.device = device


        self.X_train = torch.load(DATADIR+'X_train_mnist4M_int8.pt')
        self.y_train = torch.load(DATADIR + 'y_train_mnist4M_int8.pt').long()
        self.X_test = torch.load(DATADIR + 'X_test_mnist4M_int8.pt').to(torch.device(device[0]))
        self.y_test = torch.load(DATADIR + 'y_test_mnist4M_int8.pt').long().to(torch.device(device[0]))



        if num_knots is not None:
            randomind_knots = np.random.choice(
                range(self.y_train.shape[0]), size=num_knots, replace=False)


            self.knots_x = self.X_train[randomind_knots]
            self.knots_y = self.y_train[randomind_knots]


        if subsample is not None:
            if num_knots==None:
                diff_set = set(range(self.y_train.shape[0]))
            else:
                diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
            diff_set = np.array(list(diff_set))
            randomind = np.random.choice(
                diff_set, size=subsample, replace=False)
            self.X_train = self.X_train[randomind]
            self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))


class Cifar5mmobilenetDataset(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=4,
                 device=torch.device('cpu'), subsample =None,
                 n_test=100000,num_knots= None,knot_include =0,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device
        self.X_train = []
        self.y_train = []
        self.X_test = []
        self.y_test = []

        print('Loading cifar5mmobilenet train set...')
        for ind in range(parts+1):
            print(f'part={ind}')

            self.X_train.append( torch.load(pjoin(DATADIR,f'ciar5m_mobilenetv2_100_feature_train_{ind}.pt'), torch.device('cpu')) )
            self.y_train.append(torch.load(pjoin(DATADIR,f'ciar5m_mobilenetv2_100_y_train_{ind}.pt'), torch.device('cpu')))
        print("Loading cifar5mmobilenet test set...")
        self.X_test.append(torch.load(pjoin(DATADIR,f'ciar5m_mobilenetv2_100_feature_test.pt'), torch.device(device[0])))
        self.y_test.append(torch.load(pjoin(DATADIR,f'ciar5m_mobilenetv2_100_y_test.pt'), torch.device(device[0])))

        self.X_train = torch.cat(self.X_train)
        self.y_train = torch.cat(self.y_train).long()
        self.X_test = torch.cat(self.X_test)
        self.y_test = torch.cat(self.y_test).long()
        
        test_ind = np.random.choice(self.X_test.shape[0],size=n_test,replace=False)
        self.X_test = self.X_test[test_ind]
        self.y_test = self.y_test[test_ind]

        if knot_include:
            if subsample is not None:
                randomind = np.random.choice(
                    range(self.y_train.shape[0]), size=subsample, replace=False)
                self.X_train = self.X_train[randomind]
                self.y_train = self.y_train[randomind]
                randomind_knots = np.random.choice(
                    range(self.y_train.shape[0]), size=num_knots, replace=False)
                self.knots_x = self.X_train[randomind_knots]
                self.knots_y = self.y_train[randomind_knots]

        else:
            if num_knots is not None:
               randomind_knots = np.random.choice(
                   range(self.y_train.shape[0]), size=num_knots, replace=False)
               self.knots_x = self.X_train[randomind_knots]
               self.knots_y = self.y_train[randomind_knots]

            if subsample is not None:
               if num_knots==None:
                   diff_set = set(range(self.y_train.shape[0]))
               else:
                   diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
               diff_set = np.array(list(diff_set))
               randomind = np.random.choice(
                   diff_set, size=subsample, replace=False)
               self.X_train = self.X_train[randomind]
               self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))


class Cifar5mDataset(Dataset):

    def __init__(self,
                 DATADIR='',
                 parts=4,
                 device=torch.device('cpu'), subsample =None,
                 n_test=100_000,num_knots= None,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device
        self.X_train = []
        self.y_train = []
        self.X_test = []
        self.y_test = []

        print('Loading cifar5m train set...')
        for ind in range(parts+1):
            print(f'part={ind}')
            # z = np.load(pjoin(DATADIR, f'part{i}.npz'))
            self.X_train.append( torch.load(pjoin(DATADIR,f'part{ind}_X.pt'), torch.device('cpu')) )
            self.y_train.append(torch.load(pjoin(DATADIR,f'part{ind}_y.pt'), torch.device('cpu')))
            # print(f'Loaded part {i + 1}/6')
        print("Loading cifar5m test set...")
        # z = np.load(pjoin(DATADIR, 'part5.npz'))
        self.X_test.append(torch.load(pjoin(DATADIR,f'part5_X.pt'), torch.device(device[0]))[:n_test])
        self.y_test.append(torch.load(pjoin(DATADIR,f'part5_y.pt'), torch.device(device[0]))[:n_test])

        self.X_train = torch.cat(self.X_train)
        self.y_train = torch.cat(self.y_train)
        self.X_test = torch.cat(self.X_test)
        self.y_test = torch.cat(self.y_test)


        if num_knots is not None:
            randomind_knots = np.random.choice(
                range(self.y_train.shape[0]), size=num_knots, replace=False)
            self.knots_x = self.X_train[randomind_knots]
            self.knots_y = self.y_train[randomind_knots]


        if subsample is not None:
            if num_knots==None:
                diff_set = set(range(self.y_train.shape[0]))
            else:
                diff_set = set(range(self.y_train.shape[0])) - set(randomind_knots)
            diff_set = np.array(list(diff_set))
            randomind = np.random.choice(
                diff_set, size=subsample, replace=False)
            self.X_train = self.X_train[randomind]
            self.y_train = self.y_train[randomind]

        batches_in_1gpu = self.X_train.shape[0] // len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].to(g))


class taxidataset_all(Dataset):

    print('Dataset all...')

    def __init__(self,
                 DATADIR='',
                 parts=2,
                 device=[torch.device('cpu')], train_size =1000,
                 n_test=500000,num_knots= None,
                 **kwargs):
        super().__init__(**kwargs)
        self.device = device
        self.X = []
        self.y = []

        print('loding training and testing for taxi data set:')
        for i in range(parts):
            print(f'part:{i}')
            self.X.append( torch.load(os.path.join(DATADIR,f'x_part{i}_taxi.pt')) )
            self.y.append( torch.load(os.path.join(DATADIR, f'y_part{i}_taxi.pt')).double())


        self.X = torch.cat(self.X)
        self.y = torch.cat(self.y)

        total_data = self.X.shape[0]

        indices = np.arange(total_data)
        np.random.shuffle(indices)

        self.X =  self.X[indices]
        self.y =  self.y[indices]

        indices_count_holder = 0
        self.X_test = self.X[indices_count_holder:indices_count_holder+n_test]
        self.y_test = self.y[indices_count_holder:indices_count_holder+n_test]


        if num_knots is not None:

            self.knots_x = self.X[0:num_knots]
            self.knots_y = self.y[0:num_knots]
            indices_count_holder += num_knots

        print('indices...')



        self.X_train = self.X[indices_count_holder:indices_count_holder+train_size]
        self.y_train = self.y[indices_count_holder:indices_count_holder+train_size]
        indices_count_holder+=train_size


        print('mean and std calculation...')
        self.mean_x_tr = torch.mean(self.X_train,dim=0)
        self.std_x_tr  =  torch.std(self.X_train,dim=0)
        self.mean_y_tr = torch.mean(self.y_train, dim=0)
        self.std_y_tr = torch.std(self.y_train, dim=0)

        print(f'std of y_train:{self.std_y_tr}')

        if num_knots is not None:
            self.knots_x = (self.knots_x -self.mean_x_tr)/self.std_x_tr
            self.knots_y = (self.knots_y -self.mean_y_tr)/self.std_y_tr
            self.knots_x = self.knots_x.float().to(device[0])
            self.knots_y = self.knots_y.float().to(device[0])


        self.X_train = ( self.X_train -self.mean_x_tr )/self.std_x_tr
        self.y_train = (self.y_train - self.mean_y_tr) / self.std_y_tr
        self.X_test  = ( self.X_test -self.mean_x_tr )/self.std_x_tr
        self.y_test  = (self.y_test - self.mean_y_tr) / self.std_y_tr

        batches_in_1gpu = self.X_train.shape[0]//len(device)
        self.X_train_all = []
        self.y_train_all = []
        for ind,g in enumerate(device):
            if ind<len(device)-1:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:(ind+1)*batches_in_1gpu].float().to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:(ind + 1) * batches_in_1gpu].float().to(g))
            else:
                self.X_train_all.append(self.X_train[ind*batches_in_1gpu:].float().to(g))
                self.y_train_all.append(self.y_train[ind * batches_in_1gpu:].float().to(g))

        self.X_train = self.X_train.float().to(device[0])
        self.y_train = self.y_train.float().to(device[0])


        self.X_test  = self.X_test.float().to(device[0])
        self.y_test  = self.y_test.float().to(device[0])

        self.mean_x_tr = self.mean_x_tr.float()
        self.std_x_tr  =  self.std_x_tr.float()
        self.mean_y_tr = self.mean_y_tr.float()
        self.std_y_tr = self.std_y_tr.float()

class dataset_custom(Dataset):
    def __init__(self, X,Y,knots_x=None,knots_y=None,dataset=None,
                 **kwargs):
        super().__init__(**kwargs)
        self.X = X
        self.y = Y
        if hasattr(dataset, 'knots_x'):
            self.knots_x = dataset.knots_x
            self.knots_y = dataset.knots_y
        if hasattr(dataset, 'mean_x_tr'):
            self.mean_x_tr = dataset.mean_x_tr
            self.std_x_tr  = dataset.std_x_tr
            self.mean_y_tr = dataset.mean_y_tr
            self.std_y_tr  = dataset.std_y_tr
