import torch
from torch.utils.data import DataLoader, Dataset
from glob import glob
import os
import numpy as np

class BuildModel(Dataset):
    def __init__(self, mu_root, activate_root, risk_label_root):
        """

        Args:
            root:  the root of mu.pth and activate.pth
        """
        super(BuildModel, self).__init__()

        risk_label = []
        mu = []
        activate_vectors = []

        mu_path_name = sorted(glob(os.path.join(mu_root, '*_mu.pth')))
        activate_vectors_name = sorted(glob(os.path.join(activate_root, '*_activate.pth')))

        if risk_label_root == '':
            self.risk_label = None
            for m, a in zip(mu_path_name, activate_vectors_name):
                mu.append(torch.load(m))
                activate_vectors.append(torch.load(a))

        else:
            risk_label_name = sorted(glob(os.path.join(risk_label_root, '*_risk_label.pth')))

            for m, a, r in zip(mu_path_name, activate_vectors_name, risk_label_name):
                mu.append(torch.load(m))
                activate_vectors.append(torch.load(a))
                risk_label.append(torch.load(r))

            self.risk_label = torch.stack(risk_label, dim=-1)
        # print(self.preference.shape, self.mu.shape, self.activate_vectors.shape)

        self.activate_vectors = torch.stack(activate_vectors, dim=-1)
        self.mu = torch.stack(mu, dim=-1)


    def __getitem__(self, index):
        """

        Returns: Mean with the shape: (B, n, n_id). n means.
                Activate vectors: (B, n, n_id)

        """
        activate_vectors = self.activate_vectors[index]
        if self.risk_label is not None:
            y = self.risk_label[index]
            return self.mu, activate_vectors, y
        else:
            return self.mu, activate_vectors

    def __len__(self):
        return len(self.activate_vectors)

    def get_dim(self):
        return self.activate_vectors.shape[1], self.activate_vectors.shape[-1]

def building_up(mu_root, activate_root, risk_label_root, batch_size):
    d = BuildModel(mu_root, activate_root, risk_label_root)
    return DataLoader(d, batch_size=batch_size, shuffle=False), d.get_dim()


if __name__ == '__main__':
    for i in building_up('/mu/mu_cifar10', 'activate_vector/cifar10_train',
                         'risk_label/cifar10', 32):

        # print(a)
        # for j in a.values():
        #     print(activated_mu[j], i[2][j])
        #     if torch.mean(i[2][j]) < 0.5:
        #         i[2][j] = 0.
        #     else:
        #         i[2][j] = 1.
        #     print(i[2][j])
        # break
        print([i.shape for i in i])