import numpy as np
import torch
import os
import copy
from scipy import fft
import h5py
import hashlib
from torch.utils.data import Dataset, DataLoader

class mydataset(object):
    def __init__(self, args):
        self.x = None
        self.labels = None
        self.dlabels = None
        self.pclabels = None
        self.pdlabels = None
        self.task = None
        self.dataset = None
        self.transform = None
        self.target_transform = None
        self.loader = None
        self.args = args

    def set_labels(self, tlabels=None, label_type='domain_label'):
        assert len(tlabels) == len(self.x)
        if label_type == 'pclabel':
            self.pclabels = tlabels
        elif label_type == 'pdlabel':
            self.pdlabels = tlabels
        elif label_type == 'domain_label':
            self.dlabels = tlabels
        elif label_type == 'class_label':
            self.labels = tlabels

    def set_labels_by_index(self, tlabels=None, tindex=None, label_type='domain_label'):
        if label_type == 'pclabel':
            self.pclabels[tindex] = tlabels
        elif label_type == 'pdlabel':
            self.pdlabels[tindex] = tlabels
        elif label_type == 'domain_label':
            self.dlabels[tindex] = tlabels
        elif label_type == 'class_label':
            self.labels[tindex] = tlabels

    def target_trans(self, y):
        if self.target_transform is not None:
            return self.target_transform(y)
        else:
            return y

    def input_trans(self, x):
        if self.transform is not None:
            return self.transform(x)
        else:
            return x

    def __getitem__(self, index):
        x = self.input_trans(self.x[index])
        
        ctarget = self.target_trans(self.labels[index])
        dtarget = self.target_trans(self.dlabels[index])
        pctarget = self.target_trans(self.pclabels[index])
        pdtarget = self.target_trans(self.pdlabels[index])
        # print(x.shape, ctarget.shape, dtarget.shape, pctarget.shape, pdtarget.shape, index) 
        return x, ctarget, dtarget, pctarget, pdtarget, index

    def __len__(self):
        return len(self.x)

def seed_hash(*args):
    """
    Derive an integer hash from all args, for use as a random seed.

    This is took from DomainBed repository:
        https://github.com/facebookresearch/DomainBed
    """
    args_str = str(args)
    return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)

def get_split(dataset, holdout_fraction, seed=0, sort=False):
    """ Generates the keys that are used to split a Torch TensorDataset into (1-holdout_fraction) / holdout_fraction.

    Args:
        dataset (TensorDataset): TensorDataset to be split
        holdout_fraction (float): Fraction of the dataset that is gonna be in the out (validation) set
        seed (int, optional): seed used for the shuffling of the data before splitting. Defaults to 0.
        sort (bool, optional): If ''True'' the dataset is gonna be sorted after splitting. Defaults to False.

    Returns:
        list: in (1-holdout_fraction) keys of the split
        list: out (holdout_fraction) keys of the split
    """

    split = int(len(dataset)*holdout_fraction)

    keys = list(range(len(dataset)))
    np.random.RandomState(seed).shuffle(keys)
    
    in_keys = keys[split:]
    out_keys = keys[:split]
    if sort:
        in_keys.sort()
        out_keys.sort()

    return in_keys, out_keys

def XOR(a, b):
    """ Returns a XOR b (the 'Exclusive or' gate) 
    
    Args:
        a (bool): First input
        b (bool): Second input

    Returns:
        bool: The output of the XOR gate
    """
    return ( a - b ).abs()

def bernoulli(p, size):
    """ Returns a tensor of 1. (True) or 0. (False) resulting from the outcome of a bernoulli random variable of parameter p.
    
    Args:
        p (float): Parameter p of the Bernoulli distribution
        size (int...): A sequence of integers defining hte shape of the output tensor

    Returns:
        Tensor: Tensor of Bernoulli random variables of parameter p
    """
    return ( torch.rand(size) < p ).float()


def make_split(dataset, holdout_fraction, seed=0, sort=False):
    """ Split a Torch TensorDataset into (1-holdout_fraction) / holdout_fraction.

    Args:
        dataset (TensorDataset): Tensor dataset that has 2 tensors -> data, targets
        holdout_fraction (float): Fraction of the dataset that is gonna be in the validation set
        seed (int, optional): seed used for the shuffling of the data before splitting. Defaults to 0.
        sort (bool, optional): If ''True'' the dataset is gonna be sorted after splitting. Defaults to False.

    Returns:
        TensorDataset: 1-holdout_fraction part of the split
        TensorDataset: holdout_fractoin part of the split
    """

    in_keys, out_keys = get_split(dataset, holdout_fraction, seed=seed, sort=sort)

    in_split = dataset[in_keys]
    out_split = dataset[out_keys]

    return torch.utils.data.TensorDataset(*in_split), torch.utils.data.TensorDataset(*out_split)

class H5_dataset(Dataset):
    """ HDF5 dataset for EEG data

    The HDF5 file is expected to have the following nested dict structure::

        {'env0': {'data': np.array(n_samples, time_steps, input_size), 
                  'labels': np.array(n_samples, len(PRED_TIME))},
        'env1': {'data': np.array(n_samples, time_steps, input_size), 
                 'labels': np.array(n_samples, len(PRED_TIME))}, 
        ...}

    Good thing about this is that it imports data only when it needs to and thus saves ram space

    Args:
        h5_path (str): absolute path to the hdf5 file
        env_id (int): environment id key in the hdf5 file
        split (list): list of indices of the dataset the belong to the split. If 'None', all the data is used
    """
    def __init__(self, h5_path, env_id, domain_id, split=None):
        self.h5_path = h5_path
        self.env_id = env_id

        self.hdf = h5py.File(self.h5_path, 'r')
        self.data = self.hdf[env_id]['data'][:]
        self.labels = self.hdf[env_id]['labels'][:].squeeze(axis=1)
        self.domains = np.array([domain_id]*self.labels.shape[0])

        self.split = list(range(self.hdf[env_id]['data'].shape[0])) if split==None else split

    def __len__(self):
        """ Number of samples in the dataset """
        return len(self.split)

    def __getitem__(self, idx):
        """ Get a sample from the dataset """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        split_idx = self.split[idx]
        
        seq = torch.as_tensor(self.data[split_idx, ...])
        labels = torch.as_tensor(self.labels[split_idx])
        domain = torch.as_tensor(self.domains[split_idx])

        return (seq, labels, domain)

    def close(self):
        """ Close the hdf5 file link """    
        self.hdf.close()

class PCLDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        self.SEQ_LEN = 750
        self.PRED_TIME = [749]
        self.INPUT_SHAPE = [48]
        self.OUTPUT_SIZE = 2
        self.ENVS = ['PhysionetMI', 'Cho2017', 'Lee2019_MI']
        self.DATA_PATH = 'PCL/PCL.h5'
        self.holdout_fraction = 0.2
        self.root_path = root_path
        self.trial_seed = 0
        target_domain = args.test_envs[0]
        domains = []
        vals = []
        
        for j, e in enumerate(self.ENVS):
            # Get full environment dataset and define in/out split
            full_dataset = H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j)
            in_split, out_split = get_split(full_dataset, self.holdout_fraction, seed=seed_hash(j, self.trial_seed))
            full_dataset.close()
            domains.append(H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j, in_split))
            vals.append(H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j, out_split))
        
        self.feature_df = domains[target_domain].data
        self.labels_df = domains[target_domain].labels
        self.domain_df = domains[target_domain].domains

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        if flag == "TRAIN":
            train_list = [domains[i].data for i in range(len(domains)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i].labels for i in range(len(domains)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i].domains for i in range(len(domains)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        elif flag == "VAL":
            train_list = [domains[i].data for i in range(len(vals)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i].labels for i in range(len(vals)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i].domains for i in range(len(vals)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        
        self.x = np.transpose(self.feature_df, (0, 2, 1))
        self.x = self.x.reshape(-1, self.x.shape[1], 1, self.x.shape[2])

        self.labels = torch.from_numpy(self.labels_df)
        self.dlabels = torch.from_numpy(self.domain_df)
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None
    
    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

    def normalize(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        x = (x - mean) / std
        return x


class HHARDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        self.SEQ_LEN = 500
        self.PRED_TIME = [499]
        self.INPUT_SHAPE = [6]
        self.OUTPUT_SIZE = 6
        self.ENVS = ['nexus4', 's3', 's3mini', 'lgwatch', 'gear']
        self.DATA_PATH = 'HHAR/HHAR.h5'
        self.holdout_fraction = 0.2
        self.root_path = root_path
        self.trial_seed = 0
        target_domain = args.test_envs[0]
        domains = []
        vals = []
        
        for j, e in enumerate(self.ENVS):
            # Get full environment dataset and define in/out split
            full_dataset = H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j)
            in_split, out_split = get_split(full_dataset, self.holdout_fraction, seed=seed_hash(j, self.trial_seed))
            full_dataset.close()
            domains.append(H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j, in_split))
            vals.append(H5_dataset(os.path.join(self.root_path, self.DATA_PATH), e, j, out_split))
        
        self.feature_df = domains[target_domain].data
        self.labels_df = domains[target_domain].labels
        self.domain_df = domains[target_domain].domains

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        if flag == "TRAIN":
            train_list = [domains[i].data for i in range(len(domains)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i].labels for i in range(len(domains)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i].domains for i in range(len(domains)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        elif flag == "VAL":
            train_list = [domains[i].data for i in range(len(vals)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i].labels for i in range(len(vals)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i].domains for i in range(len(vals)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        # 128 9
        # -1 9 1 128
        self.x = np.transpose(self.feature_df, (0, 2, 1))
        self.x = self.x.reshape(-1, self.x.shape[1], 1, self.x.shape[2])

        self.labels = torch.from_numpy(self.labels_df)
        self.dlabels = torch.from_numpy(self.domain_df)
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None
    
    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

    def normalize(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        x = (x - mean) / std
        return x


class SpuriousFourierDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        self.SEQ_LEN = 50
        self.PRED_TIME = [49]
        self.INPUT_SHAPE = [1]
        self.OUTPUT_SIZE = 2
        self.ENVS = [0.1, 0.8, 0.9]
        self.LABEL_NOISE = 0.25
        self.holdout_fraction = 0.2
        self.trial_seed = 0
                ## Define label 0 and 1 Fourier spectrum
        self.fourier_0 = np.zeros(1000)
        self.fourier_0[900] = 1
        self.fourier_1 = np.zeros(1000)
        self.fourier_1[700] = 1

        ## Define the spurious Fourier spectrum (one direct and the inverse correlation)
        self.direct_fourier_0 = copy.deepcopy(self.fourier_0)
        self.direct_fourier_1 = copy.deepcopy(self.fourier_1)
        self.direct_fourier_0[200] = 0.5
        self.direct_fourier_1[400] = 0.5

        self.inverse_fourier_0 = copy.deepcopy(self.fourier_0)
        self.inverse_fourier_1 = copy.deepcopy(self.fourier_1)
        self.inverse_fourier_0[400] = 0.5
        self.inverse_fourier_1[200] = 0.5

        ## Create the sequences for direct and inverse
        direct_signal_0 = fft.irfft(self.direct_fourier_0, n=10000)
        direct_signal_0 = torch.tensor( direct_signal_0 ).float()
        direct_signal_0 /= direct_signal_0.max()
        direct_signal_0 = self.super_sample(direct_signal_0)
        direct_signal_1 = fft.irfft(self.direct_fourier_1, n=10000)
        direct_signal_1 = torch.tensor( direct_signal_1 ).float()
        direct_signal_1 /= direct_signal_1.max()
        direct_signal_1 = self.super_sample(direct_signal_1)

        perm_0 = torch.randperm(direct_signal_0.shape[0])
        direct_signal_0 = direct_signal_0[perm_0,:]
        perm_1 = torch.randperm(direct_signal_1.shape[0])
        direct_signal_1 = direct_signal_1[perm_1,:]
        direct_signal = [direct_signal_0, direct_signal_1]

        inverse_signal_0 = fft.irfft(self.inverse_fourier_0, n=10000)
        inverse_signal_0 = torch.tensor( inverse_signal_0 ).float()
        inverse_signal_0 /= inverse_signal_0.max()
        inverse_signal_0 = self.super_sample(inverse_signal_0)
        inverse_signal_1 = fft.irfft(self.inverse_fourier_1, n=10000)
        inverse_signal_1 = torch.tensor( inverse_signal_1 ).float()
        inverse_signal_1 /= inverse_signal_1.max()
        inverse_signal_1 = self.super_sample(inverse_signal_1)

        perm_0 = torch.randperm(inverse_signal_0.shape[0])
        inverse_signal_0 = inverse_signal_0[perm_0,:]
        perm_1 = torch.randperm(inverse_signal_1.shape[0])
        inverse_signal_1 = inverse_signal_1[perm_1,:]
        inverse_signal = [inverse_signal_0, inverse_signal_1]

        ## Create the environments with different correlations
        env_size = 4000

        target_domain = args.test_envs[0]
        domains = []
        vals = []

        for i, e in enumerate(self.ENVS):
            ## Create set of labels
            env_labels_0 = torch.zeros((env_size // 2, 1)).long()
            env_labels_1 = torch.ones((env_size // 2, 1)).long()
            env_labels = torch.cat((env_labels_0, env_labels_1))

            ## Fill signal
            env_signal = torch.zeros((env_size, 50, 1))
            for j, label in enumerate(env_labels):

                # Label noise
                if bool(bernoulli(self.LABEL_NOISE, 1)):
                    # Correlation to label
                    if bool(bernoulli(e, 1)):
                        env_signal[j,...] = inverse_signal[label][0,...]
                        inverse_signal[label] = inverse_signal[label][1:,...]
                    else:
                        env_signal[j,...] = direct_signal[label][0,...]
                        direct_signal[label] = direct_signal[label][1:,...]
                    
                    # Flip the label
                    env_labels[j, -1] = XOR(label, 1)
                else:
                    if bool(bernoulli(e, 1)):
                        env_signal[j,...] = direct_signal[label][0,...]
                        direct_signal[label] = direct_signal[label][1:,...]
                    else:
                        env_signal[j,...] = inverse_signal[label][0,...]
                        inverse_signal[label] = inverse_signal[label][1:,...]

            # Make Tensor dataset
            env_domains = torch.tensor([i]*env_size).long()
            dataset = torch.utils.data.TensorDataset(env_signal, env_labels.squeeze(1), env_domains)
            # print(env_signal.shape, env_labels.shape, env_domains.shape)
            in_dataset, out_dataset = make_split(dataset, self.holdout_fraction, seed=seed_hash(i, self.trial_seed))
            domains.append(in_dataset)
            vals.append(out_dataset)
        
        self.feature_df = domains[target_domain][:][0]
        self.labels_df = domains[target_domain][:][1]
        self.domain_df = domains[target_domain][:][2]

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        if flag == "TRAIN":
            train_list = [domains[i][:][0] for i in range(len(domains)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][:][1] for i in range(len(domains)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i][:][2] for i in range(len(domains)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        elif flag == "VAL":
            train_list = [domains[i][:][0] for i in range(len(vals)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][:][1] for i in range(len(vals)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i][:][2] for i in range(len(vals)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        # B, 50, 1)
        self.x = np.transpose(self.feature_df, (0, 2, 1))
        self.x = self.x.reshape(-1, self.x.shape[1], 1, self.x.shape[2])
        self.labels = self.labels_df
        self.dlabels = self.domain_df
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None
    
    def super_sample(self, signal):
        """ Sample signals frames with a bunch of offsets 
        
        Args:
            signal (torch.Tensor): Signal to sample
        
        Returns:
            torch.Tensor: Super sampled signal
        """
        import matplotlib.pyplot as plt
        all_signal = torch.zeros(0,50,1)
        for i in range(0, 50, 2):
            new_signal = copy.deepcopy(signal)[i:i-50]
            split_signal = new_signal.reshape(-1,50,1).clone().detach().float()
            all_signal = torch.cat((all_signal, split_signal), dim=0)
        
        return all_signal
    
    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

    def normalize(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        x = (x - mean) / std
        return x