"""
    Affordance dataset loader.
"""
import os, json
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


class AffordanceDataset(Dataset):
    """Affordance dataset in voxel or point cloud."""
    def __init__(self, workspace_dir, data_type, num_slots, rm=False, cross=""):
        """Data source and settings.

        Args: 
        workspace_dir: (the workspace directory) contains ./dataset/ ./sa
        data_type: 'voxel' or 'pointcloud'.
        """
        # workspace_dir (the workspace directory) contains ./dataset/ ./sa
        dataset_dir = os.path.join(workspace_dir, "dataset")

        # data source for voxel/point cloud
        self.data_type = data_type # data_dir: Directory of geometric dataset.
        if data_type == 'voxel':
            self.data_dir = os.path.join(dataset_dir, 'pc_vox%s'%cross)
        elif data_type == 'pointcloud':
            self.data_dir = os.path.join(dataset_dir, 'pc')
        # filenames of files in data_dir.
        self.filenames = os.listdir(self.data_dir)

        # other settings
        # 25 (null affordance + 24 meaningful affordances).
        self.num_affordance = 25
        # no more than "num_slots" kinds of affordance on a single object
        self.max_num_affordance = num_slots
        self.rm = rm

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        # load affordance labels
        anno_id = filename.split('.')[0]

        npz_path = os.path.join(self.data_dir, filename)
        afford_data = torch.tensor(np.load(npz_path)['arr_0'], dtype=torch.float32)
        affordance_list = np.load(npz_path)['arr_0']
        affordance_list = np.unique(affordance_list)
        affordance_list = sorted(affordance_list[affordance_list!=0])

        afford_rm = self.rm
        if afford_rm:
            np.random.seed(int(anno_id)+2333)
            if np.random.randint(3)==0:
                afford_rm = False
        if afford_rm:
            np.random.seed(int(anno_id))
            rm_idx = np.random.randint(len(affordance_list))
            rm_afford = affordance_list[rm_idx]
            del affordance_list[rm_idx]
        affordances = torch.LongTensor(affordance_list)

        one_hots = F.one_hot(affordances, num_classes=self.num_affordance)
        # pad row zeros to [num_slots, 25]
        padded_one_hots = F.pad(one_hots, (0,0,0,self.max_num_affordance-len(affordances)))
        # null affordance
        padded_one_hots[len(affordances):,0] = 1

        if afford_rm:
            afford_data[afford_data==rm_afford] = 0
        return {
            "geo": afford_data, 
            "anno_id": anno_id,
            "affordance": padded_one_hots
        }


