from torch.utils.data import Dataset
import torch
import numpy as np

class MultiLabelDataset(Dataset):
    """Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
        *tensors (Tensor): tensors that have the same size of the first dimension.
    """

    def __init__(self, features, label_ids, num_labels, pos_label, aplc=False):
        # assert all(len(lists[0]) == len(list) for list in lists)
        self.features = features
        # self.input_ids = features['input_ids']
        # self.input_mask = features['attention_mask']
        # self.segment_ids = features['token_type_ids']
        self.label_ids = label_ids
        self.num_labels = num_labels
        self.pos_label = pos_label
        self.n = len(features)
        self.aplc = aplc

    def __getitem__(self, index):
        # temp_tuple = tuple(list[index] for list in self.lists)
        f = self.features[index]
        input_ids = torch.as_tensor(f['input_ids'], dtype=torch.long)
        input_mask = torch.as_tensor(f['attention_mask'], dtype=torch.long)
        segment_ids = torch.as_tensor(f['token_type_ids'], dtype=torch.long)
        if self.aplc:
            label_ids = torch.as_tensor(self._get_pad_label(self.label_ids[index]), dtype=torch.long)
        else:
            label_ids = self._get_multi_hot_label(self.label_ids[index])
        #return f['input_ids'], f['attention_mask'], f['token_type_ids'], label_ids
        return input_ids, input_mask, segment_ids, label_ids

    def __len__(self):
        return self.n

    def _get_multi_hot_label(self, doc_labels, top=None):
        top = self.num_labels if top is None else top
        temp_array = np.zeros(top)
        temp_array[doc_labels] = 1
        return torch.from_numpy(temp_array)

    def _get_pad_label(self, doc_labels):

        num_pad = self.pos_label - len(doc_labels)
        # if num_pad > 0:
        if len(doc_labels) > 0:
            idx_pad = [doc_labels[0]] * num_pad
        else:
            idx_pad = [-1] * num_pad
        return doc_labels + idx_pad