# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import numpy as np
import torch

from fairseq.data import Dictionary, FairseqDataset
from fairseq.tasks import LegacyFairseqTask, register_task

logger = logging.getLogger(__name__)


@register_task("dummy_mt")
class DummyMTTask(LegacyFairseqTask):
    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument("--dict-size", default=49996, type=int)
        parser.add_argument("--dataset-size", default=100000, type=int)
        parser.add_argument("--src-len", default=30, type=int)
        parser.add_argument("--tgt-len", default=30, type=int)

    def __init__(self, args, dictionary):
        super().__init__(args)
        self.dictionary = dictionary
        self.seed = args.seed

        dictionary.pad_to_multiple_(8)  # often faster if divisible by 8

        self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
        self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1

    @classmethod
    def setup_task(cls, args, **kwargs):
        """Setup the task."""
        dictionary = Dictionary()
        for i in range(args.dict_size):
            dictionary.add_symbol("word{}".format(i))
        logger.info("dictionary: {} types".format(len(dictionary)))

        args.max_source_positions = args.src_len + dictionary.pad() + 2
        args.max_target_positions = args.tgt_len + dictionary.pad() + 2

        return cls(args, dictionary)

    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.
        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        item_size = max(self.args.src_len, self.args.tgt_len)
        if self.args.batch_size is not None:
            bsz = self.args.batch_size
        else:
            bsz = max(1, self.args.max_tokens // item_size)
        tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
        self.datasets[split] = DummyDataset(
            {
                "id": 1,
                "net_input": {
                    "src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
                    "src_lengths": torch.full(
                        (bsz,), self.args.src_len, dtype=torch.long
                    ),
                    "prev_output_tokens": tgt.clone(),
                },
                "target": tgt,
                "nsentences": bsz,
                "ntokens": bsz * self.args.tgt_len,
            },
            num_items=self.args.dataset_size,
            item_size=item_size,
        )

    @property
    def source_dictionary(self):
        return self.dictionary

    @property
    def target_dictionary(self):
        return self.dictionary


class DummyDataset(FairseqDataset):
    def __init__(self, batch, num_items, item_size):
        super().__init__()
        self.batch = batch
        self.num_items = num_items
        self.item_size = item_size

    def __getitem__(self, index):
        return index

    def __len__(self):
        return self.num_items

    def collater(self, samples):
        return self.batch

    @property
    def sizes(self):
        return np.array([self.item_size] * self.num_items)

    def num_tokens(self, index):
        return self.item_size

    def size(self, index):
        return self.item_size

    def ordered_indices(self):
        return np.arange(self.num_items)

    @property
    def supports_prefetch(self):
        return False
