# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# coding=utf-8
import numpy as np
from torch.utils.data import DataLoader
import sklearn.model_selection as ms
import actdata.util as actutil
from actdata.util import combindataset, subdataset

import actdata.cross_people as cross_people

task_act = {'classification': cross_people}

act_people = {'EMG': [[i*9+j for j in range(9)]for i in range(4)],
                       'DSADS': [[0,1],[2,3],[4,5],[6,7]],
                       'PAMAP':[[2,3,8],[1,5],[0,7],[4,6]],
                       'USCHAD':[[0,1,2,11],[3,5,6,9],[7,8,10,13],[4,12]],
                       'WESAD': [[0,1,2,3],[4,5,6,7],[8,9,10,11],[12,13,14]],
                       'EEG': [[0,1,2,3,4],[5,6,7,8,9],[10,11,12,13,14],[15,16,17,18,19]]}
tmp = {'EMG': ((8, 1, 200), 6, 10),
        'DSADS': ((45,1,125),19,10),
        'PAMAP': ((27,1,200),18,10),
        'USCHAD': ((6,1,200),12,10),
        'WESAD': ((8,1,200), 4, 10),
        'EEG': ((1,1,1000), 1, 10)}

# emg 8,1,200 6 10
# dsads 45,1,125 19 10
# pamap 27,1,200 18 10
# uschad 6,1,200 12 10
# wesad 8,1,200 4 10
def act_param_init(args):
    args.select_position = {'EMG': [0], 'DSADS': [0], 'PAMAP':[0],'USCHAD':[0], 'WESAD':[0], 'EEG': [0]}
    args.select_channel = {'EMG': np.arange(8),
                           'DSADS': np.arange(45), 
                           'PAMAP':np.arange(27), 
                           'USCHAD':np.arange(6),
                           'WESAD': np.arange(8),
                           'EEG': np.arange(1)}
    args.hz_list = {'EMG': 1000,'DSADS':1000,'PAMAP':1000,'USCHAD':1000, 'WESAD':1000, 'EEG': 1000}
    
    args.num_classes, args.input_shape, args.grid_size = tmp[
        args.dataset][1], tmp[args.dataset][0], tmp[args.dataset][2]

    return args


def get_dataloader(args, tr, val, tar):
    train_loader = DataLoader(dataset=tr, batch_size=args.batch_size,
                              num_workers=args.num_workers, drop_last=False, shuffle=True)
    train_loader_noshuffle = DataLoader(
        dataset=tr, batch_size=args.batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False)
    valid_loader = DataLoader(dataset=val, batch_size=args.batch_size,
                              num_workers=args.num_workers, drop_last=False, shuffle=False)
    target_loader = DataLoader(dataset=tar, batch_size=args.batch_size,
                               num_workers=args.num_workers, drop_last=False, shuffle=False)
    return train_loader, train_loader_noshuffle, valid_loader, target_loader


def get_act_data(args):
    steps_per_epoch = 10000000000
    source_datasetlist = []
    target_datalist = []
    pcross_act = task_act[args.task_name]

    tmpp = act_people[args.data]
    # args.domain_num = len(tmpp)
    for i, item in enumerate(tmpp):
        tdata = pcross_act.ActList(
            args, args.data, args.root_path, item, i, transform=actutil.act_train())
        if i == args.target_domain:
            target_datalist.append(tdata)
        else:
            source_datasetlist.append(tdata)
            if len(tdata)/args.batch_size < steps_per_epoch:
                steps_per_epoch = len(tdata)/args.batch_size
    rate = 0.2
    tdata = combindataset(args, source_datasetlist)
    tmpdatay = tdata.labels
    l = len(tmpdatay)
    indexall = np.arange(l)
    np.random.seed(args.seed)
    np.random.shuffle(indexall)
    ted = int(l*rate)
    indextr, indexval = indexall[ted:], indexall[:ted]
    train_dataset = subdataset(args, tdata, indextr)
    val_dataset = subdataset(args, tdata, indexval)
    test_dataset = combindataset(args, target_datalist)
    
    return train_dataset, val_dataset, test_dataset
