#coding=utf-8
import numpy as np
from numpy.ma import indices
from numpy.random import shuffle
import sklearn.model_selection as ms
from torch.utils.data import DataLoader

import datautil.actdata.util as actutil
from datautil.util import combindataset,subdataset

import datautil.actdata.cross_dataset as cross_dataset
import datautil.actdata.cross_people as cross_people
import datautil.actdata.cross_position as cross_position

task_act={'cross_dataset':cross_dataset,'cross_people':cross_people,'cross_position':cross_position}

def get_dataloader(args,tr,val,tar,type='common'):
    if type=='common':
        train_loader =DataLoader(dataset=tr,batch_size=args.batch_size,num_workers=args.N_WORKERS,drop_last=False,shuffle=True)
    else:
        train_loader =DataLoader(dataset=tr,batch_size=args.batch_size*args.latent_domain_num,num_workers=args.N_WORKERS,drop_last=False,shuffle=True)
    train_loader_noshuffle=DataLoader(dataset=tr,batch_size=args.batch_size,num_workers=args.N_WORKERS,drop_last=False,shuffle=False)
    valid_loader =DataLoader(dataset=val,batch_size=args.batch_size,num_workers=args.N_WORKERS,drop_last=False,shuffle=False)
    target_loader =DataLoader(dataset=tar,batch_size=args.batch_size,num_workers=args.N_WORKERS,drop_last=False,shuffle=False)
    return train_loader,train_loader_noshuffle,valid_loader,target_loader

def get_act_dataloader(args):
    source_datasetlist=[]
    target_datalist=[]
    pcross_act=task_act[args.task]

    if args.task=='cross_dataset':
        args.domain_num=len(args.act_dataset)
        for i in range(len(args.act_dataset)):
            tdata=pcross_act.ActList(args,args.act_dataset[i],args.data_dir,i,transform=actutil.act_train())
            if i in args.test_envs:
                target_datalist.append(tdata)
            else:
                source_datasetlist.append(tdata)
                if len(tdata)/args.batch_size < args.steps_per_epoch:
                    args.steps_per_epoch=len(tdata)/args.batch_size
    else:
        if args.task=='cross_people':
            tmpp=args.act_people[args.dataset]
        elif args.task=='cross_position':
            tmpp=args.act_positon[args.dataset]
        args.domain_num=len(tmpp)
        for i,item in enumerate(tmpp):
            tdata=pcross_act.ActList(args,args.dataset,args.data_dir,item,i,transform=actutil.act_train())
            if i in args.test_envs:
                target_datalist.append(tdata)
            else:
                source_datasetlist.append(tdata)
                if len(tdata)/args.batch_size < args.steps_per_epoch:
                    args.steps_per_epoch=len(tdata)/args.batch_size
    rate=0.2
    args.steps_per_epoch=int(args.steps_per_epoch*(1-rate))
    tdata=combindataset(args,source_datasetlist)
    # tr,val=split_trian_val_test(tdata,seed=args.seed)
    l=len(tdata.labels)
    indexall=np.arange(l)
    np.random.seed(args.seed)
    np.random.shuffle(indexall)
    ted=int(l*rate)
    indextr,indexval=indexall[ted:],indexall[:ted]
    tr=subdataset(args,tdata,indextr)
    val=subdataset(args,tdata,indexval)
    targetdata=combindataset(args,target_datalist)
    train_loader,train_loader_noshuffle,valid_loader,target_loader=get_dataloader(args,tr,val,targetdata)
    return train_loader,train_loader_noshuffle,valid_loader,target_loader,tr,val,targetdata
