# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import logging
import multiprocessing
import typing as tp
from pathlib import Path

import flashy
import omegaconf
import torch
from torch import nn

from .. import models, quantization
from ..utils import checkpoint
from ..utils.samples.manager import SampleManager
from ..utils.utils import get_pool_executor
from . import builders, base
from .compression import CompressionSolver, evaluate_audio_reconstruction

logger = logging.getLogger(__name__)


class MReQSolver(CompressionSolver):
    def __init__(self, cfg: omegaconf.DictConfig):
        base.StandardSolver.__init__(self, cfg)
        self.rng: torch.Generator  # set at each epoch
        self.adv_losses = nn.ModuleDict()
        self.aux_losses = nn.ModuleDict()
        self.info_losses = nn.ModuleDict()
        assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
        loss_weights = dict()
        for loss_name, weight in self.cfg.losses.items():
            if loss_name in ["adv", "feat"]:
                if weight > 0:
                    if len(self.adv_losses) == 0:
                        self.adv_losses = builders.get_adversarial_losses(self.cfg)
                        self.register_stateful("adv_losses")
                    for adv_name, _ in self.adv_losses.items():
                        loss_weights[f"{loss_name}_{adv_name}"] = weight
            elif weight > 0:
                self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
                loss_weights[loss_name] = weight
            else:
                self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
        self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)

        # load_checkpoint
        state = self.load_base_codec_checkpoint_state(self.cfg.base_encodec_path)
        self.model.base_encodec.load_state_dict(state["best_state"]["model"])

        self.l1_hidden_loss = nn.L1Loss()

        if self.cfg.init_autoencoder_from_base is True:
            assert self.model.encoder is not None
            state = self.load_base_codec_checkpoint_state(self.cfg.base_encodec_path)
            new_state = self.model.state_dict()
            for k, v in state["best_state"]["model"].items():
                # only use encoder & deocoder
                if k.startswith("encoder") or k.startswith("decoder"):
                    new_state[k] = v
            self.model.load_state_dict(new_state)
            self.logger.info("Loaded base codec checkpoint for new encoder & decoder.")

    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
        """Perform one training or valid step on a given batch."""
        self.model.base_encodec.quantizer.eval()
        x = batch.to(self.device)

        with torch.no_grad():
            y_hiddens, ys, scale = self.model.get_target(x)
            y = ys[-1]

        (
            autoencoder_inputs, autoencoder_outputs,
            qreses, sub_qreses_pre, sub_qreses_post,
            y_hidden_preds, y_pred
        )= self.model(x)
        for qres in qreses + sub_qreses_pre + sub_qreses_post:
            assert isinstance(qres, quantization.QuantizedResult)
        # Log bandwidth in kb/s
        for idx, qres in enumerate(qreses):
            metrics[f"bandwidth_{idx}"] = qres.bandwidth.mean()
        for idx, qres in enumerate(sub_qreses_pre):
            metrics[f"sub_pre_bandwidth_{idx}"] = qres.bandwidth.mean()
        for idx, qres in enumerate(sub_qreses_post):
            metrics[f"sub_post_bandwidth_{idx}"] = qres.bandwidth.mean()

        if self.is_training:
            d_losses: dict = {}
            if (
                len(self.adv_losses) > 0
                and torch.rand(1, generator=self.rng).item()
                <= 1 / self.cfg.adversarial.every
            ):
                for adv_name, adversary in self.adv_losses.items():
                    disc_loss = adversary.train_adv(y_pred, y)
                    d_losses[f"d_{adv_name}"] = disc_loss
                metrics["d_loss"] = torch.sum(torch.stack(list(d_losses.values())))
            metrics.update(d_losses)

        balanced_losses: dict = {}
        other_losses: dict = {}

        # penalty from quantization
        for idx, qres in enumerate(qreses):
            if qres.penalty is not None and qres.penalty.requires_grad:
                other_losses[f"penalty_{idx}"] = qres.penalty
        for idx, qres in enumerate(sub_qreses_pre):
            if qres.penalty is not None and qres.penalty.requires_grad:
                other_losses[f"sub_pre_penalty_{idx}"] = qres.penalty
        for idx, qres in enumerate(sub_qreses_post):
            if qres.penalty is not None and qres.penalty.requires_grad:
                other_losses[f"sub_post_penalty_{idx}"] = qres.penalty
        # hidden teacher forcing loss
        for idx, (y_hidden_pred, y_hidden) in enumerate(zip(y_hidden_preds, y_hiddens)):
            loss = self.l1_hidden_loss(y_hidden_pred, y_hidden)
            other_losses[f"l1_hidden_tf_{idx}"] = loss

        # hidden reconstruction loss
        for idx, (hidden_input, hidden_output) in enumerate(zip(
            autoencoder_inputs, autoencoder_outputs
        )):
            if isinstance(self.model.sub_encoders[idx], nn.Identity):
                continue
            loss = self.l1_hidden_loss(hidden_output, hidden_input.detach())
            other_losses[f"l1_hidden_rec_{idx}"] = loss

        # adversarial losses
        for adv_name, adversary in self.adv_losses.items():
            adv_loss, feat_loss = adversary(y_pred, y)
            balanced_losses[f"adv_{adv_name}"] = adv_loss
            balanced_losses[f"feat_{adv_name}"] = feat_loss

        # auxiliary losses
        for loss_name, criterion in self.aux_losses.items():
            loss = criterion(y_pred, y)
            balanced_losses[loss_name] = loss

        metrics.update(balanced_losses)
        metrics.update(other_losses)
        for qres in qreses + sub_qreses_pre + sub_qreses_post:
            metrics.update(qres.metrics)

        if self.is_training:
            # backprop losses that are not handled by balancer
            other_loss = torch.tensor(0.0, device=self.device)
            for key in other_losses.keys():
                _weight = 1.0
                for weight_key in self.cfg.losses_for_others.keys():
                    if weight_key in key:
                        _weight = self.cfg.losses_for_others.get(weight_key, 1.0)
                other_loss += _weight * other_losses[key]
            if other_loss.requires_grad:
                other_loss.backward(retain_graph=True)
                ratio1 = sum(
                    p.grad.data.norm(p=2).pow(2)
                    for p in self.model.parameters()
                    if p.grad is not None
                )
                assert isinstance(ratio1, torch.Tensor)
                metrics["ratio1"] = ratio1.sqrt()

            if len(balanced_losses) > 0:
                # balancer losses backward, returns effective training loss
                # with effective weights at the current batch.
                metrics["g_loss"] = self.balancer.backward(balanced_losses, y_pred)
                # add metrics corresponding to weight ratios
                metrics.update(self.balancer.metrics)
                ratio2 = sum(
                    p.grad.data.norm(p=2).pow(2)
                    for p in self.model.parameters()
                    if p.grad is not None
                )
                assert isinstance(ratio2, torch.Tensor)
                metrics["ratio2"] = ratio2.sqrt()

            # optim
            flashy.distrib.sync_model(self.model)
            if self.cfg.optim.max_norm:
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(), self.cfg.optim.max_norm
                )
            self.optimizer.step()
            if self.lr_scheduler:
                self.lr_scheduler.step()
                metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
            self.optimizer.zero_grad()

        # informative losses only
        info_losses: dict = {}
        with torch.no_grad():
            for loss_name, criterion in self.info_losses.items():
                for idx, (y_hidden_pred, y) in enumerate(zip(y_hidden_preds, ys)):
                    y_pred = self.model._decode_from_hidden(
                        y_hidden_pred, y.shape[-1], scale
                    )
                    loss = criterion(y_pred, y)
                    info_losses[f"{loss_name}_{idx}"] = loss
            for idx, (hidden_input, hidden_output_qres) in enumerate(zip(
                autoencoder_inputs, sub_qreses_post
            )):
                if isinstance(self.model.sub_encoders[idx], nn.Identity):
                    continue
                loss = self.l1_hidden_loss(hidden_output_qres.x, hidden_input)
                info_losses[f"l1_hidden_rec_after_post_qres_{idx}"] = loss

        metrics.update(info_losses)

        # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
        adv_losses = [
            loss for loss_name, loss in metrics.items() if loss_name.startswith("adv")
        ]
        if len(adv_losses) > 0:
            metrics["adv"] = torch.sum(torch.stack(adv_losses))
        feat_losses = [
            loss for loss_name, loss in metrics.items() if loss_name.startswith("feat")
        ]
        if len(feat_losses) > 0:
            metrics["feat"] = torch.sum(torch.stack(feat_losses))

        return metrics

    def evaluate(self):
        """Evaluate stage. Runs audio reconstruction evaluation."""
        self.model.eval()
        evaluate_stage_name = str(self.current_stage)

        loader = self.dataloaders["evaluate"]
        updates = len(loader)
        lp = self.log_progress(
            f"{evaluate_stage_name} inference",
            loader,
            total=updates,
            updates=self.log_updates,
        )
        average = flashy.averager()

        pendings = []
        ctx = multiprocessing.get_context("spawn")
        with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
            for idx, batch in enumerate(lp):
                x = batch.to(self.device)
                with torch.no_grad():
                    for stage in range(len(self.model.n_qs)):
                        y_pred = self.model.reconstruction(x, stage).cpu()
                        y = self.model.reconstruction_base(x, stage).cpu()

                        pendings.append(
                            pool.submit(evaluate_audio_reconstruction, y_pred, y, stage, self.cfg)
                        )

            metrics_lp = self.log_progress(
                f"{evaluate_stage_name} metrics", pendings, updates=self.log_updates
            )
            for pending in metrics_lp:
                metrics = pending.result()
                metrics = average(metrics)

        metrics = flashy.distrib.average_metrics(metrics, len(loader))
        return metrics

    def generate(self):
        """Generate stage."""
        self.model.eval()
        sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
        generate_stage_name = str(self.current_stage)

        loader = self.dataloaders["generate"]
        updates = len(loader)
        lp = self.log_progress(
            generate_stage_name, loader, total=updates, updates=self.log_updates
        )

        for batch in lp:
            reference, _ = batch
            reference = reference.to(self.device)
            estimates = []
            references = []
            with torch.no_grad():
                for stage in range(len(self.model.n_qs)):
                    _estimate = self.model.reconstruction(reference, stage).cpu()
                    _reference = self.model.reconstruction_base(reference, stage).cpu()
                    estimates.append(_estimate)
                    references.append(_reference)
                estimates = torch.cat(estimates, dim=0)
                references = torch.cat(references, dim=0)

                sample_manager.add_samples(
                    estimates, self.epoch, ground_truth_wavs=references
                )

        flashy.distrib.barrier()

    def load_from_pretrained(self, name: str) -> dict:
        raise NotImplementedError()

    @staticmethod
    def wrapped_model_from_checkpoint(
        cfg: omegaconf.DictConfig,
        checkpoint_path: tp.Union[Path, str],
        device: tp.Union[torch.device, str] = "cpu",
    ) -> models.CompressionModel:
        """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
        Args:
            cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
            checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
            use_ema (bool): Use EMA variant of the model instead of the actual model.
            device (torch.device or str): Device on which the model is loaded.
        """
        compression_model = CompressionHierarchicalSolver2.model_from_checkpoint(
            checkpoint_path, device
        )
        compression_model = models.builders.get_wrapped_compression_model(
            compression_model, cfg
        )
        return compression_model

    def load_base_codec_checkpoint_state(
        self,
        checkpoint_path: tp.Union[Path, str],
    ) -> tp.Dict[str, torch.Tensor]:
        checkpoint_path = str(checkpoint_path)
        logger = logging.getLogger(__name__)
        logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
        _checkpoint_path = checkpoint.resolve_checkpoint_path(
            checkpoint_path, use_fsdp=False
        )
        assert (
            _checkpoint_path is not None
        ), f"Could not resolve compression model checkpoint path: {checkpoint_path}"
        state = checkpoint.load_checkpoint(_checkpoint_path)
        assert (
            state is not None and "xp.cfg" in state
        ), f"Could not load compression model from ckpt: {checkpoint_path}"
        return state

def evaluate_audio_reconstruction(
    y_pred: torch.Tensor, y: torch.Tensor, stage: int, cfg: omegaconf.DictConfig
) -> dict:
    """Audio reconstruction evaluation method that can be conveniently pickled."""
    metrics = {}
    if cfg.evaluate.metrics.visqol:
        visqol = builders.get_visqol(cfg.metrics.visqol)
        metrics[f"visqol_{stage}"] = visqol(y_pred, y, cfg.sample_rate)
    sisnr = builders.get_loss("sisnr", cfg)
    metrics[f"sisnr_{stage}"] = sisnr(y_pred, y)
    return metrics