#coding=utf-8

from datautil.actdata.util import *
from datautil.util import mydataset,Nmax
import numpy as np
class ActList(mydataset):
    def __init__(self, args,dataset,root_dir, position_group,group_num, transform=None, target_transform=None,pclabels=None,pdlabels=None,shuffle_grid=True):
        super(ActList,self).__init__(args)
        self.domain_num=0
        self.dataset=dataset
        self.task='cross-position'
        self.transform = transform
        self.target_transform = target_transform
        x,cy,py,sy=loaddata_from_numpy(self.dataset,self.task,root_dir)
        self.position_group=position_group
        self.split_via_position(x,cy,sy)
        self.x=self.x[:,:,np.newaxis,:]
        self.transform=None
        self.x=torch.tensor(self.x).float()
        if pclabels is not None:
            self.pclabels=pclabels
        else:
            self.pclabels=np.ones(self.labels.shape)*(-1)
        if pdlabels is not None:
            self.pdlabels=pdlabels
        else:
            self.pdlabels=np.ones(self.labels.shape)*(0)
        self.tdlabels=np.ones(self.labels.shape)*group_num
        self.dlabels=np.ones(self.labels.shape)*(group_num-Nmax(args,group_num))
        self.x1=load_shuffle(self.x,shuffle_grid,args.grid_size)
        self.x2=load_shuffle(self.x,shuffle_grid,args.grid_size)

    # def __getitem__(self,index):
    #     x = self.input_trans(self.x[index])
    #     ctarget=self.target_trans(self.labels[index])
    #     dtarget=self.target_trans(self.dlabels[index])
    #     pctarget=self.target_trans(self.pclabels[index])
    #     pdtarget=self.target_trans(self.pdlabels[index])
    #     x1=self.input_trans(self.x1[index])
    #     x2=self.input_trans(self.x2[index])
    #     return x,ctarget,dtarget,pctarget,pdtarget,index,x1,x2

    def split_via_position(self,x,cy,sy):
        for i,sen in enumerate(self.position_group):
            index = np.where(sy==sen)[0]
            if i==0:
                self.x,self.labels=x[index],cy[index]
            else:
                self.x,self.labels=np.vstack((self.x,x[index])),np.hstack((self.labels,cy[index]))

    def set_x(self,x):
        self.x=x
