# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from typing import Dict, Optional

from fairseq.data import Dictionary

logger = logging.getLogger(__name__)


def get_config_from_yaml(yaml_path: Path):
    try:
        import yaml
    except ImportError:
        print("Please install PyYAML: pip install PyYAML")
    config = {}
    if yaml_path.is_file():
        try:
            with open(yaml_path) as f:
                config = yaml.load(f, Loader=yaml.FullLoader)
        except Exception as e:
            raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}")
    else:
        raise FileNotFoundError(f"{yaml_path.as_posix()} not found")

    return config


class S2TDataConfig(object):
    """Wrapper class for data config YAML"""

    def __init__(self, yaml_path: Path):
        self.config = get_config_from_yaml(yaml_path)
        self.root = yaml_path.parent

    def _auto_convert_to_abs_path(self, x):
        if isinstance(x, str):
            if not Path(x).exists() and (self.root / x).exists():
                return (self.root / x).as_posix()
        elif isinstance(x, dict):
            return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()}
        return x

    @property
    def vocab_filename(self):
        """fairseq vocabulary file under data root"""
        return self.config.get("vocab_filename", "dict.txt")

    @property
    def speaker_set_filename(self):
        """speaker set file under data root"""
        return self.config.get("speaker_set_filename", None)

    @property
    def shuffle(self) -> bool:
        """Shuffle dataset samples before batching"""
        return self.config.get("shuffle", False)

    @property
    def pre_tokenizer(self) -> Dict:
        """Pre-tokenizer to apply before subword tokenization. Returning
        a dictionary with `tokenizer` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None})
        return self._auto_convert_to_abs_path(tokenizer)

    @property
    def bpe_tokenizer(self) -> Dict:
        """Subword tokenizer to apply after pre-tokenization. Returning
        a dictionary with `bpe` providing the tokenizer name and
        the other items providing the tokenizer-specific arguments.
        Tokenizers are defined in `fairseq.data.encoders.*`"""
        tokenizer = self.config.get("bpe_tokenizer", {"bpe": None})
        return self._auto_convert_to_abs_path(tokenizer)

    @property
    def prepend_tgt_lang_tag(self) -> bool:
        """Prepend target lang ID token as the target BOS (e.g. for to-many
        multilingual setting). During inference, this requires `--prefix-size 1`
        to force BOS to be lang ID token."""
        return self.config.get("prepend_tgt_lang_tag", False)

    @property
    def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
        """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
        return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)

    @property
    def input_feat_per_channel(self):
        """The dimension of input features (per audio channel)"""
        return self.config.get("input_feat_per_channel", 80)

    @property
    def input_channels(self):
        """The number of channels in the input audio"""
        return self.config.get("input_channels", 1)

    @property
    def sample_rate(self):
        return self.config.get("sample_rate", 16_000)

    @property
    def sampling_alpha(self):
        """Hyper-parameter alpha = 1/T for temperature-based resampling.
        (alpha = 1 for no resampling)"""
        return self.config.get("sampling_alpha", 1.0)

    @property
    def use_audio_input(self):
        """Needed by the dataset loader to see if the model requires
        raw audio as inputs."""
        return self.config.get("use_audio_input", False)

    def standardize_audio(self) -> bool:
        return self.use_audio_input and self.config.get("standardize_audio", False)

    @property
    def use_sample_rate(self):
        """Needed by the dataset loader to see if the model requires
        raw audio with specific sample rate as inputs."""
        return self.config.get("use_sample_rate", 16000)

    @property
    def audio_root(self):
        """Audio paths in the manifest TSV can be relative and this provides
        the root path. Set this to empty string when using absolute paths."""
        return self.config.get("audio_root", "")

    def get_transforms(self, transform_type, split, is_train):
        """Split-specific feature transforms. Allowing train set
        wildcard `_train`, evaluation set wildcard `_eval` and general
        wildcard `*` for matching."""
        from copy import deepcopy

        cfg = deepcopy(self.config)
        _cur = cfg.get(f"{transform_type}transforms", {})
        cur = _cur.get(split)
        cur = _cur.get("_train") if cur is None and is_train else cur
        cur = _cur.get("_eval") if cur is None and not is_train else cur
        cur = _cur.get("*") if cur is None else cur
        return cur

    def get_feature_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        # TODO: deprecate transforms
        cur = self.get_transforms("", split, is_train)
        if cur is not None:
            logger.warning(
                "Auto converting transforms into feature_transforms, "
                "but transforms will be deprecated in the future. Please "
                "update this in the config."
            )
            ft_transforms = self.get_transforms("feature_", split, is_train)
            if ft_transforms:
                cur.extend(ft_transforms)
        else:
            cur = self.get_transforms("feature_", split, is_train)
        cfg["feature_transforms"] = cur
        return cfg

    def get_waveform_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train)
        return cfg

    def get_dataset_transforms(self, split, is_train):
        cfg = deepcopy(self.config)
        cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train)
        return cfg

    @property
    def global_cmvn_stats_npz(self) -> Optional[str]:
        path = self.config.get("global_cmvn", {}).get("stats_npz_path", None)
        return self._auto_convert_to_abs_path(path)

    @property
    def vocoder(self) -> Dict[str, str]:
        vocoder = self.config.get("vocoder", {"type": "griffin_lim"})
        return self._auto_convert_to_abs_path(vocoder)

    @property
    def hub(self) -> Dict[str, str]:
        return self.config.get("hub", {})


class S2SDataConfig(S2TDataConfig):
    """Wrapper class for data config YAML"""

    @property
    def vocab_filename(self):
        """fairseq vocabulary file under data root"""
        return self.config.get("vocab_filename", None)

    @property
    def pre_tokenizer(self) -> Dict:
        return None

    @property
    def bpe_tokenizer(self) -> Dict:
        return None

    @property
    def input_transformed_channels(self):
        """The number of channels in the audio after feature transforms"""
        # TODO: move this into individual transforms
        # TODO: deprecate transforms
        _cur = self.config.get("transforms", {})
        ft_transforms = self.config.get("feature_transforms", {})
        if _cur and ft_transforms:
            _cur.update(ft_transforms)
        else:
            _cur = self.config.get("feature_transforms", {})
        cur = _cur.get("_train", [])

        _channels = self.input_channels
        if "delta_deltas" in cur:
            _channels *= 3

        return _channels

    @property
    def output_sample_rate(self):
        """The audio sample rate of output target speech"""
        return self.config.get("output_sample_rate", 22050)

    @property
    def target_speaker_embed(self):
        """Target speaker embedding file (one line per target audio sample)"""
        return self.config.get("target_speaker_embed", None)

    @property
    def prepend_tgt_lang_tag_as_bos(self) -> bool:
        """Prepend target lang ID token as the target BOS."""
        return self.config.get("prepend_tgt_lang_tag_as_bos", False)


class MultitaskConfig(object):
    """Wrapper class for data config YAML"""

    def __init__(self, yaml_path: Path):
        config = get_config_from_yaml(yaml_path)
        self.config = {}
        for k, v in config.items():
            self.config[k] = SingleTaskConfig(k, v)

    def get_all_tasks(self):
        return self.config

    def get_single_task(self, name):
        assert name in self.config, f"multitask '{name}' does not exist!"
        return self.config[name]

    @property
    def first_pass_decoder_task_index(self):
        """Return the task index of the first-pass text decoder.
        If there are multiple 'is_first_pass_decoder: True' in the config file,
            the last task is used for the first-pass decoder.
        If there is no 'is_first_pass_decoder: True' in the config file,
            the last task whose task_name includes 'target' and decoder_type is not ctc.
        """
        idx = -1
        for i, (k, v) in enumerate(self.config.items()):
            if v.is_first_pass_decoder:
                idx = i
        if idx < 0:
            for i, (k, v) in enumerate(self.config.items()):
                if k.startswith("target") and v.decoder_type == "transformer":
                    idx = i
        return idx


class SingleTaskConfig(object):
    def __init__(self, name, config):
        self.task_name = name
        self.config = config
        dict_path = config.get("dict", "")
        self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None

    @property
    def data(self):
        return self.config.get("data", "")

    @property
    def decoder_type(self):
        return self.config.get("decoder_type", "transformer")

    @property
    def decoder_args(self):
        """Decoder arch related args"""
        args = self.config.get("decoder_args", {})
        return Namespace(**args)

    @property
    def criterion_cfg(self):
        """cfg for the multitask criterion"""
        if self.decoder_type == "ctc":
            from fairseq.criterions.ctc import CtcCriterionConfig

            cfg = CtcCriterionConfig
            cfg.zero_infinity = self.config.get("zero_infinity", True)
        else:
            from fairseq.criterions.label_smoothed_cross_entropy import (
                LabelSmoothedCrossEntropyCriterionConfig,
            )

            cfg = LabelSmoothedCrossEntropyCriterionConfig
            cfg.label_smoothing = self.config.get("label_smoothing", 0.2)
        return cfg

    @property
    def input_from(self):
        """Condition on encoder/decoder of the main model"""
        return "decoder" if "decoder_layer" in self.config else "encoder"

    @property
    def input_layer(self):
        if self.input_from == "decoder":
            return self.config["decoder_layer"] - 1
        else:
            # default using the output from the last encoder layer (-1)
            return self.config.get("encoder_layer", 0) - 1

    @property
    def loss_weight_schedule(self):
        return (
            "decay"
            if "loss_weight_max" in self.config
            and "loss_weight_decay_steps" in self.config
            else "fixed"
        )

    def get_loss_weight(self, num_updates):
        if self.loss_weight_schedule == "fixed":
            weight = self.config.get("loss_weight", 1.0)
        else:  # "decay"
            assert (
                self.config.get("loss_weight_decay_steps", 0) > 0
            ), "loss_weight_decay_steps must be greater than 0 for a decay schedule"
            loss_weight_min = self.config.get("loss_weight_min", 0.0001)
            loss_weight_decay_stepsize = (
                self.config["loss_weight_max"] - loss_weight_min
            ) / self.config["loss_weight_decay_steps"]
            weight = max(
                self.config["loss_weight_max"]
                - loss_weight_decay_stepsize * num_updates,
                loss_weight_min,
            )
        return weight

    @property
    def prepend_bos_and_append_tgt_lang_tag(self) -> bool:
        """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining)."""
        return self.config.get("prepend_bos_and_append_tgt_lang_tag", False)

    @property
    def eos_token(self):
        """EOS token during generation"""
        return self.config.get("eos_token", "<eos>")

    @property
    def rdrop_alpha(self):
        return self.config.get("rdrop_alpha", 0.0)

    @property
    def is_first_pass_decoder(self):
        flag = self.config.get("is_first_pass_decoder", False)
        if flag:
            if self.decoder_type == "ctc":
                raise ValueError(
                    "First-pass decoder in the multi-decoder model must not be CTC."
                )
            if "target" not in self.task_name:
                raise Warning(
                    'The name of the first-pass decoder does not include "target".'
                )
        return flag

    @property
    def get_lang_tag_mapping(self):
        return self.config.get("lang_tag_mapping", {})
