# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from collections import OrderedDict, defaultdict
import json
import os
import logging

from fairseq import options, models
from fairseq.data import (
    data_utils,
    Dictionary,
    LanguagePairDataset,
    IndexedDataset,
    FairseqDataset,
)
from .multitask_data_utils import (
    MultitaskDatasetWrapper,
    MultidatasetEpochBatchIterator,
)


from fairseq.tasks import LegacyFairseqTask, register_task

logger = logging.getLogger(__name__)


@register_task("laser")
class LaserTask(LegacyFairseqTask):
    @staticmethod
    def add_args(parser):
        """Add task-specific arguments to the parser."""
        parser.add_argument(
            "configfile", metavar="PATH", help="dataset configuration file in json"
        )
        parser.add_argument(
            "--weighting-alpha",
            type=float,
            default=None,
            help="alpha for automatic weighting",
        )
        parser.add_argument(
            "--raw-text", action="store_true", help="load raw text dataset"
        )
        parser.add_argument(
            "--left-pad-source",
            default="True",
            type=str,
            metavar="BOOL",
            help="pad the source on the left (default: True)",
        )
        parser.add_argument(
            "--left-pad-target",
            default="False",
            type=str,
            metavar="BOOL",
            help="pad the target on the left (default: False)",
        )
        parser.add_argument(
            "--max-source-positions",
            default=1024,
            type=int,
            metavar="N",
            help="max number of tokens in the source sequence",
        )
        parser.add_argument(
            "--max-target-positions",
            default=1024,
            type=int,
            metavar="N",
            help="max number of tokens in the target sequence",
        )

    def __init__(self, args, config, src_dictionary, tgt_dictionary, num_tasks):
        super().__init__(args)
        self.config = config
        self.src_dictionary = src_dictionary
        self.tgt_dictionary = tgt_dictionary
        self.num_tasks = num_tasks

    @classmethod
    def setup_task(cls, args, **kwargs):
        with open(args.configfile, "r") as f:
            config = json.load(f)
        num_tasks = max(dataset["id"] for dataset in config["train"]) + 1

        args.left_pad_source = options.eval_bool(args.left_pad_source)
        args.left_pad_target = options.eval_bool(args.left_pad_target)

        src_dictionary = Dictionary.load(config["src_vocab"])
        tgt_dictionary = Dictionary.load(config["tgt_vocab"])

        logger.info(
            "| src Dictionary {} : {} types".format(
                config["src_vocab"], len(src_dictionary)
            )
        )
        logger.info(
            "| tgt Dictionary {} : {} types".format(
                config["tgt_vocab"], len(tgt_dictionary)
            )
        )

        return cls(args, config, src_dictionary, tgt_dictionary, num_tasks)

    # Experimental overriding for backtranslation
    def build_model(self, args):
        model = models.build_model(args, self)
        return model

    def dataset(self, split):
        if split not in self.datasets:
            raise KeyError("Dataset not loaded: " + split)
        return self.datasets[split]

    def load_dataset(self, split, epoch=1, **kwargs):
        """Load a dataset split."""

        def indexed_dataset(path, dictionary):
            if self.args.raw_text:
                raise Exception("Unable to handle raw text.")
            dataset = IndexedDataset(path, fix_lua_indexing=True)

            return dataset

        pair_datasets = OrderedDict()

        if split == "valid":
            self.datasets[split] = pair_datasets
            return

        if split not in self.config:
            raise FileNotFoundError(
                "Dataset not found in config file: {}".format(split)
            )

        size_by_corpus = defaultdict(int)
        size_sum = 0
        size_sum_with_subsampling = 0
        init_pair_datasets = {}

        for dataset_config in self.config[split]:
            src_path = os.path.dirname(dataset_config["src"])
            corpus_name = src_path.split("/")[-2]
            language_pair_name = src_path.split("/")[-1]
            pair_datasets_key = corpus_name + "-" + language_pair_name

            logger.info(f"loading... {pair_datasets_key}")
            if "src" in dataset_config:
                src_dataset = indexed_dataset(
                    dataset_config["src"], self.src_dictionary
                )
            else:
                src_dataset = None

            if "tgt" in dataset_config:
                tgt_dataset = indexed_dataset(
                    dataset_config["tgt"], self.tgt_dictionary
                )
            else:
                tgt_dataset = None

            dataset = LanguagePairDataset(
                src_dataset,
                src_dataset.sizes,
                self.src_dictionary,
                tgt_dataset,
                tgt_dataset.sizes,
                self.tgt_dictionary,
                left_pad_source=self.args.left_pad_source,
                left_pad_target=self.args.left_pad_target,
            )

            if pair_datasets_key in init_pair_datasets:
                logger.warning(
                    f"Ignoring already added {pair_datasets_key}. "
                    f"Consider using `sample` key in order to upsample."
                )
            else:
                init_pair_datasets[pair_datasets_key] = {
                    "dataset": dataset,
                    "sample": dataset_config.get("sample", None),
                    "id": dataset_config.get("id", None),
                    "len": len(dataset),
                }

        length_sum = 0
        weighted_freqs_sum = 0
        freq_per_dataset = {}
        vmax = 0
        vmin = 1
        weighted_freq_per_dataset = {}

        if self.args.weighting_alpha:
            for key in init_pair_datasets:
                if init_pair_datasets[key]["sample"] is None:
                    length_sum += len(init_pair_datasets[key]["dataset"])

            for key in init_pair_datasets:
                if init_pair_datasets[key]["sample"] is None:
                    val = float(init_pair_datasets[key]["len"]) / length_sum
                    freq_per_dataset[key] = val
                    weighted_freqs_sum += val ** self.args.weighting_alpha

            for key in freq_per_dataset:
                val = (
                    freq_per_dataset[key] ** self.args.weighting_alpha
                    / weighted_freqs_sum
                )
                vmin = min(vmin, val)
                vmax = max(vmax, val)
                weighted_freq_per_dataset[key] = val

        for pair_datasets_key in init_pair_datasets:
            dataset_config = init_pair_datasets[pair_datasets_key]
            dataset = dataset_config["dataset"]
            sample = dataset_config["sample"]
            if sample is None:
                sample = 1.0

            if pair_datasets_key in weighted_freq_per_dataset:
                w = vmax / weighted_freq_per_dataset[pair_datasets_key]
                sample = w

            sample = round(sample)

            initial_sample = sample
            initial_pair_datasets_key = pair_datasets_key

            while sample >= 1.0:
                assert (
                    pair_datasets_key not in pair_datasets
                ), f"{pair_datasets_key} already in"
                size_sum_with_subsampling += len(dataset)
                pair_datasets[pair_datasets_key] = MultitaskDatasetWrapper(
                    dataset, dataset_config.get("id", 0), 1.0, name=pair_datasets_key
                )
                size_sum += len(dataset)
                sample -= 1.0
                pair_datasets_key += "-up"

            assert sample < 1e-6, f"sample remains > 0 {pair_datasets_key}"

            logger.info(
                f"added pair {initial_pair_datasets_key} length {len(dataset)} new_length = {len(dataset)*initial_sample}"
            )
            size_by_corpus[corpus_name] += len(dataset)

        self.datasets[split] = pair_datasets
        logger.info(
            f"Datasets number = {len(self.datasets[split])} size = {size_sum} size_sum_with_subsampling = {size_sum_with_subsampling}"
        )

    @property
    def source_dictionary(self):
        return self.src_dictionary

    @property
    def target_dictionary(self):
        return self.tgt_dictionary

    def get_batch_iterator(
        self,
        dataset,
        max_tokens=None,
        max_sentences=None,
        max_positions=None,
        ignore_invalid_inputs=False,
        required_batch_size_multiple=1,
        seed=1,
        num_shards=1,
        shard_id=0,
        num_workers=0,
        epoch=1,
        data_buffer_size=0,
        disable_iterator_cache=False,
    ):

        assert isinstance(dataset, OrderedDict)
        assert len(dataset)
        assert isinstance(dataset[next(iter(dataset))], FairseqDataset)

        # initialize the dataset with the correct starting epoch
        for _, dt in dataset.items():
            dt.set_epoch(epoch)

        indices = OrderedDict()
        batch_sampler = OrderedDict()

        with data_utils.numpy_seed(seed + epoch):
            for key, dt in dataset.items():
                logger.info(f"\t ordered_indices {key}")
                indices[key] = dt.ordered_indices()

        # filter examples that are too large
        if max_positions is not None:
            for key, dt in dataset.items():
                logger.info(f"\t filter_by_size {key}")
                indices[key], ignored = dt.filter_indices_by_size(
                    indices[key], max_positions
                )

        for key, dt in dataset.items():
            logger.info(f"\t batch_by_size {key}")
            batch_sampler[key] = data_utils.batch_by_size(
                indices[key],
                dt.num_tokens,
                max_tokens=max_tokens,
                max_sentences=max_sentences,
                required_batch_size_multiple=required_batch_size_multiple,
            )

        epoch_iter = MultidatasetEpochBatchIterator(
            dataset=dataset,
            batch_sampler=batch_sampler,
            seed=seed,
            num_shards=num_shards,
            shard_id=shard_id,
            num_workers=num_workers,
            epoch=epoch,
        )

        return epoch_iter
