# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from woods.actdata.util import mydataset, loaddata_from_numpy, Nmax
import numpy as np
import torch


class ActList(mydataset):
    def __init__(self, target_domain, domain_num, dataset, root_dir, people_group, group_num, transform=None, target_transform=None, pclabels=None, pdlabels=None, shuffle_grid=True):
        super(ActList, self).__init__(domain_num)
        self.domain_num = 0
        self.dataset = dataset
        self.task = 'cross_people'
        self.transform = transform
        self.target_transform = target_transform
        x, cy, py, sy = loaddata_from_numpy(self.dataset, self.task, root_dir)
        self.people_group = people_group
        self.position = np.sort(np.unique(sy))
        self.comb_position(x, cy, py, sy)
        self.x = self.x[:, :, np.newaxis, :]
        self.transform = None
        self.x = torch.tensor(self.x).float()
        if pclabels is not None:
            self.pclabels = pclabels
        else:
            self.pclabels = np.ones(self.labels.shape)*(-1)
        if pdlabels is not None:
            self.pdlabels = pdlabels
        else:
            self.pdlabels = np.ones(self.labels.shape)*(0)
        self.tdlabels = np.ones(self.labels.shape)*group_num
        self.dlabels = np.ones(self.labels.shape) * \
            (group_num-Nmax(domain_num, target_domain, group_num))

    def comb_position(self, x, cy, py, sy):
        for i, peo in enumerate(self.people_group):
            index = np.where(py == peo)[0]
            tx, tcy, tsy = x[index], cy[index], sy[index]
            for j, sen in enumerate(self.position):
                index = np.where(tsy == sen)[0]
                if j == 0:
                    ttx, ttcy = tx[index], tcy[index]
                else:
                    ttx = np.hstack((ttx, tx[index]))
            if i == 0:
                self.x, self.labels = ttx, ttcy
            else:
                self.x, self.labels = np.vstack(
                    (self.x, ttx)), np.hstack((self.labels, ttcy))

    def set_x(self, x):
        self.x = x
