import numpy as np
import torch


class ScanFamilyDatasetWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset, tokenizer, max_seq_length=80, max_obj_len=80):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.max_obj_len = max_obj_len

    def __len__(self):
        return len(self.dataset)

    def pad_tensors(self, tensors, lens=None, pad=0):
        if tensors.shape[0] > lens:
            return tensors[:lens]
        if tensors.shape[0] == lens:
            return tensors
        shape = list(tensors.shape)
        shape[0] = lens - shape[0]
        res = torch.ones(shape, dtype=tensors.dtype) * pad
        res = torch.cat((tensors, res), dim=0)
        return res

    def __getitem__(self, idx):
        data_dict = self.dataset[idx]
        sentence = data_dict["sentence"]
        encoded_input = self.tokenizer(
            sentence,
            max_length=self.max_seq_length,
            add_special_tokens=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        data_dict["txt_ids"] = encoded_input["input_ids"].squeeze(0)
        data_dict["txt_masks"] = encoded_input["attention_mask"].squeeze(0)

        data_dict["obj_masks"] = torch.arange(self.max_obj_len) < len(data_dict["obj_locs"])
        data_dict["obj_fts"] = self.pad_tensors(data_dict["obj_fts"], lens=self.max_obj_len, pad=1.0).float()
        data_dict["obj_locs"] = self.pad_tensors(data_dict["obj_locs"], lens=self.max_obj_len, pad=0.0).float()
        data_dict["obj_boxes"] = self.pad_tensors(data_dict["obj_boxes"], lens=self.max_obj_len, pad=0.0).float()
        data_dict["obj_labels"] = self.pad_tensors(data_dict["obj_labels"], lens=self.max_obj_len, pad=-100).long()
        data_dict["obj_sem_masks"] = torch.arange(self.max_obj_len) < len(data_dict["obj_locs"])

        data_dict["tgt_object_label"] = data_dict["tgt_object_label"].long()
        data_dict["tgt_object_id"] = data_dict["tgt_object_id"].long()
        if len(data_dict["tgt_object_id"]) > 1:
            data_dict["tgt_object_id"] = self.pad_tensors(
                data_dict["tgt_object_id"].long(), lens=self.max_obj_len, pad=0
            ).long()

        if data_dict.get("tgt_object_id_iou25") is not None:
            data_dict["tgt_object_id_iou25"] = self.pad_tensors(
                data_dict["tgt_object_id_iou25"], lens=self.max_obj_len, pad=0
            ).long()
        if data_dict.get("tgt_object_id_iou50") is not None:
            data_dict["tgt_object_id_iou50"] = self.pad_tensors(
                data_dict["tgt_object_id_iou50"], lens=self.max_obj_len, pad=0
            ).long()

        if "answer_label" in data_dict:
            data_dict["answer_label"] = data_dict["answer_label"].long()
        return data_dict
