#coding=utf-8
import os,sys
from datautil.actdata.util import *
from datautil.util import mydataset,Nmax
import numpy as np

class ActList(mydataset):
    def __init__(self,args,dataset,root_dir,domain_num, hz=25,win=2,transform=None, target_transform=None,pclabels=None,pdlabels=None,shuffle_grid=True):
        super(ActList,self).__init__(args)
        self.domain_num=0
        self.hz=hz
        self.win=win
        self.dataset=dataset
        self.task='cross_dataset'
        self.transform = transform
        self.target_transform = target_transform
        x,cy,_,sy=loaddata_from_numpy(self.dataset,self.task,root_dir)
        # print(len(x))
        x,cy=self.select_position_channel(args,x,cy,sy)
        # print(len(x))
        # print(x.shape,cy.shape)
        cy=self.map_label(args,cy)
        index=np.where(cy>=0)[0]
        # print(args.label_cor[self.dataset])
        x,cy=x[index],cy[index]
        # print(len(x))
        x=seq_downsapmle(x,args.hz_list[self.dataset],hz)
        self.x,self.labels=split_via_wind_1(x,cy,win*hz,int(win*hz/2))
        # print(len(self.x))
        self.x=self.x[:,:,np.newaxis,:]
        self.transform=None
        self.x=torch.tensor(self.x).float()
        # print(self.x.shape)
        # print(self.transform(self.x[0]).shape)
        if pclabels is not None:
            self.pclabels=pclabels
        else:
            self.pclabels=np.ones(self.labels.shape)*(-1)
        if pdlabels is not None:
            self.pdlabels=pdlabels
        else:
            self.pdlabels=np.ones(self.labels.shape)*(0)
        self.tdlabels=np.ones(self.labels.shape)*domain_num
        self.dlabels=np.ones(self.labels.shape)*(domain_num-Nmax(args,domain_num))
        self.x1=load_shuffle(self.x,shuffle_grid,args.grid_size)
        self.x2=load_shuffle(self.x,shuffle_grid,args.grid_size)

    # def __getitem__(self,index):
    #     x = self.input_trans(self.x[index])
    #     ctarget=self.target_trans(self.labels[index])
    #     dtarget=self.target_trans(self.dlabels[index])
    #     pctarget=self.target_trans(self.pclabels[index])
    #     pdtarget=self.target_trans(self.pdlabels[index])
    #     x1=self.input_trans(self.x1[index])
    #     x2=self.input_trans(self.x2[index])
    #     return x,ctarget,dtarget,pctarget,pdtarget,index,x1,x2

    def select_position_channel(self,args,x,cy,sy):
        sel_sen=args.select_position[self.dataset]
        sel_chn=np.array(args.select_channel[self.dataset])
        index=[]
        for item in sel_sen:
            index.append(np.where(sy==item)[0])
        index=np.hstack(index)
        # print(x.shape)
        # print(index.shape)
        # print(sel_chn.shape)
        return x[index][:,sel_chn,:],cy[index]

    def map_label(self,args,cy):
        map_l=args.label_cor[self.dataset]
        tcy=np.ones(cy.shape)*(-1)
        for i,item in enumerate(map_l):
            for c in item:
                index=np.where(cy==c)[0]
                tcy[index]=i
        return tcy

    def set_x(self,x):
        self.x=x


# root_dir='./data/act/'
# dataset=ActList('dsads',root_dir,0,transform=act_train())
# print(dataset.x.shape)