# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.


import numpy as np
import torch
import os
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import datautil.sliding_window as sliding_window

def Nmax(args, d):
    for i in range(len(args.test_envs)):
        if d < args.test_envs[i]:
            return i
    return len(args.test_envs)

def random_pairs_of_minibatches_by_domainperm(minibatches):
    perm = torch.randperm(len(minibatches)).tolist()
    pairs = []

    for i in range(len(minibatches)):
        j = i + 1 if i < (len(minibatches) - 1) else 0

        xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1]
        xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1]

        min_n = min(len(xi), len(xj))

        pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n])))

    return pairs


def random_pairs_of_minibatches(args, minibatches):
    ld = len(minibatches)
    pairs = []
    tdlist = np.arange(ld)
    # print("ld", tdlist)
    txlist = np.arange(args.batch_size)
    for i in range(ld):
        for j in range(args.batch_size):
            (tdi, tdj), (txi, txj) = np.random.choice(tdlist, 2,
                                                      replace=False), np.random.choice(txlist, 2, replace=True)
            if j == 0:
                xi, yi, di = torch.unsqueeze(
                    minibatches[tdi][0][txi], dim=0), minibatches[tdi][1][txi], minibatches[tdi][2][txi]
                xj, yj, dj = torch.unsqueeze(
                    minibatches[tdj][0][txj], dim=0), minibatches[tdj][1][txj], minibatches[tdj][2][txj]
            else:
                # print(len(minibatches[tdi][0]), len(minibatches[tdj][0]))
                xi, yi, di = torch.vstack((xi, torch.unsqueeze(minibatches[tdi][0][txi], dim=0))), torch.hstack(
                    (yi, minibatches[tdi][1][txi])), torch.hstack((di, minibatches[tdi][2][txi]))
                xj, yj, dj = torch.vstack((xj, torch.unsqueeze(minibatches[tdj][0][txj], dim=0))), torch.hstack(
                    (yj, minibatches[tdj][1][txj])), torch.hstack((dj, minibatches[tdj][2][txj]))
        pairs.append(((xi, yi, di), (xj, yj, dj)))
    return pairs


class basedataset(object):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __getitem__(self, index):
        return self.x[index], self.y[index]

    def __len__(self):
        return len(self.x)


class mydataset(object):
    def __init__(self, args):
        self.x = None
        self.labels = None
        self.dlabels = None
        self.pclabels = None
        self.pdlabels = None
        self.task = None
        self.dataset = None
        self.transform = None
        self.target_transform = None
        self.loader = None
        self.args = args

    def set_labels(self, tlabels=None, label_type='domain_label'):
        assert len(tlabels) == len(self.x)
        if label_type == 'pclabel':
            self.pclabels = tlabels
        elif label_type == 'pdlabel':
            self.pdlabels = tlabels
        elif label_type == 'domain_label':
            self.dlabels = tlabels
        elif label_type == 'class_label':
            self.labels = tlabels

    def set_labels_by_index(self, tlabels=None, tindex=None, label_type='domain_label'):
        if label_type == 'pclabel':
            self.pclabels[tindex] = tlabels
        elif label_type == 'pdlabel':
            self.pdlabels[tindex] = tlabels
        elif label_type == 'domain_label':
            self.dlabels[tindex] = tlabels
        elif label_type == 'class_label':
            self.labels[tindex] = tlabels

    def target_trans(self, y):
        if self.target_transform is not None:
            return self.target_transform(y)
        else:
            return y

    def input_trans(self, x):
        if self.transform is not None:
            return self.transform(x)
        else:
            return x

    def __getitem__(self, index):
        x = self.input_trans(self.x[index])
        
        ctarget = self.target_trans(self.labels[index])
        dtarget = self.target_trans(self.dlabels[index])
        pctarget = self.target_trans(self.pclabels[index])
        pdtarget = self.target_trans(self.pdlabels[index])
        # print(x.shape, ctarget.shape, dtarget.shape, pctarget.shape, pdtarget.shape, index) 
        return x, ctarget, dtarget, pctarget, pdtarget, index

    def __len__(self):
        return len(self.x)


class subdataset(mydataset):
    def __init__(self, args, dataset, indices):
        super(subdataset, self).__init__(args)
        self.x = dataset.x[indices]
        self.loader = dataset.loader
        self.labels = dataset.labels[indices]
        self.dlabels = dataset.dlabels[indices] if dataset.dlabels is not None else None
        self.pclabels = dataset.pclabels[indices] if dataset.pclabels is not None else None
        self.pdlabels = dataset.pdlabels[indices] if dataset.pdlabels is not None else None
        self.task = dataset.task
        self.dataset = dataset.dataset
        self.transform = dataset.transform
        self.target_transform = dataset.target_transform


class combindataset(mydataset):
    def __init__(self, args, datalist):
        super(combindataset, self).__init__(args)
        self.domain_num = len(datalist)
        self.loader = datalist[0].loader
        xlist = [item.x for item in datalist]
        cylist = [item.labels for item in datalist]
        dylist = [item.dlabels for item in datalist]
        pcylist = [item.pclabels for item in datalist]
        pdylist = [item.pdlabels for item in datalist]
        self.dataset = datalist[0].dataset
        self.task = datalist[0].task
        self.transform = datalist[0].transform
        self.target_transform = datalist[0].target_transform
        self.x = torch.vstack(xlist)

        self.labels = np.hstack(cylist)
        self.dlabels = np.hstack(dylist)
        self.pclabels = np.hstack(pcylist) if pcylist[0] is not None else None
        self.pdlabels = np.hstack(pdylist) if pdylist[0] is not None else None


class UCIHARDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        domain_0 = np.load(os.path.join(root_path+dataset_name, 'ucihar_domain_0_wd.data'), allow_pickle = True)
        domain_1 = np.load(os.path.join(root_path+dataset_name, 'ucihar_domain_1_wd.data'), allow_pickle = True)
        domain_2 = np.load(os.path.join(root_path+dataset_name, 'ucihar_domain_2_wd.data'), allow_pickle = True)
        domain_3 = np.load(os.path.join(root_path+dataset_name, 'ucihar_domain_3_wd.data'), allow_pickle = True)
        domain_4 = np.load(os.path.join(root_path+dataset_name, 'ucihar_domain_4_wd.data'), allow_pickle = True)
        domains = [domain_0, domain_1, domain_2, domain_3, domain_4]
        target_domain = args.test_envs[0]
        test = domains[target_domain][0]
        
        self.feature_df = test[0]
        self.labels_df = test[1]
        self.domain_df = test[2]

        self.max_seq_len = self.feature_df.shape[1]
        self.class_names = np.unique(self.labels_df)
        if flag == "TRAIN":
            train_list = [domains[i][0][0] for i in range(len(domains)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][0][1] for i in range(len(domains)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0) 
            train_domain = [domains[i][0][2] for i in range(len(domains)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        
        self.x = torch.from_numpy(self.feature_df).permute(0, 2, 1).unsqueeze(2)
        self.labels = torch.from_numpy(self.labels_df)
        self.dlabels = torch.from_numpy(self.domain_df)
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None
    
    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 100.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

    def normalize(self, x):
        mean = x.mean(dim=1, keepdim=True)
        std = x.std(dim=1, keepdim=True)
        x = (x - mean) / std
        return x


class SHARDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        domain_0 = np.load(os.path.join(root_path+dataset_name, 'shar_domain_1_wd.data'), allow_pickle = True)
        domain_1 = np.load(os.path.join(root_path+dataset_name, 'shar_domain_2_wd.data'), allow_pickle = True)
        domain_2 = np.load(os.path.join(root_path+dataset_name, 'shar_domain_3_wd.data'), allow_pickle = True)
        domain_3 = np.load(os.path.join(root_path+dataset_name, 'shar_domain_5_wd.data'), allow_pickle = True)
        domains = [domain_0, domain_1, domain_2, domain_3]
        target_domain = args.test_envs[0]
        test = domains[target_domain][0]
        
        self.feature_df = test[0]
        self.labels_df = test[1]
        self.domain_df = test[2]

        if flag == "TRAIN":
            train_list = [domains[i][0][0] for i in range(len(domains)) if i != target_domain]
            self.feature_df = np.concatenate(train_list, axis=0)
            train_label = [domains[i][0][1] for i in range(len(domains)) if i != target_domain]
            self.labels_df = np.concatenate(train_label, axis=0)
            train_domain = [domains[i][0][2] for i in range(len(domains)) if i != target_domain]
            self.domain_df = np.concatenate(train_domain, axis=0)
        self.feature_df = self.feature_df.reshape(-1, 151, 3)
        self.x = torch.from_numpy(self.feature_df).permute(0, 2, 1).unsqueeze(2)
        self.labels = torch.from_numpy(self.labels_df)
        self.dlabels = torch.from_numpy(self.domain_df)
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None

    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights

class OPPORTUNITYDataset(mydataset):
    def __init__(self, args, root_path, flag):
        dataset_name = args.dataset
        self.SLIDING_WINDOW_LEN = 30
        self.SLIDING_WINDOW_STEP = 15
        self.NUM_FEAFRUES = 77
        domain_0 = np.load(os.path.join(root_path+dataset_name, 'oppor_domain_S1_wd.data'), allow_pickle=True)
        domain_1 = np.load(os.path.join(root_path+dataset_name,'oppor_domain_S2_wd.data'), allow_pickle=True)
        domain_2 = np.load(os.path.join(root_path+dataset_name, 'oppor_domain_S3_wd.data'), allow_pickle=True)
        domain_3 = np.load(os.path.join(root_path+dataset_name, 'oppor_domain_S4_wd.data'), allow_pickle=True)
        domains = [domain_0, domain_1, domain_2, domain_3]
        target_domain = args.test_envs[0]
        test = domains[target_domain][0]
        x = test[0]
        y = test[1]
        d = test[2]
        x_win, y_win, d_win = self.opp_sliding_window_w_d(x,y,d,self.SLIDING_WINDOW_LEN, self.SLIDING_WINDOW_STEP)
        self.feature_df = x_win
        self.labels_df = y_win
        self.domain_df = d_win

        if flag == "TRAIN":
            train_list = []
            train_label = []
            train_domain = []
            for i in range(len(domains)):
                if i == target_domain:
                    continue
                x, y, d = domains[i][0][0], domains[i][0][1], domains[i][0][2]
                x_win, y_win, d_win = self.opp_sliding_window_w_d(x,y,d,self.SLIDING_WINDOW_LEN, self.SLIDING_WINDOW_STEP)
                train_list.append(x_win)
                train_label.append(y_win)
                train_domain.append(d_win)
            self.feature_df = np.concatenate(train_list, axis=0)
            self.labels_df = np.concatenate(train_label, axis=0)
            self.domain_df = np.concatenate(train_domain, axis=0)

        self.x = torch.from_numpy(self.feature_df).permute(0, 2, 1).unsqueeze(2)
        self.labels = torch.from_numpy(self.labels_df)
        self.dlabels = torch.from_numpy(self.domain_df)
        self.pclabels = np.ones(self.labels.shape)*(-1)
        self.pdlabels = np.ones(self.labels.shape)*(0)
        self.target_transform = None
        self.transform = None

    def opp_sliding_window_w_d(self, data_x, data_y, d, ws, ss): # window size, step size
        data_x = sliding_window.sliding_window(data_x,(ws,data_x.shape[1]),(ss,1))
        data_y = np.asarray([[i[-1]] for i in sliding_window.sliding_window(data_y,ws,ss)])
        data_d = np.asarray([[i[-1]] for i in sliding_window.sliding_window(d, ws, ss)])
        return data_x.astype(np.float32), data_y.reshape(len(data_y)).astype(np.uint8), data_d.reshape(len(data_d)).astype(np.uint8)

    def get_sample_weights(self):
        y = self.labels_df.numpy()
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        label_unique = np.unique(y)
        sample_weights = []
        for val in y:
            idx = np.where(label_unique == val)
            sample_weights.append(weights[idx])
        return sample_weights
    

class OneDomainDataset(Dataset):
    def __init__(self, samples, labels, domains):
        self.samples = samples
        self.labels = labels
        self.domains = domains

    def __getitem__(self, index):
        sample, target, domain = self.samples[index], self.labels[index], self.domains[index]
        return sample, target, domain

    def __len__(self):
        return len(self.samples)
    
class UCIHAROneDomainDataset(Dataset):
    def __init__(self, sample, labels, domains, t):
        self.samples = sample
        self.labels = labels
        self.domains = domains
        self.t = t
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        label = self.labels[idx]
        domain = self.domains[idx]
        if self.t is not None:
            sample = self.t(sample)
        return np.transpose(sample, (0,2,1)),label, domain



def load_domain_data(dataset_name, domain_idx):
    """ to load all the data from the specific domain
    :param domain_idx:
    :return: X and y data of the entire domain
    """
    data_dir = None
    saved_filename = None
    if dataset_name == 'UCIHAR':
        data_dir = '../../data/UCIHAR/'
        saved_filename = 'ucihar_domain_' + domain_idx + '_wd.data'
    elif dataset_name == 'OPP':
        data_dir = '../../data/OPP/'
        saved_filename = 'oppor_domain_' + domain_idx + '_wd.data'
    elif dataset_name == 'SHAR':
        data_dir = '../../data/SHAR/'
        saved_filename = 'shar_domain_' + domain_idx + '_wd.data'
    else:
        print('dataset not found')
        return None, None, None
    
    if os.path.isfile(data_dir + saved_filename) == True:
        data = np.load(data_dir + saved_filename, allow_pickle=True)
        X = data[0][0]
        y = data[0][1]
        d = data[0][2]
        return X, y, d
    else:
        print('file not found')
        return None, None, None

def get_sample_weights(y, weights):
    '''
    to assign weights to each sample
    '''
    label_unique = np.unique(y)
    sample_weights = []
    for val in y:
        idx = np.where(label_unique == val)
        sample_weights.append(weights[idx])
    return sample_weights

def prep_domains_shar(args, SLIDING_WINDOW_LEN=0, SLIDING_WINDOW_STEP=0):
    source_domain_list = ['1', '2', '3', '5']
    target_domain = args.test_envs[0]
    target_domain_name = source_domain_list[target_domain]
    source_domain_list.remove(target_domain_name)

    # source domain data prep
    source_datasets = []
    sample_weights = []
    for source_domain in source_domain_list:
        print('source_domain:', source_domain)
        x, y, d = load_domain_data(args.dataset, source_domain)
        x = np.transpose(x.reshape((-1, 1, SLIDING_WINDOW_LEN, 3)), (0, 3, 1, 2))

        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 1000.0 / torch.Tensor(counts_y)
        weights = weights.double()
        weights = get_sample_weights(y, weights)
        sample_weights.append(weights)
        data_set = OneDomainDataset(x, y, d)
        source_datasets.append(data_set)

    # target domain data prep
    print('target_domain:', target_domain)
    x, y, d = load_domain_data(args.dataset, target_domain_name)
    x = np.transpose(x.reshape((-1, 1, SLIDING_WINDOW_LEN, 3)), (0, 3, 1, 2))
    target_dataset = OneDomainDataset(x, y, d)
    return source_datasets, target_dataset, sample_weights


def prep_domains_ucihar(args, SLIDING_WINDOW_LEN=0, SLIDING_WINDOW_STEP=0):
    source_domain_list = ['0', '1', '2', '3', '4']
    target_domain = args.test_envs[0]
    target_domain_name = source_domain_list[target_domain]
    source_domain_list.remove(target_domain_name)

    # source domain data prep
    source_datasets = []
    sample_weights = []
    for source_domain in source_domain_list:
        print('source_domain:', source_domain)
        x, y, d = load_domain_data(args.dataset,source_domain)
        x = np.transpose(x.reshape((-1, 1, SLIDING_WINDOW_LEN, 9)), (0, 2, 1, 3))
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0, 0, 0, 0, 0, 0, 0, 0, 0), std=(1, 1, 1, 1, 1, 1, 1, 1, 1))
        ])
        
        unique_y, counts_y = np.unique(y, return_counts=True)
        weights = 100.0 / torch.Tensor(counts_y)
        weights = weights.double()
        weights = get_sample_weights(y, weights)
        sample_weights.append(weights)
        data_set = UCIHAROneDomainDataset(x, y, d, t = transform)
        source_datasets.append(data_set)


    # target domain data prep
    print('target_domain:', target_domain)
    x, y, d = load_domain_data(args.dataset, target_domain_name)
    x = np.transpose(x.reshape((-1, 1, SLIDING_WINDOW_LEN, 9)), (0, 2, 1, 3))
    target_dataset = UCIHAROneDomainDataset(x, y, d, t = transform)
    return source_datasets, target_dataset, sample_weights

def opp_sliding_window_w_d(data_x, data_y, d, ws, ss): # window size, step size
    data_x = sliding_window.sliding_window(data_x,(ws,data_x.shape[1]),(ss,1))
    data_y = np.asarray([[i[-1]] for i in sliding_window.sliding_window(data_y,ws,ss)])
    data_d = np.asarray([[i[-1]] for i in sliding_window.sliding_window(d, ws, ss)])
    return data_x.astype(np.float32), data_y.reshape(len(data_y)).astype(np.uint8), data_d.reshape(len(data_d)).astype(np.uint8)


def prep_domains_oppor(args, SLIDING_WINDOW_LEN=0, SLIDING_WINDOW_STEP=0):
    source_domain_list = ['S1', 'S2', 'S3', 'S4']
    target_domain = args.test_envs[0]
    target_domain_name = source_domain_list[target_domain]
    source_domain_list.remove(target_domain_name)

    # source domain data prep
    source_datasets = []
    sample_weights = []
    for source_domain in source_domain_list:
        print('source_domain:', source_domain)
        x, y, d = load_domain_data(args.dataset, source_domain)
        x_win, y_win, d_win = opp_sliding_window_w_d(x, y, d, SLIDING_WINDOW_LEN, SLIDING_WINDOW_STEP)
        unique_y, counts_y = np.unique(y_win, return_counts=True)
        weights = 100.0 / torch.Tensor(counts_y)
        weights = weights.double()
        weights = get_sample_weights(y_win, weights)
        sample_weights.append(weights)
        x_win = np.transpose(x_win.reshape((-1, 1, SLIDING_WINDOW_LEN, 77)), (0, 3, 1, 2))
        data_set = OneDomainDataset(x_win, y_win, d_win)
        source_datasets.append(data_set)

    # target domain data prep
    print('target_domain:', target_domain)
    x, y, d = load_domain_data(args.dataset, target_domain_name)
    x_win, y_win, d_win = opp_sliding_window_w_d(x, y, d, SLIDING_WINDOW_LEN, SLIDING_WINDOW_STEP)
    x_win = np.transpose(x_win.reshape((-1, 1, SLIDING_WINDOW_LEN, 77)), (0, 3, 1, 2))
    target_dataset = OneDomainDataset(x_win, y_win, d_win)
    
    return source_datasets, target_dataset, sample_weights