#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import copy
import logging
import os
from typing import Any, Dict, Iterator, List
from fairseq.distributed.utils import get_data_parallel_rank, get_data_parallel_world_size

import torch
from fairseq import utils
from fairseq.data import encoders
from omegaconf import open_dict
from torch import nn


logger = logging.getLogger(__name__)


def from_pretrained(
    model_name_or_path,
    checkpoint_file="model.pt",
    data_name_or_path=".",
    archive_map=None,
    **kwargs,
):
    from fairseq import checkpoint_utils, file_utils

    if archive_map is not None:
        if model_name_or_path in archive_map:
            model_name_or_path = archive_map[model_name_or_path]
        if data_name_or_path is not None and data_name_or_path in archive_map:
            data_name_or_path = archive_map[data_name_or_path]

        # allow archive_map to set default arg_overrides (e.g., tokenizer, bpe)
        # for each model
        if isinstance(model_name_or_path, dict):
            for k, v in model_name_or_path.items():
                if k == "checkpoint_file":
                    checkpoint_file = v
                elif (
                    k != "path"
                    # only set kwargs that don't already have overrides
                    and k not in kwargs
                ):
                    kwargs[k] = v
            model_name_or_path = model_name_or_path["path"]

    model_path = file_utils.load_archive_file(model_name_or_path)

    # convenience hack for loading data and BPE codes from model archive
    if data_name_or_path.startswith("."):
        kwargs["data"] = os.path.abspath(os.path.join(model_path, data_name_or_path))
    else:
        kwargs["data"] = file_utils.load_archive_file(data_name_or_path)
    for file, arg in {
        "code": "bpe_codes",
        "bpecodes": "bpe_codes",
        "sentencepiece.bpe.model": "sentencepiece_model",
        "merges.txt": "bpe_merges",
        "vocab.json": "bpe_vocab",
    }.items():
        path = os.path.join(model_path, file)
        if os.path.exists(path):
            kwargs[arg] = path

    if "user_dir" in kwargs:
        utils.import_user_module(argparse.Namespace(user_dir=kwargs["user_dir"]))

    models, args, task = checkpoint_utils.load_model_ensemble_and_task(
        [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(os.pathsep)],
        arg_overrides=kwargs,
        suffix=kwargs.get("suffix", ""),
        is_moe=kwargs.get("is_moe", False)
    )

    return {
        "args": args,
        "task": task,
        "models": models,
    }


class GeneratorHubInterface(nn.Module):
    """
    PyTorch Hub interface for generating sequences from a pre-trained
    translation or language model.
    """

    lang_tokens = {}
    langs = None
    add_lang_bos_token = False

    def to_lang_token(self, lang):
        return f"<{lang}>"

    def __init__(self, cfg, task, models, moe_disable_padding=True, skip_prepare_for_inference=False):
        super().__init__()
        self.cfg = cfg

        self.task = task
        self.models = nn.ModuleList(models)
        self.src_dict = task.source_dictionary
        self.tgt_dict = task.target_dictionary

        if "langs" in cfg.task:
            self.langs = self.cfg.task.langs
            lang_tokens = [
                self.to_lang_token(x.strip()) for x in self.cfg.task.langs.split(",")
            ]

            # for debug purpose
            for lang_token in lang_tokens:
                if lang_token not in self.src_dict:
                    self.src_dict.add_symbol(lang_token)

                if lang_token not in self.tgt_dict:
                    self.tgt_dict.add_symbol(lang_token)

            self.lang_tokens = set(lang_tokens)

            if "add_bos_token" in cfg.task:
                #self.add_lang_bos_token = True
                self.add_lang_bos_token = cfg.task.add_bos_token

        # optimize model for generation
        if not skip_prepare_for_inference:
            for model in self.models:
                # For moe models and eval_lm
                model.prepare_for_inference_(cfg, moe_disable_padding=moe_disable_padding)

        # Load alignment dictionary for unknown word replacement
        # (None if no unknown word replacement, empty if no path to align dictionary)
        self.align_dict = utils.load_align_dict(cfg.generation.replace_unk)

        self.tokenizer = encoders.build_tokenizer(cfg.tokenizer)
        self.bpe = encoders.build_bpe(cfg.bpe)

        self.max_positions = utils.resolve_max_positions(
            self.task.max_positions(), *[model.max_positions() for model in models]
        )

        # this is useful for determining the device
        self.register_buffer("_float_tensor", torch.tensor([0], dtype=torch.float))

    @property
    def device(self):
        return self._float_tensor.device

    def translate(
        self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs
    ) -> List[str]:
        return self.sample(sentences, beam, verbose, **kwargs)

    def sample(
        self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs
    ) -> List[str]:
        if isinstance(sentences, str):
            return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]
        batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
        return [self.decode(hypos[0]["tokens"]) for hypos in batched_hypos]

    def score(self, sentences: List[str], **kwargs):
        if isinstance(sentences, str):
            return self.score([sentences], **kwargs)[0]
        # NOTE: this doesn't support translation tasks currently
        tokenized_sentences = [self.encode(sentence) for sentence in sentences]
        return [
            hypos[0]
            for hypos in self.generate(
                tokenized_sentences, score_reference=True, **kwargs
            )
        ]

    def generate(
        self,
        tokenized_sentences: List[torch.LongTensor],
        beam: int = 5,
        verbose: bool = False,
        skip_invalid_size_inputs=False,
        inference_step_args=None,
        batch_size=None,
        **kwargs,
    ) -> List[List[Dict[str, torch.Tensor]]]:
        if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
            return self.generate(
                tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, batch_size=batch_size, **kwargs
            )[0]

        # build generator using current args as well as any kwargs
        gen_args = copy.deepcopy(self.cfg.generation)
        with open_dict(gen_args):
            gen_args.beam = beam
            for k, v in kwargs.items():
                setattr(gen_args, k, v)
        generator = self.task.build_generator(self.models, gen_args)

        inference_step_args = inference_step_args or {}
        results = []
        rank, world_size = get_data_parallel_rank(), get_data_parallel_world_size()
        batches = self._build_batches(
            tokenized_sentences, skip_invalid_size_inputs, rank=rank,
            world_size=world_size, batch_size=batch_size,
        )
        # To ensure even batch count across workers, some batches might be dummy batches. We shouldn't score these.
        first_batch = None
        for batch in batches:
            is_dummy_batch = False
            if not first_batch and "net_input" in batch:
                first_batch = batch
            if "net_input" not in batch:
                if first_batch is not None:
                    batch = first_batch
                    is_dummy_batch = True
                else:
                    continue
            batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
            translations = self.task.inference_step(generator, self.models, batch, **inference_step_args)
            if is_dummy_batch:  # Don't score it or add it to hypotheses
                continue
            for id, hypos in zip(batch["id"].tolist(), translations):
                results.append((id, hypos))

        # sort output to match input order
        outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]

        if verbose:

            def getarg(name, default):
                return getattr(gen_args, name, getattr(self.cfg, name, default))

            for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
                src_str_with_unk = self.string(source_tokens)
                logger.info("S\t{}".format(src_str_with_unk))
                for hypo in target_hypotheses:
                    hypo_str = self.decode(hypo["tokens"])
                    logger.info("H\t{}\t{}".format(hypo["score"], hypo_str))
                    logger.info(
                        "P\t{}".format(
                            " ".join(
                                map(
                                    lambda x: "{:.4f}".format(x),
                                    hypo["positional_scores"].tolist(),
                                )
                            )
                        )
                    )
                    if hypo["alignment"] is not None and getarg(
                        "print_alignment", False
                    ):
                        logger.info(
                            "A\t{}".format(
                                " ".join(
                                    [
                                        "{}-{}".format(src_idx, tgt_idx)
                                        for src_idx, tgt_idx in hypo["alignment"]
                                    ]
                                )
                            )
                        )
        return outputs

    def get_sentence_and_language(self, sentence: str):
        """
        If sentence is prefixed with the language, it is striped and both are replaced.

        input: '<lang>en-EN</lang>Some sentence here'
        output: en-EN, 'Some sentence here'
        """

        lang_begin = "<lang>"
        lang_end = "</lang>"

        lang = None
        if sentence.startswith(lang_begin):
            idx = sentence.find(lang_end)
            if idx > 0:
                lang = sentence[: idx + len(lang_end)]
                lang = lang.replace(lang_begin, "").replace(lang_end, "")
                sentence = sentence[idx + len(lang_end) :]

        return lang, sentence

    def add_language_to_sentence(self, sentence: str, lang_token):
        lang_begin = "<lang>"
        lang_end = "</lang>"

        lang_prefix = lang_begin + lang_token + lang_end
        sentence = lang_prefix + sentence

        return sentence

    def encode(self, sentence: str) -> torch.LongTensor:
        lang, sentence = self.get_sentence_and_language(sentence)

        sentence = self.tokenize(sentence)
        sentence = self.apply_bpe(sentence)

        if lang is not None:
            sentence = f"{lang} {sentence}"

        return self.binarize(sentence)

    def decode(self, tokens: torch.LongTensor) -> str:
        sentence = self.string(tokens)

        # Remove the lang token
        sent_split = sentence.split(" ", 1)
        lang_token = None
        if sent_split[0] in self.lang_tokens:
            lang_token = sent_split[0]
            sentence = sent_split[1]

        sentence = self.remove_bpe(sentence)
        sentence = self.detokenize(sentence)

        if lang_token is not None:
            sentence = self.add_language_to_sentence(sentence, lang_token)

        return sentence

    def tokenize(self, sentence: str) -> str:
        if self.tokenizer is not None:
            sentence = self.tokenizer.encode(sentence)
        return sentence

    def detokenize(self, sentence: str) -> str:
        if self.tokenizer is not None:
            sentence = self.tokenizer.decode(sentence)
        return sentence

    def apply_bpe(self, sentence: str) -> str:
        if self.bpe is not None:
            sentence = self.bpe.encode(sentence)
        return sentence

    def remove_bpe(self, sentence: str) -> str:
        if self.bpe is not None:
            sentence = self.bpe.decode(sentence)
        return sentence

    def binarize(self, sentence: str) -> torch.LongTensor:
        return self.src_dict.encode_line(sentence, add_if_not_exist=False).long()

    def string(self, tokens: torch.LongTensor) -> str:
        return self.tgt_dict.string(tokens)

    def _build_batches(
        self, tokens:  List[torch.LongTensor], skip_invalid_size_inputs: bool, world_size=None, rank=None, batch_size=None
    ) -> Iterator[Dict[str, Any]]:
        lengths = torch.LongTensor([t.numel() for t in tokens])
        if batch_size is None:
            batch_size = self.cfg.dataset.batch_size
        batch_iterator = self.task.get_batch_iterator(
            dataset=self.task.build_dataset_for_inference(tokens, lengths),
            max_tokens=self.cfg.dataset.max_tokens,
            max_sentences=batch_size,
            max_positions=self.max_positions,
            ignore_invalid_inputs=skip_invalid_size_inputs,
            disable_iterator_cache=True,
            num_shards=world_size,
            shard_id=rank,
        ).next_epoch_itr(shuffle=False)
        return batch_iterator


class BPEHubInterface(object):
    """PyTorch Hub interface for Byte-Pair Encoding (BPE)."""

    def __init__(self, bpe, **kwargs):
        super().__init__()
        args = argparse.Namespace(bpe=bpe, **kwargs)
        self.bpe = encoders.build_bpe(args)
        assert self.bpe is not None

    def encode(self, sentence: str) -> str:
        return self.bpe.encode(sentence)

    def decode(self, sentence: str) -> str:
        return self.bpe.decode(sentence)


class TokenizerHubInterface(object):
    """PyTorch Hub interface for tokenization."""

    def __init__(self, tokenizer, **kwargs):
        super().__init__()
        args = argparse.Namespace(tokenizer=tokenizer, **kwargs)
        self.tokenizer = encoders.build_tokenizer(args)
        assert self.tokenizer is not None

    def encode(self, sentence: str) -> str:
        return self.tokenizer.encode(sentence)

    def decode(self, sentence: str) -> str:
        return self.tokenizer.decode(sentence)
