

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import pickle
import matplotlib.pyplot as plt
import sys
import pandas as pd
import h5py
from collections import OrderedDict
from helpers import set_seed

import os

class data():

    def __init__(self):
        pass

    def load_data(self, dir):
        print('Loading data from: {}'.format(dir))

        with open(dir, 'rb') as f:
            data = pickle.load(f)
        print('Data loaded')

        return data
    
    def get_k_eigenvectors(self, X, k, make_plot = False):
        """
        Calculates the k eigenvectors of X
        """

        # get the k eigenvectors
        U, S, V = torch.pca_lowrank(X, center= True, q = k)

        # if make plot, plot the variance explained
        if make_plot:
            U, S, V = torch.pca_lowrank(X, center= True, q = X.shape[1])
            eigenvalues = S**2/(X.shape[0] - 1)
            variance_explained = torch.cumsum(eigenvalues, dim = 0)/torch.sum(eigenvalues)
            plt.plot(list(range(1 , k  + 1)), variance_explained[:k])
            plt.xlabel('Number of components')
            plt.ylabel('Cumulative explained variance')
            plt.rcParams['savefig.dpi'] = 500
            plt.show()
            sys.exit()

        return V[:, :k]
    
    def demean_X(self, include_test = False, reset_mean = True):
        """
        Demean the data
        """

        # estimate the mean based on the training data
        if reset_mean:
            self.X_train_mean = self.X_train.mean(dim = 0)  

        # demean the data based on the mean estimated from the training data 
        self.X_train = self.X_train - self.X_train_mean
        self.X_val = self.X_val - self.X_train_mean

        # demean the test data based on the mean estimated frodat m the training data
        if include_test:
            self.X_test = self.X_test - self.X_train_mean 

    
    def unit_stdev_X(self, include_test = False, reset_std = True):
        """
        Set the standard deviation of each column in X to 1
        """
        
        # estimate the standard deviation based on the training data
        if reset_std:
            self.X_train_std = self.X_train.std(dim = 0)
        
        # standardize the data
        self.X_train = self.X_train/self.X_train_std
        self.X_val = self.X_val/self.X_train_std

        # standardize the test data
        if include_test:
            self.X_test = self.X_test/self.X_train_std
    
    def demean_scale_X(self, include_test=True):
        self.demean_X(include_test=include_test)
        self.unit_stdev_X(include_test=include_test)

        return self.X_train, self.y_train, self.X_val, self.y_val, self.X_test, self.y_test
    


    def transform_data_to_k_components(self, k, include_test = False, reset_V_k = True):
        """
        Turn the data into k components
        """

        # get the k eigenvectors
        if reset_V_k:
           self.V_k_train = self.get_k_eigenvectors(self.X_train, k)

        # transform the data to the k components
        self.X_train = torch.matmul(self.X_train, self.V_k_train)
        self.X_val = torch.matmul(self.X_val, self.V_k_train)

        # transform the test data to the k components
        if include_test:
            self.X_test = torch.matmul(self.X_test, self.V_k_train)

    
    def get_p_dict(self, p=None, g=None):

        if p == 'train':
            g = self.g_train
        elif p == 'val':
            g = self.g_val
        elif p == 'test':
            g = self.g_test
        elif g is None and p is None:
            raise ValueError('Please provide either p or g')


        # select all the groups
        groups = [int(group.item()) for group in torch.unique(g)]
        p_dict = dict.fromkeys(groups, 0)
        n = len(g)

        for group in groups:
            n_g = (g==group).sum()
            p_g = n_g/n
            p_dict[group] = torch.as_tensor(p_g, dtype=torch.float32)

        return p_dict
    
    def set_data_attributes(self, X_train,  y_train, X_val,  y_val, X_test, y_test, device, add_dim_y=True, g_train=None, g_val=None, g_test=None):

        # set the attributes - train
        self.X_train = X_train.to(device)
        self.y_train = y_train.to(device)

        # set the attributes - val
        self.X_val = X_val.to(device)
        self.y_val = y_val.to(device)

        # set the attributes - test
        self.X_test = X_test.to(device)   
        self.y_test = y_test.to(device)

        # add dimension to y if needed
        if add_dim_y:
            self.y_train = self.y_train.unsqueeze(-1)
            self.y_val = self.y_val.unsqueeze(-1)
            self.y_test = self.y_test.unsqueeze(-1)

        # set the group attribute
        if g_train is not None and g_val is not None and g_test is not None:
            self.g_train = g_train.to(device)
            self.g_val = g_val.to(device)
            self.g_test = g_test.to(device)
    
    def split_datasets(total_length, train_ratio=0.8, seed=None):
        if seed is not None:
            np.random.seed(seed)
        
        all_indices = np.arange(total_length)
        np.random.shuffle(all_indices)
        
        split_point = int(total_length * train_ratio)
        train_indices = all_indices[:split_point]
        val_indices = all_indices[split_point:]
        
        return train_indices, val_indices
        
    

    def create_loaders(self, batch_size, workers, shuffle=True, include_weights=False, train_weights = None, val_weights = None, pin_memory=True, h5_file_path=None, x_key_train=None, y_key_train=None, x_key_val=None, y_key_val=None, device='cpu', include_test=False, x_key_test=None, y_key_test=None):
        """ 
        Create train and validation loaders
        """

        # create the datasets
        if h5_file_path is not None:
            train = HDF5Dataset(h5_file_path, x_key_train, y_key_train, device, batch_size)
            val = HDF5Dataset(h5_file_path, x_key_val, y_key_val, device, batch_size)
            batch_size = 1

            if include_test:
                test = HDF5Dataset(h5_file_path, x_key_test, y_key_test, device)



        else:
            train = TensorDataset(self.X_train, self.y_train)
            val = TensorDataset(self.X_val, self.y_val)

            if include_test:
                test = TensorDataset(self.X_test, self.y_test)

        
        
        # if include weights,add weights to the sampler
        if include_weights:
            train_sampler = torch.utils.data.WeightedRandomSampler(train_weights, num_samples = len(train_weights))
            val_sampler = torch.utils.data.WeightedRandomSampler(val_weights, num_samples = len(val_weights))
            shuffle = False
        else:
            train_sampler = val_sampler = None

        # create the loaders
        self.train_loader = DataLoader(train, 
                                        batch_size=batch_size,
                                         shuffle=shuffle, 
                                         sampler =train_sampler,
                                         num_workers=workers,
                                         pin_memory=pin_memory)
        self.val_loader = DataLoader(val, 
                                        batch_size=batch_size, 
                                        shuffle=shuffle, 
                                        sampler = val_sampler,
                                        num_workers=workers,
                                        pin_memory=pin_memory)
        
        
            
        self.dict_loaders = {'train': self.train_loader, 'val': self.val_loader}

        print('Created a loader with batch size {}, of length {}'.format(batch_size, len(self.train_loader)))

        # include the test loader
        if include_test:
            self.test_loader = DataLoader(test, 
                                        batch_size=batch_size, 
                                        shuffle=shuffle, 
                                        num_workers=workers,
                                        pin_memory=pin_memory)
            self.dict_loaders['test'] = self.test_loader



   





class HDF5Dataset(TensorDataset):
    def __init__(self, file_path, x_key, y_key, device, batch_size=128, indices=None):
        self.file = h5py.File(file_path, 'r')
        self.X = self.file[x_key]
        self.y = self.file[y_key]
        self.x_key = x_key
        self.y_key = y_key
        self.dataset_length = len(self.y)
        self.batch_size = batch_size
        self.num_batches = self.dataset_length // self.batch_size
        self.device = device

        if indices is None:
            self.indices = np.arange(self.dataset_length)
        else:
            self.indices = indices
        
        self.num_batches = len(self.indices) // self.batch_size
        

    def __len__(self):
        return self.num_batches + 1

    def __getitem__(self, idx):
        start_idx = idx * self.batch_size
        end_idx = start_idx + self.batch_size

        if end_idx > self.dataset_length:
            end_idx = self.dataset_length

        batch_indices = self.indices[start_idx:end_idx]
        
        X_batch = torch.from_numpy(self.X[batch_indices]).squeeze(0).to(self.device)
        y_batch = torch.from_numpy(self.y[batch_indices]).to(self.device)
        return X_batch, y_batch
    
    def __del__(self):
        self.file.close()

    def update_indices(self, new_indices):
        self.indices = new_indices
        self.num_batches = len(self.indices) // self.batch_size

       

    

    def load_embeddings_folder(self, model_folder):

        data = torch.load(model_folder + '/data.pt')

        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']

        y_train = data['y_train']
        y_val = data['y_val']
        y_test = data['y_test']

        g_train = data['g_train']
        g_val = data['g_val']
        g_test = data['g_test']

        return X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test
    

    

    

   


class CelebA(data):

    def __init__(self):
        super().__init__()

    def load_y_c(self, file_path, include_test=False):
        with h5py.File(file_path, 'r') as f:
            y_train = f['y_train'][:]
            c_train = f['c_train'][:]
            y_val = f['y_val'][:]
            c_val = f['c_val'][:]
            if include_test:
                y_test = f['y_test'][:]
                c_test = f['c_test'][:]
        
        if include_test:
            return y_train, c_train, y_val, c_val, y_test, c_test
        else:
            return y_train, c_train, y_val, c_val

    def map_to_g(self, y, c):
        if y==0 and c==0:
            g=1
        elif y==0 and c==1:
            g=2
        elif y==1 and c==0:
            g=3
        elif y==1 and c==1:
            g=4
        return g
    

    
    def create_g(self,y, c):

        # create a variable which indicates the combinations of y, c
        # if y ==0, c==0, g=1, y==0, c==1, g=2, etc.         
        g = torch.zeros(len(y))

        for i in range(len(y)):
            y_i, c_i = y[i], c[i]
            g_i = self.map_to_g(y_i, c_i)
            g[i]=g_i
        
        return g
    
    def load_embeddings_test(self, model_folder):

        data = torch.load(model_folder + '/data.pt')

        # get X
        X_test = data['X_test']

        # get y
        y_test = data['y_test']

        # get the group variables
        g_test = data['g_test']

        return X_test, y_test, g_test
    
    def load_embeddings_folder(self, model_folder, include_test=False):

        data = torch.load(model_folder + '/data.pt')

        # get X
        X_train = data['X_train']
        X_val = data['X_val']
       
        # get y
        y_train = data['y_train']
        y_val = data['y_val']

        # get the group variables
        g_train = data['g_train']
        g_val = data['g_val']

        # load the test data if needed
        if include_test:
            X_test = data['X_test']
            y_test = data['y_test']
            g_test = data['g_test']

            return X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test
        else:
            return X_train, X_val, y_train, y_val, g_train, g_val


      
        
    

    def return_CelebA_pred(self, seed):

        # load the pred
        pred_folder = 'embeddings/snellius_models/CelebA/JTT/CelebA_model_seed_{}'.format(seed)
        pred = torch.load(pred_folder + '/pred.pt')

        pred_train = pred['train']
        pred_val = pred['val']
        pred_test = pred['test']


        return pred_train, pred_val, pred_test


    def return_CelebA_data(self, seed, batch_size, early_stopping, augmentation, device , train_val_split='original',):

        # load the embeddings
        if early_stopping:
            ES_str = 'True'
        else:
            ES_str = 'False'
        
        if augmentation:
            augmentation_str = 'random_crop_True_random_flip_True'
        else:
            augmentation_str = 'random_crop_False_random_flip_False'

        # define the embedding string
        embedding_str = 'param_set_ES_{}_BS_{}_data_celebA_Blond_Hair_Female_{}'.format(ES_str, batch_size, augmentation_str)

        # add to the embedding string the train_val_split
        if train_val_split != 'original':
            train_val_split_str = str(train_val_split).replace('.', '')
            embedding_str += '_train_val_split_'+ train_val_split_str
        
        # load the embeddings
        embedding_folder = 'embeddings/snellius_models/CelebA/{}/CelebA_model_seed_{}'.format(embedding_str, seed)

        # load the data from the original folder
        X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test = self.load_embeddings_folder(embedding_folder, include_test=True)
      
        # turn to torch tensors
        y_train, y_val, y_test = torch.from_numpy(y_train).to(device), torch.from_numpy(y_val).to(device), torch.from_numpy(y_test).to(device)

        # unsqueeze the y's
        y_train, y_val, y_test = y_train.unsqueeze(1), y_val.unsqueeze(1), y_test.unsqueeze(1)

        # set the data attributes
        self.set_data_attributes(X_train, y_train, X_val, y_val, X_test, y_test, 'cpu', False, g_train, g_val, g_test)

        # demean and scale the data
        X_train, y_train, X_val, y_val, X_test, y_test = self.demean_scale_X( include_test=True)

        return  X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test

        
class multiNLI(data):
    def __init__(self):
        super().__init__()
        self.n_class_balanced_train=50000
        self.n_class_balanced_val=20000
        self.n_class_balanced_test=30000

    def return_multiNLI_pred(self, seed):

        pred_folder = 'embeddings/snellius_models/multiNLI/JTT/param_set_ES_True_BS_16_WS_1e-4_multiNLI/multiNLI_model_seed_{}'.format(seed)
        pred = torch.load(pred_folder + '/pred.pt')


        pred_train = pred['train']
        pred_val = pred['val']
        pred_test = pred['test']

        return pred_train, pred_val, pred_test


    def create_loaders(self, batch_size, shuffle, workers=0, pin_memory=True, include_test=False):
          
            train = TensorDataset(self.X_train, self.y_train)
            val = TensorDataset(self.X_val, self.y_val)
            if include_test:
                test = TensorDataset(self.X_test, self.y_test)


           
            # create the loaders
            self.train_loader = DataLoader(train, 
                                            batch_size=batch_size,
                                            shuffle=shuffle, 
                                            sampler =None,
                                            num_workers=workers,
                                            pin_memory=pin_memory)
            self.val_loader = DataLoader(val, 
                                            batch_size=batch_size, 
                                            shuffle=shuffle, 
                                            sampler = None,
                                            num_workers=workers,
                                            pin_memory=pin_memory)
        

            self.dict_loaders = {'train': self.train_loader, 'val': self.val_loader}

            print(f'Created a loader with batch size {batch_size}, of length {len(self.train_loader)}')

            if include_test:
                self.test_loader = DataLoader(test, 
                                            batch_size=batch_size, 
                                            shuffle=shuffle, 
                                            num_workers=workers,
                                            pin_memory=pin_memory)
                self.dict_loaders['test'] = self.test_loader

    
    def load_embeddings(self, seed, batch_size_model, early_stopping_model, device):

        # folder
        #folder = 'embeddings/snellius_models/multiNLI/standard_param_balanced_fin/param_set_ES_False_BS_16_multiNLI/multiNLI_model_seed_1'
        folder = 'embeddings/snellius_models/multiNLI/standard_param_balanced_final/param_set_ES_{}_BS_{}_multiNLI/multiNLI_model_seed_{}'.format(early_stopping_model, batch_size_model, seed)

        # load the data
        data = torch.load(folder + '/data.pt')

        # get the data
        X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test = data['X_train'], data['X_val'], data['X_test'], data['y_train'], data['y_val'], data['y_test'], data['g_train'], data['g_val'], data['g_test']

        # set to device
        X_train, X_val, X_test = X_train.to(device), X_val.to(device), X_test.to(device)
        y_train, y_val, y_test = y_train.to(device), y_val.to(device), y_test.to(device)
        g_train, g_val, g_test = g_train.to(device), g_val.to(device), g_test.to(device)

        # unsqueeze the y's
        y_train, y_val, y_test = y_train.unsqueeze(1), y_val.unsqueeze(1), y_test.unsqueeze(1)

        # set the data attributes
        self.set_data_attributes(X_train, y_train, X_val, y_val, X_test, y_test, device, False, g_train, g_val, g_test)

        # demean and scale the data
        X_train, y_train, X_val, y_val, X_test, y_test = self.demean_scale_X( include_test=True)

        return X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val,g_test 

        
    def create_class_balanced_set(self, X, y,g, n, seed):
        # set the seed
        set_seed(seed)

        #which indices are class 0 and class 1
        i_class_0 = torch.where(y== 0)[0]
        i_class_1 = torch.where(y == 1)[0]

        # select n/2 samples from each class
        n_class_0 = len(i_class_0)
        n_class_1 = len(i_class_1)
        i_sample_class_0 = torch.randperm(n_class_0)[:n//2]
        i_sample_class_1 = torch.randperm(n_class_1)[:n//2]
        

        # get the indices
        i_sample = torch.cat((i_class_0[i_sample_class_0], i_class_1[i_sample_class_1]))

        print('i of sample are: ', i_sample)

        # get the data
        X_sample = X[i_sample,:]
        y_sample = y[i_sample]
        g_sample = g[i_sample]

        return X_sample, y_sample, g_sample
    



    def load_tokens(self, directory, turn_binary=True, class_balanced=True):
            """
            Load the original  dataset
            """

            # set the directory of the features
            directory_features = directory+'/multiNLI_bert_features'

            # load the metadata
            metadata_df = pd.read_csv(directory +  "/metadata_random.csv")

            # Load features
            self.features_array = []
            for feature_file in [
                'cached_train_bert-base-uncased_128_mnli',  
                'cached_dev_bert-base-uncased_128_mnli',
                'cached_dev_bert-base-uncased_128_mnli-mm'
                ]:
                print(os.path.join(
                        directory_features,
                        feature_file) )
                features = torch.load(
                    os.path.join(
                        directory_features,
                        feature_file))

                self.features_array += features
            
            # get the input ids, input masks, segment ids, label ids
            self.all_input_ids = torch.tensor([f.input_ids for f in self.features_array], dtype=torch.long)
            self.all_input_masks = torch.tensor([f.input_mask for f in self.features_array], dtype=torch.long)
            self.all_segment_ids = torch.tensor([f.segment_ids for f in self.features_array], dtype=torch.long)
            self.all_label_ids = torch.tensor([f.label_id for f in self.features_array], dtype=torch.long)
            

            # if not turn_binary, turn to a 3-dimensional tensor, with [1, 0, 0] if 0, [0, 1, 0] if 1, etc. 
            if turn_binary:
                self.y = (self.all_label_ids == 0).squeeze(-1).to(torch.float16)
            else:
                self.y = torch.nn.functional.one_hot(self.all_label_ids, num_classes=3).to(torch.float16)

            # get the train_ids from the metadata
            self.train_ids = torch.Tensor(metadata_df[metadata_df['split'] == 0].index.values).long()
            self.val_ids = torch.Tensor(metadata_df[metadata_df['split'] == 1].index.values).long()
            self.test_ids = torch.Tensor(metadata_df[metadata_df['split'] == 2].index.values).long()

            # get the train, val, test input ids
            self.input_ids_train = self.all_input_ids[self.train_ids, :]
            self.input_ids_val = self.all_input_ids[self.val_ids, :]
            self.input_ids_test = self.all_input_ids[self.test_ids, :]

            # get the train, val, test input masks
            self.input_masks_train = self.all_input_masks[self.train_ids, :]
            self.input_masks_val = self.all_input_masks[self.val_ids, :]
            self.input_masks_test = self.all_input_masks[self.test_ids, :]

            # get the train, val, test segment ids
            self.segment_ids_train = self.all_segment_ids[self.train_ids, :]
            self.segment_ids_val = self.all_segment_ids[self.val_ids, :]
            self.segment_ids_test = self.all_segment_ids[self.test_ids, :]

            # combine the input ids, input masks, segment ids 
            self.X_train = torch.stack((self.input_ids_train, self.input_masks_train, self.segment_ids_train), dim=2)
            self.X_val = torch.stack((self.input_ids_val, self.input_masks_val, self.segment_ids_val), dim=2)
            self.X_test = torch.stack((self.input_ids_test, self.input_masks_test, self.segment_ids_test), dim=2)

            # get the train, val, test labels for main
            if not turn_binary:
                self.y_train = self.y[self.train_ids, ]
                self.y_val = self.y[self.val_ids, ]
                self.y_test = self.y[self.test_ids, ]
            else:
                self.y_train = self.y[self.train_ids]
                self.y_val = self.y[self.val_ids]
                self.y_test = self.y[self.test_ids]

            # get the group label
            self.c = torch.tensor(metadata_df['sentence2_has_negation'].values).long()
            self.g = self.create_g(self.y, self.c)
            self.g_train = self.g[self.train_ids]
            self.g_val = self.g[self.val_ids]
            self.g_test = self.g[self.test_ids]

            if class_balanced:
                self.X_train, self.y_train, self.g_train = self.create_class_balanced_set(self.X_train, self.y_train, self.g_train, self.n_class_balanced_train, 0)
                self.X_val, self.y_val, self.g_val = self.create_class_balanced_set(self.X_val, self.y_val, self.g_val, self.n_class_balanced_val, 0)
                self.X_test, self.y_test, self.g_test = self.create_class_balanced_set(self.X_test, self.y_test, self.g_test, self.n_class_balanced_test, 0)
            print('group division in train: {}'.format(np.unique(self.g_train, return_counts=True)))
        

    def map_to_g(self, y, c):
            if y==0 and c==0:
                g=1
            elif y==0 and c==1:
                g=2
            elif y==1 and c==0:
                g=3
            elif y==1 and c==1:
                g=4
            return g
    
    def create_g(self,y, c):

        # create a variable which indicates the combinations of y, c
        # if y ==0, c==0, g=1, y==0, c==1, g=2, etc.         
        g = torch.zeros(len(y))

        for i in range(len(y)):
            y_i, c_i = y[i], c[i]
            g_i = self.map_to_g(y_i, c_i)
            g[i]=g_i
        
        return g
            

       


class WB(data):

    def __init__(self):
        super().__init__()


    def return_WB_model(self, seed, batch_size, early_stopping, augmentation, device):

        # load the embeddings
        if early_stopping:
            ES_str = 'True'
        else:
            ES_str = 'False'
        
        if augmentation:
            augmentation_str = 'jitter_True_random_crop_True_random_flip_True'
        else:
            augmentation_str = 'jitter_False_random_crop_False_random_flip_False'
        
        # set the model file
        model_file = 'models/snellius_models/WB/param_set_ES_{}_BS_{}_data_WB_95_{}/CUB_model_WB_95_seed_{}.pt'.format(ES_str, batch_size, augmentation_str, seed)

        # load the model
        model = torch.load(model_file, map_location=device)

        return model
    
    def return_WB_pred(self, seed, batch_size, early_stopping, augmentation,JTT_folder=False, weight_decay=1):

        # load the embeddings
        if early_stopping:
            ES_str = 'True'
        else:
            ES_str = 'False'
        
        if augmentation:
            augmentation_str = 'jitter_True_random_crop_True_random_flip_True'
        else:
            augmentation_str = 'jitter_False_random_crop_False_random_flip_False'
        
       # load the pred
        if JTT_folder:
            pred_folder = 'embeddings/snellius_models/WB/JTT/param_set_ES_{}_BS_{}_WD_{}_data_WB_95_{}/WB_model_seed_{}'.format(ES_str, batch_size, str(weight_decay), augmentation_str, seed)
        else:
            pred_folder = 'embeddings/snellius_models/WB/param_set_ES_{}_BS_{}_data_WB_95_{}/CUB_model_WB_95_seed_{}'.format(ES_str, batch_size, augmentation_str, seed)
        print('Loading pred from:', pred_folder)

        pred = torch.load(pred_folder + '/pred.pt')

        pred_train = pred['train']
        pred_val = pred['val']
        pred_test = pred['test']


        return pred_train, pred_val, pred_test



    def return_WB_embeddings(self, seed, batch_size, early_stopping, augmentation, JTT_folder=False, weight_decay=1):


        # load the embeddings
        if early_stopping:
            ES_str = 'True'
        else:
            ES_str = 'False'
        
        if augmentation:
            augmentation_str = 'jitter_True_random_crop_True_random_flip_True'
        else:
            augmentation_str = 'jitter_False_random_crop_False_random_flip_False'
        
        # load the embeddings
        if JTT_folder:
            embedding_folder = 'embeddings/snellius_models/WB/JTT/param_set_ES_{}_BS_{}_WD_{}_data_WB_95_{}/WB_model_seed_{}'.format(ES_str, batch_size, str(weight_decay), augmentation_str, seed)
        else:
            embedding_folder = 'embeddings/snellius_models/WB/param_set_ES_{}_BS_{}_data_WB_95_{}/CUB_model_WB_95_seed_{}'.format(ES_str, batch_size, augmentation_str, seed)
        print('Loading from:', embedding_folder)
        X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test = self.load_embeddings_folder(embedding_folder)
        X_train, X_val, y_train, y_val = X_train.cpu(), X_val.cpu(), y_train.cpu(), y_val.cpu()

        # set the data attributes
        self.set_data_attributes(X_train, y_train, X_val, y_val, X_test, y_test, 'cpu', False, g_train, g_val, g_test)

        # demean and scale the data
        X_train, y_train, X_val, y_val, X_test, y_test = self.demean_scale_X( include_test=True)

        return  X_train, y_train, X_val, y_val, X_test, y_test, g_train, g_val, g_test

    def load_embeddings_folder(self, model_folder):

        data = torch.load(model_folder + '/data.pt')

        X_train = data['X_train']
        X_val = data['X_val']
        X_test = data['X_test']

        y_train = data['y_train']
        y_val = data['y_val']
        y_test = data['y_test']

        g_train = data['g_train']
        g_val = data['g_val']
        g_test = data['g_test']

        return X_train, X_val, X_test, y_train, y_val, y_test, g_train, g_val, g_test
    
    def load_data(self, file_path):

        # load the data
        data = super().load_data(file_path)

        return data



    
    def map_to_g(self, y, c):
        if y==0 and c==0:
            g=1
        elif y==0 and c==1:
            g=2
        elif y==1 and c==0:
            g=3
        elif y==1 and c==1:
            g=4
        return g
    

    
    def create_g(self,y, c):

        # create a variable which indicates the combinations of y, c
        # if y ==0, c==0, g=1, y==0, c==1, g=2, etc.         
        g = torch.zeros(len(y))

        for i in range(len(y)):
            y_i, c_i = y[i], c[i]
            g_i = self.map_to_g(y_i, c_i)
            g[i]=g_i
        
        return g


    




class Toy(data):


    def __init__(self,  n, p_1, beta_1, beta_0, sigma_1, sigma_0, mu, gamma, a_0, a_1, d=1, add_third_group=False, p_2=None, beta_2=None, a_2=None, sigma_2=None):
        super().__init__()


        # set the parameters of the Toy data
        self.n = n
        self.p_1 = p_1
        self.beta_1 = beta_1
        self.beta_0 = beta_0
        self.sigma_1 = sigma_1
        self.sigma_0 = sigma_0
        self.mu = mu
        self.gamma = gamma
        self.a_0 = a_0
        self.a_1 = a_1
        self.d = d
        self.add_third_group = add_third_group

        if add_third_group:
            self.beta_2 = beta_2
            self.a_2 = a_2
            self.p_2 = p_2
            self.sigma_2 = sigma_2
            self.n_0 = int(np.round(n * (1 - p_1 - p_2)))
            self.n_1 = int(np.round(n * p_1))
            self.n_2 = n - self.n_0 - self.n_1

        else:
            self.n_0 = int(np.round(n * (1 - p_1)))
            self.n_1 = n - self.n_0

        
    def gen_for_sim(self, n_sims):

        # create matrix of size n_sims x n for x
        X_sims = np.zeros((n_sims, self.n, self.d))

        # create matrix of size n_sims x n for y
        y_sims = np.zeros((n_sims, self.n))

        # create matrix of size n_sims x n for g
        g_sims = np.zeros((n_sims, self.n))

        # generate for each sim
        for i in range(n_sims):
            X_sims[i], y_sims[i], g_sims[i] = self.dgp()

        return X_sims, y_sims, g_sims

    def dgp(self, uv=True, add_index_group=True, logistic=False):

        if uv:
            X, y, g = self.dgp_uv(add_index_group=add_index_group, logistic=logistic)
            return X, y, g
        else:
            X, y, g = self.dgp_mv(add_index_group=add_index_group, logistic=logistic)
        return X, y, g

    def dgp_mv(self, add_index_group=True, logistic=False):
        

        # generate multivariate x with mean mu and covariance matrix based on gamma
        Sigma = np.diag(np.ones(self.d)) * self.gamma
        X = np.random.multivariate_normal(self.mu, Sigma, self.n)

        # create n obs. for g
        g = np.zeros(self.n)

        # set n_1 random obs. to 1
        i_1 = np.random.choice(self.n, self.n_1, replace=False)
        g[i_1] = 1

        if self.add_third_group:
            # get the indices that are 0
            i_0 = np.where(g == 0)[0]

            # from these indices, set n_2 random obs. to 2
            i_2 = np.random.choice(i_0, self.n_2, replace=False)
            g[i_2] = 2

        # create empty y
        y = np.zeros(self.n)

        # if g = 1, generate y with mean beta_1*x + a and variance sigma_1
        X_1 = X[g == 1]
        eps_1 = np.random.normal(0, np.sqrt(self.sigma_1), self.n_1).reshape(-1, 1)
        B_1 = np.zeros((self.d, 1))
        B_1[0] = self.beta_1


        # if logistic, first generate the p(y=1)
        if logistic:
            p_1 = 1/(1 + np.exp(-np.matmul(X_1, B_1) - self.a_1))
            y_1 = np.random.binomial(1, p_1, size=(self.n_1, 1))
    
        else:
            y_1 = np.matmul(X_1, B_1) + self.a_1 + eps_1
        
        y[g == 1] = y_1.squeeze(-1)

        # if g = 0, generate y with mean beta_0*x and variance sigma_2
        X_0 = X[g == 0]
        eps_0 = np.random.normal(0, np.sqrt(self.sigma_0), self.n_0).reshape(-1, 1)
        B_0 =  np.zeros((self.d, 1))
        B_0[0] = self.beta_0

        # if logistic, first generate the p(y=1)
        if logistic:
            p_0 = 1/(1 + np.exp(-np.matmul(X_0, B_0) - self.a_0))
            y_0 = np.random.binomial(1, p_0, size=(self.n_0, 1))
        else:
            y_0 = np.matmul(X_0, B_0) + self.a_0 + eps_0
        y[g == 0] = y_0.squeeze(-1)

        if self.add_third_group:
            # if g = 2, generate y with mean beta_2*x and variance sigma_2
            X_2 = X[g == 2]
            eps_2 = np.random.normal(0, np.sqrt(self.sigma_2), self.n_2).reshape(-1, 1)
            B_2 = np.zeros((self.d, 1))
            B_2[0] = self.beta_2
            if logistic:
                p_2 = 1/(1 + np.exp(-np.matmul(X_2, B_2) - self.a_2))
                y_2 = np.random.binomial(1, p_2, size=(self.n_2, 1))
            else:
                y_2 = np.matmul(X_2, B_2) + self.a_2 + eps_2
            y[g == 2] = y_2.squeeze(-1)

        y, g = np.expand_dims(y, -1), np.expand_dims(g, -1)

        # save the parameters
        self.B_1 = B_1
        self.B_0 = B_0

        if add_index_group:
            g += 1

        return X, y, g


            

    def dgp_uv(self):
        """
        Generate Toy data based on parameters
        """

        # generate an x with mean mu and variance gamma
        x = np.random.normal(self.mu, np.sqrt(self.gamma), self.n)

        # create n obs. for g
        g = np.zeros(self.n)

        # set n_1 random obs. to 1
        i_1 = np.random.choice(self.n, self.n_1, replace=False)
        g[i_1] = 1

        # set n_2 random obs. to 2
        if self.add_third_group:
            i_2 = np.random.choice(self.n, self.n_2, replace=False)
            g[i_2] = 2
       
        # create empty y
        y = np.zeros(self.n)

        # if g = 1, generate y with mean beta_1*x + a and variance sigma_1
        x_1 = x[g == 1]
        n_1= len(x_1)
        eps_1 = np.random.normal(0, np.sqrt(self.sigma_1), n_1)
        y_1 = (x_1 * self.beta_1) + self.a_1 + eps_1
        y[g == 1] = y_1

        # if g = 0, generate y with mean beta_0*x and variance sigma_2
        x_0 = x[g == 0]
        n_0 = len(x_0)
        eps_0 = np.random.normal(0, np.sqrt(self.sigma_0), n_0)
        y_0 = (x_0 * self.beta_0) + self.a_0 + eps_0
        y[g == 0] = y_0

        if self.add_third_group:
            # if g = 2, generate y with mean beta_2*x and variance sigma_2
            x_2 = x[g == 2]
            n_2 = len(x_2)
            eps_2 = np.random.normal(0, np.sqrt(self.sigma_0), n_2)
            y_2 = (x_2 * self.beta_2) + self.a_2 + eps_2
            y[g == 2] = y_2

        # expand the dim of x, y, g
        x, y, g=  np.expand_dims(x, -1), np.expand_dims(y, -1), np.expand_dims(g, -1)
       
        return x, y, g