from typing import Any, Dict, List, Optional

import wandb
import torch
from omegaconf import DictConfig, OmegaConf

from pado.core import PadoAccelerator, PadoTrainer, PadoWrapper, set_logger
from pado.data.datasets import build_dataset, build_datasets
from pado.data.transforms import build_transforms
from pado.data.datasets.speech.librispeech import PadoLibriSpeech
from pado.data.dataloader import PadoDataLoader
from pado.data.transforms.text import SentencePieceTokenizer
from pado.losses import CTCLoss
from pado.metrics import CharErrorRate, WordErrorRate
from pado.optim import build_optimizer
from pado.optim.lr_scheduler import build_scheduler
from pado.models.utils import freeze_bn, set_weight_variational_noise
from pado.utils import (default_parser, load_config, override_config_by_cli, set_wandb,
                        split_params_for_weight_decay)
from model import ReuseAttnConformerCTC


class ReuseAttnConformerWrapper(PadoWrapper):

    def __init__(self,
                 network: ReuseAttnConformerCTC,
                 logger,
                 tokenizer: SentencePieceTokenizer,
                 loss_type: str,
                 loss_cfg: DictConfig,
                 metric_cfg: DictConfig) -> None:
        super().__init__(network, logger)
        self.tokenizer = tokenizer

        self.loss_type = loss_type
        self.loss = CTCLoss.from_config(loss_cfg)

        if self.loss.reduction != "sum":
            self.logger.warning(f"Loss reduction is {self.loss.reduction}, changed to sum.")
            self.loss.reduction = "sum"  # force set

        self.cer = CharErrorRate.from_config(metric_cfg)
        self.wer = WordErrorRate.from_config(metric_cfg)

    def forward(self,
                feature: torch.Tensor,
                feature_lengths: torch.Tensor,
                target: torch.Tensor,
                target_lengths: torch.Tensor,
                target_scripts: List[str],
                audio_paths: List[str],
                compute_metric: bool = True,
                verbose_metric: bool = False,
                verbose_info: str = "",
                is_train: bool = True,
                beam_width: int = 1) -> Dict[str, Any]:
        # ------------------------------------------------------------------------ #
        output = dict()

        batch_size = target_lengths.shape[0]
        output["num_samples"] = batch_size

        if self.loss_type == "ctc":
            enc, enc_lengths, _, _, _ = self.network(feature, feature_lengths)
            loss = self.loss(enc, target, enc_lengths, target_lengths)
        elif self.loss_type == "rnnt":
            logits, enc, enc_lengths, _, _, _ = self.network(feature, feature_lengths, target, target_lengths)
            loss = self.loss(logits, target, enc_lengths, target_lengths)
        else:  # las
            # logits, enc, enc_lengths, _, _ = self.network(feature, feature_lengths, target, target_lengths)
            # smoothed_loss, nll_loss = self.loss(logits, target)
            # loss = smoothed_loss if is_train else nll_loss
            raise NotImplementedError

        output["loss"] = loss / batch_size  # per-sample
        if torch.isnan(loss) or torch.isinf(loss):  # shortcut check
            raise RuntimeError("Loss is NaN or Inf")

        if compute_metric:
            decode_indices, decode_logp = self.network.decode(enc, enc_lengths, beam_width)
            decode_scripts = [self.tokenizer.decode(di) for di in decode_indices]
            assert len(decode_scripts) == len(target_scripts) == batch_size

            if verbose_metric:
                s = f"Decode output: {verbose_info}\n" \
                    f"... log_prob: {decode_logp[0]:.6f} " \
                    f"(normalized: {decode_logp[0] / max(len(decode_indices[0]), 1):.6f})\n" \
                    f"... decode: {decode_scripts[0]}\n" \
                    f"... target: {target_scripts[0]}\n" \
                    f"... CER: {self.cer([decode_scripts[0]], [target_scripts[0]])[0]:.6f}\n" \
                    f"... WER: {self.wer([decode_scripts[0]], [target_scripts[0]])[0]:.6f}"
                self.logger.info(s)

            _, cer_distance, cer_length = self.cer(decode_scripts, target_scripts)
            _, wer_distance, wer_length = self.wer(decode_scripts, target_scripts)
        else:
            cer_distance = wer_distance = 0
            cer_length = wer_length = 0

        output["cer_distance"] = cer_distance
        output["cer_length"] = cer_length
        output["wer_distance"] = wer_distance
        output["wer_length"] = wer_length
        return output


class ReuseAttnConformerTrainer(PadoTrainer):

    def __init__(self,
                 model: ReuseAttnConformerWrapper,
                 accelerator: PadoAccelerator,
                 trainer_cfg: dict,
                 beam_width: dict,
                 verbose_metric: dict,
                 save_dir: Optional[str] = None,
                 logger: Optional = None,
                 train_dataloader: Optional[PadoDataLoader] = None,
                 valid_dataloader: Optional[PadoDataLoader] = None,
                 test_dataloader: Optional[PadoDataLoader] = None,
                 optimizer: Optional = None,
                 scheduler: Optional = None,
                 *, sorta_grad: bool = False,
                 bn_freeze: bool = False,
                 variational_noise_start_iter: int = -1,
                 variational_noise_std: float = 0.0
                 ) -> None:
        super().__init__(model, accelerator, save_dir, logger, train_dataloader, valid_dataloader, test_dataloader,
                         optimizer, scheduler, **trainer_cfg)

        self.beam_width = beam_width  # {"train", "valid", "test"}
        self.verbose_metric = verbose_metric  # {"train", "valid", "test"}
        self.sorta_grad = sorta_grad
        self.bn_freeze = bn_freeze
        self.variational_noise_start_iter = variational_noise_start_iter
        self.variational_noise_std = variational_noise_std

    def init_additional(self):
        self.epoch_tracker.add_keys(["cer_distance", "cer_length", "wer_distance", "wer_length"])

    def on_train_epoch_start(self, dataloader):
        dataloader = super().on_train_epoch_start(dataloader)
        if self.sorta_grad:
            if self.current_epoch < 1:
                self.logger.info(f"SortaGrad {self.sorta_grad} and epoch is {self.current_epoch}, "
                                 f"shuffle is False only for this single epoch.")
                dataloader = PadoDataLoader.from_other(dataloader, override_kwargs={"shuffle": False})
                dataloader.set_collate_fn(PadoLibriSpeech.collate_fn)
            elif self.current_epoch == 1:
                self.logger.info(f"SortaGrad {self.sorta_grad} and epoch is {self.current_epoch}, "
                                 f"shuffle is set to True.")
                dataloader = PadoDataLoader.from_other(dataloader, override_kwargs={"shuffle": True})
                dataloader.set_collate_fn(PadoLibriSpeech.collate_fn)
            else:
                self.logger.info(f"SortaGrad {self.sorta_grad} and epoch is {self.current_epoch}, "
                                 f"shuffle is set to {dataloader.init_kwargs['shuffle']}.")
        return dataloader

    def on_train_iter_start(self) -> None:
        if self.bn_freeze:
            freeze_bn(self.model)
        if (self.variational_noise_start_iter > 0) and (self.current_iter > self.variational_noise_start_iter):
            set_weight_variational_noise(self.model, noise=self.variational_noise_std)

    def train_iter_body(self, batch: Any,
                        num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        feat, feat_len, tgt, tgt_len, scripts, paths = batch
        output = self.model(feat, feat_len, tgt, tgt_len, scripts, paths,
                            compute_metric=self.verbose_metric["train"] or self.is_print_iter(num_iter),
                            verbose_metric=self.verbose_metric["train"],
                            verbose_info=f"train {num_iter} /{dataloader_length}",
                            is_train=True,
                            beam_width=self.beam_width["train"])
        return output

    def train_iter_track(self, output: Dict[str, Any]) -> None:
        num_samples = output["num_samples"]
        self.static_tracker.update_add({"num_samples": num_samples})
        self.epoch_tracker.update_add({
            "loss": (output["loss"] * num_samples, num_samples),
            "cer_distance": output["cer_distance"],
            "cer_length": output["cer_length"],
            "wer_distance": output["wer_distance"],
            "wer_length": output["wer_length"],
        })

    def valid_iter_body(self, batch: Any,
                        num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        feat, feat_len, tgt, tgt_len, scripts, paths = batch
        output = self.model(feat, feat_len, tgt, tgt_len, scripts, paths,
                            compute_metric=self.verbose_metric["valid"],
                            verbose_metric=self.verbose_metric["valid"],
                            verbose_info=f"valid {num_iter} /{dataloader_length}",
                            is_train=False,
                            beam_width=self.beam_width["valid"])
        return output

    def valid_iter_track(self, output: Dict[str, Any]) -> None:
        num_samples = output["num_samples"]
        self.epoch_tracker.update_add({
            "loss": (output["loss"] * num_samples, num_samples),
            "cer_distance": output["cer_distance"],
            "cer_length": output["cer_length"],
            "wer_distance": output["wer_distance"],
            "wer_length": output["wer_length"],
        })

    def test_iter_body(self, batch: Any,
                       num_iter: int, dataloader_length: int) -> Dict[str, Any]:
        feat, feat_len, tgt, tgt_len, scripts, paths = batch
        output = self.model(feat, feat_len, tgt, tgt_len, scripts, paths,
                            compute_metric=self.verbose_metric["test"],
                            verbose_metric=self.verbose_metric["test"],
                            verbose_info=f"test {num_iter} /{dataloader_length}",
                            is_train=False,
                            beam_width=self.beam_width["test"])
        return output

    def test_iter_track(self, output: Dict[str, Any]) -> None:
        return self.valid_iter_track(output)

    def train_iter_log_and_print(self, output: Dict[str, Any],
                                 num_iter: int, dataloader_length: int,
                                 param_norm: torch.Tensor, grad_norm: torch.Tensor) -> None:

        cer = output["cer_distance"] / max(output["cer_length"], 1)
        wer = output["wer_distance"] / max(output["wer_length"], 1)
        cer_avg = self.epoch_tracker.value("cer_distance") / max(self.epoch_tracker.value("cer_length"), 1)
        wer_avg = self.epoch_tracker.value("wer_distance") / max(self.epoch_tracker.value("wer_length"), 1)

        if self.is_print_iter(num_iter):
            s = f"... train iter {num_iter} / {dataloader_length} ({100 * num_iter / dataloader_length:.2f} %) "
            s += f"(epoch: {self.current_epoch} / {self.max_epochs}, iters: {self.current_iter} / {self.max_iters})\n"
            s += f"... loss(batch/avg): {output['loss'].item():.6f} / {self.epoch_tracker.avg('loss'):.6f}\n"
            s += f"... CER(batch/avg): {cer:.4f} / {cer_avg:.4f}\n"
            s += f"... WER(batch/avg): {wer:.4f} / {wer_avg:.4f}\n"
            s += f"... grad/param norm: {grad_norm.item():.6f} / {param_norm.item():.6f}\n"
            s += f"... data/fwd/bwd/step: " \
                 f"{self.time_tracker.get('data'):.4f} / {self.time_tracker.get('forward'):.4f} / " \
                 f"{self.time_tracker.get('backward'):.4f} / {self.time_tracker.get('step'):.4f}\n"
            s += f"...... LR: {self.current_lrs()[0]:.6f}\n"
            s += f"...... batch_size: {output['num_samples']} " \
                 f"(per-device: {output['num_samples'] // self.world_size}, acc-batches: {self.acc_num_batches})"
            if self.fp16:
                s += f"\n...... (gradient scaler: {self.accelerator.scaler.get_scale()})"
            self.logger.info(s)

        if self.is_log_iter(num_iter):
            wandb.log({
                "train_loss": output['loss'].item(),
                "train_cer": cer,
                "train_wer": wer,
                "grad_norm": grad_norm.item(),
                "param_norm": param_norm.item(),
                "lr": self.current_lrs()[0],
                "iterations": self.current_iter,
                "num_samples": self.static_tracker.value('num_samples')
            })

    def valid_epoch_log_and_print(self) -> None:

        loss_avg = self.epoch_tracker.avg("loss")
        cer_avg = self.epoch_tracker.value("cer_distance") / max(self.epoch_tracker.value("cer_length"), 1)
        wer_avg = self.epoch_tracker.value("wer_distance") / max(self.epoch_tracker.value("wer_length"), 1)

        s = f"... valid loss(avg): {loss_avg:.6f}\n"
        s += f"... valid CER(avg): {cer_avg:.4f}\n"
        s += f"... valid WER(avg): {wer_avg:.4f}"
        self.logger.info(s)

        wandb.log({
            f"valid_loss": loss_avg,
            f"valid_cer": cer_avg,
            f"valid_wer": wer_avg,
            "epoch": self.current_epoch,
            "iterations": self.current_iter,
        })

    def valid_iter_print(self, output: Dict[str, Any],
                         num_iter: int, dataloader_length: int) -> None:

        cer = output["cer_distance"] / max(output["cer_length"], 1)
        wer = output["wer_distance"] / max(output["wer_length"], 1)
        cer_avg = self.epoch_tracker.value("cer_distance") / max(self.epoch_tracker.value("cer_length"), 1)
        wer_avg = self.epoch_tracker.value("wer_distance") / max(self.epoch_tracker.value("wer_length"), 1)

        s = f"... valid/test iter {num_iter} / {dataloader_length} ({100 * num_iter / dataloader_length:.2f} %) "
        s += f"(epoch: {self.current_epoch} / {self.max_epochs}, iters: {self.current_iter} / {self.max_iters})\n"
        s += f"... loss(batch/avg): {output['loss'].item():.6f} / {self.epoch_tracker.avg('loss'):.6f}\n"
        s += f"... CER(batch/avg): {cer:.4f} / {cer_avg:.4f}\n"
        s += f"... WER(batch/avg): {wer:.4f} / {wer_avg:.4f}\n"
        s += f"...... batch_size: {output['num_samples']} (per-device: {output['num_samples'] // self.world_size})"
        self.logger.info(s)

    def valid_update_best(self) -> bool:
        wer_avg = self.epoch_tracker.value("wer_distance") / max(self.epoch_tracker.value("wer_length"), 1)
        assert self.scheduler.mode == "min"
        is_updated = self.scheduler.update_best(wer_avg)
        if is_updated and self.is_master:
            wandb.run.summary["best_valid_wer"] = wer_avg
        return is_updated

    def test_epoch_log_and_print(self) -> None:

        loss_avg = self.epoch_tracker.avg("loss")
        cer_avg = self.epoch_tracker.value("cer_distance") / max(self.epoch_tracker.value("cer_length"), 1)
        wer_avg = self.epoch_tracker.value("wer_distance") / max(self.epoch_tracker.value("wer_length"), 1)

        s = f"... test loss(avg): {loss_avg:.6f}\n"
        s += f"... test CER(avg): {cer_avg:.4f}\n"
        s += f"... test WER(avg): {wer_avg:.4f}"
        self.logger.info(s)

    def test_iter_print(self, output: Dict[str, Any],
                        num_iter: int, dataloader_length: int) -> None:
        return self.valid_iter_print(output, num_iter, dataloader_length)


def run(cfg: DictConfig):
    # ======================================================================================== #
    # Common
    # ======================================================================================== #
    accelerator_cfg = OmegaConf.to_container(cfg["accelerator"], resolve=True)
    accelerator = PadoAccelerator(**accelerator_cfg)

    run_type = cfg["run_type"]
    save_dir = set_wandb(cfg)

    logger_cfg = OmegaConf.to_container(cfg["logging"], resolve=True)
    logger = set_logger(save_dir, accelerator.local_rank, accelerator.world_size, **logger_cfg)

    logger.info(OmegaConf.to_yaml(cfg, resolve=True))

    # ======================================================================================== #
    # Data
    # ======================================================================================== #
    test_transform = build_transforms(cfg["transforms"]["test"])  # same as valid
    test_target_transform = SentencePieceTokenizer.from_config(cfg["target_transforms"]["test"])  # same as valid

    # test_dataset = build_datasets(cfg["datasets"]["test"], test_transform, test_target_transform)
    test_dataset = build_dataset(cfg["datasets"]["test"], test_transform, test_target_transform)

    test_dataloader = PadoDataLoader.from_config(cfg["dataloaders"]["test"], test_dataset)
    test_dataloader.set_collate_fn(PadoLibriSpeech.collate_fn)  # manual set collate_fn for ConcatDataset

    if run_type == "train":
        train_transform = build_transforms(cfg["transforms"]["train"])
        train_target_transform = SentencePieceTokenizer.from_config(cfg["target_transforms"]["train"])

        train_dataset = build_datasets(cfg["datasets"]["train"], train_transform, train_target_transform)
        # valid_dataset = build_datasets(cfg["datasets"]["valid"], test_transform, test_target_transform)
        valid_dataset = build_dataset(cfg["datasets"]["valid"], test_transform, test_target_transform)

        train_dataloader = PadoDataLoader.from_config(cfg["dataloaders"]["train"], train_dataset)
        valid_dataloader = PadoDataLoader.from_config(cfg["dataloaders"]["test"], valid_dataset)

        train_dataloader.set_collate_fn(PadoLibriSpeech.collate_fn)
        valid_dataloader.set_collate_fn(PadoLibriSpeech.collate_fn)
    else:
        train_dataloader = valid_dataloader = None

    # ======================================================================================== #
    # Network
    # ======================================================================================== #
    loss_type = cfg["loss_type"].lower()
    network = ReuseAttnConformerCTC(cfg["model"])

    tokenizer = test_target_transform
    model = ReuseAttnConformerWrapper(network, logger, tokenizer, loss_type, cfg["loss"], cfg["metric"])
    model.to(accelerator.device)
    if cfg["additional"]["freeze_bn"]:
        freeze_bn(model)  # before pass to accelerator

    # ======================================================================================== #
    # Optimizer, Scheduler
    # ======================================================================================== #
    if run_type == "train":
        if cfg["additional"]["freeze_bn"]:
            params_for_opt = []
            for p_name, p in network.named_parameters():
                if "bn" not in p_name:
                    params_for_opt.append(p)
            params_for_opt = split_params_for_weight_decay(params_for_opt)
        else:
            params_for_opt = split_params_for_weight_decay(network.parameters())

        optimizer = build_optimizer(cfg["optimizer"], params_for_opt)
        scheduler = build_scheduler(cfg["scheduler"], optimizer)
    else:
        optimizer = scheduler = None

    # ======================================================================================== #
    # Trainer
    # ======================================================================================== #
    trainer_cfg = OmegaConf.to_container(cfg["trainer"], resolve=True)
    trainer = ReuseAttnConformerTrainer(
        model=model,
        accelerator=accelerator,
        trainer_cfg=trainer_cfg,
        beam_width=cfg["additional"]["beam_width"],
        verbose_metric=cfg["additional"]["verbose_metric"],
        save_dir=save_dir,
        logger=logger,
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        test_dataloader=test_dataloader,
        optimizer=optimizer,
        scheduler=scheduler,
        sorta_grad=cfg["additional"]["sorta_grad"],
        bn_freeze=cfg["additional"]["freeze_bn"],
        variational_noise_start_iter=cfg["additional"]["variational_noise"],
        variational_noise_std=cfg["additional"]["variational_noise_std"],
    )
    _ = trainer.resume(cfg["resume"])

    # ======================================================================================== #
    # Run
    # ======================================================================================== #
    if run_type == "train":
        trainer.train(train_dataloader, valid_dataloader, test_dataloader)
    else:
        trainer.test(test_dataloader)

    # ======================================================================================== #
    # Finish
    # ======================================================================================= #
    accelerator.wait_for_everyone()
    if accelerator.is_master:
        wandb.run.finish()


def main():
    parser = default_parser()
    args = parser.parse_args()

    d_config = load_config(args.config)
    d_config = override_config_by_cli(d_config, args.script_args)
    run(d_config)


if __name__ == '__main__':
    main()
