import torch
from torch.utils.data import Dataset

class LTR_BatchDataset(Dataset):
    def __init__(self, data):
        self.data = data
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def padding_tensor(sequences, labels):
    num = len(sequences)
    max_len = max([s.shape[0] for s in sequences])
    out_dims = (num, max_len, *sequences[0].shape[1:])
    out_tensor = sequences[0].data.new(*out_dims).fill_(0)
    mask = sequences[0].data.new(*out_dims[:-1]).fill_(0)
    out_label = labels[0].data.new(*out_dims[:-1]).fill_(0)
    for i, tensor in enumerate(sequences):
        length = tensor.size(0)
        out_tensor[i, :length] = tensor
        mask[i, :length] = 1
        out_label[i, :length] = labels[i]
    return out_tensor, out_label, mask

def collate_fn(batch):
    out_tensor, out_labels, mask = padding_tensor([f[1][0] for f in batch], [l[2] for l in batch])
    return out_tensor, out_labels, mask