# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

# coding=utf-8
import numpy as np
from torch.utils.data import DataLoader

import woods.actdata.util as actutil
from woods.actdata.util import combindataset, subdataset

import woods.actdata.cross_people as cross_people

task_act = {'classification': cross_people}

act_people = {'EMG': [[i*9+j for j in range(9)]for i in range(4)],
                       'DSADS': [[0,1],[2,3],[4,5],[6,7]],
                       'PAMAP':[[2,3,8],[1,5],[0,7],[4,6]],
                       'USCHAD':[[0,1,2,11],[3,5,6,9],[7,8,10,13],[4,12]]}
tmp = {'EMG': ((8, 1, 200), 6, 10),
        'DSADS': ((45,1,125),19,10),
        'PAMAP': ((27,1,200),18,10),
        'USCHAD': ((6,1,200),12,10)}

# emg 8,1,200 6 10
# dsads 45,1,125 19 10
# pamap 27,1,200 18 10
# uschad 6,1,200 12 10
def act_param_init(args):
    args.select_position = {'EMG': [0], 'DSADS': [0], 'PAMAP':[0],'USCHAD':[0]}
    args.select_channel = {'EMG': np.arange(8),
                           'DSADS': np.arange(45), 
                           'PAMAP':np.arange(27), 
                           'USCHAD':np.arange(6)}
    args.hz_list = {'EMG': 1000,'DSADS':1000,'PAMAP':1000,'USCHAD':1000}
    
    args.num_classes, args.input_shape, args.grid_size = tmp[
        args.dataset][1], tmp[args.dataset][0], tmp[args.dataset][2]

    return args


def get_dataloader(args, tr, val, tar):
    train_loader = DataLoader(dataset=tr, batch_size=args.batch_size,
                              num_workers=args.N_WORKERS, drop_last=False, shuffle=True)
    train_loader_noshuffle = DataLoader(
        dataset=tr, batch_size=args.batch_size, num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
    valid_loader = DataLoader(dataset=val, batch_size=args.batch_size,
                              num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
    target_loader = DataLoader(dataset=tar, batch_size=args.batch_size,
                               num_workers=args.N_WORKERS, drop_last=False, shuffle=False)
    return train_loader, train_loader_noshuffle, valid_loader, target_loader


def get_act_data(target_domain, name, dataset_name, root_path):
    steps_per_epoch = 10000000000
    source_datasetlist = []
    target_datalist = []
    pcross_act = task_act[name]

    tmpp = act_people[dataset_name]
    domain_num = len(tmpp)
    for i, item in enumerate(tmpp):
        tdata = pcross_act.ActList(
            target_domain, domain_num, dataset_name, root_path, item, i, transform=actutil.act_train())
        # if i == args.target_domain:
        #     target_datalist.append(tdata)
        # else:
        source_datasetlist.append(tdata)
            # if len(tdata)/args.batch_size < steps_per_epoch:
            #     steps_per_epoch = len(tdata)/args.batch_size
    # rate = 0.2
    # # args.steps_per_epoch = int(steps_per_epoch*(1-rate))
    # tdata = combindataset(args, source_datasetlist)
    # l = len(tdata.labels)
    # indexall = np.arange(l)
    # np.random.seed(seed)
    # np.random.shuffle(indexall)
    # ted = int(l*rate)
    # indextr, indexval = indexall[ted:], indexall[:ted]
    # train_dataset = subdataset(args, tdata, indextr)
    # val_dataset = subdataset(args, tdata, indexval)
    # test_dataset = combindataset(args, target_datalist)
    # train_loader, train_loader_noshuffle, valid_loader, target_loader = get_dataloader(
    #     args, tr, val, targetdata)
    return source_datasetlist
