# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# coding=utf-8
import numpy as np
import copy
from torch.utils.data import DataLoader
import torch
import hashlib
import datautil.actdata.util as actutil
from datautil.util import combindataset, subdataset, UCIHARDataset, SHARDataset, OPPORTUNITYDataset, prep_domains_shar, prep_domains_ucihar, prep_domains_oppor
from datautil.woods_datasets import PCLDataset, HHARDataset, SpuriousFourierDataset
from datautil.DGloader import InfiniteDataLoader
import datautil.actdata.cross_people as cross_people
import sklearn.model_selection as ms

def seed_hash(*args):
    """
    Derive an integer hash from all args, for use as a random seed.

    This is took from DomainBed repository:
        https://github.com/facebookresearch/DomainBed
    """
    args_str = str(args)
    return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)

def get_split(dataset, holdout_fraction, seed=0, sort=False):
    """ Generates the keys that are used to split a Torch TensorDataset into (1-holdout_fraction) / holdout_fraction.

    Args:
        dataset (TensorDataset): TensorDataset to be split
        holdout_fraction (float): Fraction of the dataset that is gonna be in the out (validation) set
        seed (int, optional): seed used for the shuffling of the data before splitting. Defaults to 0.
        sort (bool, optional): If ''True'' the dataset is gonna be sorted after splitting. Defaults to False.

    Returns:
        list: in (1-holdout_fraction) keys of the split
        list: out (holdout_fraction) keys of the split
    """

    split = int(len(dataset)*holdout_fraction)

    keys = list(range(len(dataset)))
    np.random.RandomState(seed).shuffle(keys)
    
    in_keys = keys[split:]
    out_keys = keys[:split]
    if sort:
        in_keys.sort()
        out_keys.sort()

    return in_keys, out_keys

def make_split(dataset, holdout_fraction, seed=0, sort=False):
    """ Split a Torch TensorDataset into (1-holdout_fraction) / holdout_fraction.

    Args:
        dataset (TensorDataset): Tensor dataset that has 2 tensors -> data, targets
        holdout_fraction (float): Fraction of the dataset that is gonna be in the validation set
        seed (int, optional): seed used for the shuffling of the data before splitting. Defaults to 0.
        sort (bool, optional): If ''True'' the dataset is gonna be sorted after splitting. Defaults to False.

    Returns:
        TensorDataset: 1-holdout_fraction part of the split
        TensorDataset: holdout_fractoin part of the split
    """

    in_keys, out_keys = get_split(dataset, holdout_fraction, seed=seed, sort=sort)

    in_split = dataset[in_keys]
    out_split = dataset[out_keys]

    return torch.utils.data.TensorDataset(*in_split), torch.utils.data.TensorDataset(*out_split)


task_act = {'cross_people': cross_people}

def get_dataloaders(args):
    sourcesdatasets = []
    targetdataset = []
    evaldatasets = []
    eval_domains = []
    train_loaders = None
    eval_loaders = None
    target_loaders = None
    weights = None
    if args.dataset == "UCIHAR":
        sourcesdatasets, targetdataset, weights = prep_domains_ucihar(args, SLIDING_WINDOW_LEN=128, SLIDING_WINDOW_STEP=int(0.5*128))
    elif args.dataset == "SHAR":
        sourcesdatasets, targetdataset, weights = prep_domains_shar(args, SLIDING_WINDOW_LEN=151, SLIDING_WINDOW_STEP=int(0.5*151))
    elif args.dataset == "OPP":
        sourcesdatasets, targetdataset, weights = prep_domains_oppor(args, SLIDING_WINDOW_LEN=30, SLIDING_WINDOW_STEP=int(0.5*30))
    else:
        pcross_act = task_act[args.task]
        tmpp = args.act_people[args.dataset]
        args.domain_num = len(tmpp)
        rate = 0.2
        for i, item in enumerate(tmpp):
            tdata = pcross_act.ActList(
                args, args.dataset, args.data_dir, item, i, transform=actutil.act_train())
            if i in args.test_envs:
                targetdataset.append(tdata)
            else:
                tmpdatay = tdata.labels
                l = len(tmpdatay)
                if args.split_style == 'strat':
                    lslist = np.arange(l)
                    stsplit = ms.StratifiedShuffleSplit(
                        2, test_size=rate, train_size=1-rate, random_state=args.seed)
                    stsplit.get_n_splits(lslist, tmpdatay)
                    indextr, indexval = next(stsplit.split(lslist, tmpdatay))
                else:
                    indexall = np.arange(l)
                    np.random.seed(args.seed)
                    np.random.shuffle(indexall)
                    ted = int(l*rate)
                    indextr, indexval = indexall[:-ted], indexall[-ted:]

                train_dataset = subdataset(args, tdata, indextr)
                val_dataset = subdataset(args, tdata, indexval)
                sourcesdatasets.append(train_dataset)
                evaldatasets.append(val_dataset)
                eval_domains.append(i)

        
        targetdataset = combindataset(args, targetdataset)

        eval_loaders = {f"eval{domain}_out": DataLoader(dataset=env, batch_size=args.batch_size,
                                 num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
                      for domain, env in zip(eval_domains, evaldatasets)}

    if weights == None:
        train_loaders = [InfiniteDataLoader(
            dataset=env,
            batch_size=args.batch_size,
            num_workers=args.N_WORKERS)
            for env in sourcesdatasets]
    else:
        train_loaders = [InfiniteDataLoader(
            dataset=env,
            weights=weights,
            batch_size=args.batch_size,
            num_workers=args.N_WORKERS)
            for env, weights in zip(sourcesdatasets, weights)]
    
    target_loaders = [DataLoader(dataset=targetdataset, batch_size=args.batch_size,
                               num_workers=args.N_WORKERS, drop_last=False, shuffle=False)]
    
    return train_loaders, target_loaders, eval_loaders


def get_dataloader(args, tr, val, tar):
    drop_last = False
    if args.algorithm == "Mixup":
        drop_last = True
    train_loader = DataLoader(dataset=tr, batch_size=args.batch_size,
                              num_workers=args.N_WORKERS, drop_last=drop_last, shuffle=True)
    train_loader_noshuffle = DataLoader(
        dataset=tr, batch_size=args.batch_size, num_workers=args.N_WORKERS, drop_last=drop_last, shuffle=False)
    valid_loader = DataLoader(dataset=val, batch_size=args.batch_size,
                              num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
    target_loader = DataLoader(dataset=tar, batch_size=args.batch_size,
                               num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
    return train_loader, train_loader_noshuffle, valid_loader, target_loader


def get_act_dataloader(args):
    source_datasetlist = []
    target_datalist = []
    pcross_act = task_act[args.task]

    tmpp = args.act_people[args.dataset]
    args.domain_num = len(tmpp)

    if args.dataset == 'UCIHAR':
        train_dataset = UCIHARDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = UCIHARDataset(args, args.data_dir, flag='TEST')
        target_dataset = UCIHARDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    elif args.dataset == 'SHAR':
        train_dataset = SHARDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = SHARDataset(args, args.data_dir, flag='TEST')
        target_dataset = SHARDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    elif args.dataset == 'OPP':
        train_dataset = OPPORTUNITYDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = OPPORTUNITYDataset(args, args.data_dir, flag='TEST')
        target_dataset = OPPORTUNITYDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    elif args.dataset == 'PCL':
        train_dataset = PCLDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = PCLDataset(args, args.data_dir, flag='VAL')
        target_dataset = PCLDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    elif args.dataset == 'HHAR':
        train_dataset = HHARDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = HHARDataset(args, args.data_dir, flag='VAL')
        target_dataset = HHARDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    elif args.dataset == "Spurious_Fourier":
        train_dataset = SpuriousFourierDataset(args, args.data_dir, flag='TRAIN')
        valid_dataset = SpuriousFourierDataset(args, args.data_dir, flag='VAL')
        target_dataset = SpuriousFourierDataset(args, args.data_dir, flag='TEST')
        train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(args, train_dataset, valid_dataset, target_dataset)
        return train_loader, train_loader_noshuffle, valid_loader, target_loader, None, None, None
    
    for i, item in enumerate(tmpp):
        tdata = pcross_act.ActList(
            args, args.dataset, args.data_dir, item, i, transform=actutil.act_train())
        if i in args.test_envs:
            target_datalist.append(tdata)
        else:
            source_datasetlist.append(tdata)
            if len(tdata)/args.batch_size < args.steps_per_epoch:
                args.steps_per_epoch = len(tdata)/args.batch_size
    rate = 0.2
    args.steps_per_epoch = int(args.steps_per_epoch*(1-rate))
    tdata = combindataset(args, source_datasetlist)
    tmpdatay = tdata.labels
    l = len(tmpdatay)
    indexall = np.arange(l)
    np.random.seed(args.seed)
    np.random.shuffle(indexall)
    ted = int(l*rate)
    indextr, indexval = indexall[ted:], indexall[:ted]
    print('train label unique:', np.unique(tmpdatay[indextr]))
    print('val label unique:', np.unique(tmpdatay[indexval]))
    tr = subdataset(args, tdata, indextr)
    val = subdataset(args, tdata, indexval)
    targetdata = combindataset(args, target_datalist)
    train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(
        args, tr, val, targetdata)
    return train_loader, train_loader_noshuffle, valid_loader, target_loader, tr, val, targetdata


