# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import typing as tp

import flashy
import julius
import omegaconf
import torch
import torch.nn.functional as F

from .. import models
from ..metrics import RelativeVolumeMel
from ..models.builders import get_processor
from ..modules.diffusion_schedule import NoiseSchedule
from ..solvers.compression import CompressionSolver
from ..utils.samples.manager import SampleManager
from . import base, builders


class PerStageMetrics:
    """Handle prompting the metrics per stage.
    It outputs the metrics per range of diffusion states.
    e.g. avg loss when t in [250, 500]
    """

    def __init__(self, num_steps: int, num_stages: int = 4):
        self.num_steps = num_steps
        self.num_stages = num_stages

    def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
        if type(step) is int:
            stage = int((step / self.num_steps) * self.num_stages)
            return {f"{name}_{stage}": loss for name, loss in losses.items()}
        elif type(step) is torch.Tensor:
            stage_tensor = ((step / self.num_steps) * self.num_stages).long()
            out: tp.Dict[str, float] = {}
            for stage_idx in range(self.num_stages):
                mask = stage_tensor == stage_idx
                N = mask.sum()
                stage_out = {}
                if N > 0:  # pass if no elements in the stage
                    for name, loss in losses.items():
                        stage_loss = (mask * loss).sum() / N
                        stage_out[f"{name}_{stage_idx}"] = stage_loss
                out = {**out, **stage_out}
            return out


class DataProcess:
    """Apply filtering or resampling.

    Args:
        initial_sr (int): Initial sample rate.
        target_sr (int): Target sample rate.
        use_resampling: Whether to use resampling or not.
        use_filter (bool):
        n_bands (int): Number of bands to consider.
        idx_band (int):
        device (torch.device or str):
        cutoffs ():
        boost (bool):
    """

    def __init__(
        self,
        initial_sr: int = 24000,
        target_sr: int = 16000,
        use_resampling: bool = False,
        use_filter: bool = False,
        n_bands: int = 4,
        idx_band: int = 0,
        device: torch.device = torch.device("cpu"),
        cutoffs=None,
        boost=False,
    ):
        """Apply filtering or resampling
        Args:
            initial_sr (int): sample rate of the dataset
            target_sr (int): sample rate after resampling
            use_resampling (bool): whether or not performs resampling
            use_filter (bool): when True filter the data to keep only one frequency band
            n_bands (int): Number of bands used
            cuts (none or list): The cutoff frequencies of the band filtering
                                if None then we use mel scale bands.
            idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
            boost (bool): make the data scale match our music dataset.
        """
        assert idx_band < n_bands
        self.idx_band = idx_band
        if use_filter:
            if cutoffs is not None:
                self.filter = julius.SplitBands(
                    sample_rate=initial_sr, cutoffs=cutoffs
                ).to(device)
            else:
                self.filter = julius.SplitBands(
                    sample_rate=initial_sr, n_bands=n_bands
                ).to(device)
        self.use_filter = use_filter
        self.use_resampling = use_resampling
        self.target_sr = target_sr
        self.initial_sr = initial_sr
        self.boost = boost

    def process_data(self, x, metric=False):
        if x is None:
            return None
        if self.boost:
            x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
            x * 0.22
        if self.use_filter and not metric:
            x = self.filter(x)[self.idx_band]
        if self.use_resampling:
            x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
        return x

    def inverse_process(self, x):
        """Upsampling only."""
        if self.use_resampling:
            x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
        return x


class DiffusionSolver(base.StandardSolver):
    """Solver for compression task.

    The diffusion task allows for MultiBand diffusion model training.

    Args:
        cfg (DictConfig): Configuration.
    """

    def __init__(self, cfg: omegaconf.DictConfig):
        super().__init__(cfg)
        self.cfg = cfg
        self.device = cfg.device
        self.sample_rate: int = self.cfg.sample_rate
        self.codec_model = CompressionSolver.model_from_checkpoint(
            cfg.compression_model_checkpoint, device=self.device
        )

        self.codec_model.set_num_codebooks(cfg.n_q)
        assert self.codec_model.sample_rate == self.cfg.sample_rate, (
            f"Codec model sample rate is {self.codec_model.sample_rate} but "
            f"Solver sample rate is {self.cfg.sample_rate}."
        )
        assert self.codec_model.sample_rate == self.sample_rate, (
            f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} "
            "don't match."
        )

        self.sample_processor = get_processor(
            cfg.processor, sample_rate=self.sample_rate
        )
        self.register_stateful("sample_processor")
        self.sample_processor.to(self.device)

        self.schedule = NoiseSchedule(
            **cfg.schedule, device=self.device, sample_processor=self.sample_processor
        )

        self.eval_metric: tp.Optional[torch.nn.Module] = None

        self.rvm = RelativeVolumeMel()
        self.data_processor = DataProcess(
            initial_sr=self.sample_rate,
            target_sr=cfg.resampling.target_sr,
            use_resampling=cfg.resampling.use,
            cutoffs=cfg.filter.cutoffs,
            use_filter=cfg.filter.use,
            n_bands=cfg.filter.n_bands,
            idx_band=cfg.filter.idx_band,
            device=self.device,
        )

    @property
    def best_metric_name(self) -> tp.Optional[str]:
        if self._current_stage == "evaluate":
            return "rvm"
        else:
            return "loss"

    @torch.no_grad()
    def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
        codes, scale = self.codec_model.encode(wav)
        assert scale is None, "Scaled compression models not supported."
        emb = self.codec_model.decode_latent(codes)
        return emb

    def build_model(self):
        """Build model and optimizer as well as optional Exponential Moving Average of the model."""
        # Model and optimizer
        self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
        self.register_stateful("model", "optimizer")
        self.register_best_state("model")
        self.register_ema("model")

    def build_dataloaders(self):
        """Build audio dataloaders for each stage."""
        self.dataloaders = builders.get_audio_datasets(self.cfg)

    def show(self):
        # TODO
        raise NotImplementedError()

    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
        """Perform one training or valid step on a given batch."""
        x = batch.to(self.device)
        loss_fun = F.mse_loss if self.cfg.loss.kind == "mse" else F.l1_loss

        condition = self.get_condition(x)  # [bs, 128, T/hop, n_emb]
        sample = self.data_processor.process_data(x)

        input_, target, step = self.schedule.get_training_item(
            sample, tensor_step=self.cfg.schedule.variable_step_batch
        )
        out = self.model(input_, step, condition=condition).sample

        base_loss = loss_fun(out, target, reduction="none").mean(dim=(1, 2))
        reference_loss = loss_fun(input_, target, reduction="none").mean(dim=(1, 2))
        loss = base_loss / reference_loss**self.cfg.loss.norm_power

        if self.is_training:
            loss.mean().backward()
            flashy.distrib.sync_model(self.model)
            self.optimizer.step()
            self.optimizer.zero_grad()
        metrics = {
            "loss": loss.mean(),
            "normed_loss": (base_loss / reference_loss).mean(),
        }
        metrics.update(
            self.per_stage(
                {"loss": loss, "normed_loss": base_loss / reference_loss}, step
            )
        )
        metrics.update({"std_in": input_.std(), "std_out": out.std()})
        return metrics

    def run_epoch(self):
        # reset random seed at the beginning of the epoch
        self.rng = torch.Generator()
        self.rng.manual_seed(1234 + self.epoch)
        self.per_stage = PerStageMetrics(
            self.schedule.num_steps, self.cfg.metrics.num_stage
        )
        # run epoch
        super().run_epoch()

    def evaluate(self):
        """Evaluate stage.
        Runs audio reconstruction evaluation.
        """
        self.model.eval()
        evaluate_stage_name = f"{self.current_stage}"
        loader = self.dataloaders["evaluate"]
        updates = len(loader)
        lp = self.log_progress(
            f"{evaluate_stage_name} estimate",
            loader,
            total=updates,
            updates=self.log_updates,
        )

        metrics = {}
        n = 1
        for idx, batch in enumerate(lp):
            x = batch.to(self.device)
            with torch.no_grad():
                y_pred = self.regenerate(x)

            y_pred = y_pred.cpu()
            y = batch.cpu()  # should already be on CPU but just in case
            rvm = self.rvm(y_pred, y)
            lp.update(**rvm)
            if len(metrics) == 0:
                metrics = rvm
            else:
                for key in rvm.keys():
                    metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
        metrics = flashy.distrib.average_metrics(metrics)
        return metrics

    @torch.no_grad()
    def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
        """Regenerate the given waveform."""
        condition = self.get_condition(wav)
        initial = self.schedule.get_initial_noise(
            self.data_processor.process_data(wav)
        )  # sampling rate changes.
        result = self.schedule.generate_subsampled(
            self.model, initial=initial, condition=condition, step_list=step_list
        )
        result = self.data_processor.inverse_process(result)
        return result

    def generate(self):
        """Generate stage."""
        sample_manager = SampleManager(self.xp)
        self.model.eval()
        generate_stage_name = f"{self.current_stage}"

        loader = self.dataloaders["generate"]
        updates = len(loader)
        lp = self.log_progress(
            generate_stage_name, loader, total=updates, updates=self.log_updates
        )

        for batch in lp:
            reference, _ = batch
            reference = reference.to(self.device)
            estimate = self.regenerate(reference)
            reference = reference.cpu()
            estimate = estimate.cpu()
            sample_manager.add_samples(
                estimate, self.epoch, ground_truth_wavs=reference
            )
        flashy.distrib.barrier()
