# Copyright 2022 The OFA-Sys Team. 
# All rights reserved.
# This source code is licensed under the Apache 2.0 license 
# found in the LICENSE file in the root directory.

import logging
import warnings
import torch
import numpy as np

from data import data_utils
from data.ofa_dataset import OFADataset

logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)


def collate(samples, pad_idx, eos_idx):
    if len(samples) == 0:
        return {}

    def merge(key):
        return data_utils.collate_tokens(
            [s[key] for s in samples],
            pad_idx,
            eos_idx=eos_idx,
        )

    src_tokens = merge("source")
    src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])

    ref_dict = None
    if samples[0].get("ref_dict", None) is not None:
        ref_dict = np.array([s['ref_dict'] for s in samples])

    constraint_masks = None
    if samples[0].get("constraint_mask", None) is not None:
        constraint_masks = merge("constraint_mask")

    prev_output_tokens = None
    target = None
    if samples[0].get("target", None) is not None:
        target = merge("target")
        tgt_lengths = torch.LongTensor(
            [s["target"].ne(pad_idx).long().sum() for s in samples]
        )
        ntokens = tgt_lengths.sum().item()

        if samples[0].get("prev_output_tokens", None) is not None:
            prev_output_tokens = merge("prev_output_tokens")
    else:
        ntokens = src_lengths.sum().item()

    batch = {
        "nsentences": len(samples),
        "ntokens": ntokens,
        "net_input": {
            "src_tokens": src_tokens,
            "src_lengths": src_lengths,
            "prev_output_tokens": prev_output_tokens
        },
        "ref_dict": ref_dict,
        "constraint_masks": constraint_masks,
        "target": target,
    }

    return batch


class MNLIDataset(OFADataset):
    def __init__(
        self,
        split,
        dataset,
        bpe,
        src_dict,
        tgt_dict=None,
        max_src_length=512,
        max_tgt_length=30,
        constraint_trie=None,
        prompt_type="none"
    ):
        super().__init__(split, dataset, bpe, src_dict, tgt_dict)
        self.max_src_length = max_src_length
        self.max_tgt_length = max_tgt_length
        self.constraint_trie = constraint_trie
        self.prompt_type = prompt_type

    def __getitem__(self, index):
        sentence1, sentence2, label = self.dataset[index]
        if label == '0':
            label = 'maybe'
        elif label == '1':
            label = 'yes'
        elif label == '2':
            label = 'no'
        else:
            raise NotImplementedError

        sentence1 = ' '.join(sentence1.lower().strip().split()[:self.max_src_length])
        sentence2 = ' '.join(sentence2.lower().strip().split()[:self.max_src_length])
        src_item = self.encode_text(
            ' can text1 " {} " imply text2 " {} "?'.format(sentence1, sentence2)
        )
        tgt_item = self.encode_text(" {}".format(label))
        assert tgt_item.size(0) == 1
        ref_dict = {label: 1.0}

        src_item = torch.cat([self.bos_item, src_item, self.eos_item])
        if self.prompt_type == 'none':
            prev_output_item = self.bos_item
            target_item = tgt_item
        elif self.prompt_type == 'src':
            prev_output_item = src_item.clone()
            target_item = torch.cat([prev_output_item[1:], tgt_item])
        elif self.prompt_type == 'prev_output':
            prev_output_item = src_item[:-1].clone()
            target_item = torch.cat([prev_output_item[1:], tgt_item])
        else:
            raise NotImplementedError
        target_item[:-1] = self.tgt_dict.pad()

        example = {
            "source": src_item,
            "target": target_item,
            "prev_output_tokens": prev_output_item,
            "ref_dict": ref_dict,
        }
        if self.constraint_trie is not None:
            constraint_mask = torch.zeros((len(prev_output_item), len(self.tgt_dict))).bool()
            constraint_nodes = self.constraint_trie.get_next_layer(self.bos_item.tolist())
            constraint_mask[-1][constraint_nodes] = True
            example["constraint_mask"] = constraint_mask
        return example

    def collater(self, samples, pad_to_length=None):
        """Merge a list of samples to form a mini-batch.
        Args:
            samples (List[dict]): samples to collate
        Returns:
            dict: a mini-batch containing the data of the task
        """
        return collate(samples, pad_idx=self.pad, eos_idx=self.eos)
