#coding=utf-8

from collections import Counter
from datautil.actdata.util import *
from datautil.util import mydataset,Nmax
import numpy as np
###people_group
class ActList(mydataset):
    def __init__(self,args, dataset,root_dir, people_group,group_num, transform=None, target_transform=None,pclabels=None,pdlabels=None,shuffle_grid=True):
        super(ActList,self).__init__(args)
        self.domain_num=0
        self.dataset=dataset
        self.task='cross_people'
        self.transform = transform
        self.target_transform = target_transform
        # print(people_group)
        x,cy,py,sy=loaddata_from_numpy(self.dataset,self.task,root_dir)
        self.people_group=people_group
        self.position=np.sort(np.unique(sy))
        # print(x.shape)
        self.comb_position(x,cy,py,sy)
        # print(self.x.shape)
        self.x=self.x[:,:,np.newaxis,:]
        self.transform=None
        # print(dataset)
        # print(self.x.shape)
        # print(self.labels.shape)
        self.x=torch.tensor(self.x).float()
        # print(self.x.shape)
        if pclabels is not None:
            self.pclabels=pclabels
        else:
            self.pclabels=np.ones(self.labels.shape)*(-1)
        if pdlabels is not None:
            self.pdlabels=pdlabels
        else:
            self.pdlabels=np.ones(self.labels.shape)*(0)
        self.tdlabels=np.ones(self.labels.shape)*group_num
        self.dlabels=np.ones(self.labels.shape)*(group_num-Nmax(args,group_num))
        self.x1=load_shuffle(self.x,shuffle_grid,args.grid_size)
        self.x2=load_shuffle(self.x,shuffle_grid,args.grid_size)

        # print(self.x.shape)
        # print(self.dlabels[0])
        # print(Counter(self.labels))
        # print(len(self.labels))

    # def __getitem__(self,index):
    #     x = self.input_trans(self.x[index])
    #     ctarget=self.target_trans(self.labels[index])
    #     dtarget=self.target_trans(self.dlabels[index])
    #     pctarget=self.target_trans(self.pclabels[index])
    #     pdtarget=self.target_trans(self.pdlabels[index])
    #     x1=self.input_trans(self.x1[index])
    #     x2=self.input_trans(self.x2[index])
    #     return x,ctarget,dtarget,pctarget,pdtarget,index,x1,x2

    def comb_position(self,x,cy,py,sy):
        for i,peo in enumerate(self.people_group):
            index = np.where(py==peo)[0]
            tx,tcy,tsy=x[index],cy[index],sy[index]
            for j,sen in enumerate(self.position):
                index=np.where(tsy==sen)[0]
                if j==0:
                    ttx,ttcy=tx[index],tcy[index]
                else:
                    ttx=np.hstack((ttx,tx[index]))#,np.hstack((ttcy,tcy[index]))
            if i==0:
                self.x,self.labels=ttx,ttcy
            else:
                self.x,self.labels=np.vstack((self.x,ttx)),np.hstack((self.labels,ttcy))

    def set_x(self,x):
        self.x=x
