# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import time
import typing as tp
import warnings

import flashy
import torch

from ..modules.conditioners import SegmentWithAttributes
from . import builders
from .valle_ar import ValleARSolver

warnings.simplefilter("ignore", UserWarning)


class ValleNARSolver(ValleARSolver):
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SPEECH

    def _prepare_tokens_and_attributes(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False,
    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
        """Prepare input batchs for language model training.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
                and corresponding metadata as SegmentWithAttributes (with B items).
            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
        Returns:
            Condition tensors (dict[str, any]): Preprocessed condition attributes.
            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
                with B the batch size, K the number of codebooks, T_s the token timesteps.
            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
        """
        if self.model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that."
            )
        audio, infos = batch
        audio = audio.to(self.device)
        assert audio.size(0) == len(infos), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(infos)})",
        )
        # prepare attributes
        attributes = []
        for info in infos:
            attr = info.to_condition_attributes()
            attr.wav = {}  # delete wav condition
            attributes.append(attr)
        tokenized = self.model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with torch.no_grad():
            audio_tokens, scale = self.compression_model.encode(audio)
            assert scale is None, "Scaled compression model not supported with LM."
        text_tokens, text_mask = tokenized["text"]

        # create a padding mask to hold valid vs invalid positions
        audio_mask = torch.ones_like(
            audio_tokens, dtype=torch.bool, device=audio_tokens.device
        )
        token_frame_rate = self.compression_model.frame_rate
        token_hop_length = (
            self.compression_model.sample_rate // self.compression_model.frame_rate
        )
        B, _, _ = audio_tokens.shape
        prompt_lengths = []
        for i in range(B):
            n_samples = infos[i].n_frames
            assert infos[i].sample_rate == self.compression_model.sample_rate
            # take the last token generated from actual audio frames (non-padded audio)
            valid_tokens = math.ceil(n_samples / token_hop_length)
            audio_tokens[i, :, valid_tokens:] = self.model.audio_pad_token_id
            audio_mask[i, :, valid_tokens:] = 0

            # calc a mask for prompting
            min_prompt_length = int(valid_tokens * self.cfg.prompt_cond.min_ratio)
            max_prompt_length = int(valid_tokens * self.cfg.prompt_cond.max_ratio)
            prompt_length = torch.randint(
                min_prompt_length, max_prompt_length, (1,), generator=self.rng
            ).item()
            max_prompt_length = math.ceil(
                self.cfg.prompt_cond.max_length * token_frame_rate
            )
            prompt_length = min(prompt_length, max_prompt_length)
            audio_mask[i, :, :prompt_length] = 0
            prompt_lengths.append(prompt_length)

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")
        return (text_tokens, audio_tokens), (text_mask, audio_mask), prompt_lengths

    def run_step(
        self,
        idx: int,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        metrics: dict,
    ) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == "cuda"
        (tokens, masks, prompt_lengths) = self._prepare_tokens_and_attributes(
            batch, check_synchronization_points
        )

        self.deadlock_detect.update("tokens_and_conditions")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with self.autocast:
            target_stage = int(
                torch.randint(
                    low=1,
                    high=self.compression_model.num_codebooks,
                    size=(1,),
                    generator=self.rng,
                ).item()
            )
            t_token, s_token = tokens
            model_output = self.model.compute_predictions(
                (t_token, s_token), target_stage, prompt_lengths
            )  # type: ignore
            # audio ce
            s_logits = model_output.logits
            s_mask = masks[1][:, target_stage : target_stage + 1, :]
            s_target_token = s_token[:, target_stage : target_stage + 1, :]
            audio_ce, _ = self._compute_cross_entropy(s_logits, s_target_token, s_mask)
            loss = audio_ce
        self.deadlock_detect.update("loss")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        metrics["audio_ce"] = audio_ce
        metrics["audio_ppl"] = torch.exp(audio_ce)
        metrics["ce"] = audio_ce
        metrics["ppl"] = torch.exp(audio_ce)
        metrics[f"audio_ce_q{target_stage+1}"] = audio_ce
        metrics[f"audio_ppl_q{target_stage + 1}"] = torch.exp(audio_ce)

        if self.is_training:
            metrics["lr"] = self.optimizer.param_groups[0]["lr"]

            skip_update = torch.tensor([0], dtype=torch.float, device=self.device)
            for key, value in metrics.items():
                if (isinstance(value, torch.Tensor) is True) and (
                    not value.isfinite().all()
                ):
                    self.logger.warning(
                        f"Value of {key} is not finite. worldsize: {flashy.distrib.world_size()}, rank: {flashy.distrib.rank()}"
                    )
                    skip_update += 1
            self.deadlock_detect.update("update_check")
            flashy.distrib.average_tensors(skip_update)
            if skip_update.item() > 0:
                self.logger.warning(
                    "skip update because of non-finite values in the metrics."
                )
                metrics = {}
                torch.cuda.empty_cache()
            else:
                if self.scaler is not None:
                    loss = self.scaler.scale(loss)
                self.deadlock_detect.update("scale")
                if self.cfg.fsdp.use:
                    loss.backward()
                    flashy.distrib.average_tensors(self.model.buffers())
                elif self.cfg.optim.eager_sync:
                    with flashy.distrib.eager_sync_model(self.model):
                        loss.backward()
                else:
                    # this should always be slower but can be useful
                    # for weird use cases like multiple backwards.
                    loss.backward()
                    flashy.distrib.sync_model(self.model)
                self.deadlock_detect.update("backward")
                if self.scaler is not None:
                    self.scaler.unscale_(self.optimizer)
                if self.cfg.optim.max_norm:
                    if self.cfg.fsdp.use:
                        metrics["grad_norm"] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                    else:
                        metrics["grad_norm"] = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.cfg.optim.max_norm
                        )
                if self.scaler is None:
                    self.optimizer.step()
                else:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
                self.deadlock_detect.update("optim")
                if self.scaler is not None:
                    scale = self.scaler.get_scale()
                    metrics["grad_scale"] = scale

        return metrics

    @torch.no_grad()
    def run_generate_step(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        gen_duration: float,
        prompt_duration: tp.Optional[float] = None,
        **generation_params,
    ) -> dict:
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})",
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]

        # prepare audio prompt
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        prompt_token_frames = math.ceil(
            prompt_duration * self.compression_model.frame_rate
        )

        total_gen_len = math.ceil(gen_duration * self.compression_model.sample_rate)
        with torch.no_grad():
            audio_tokens, _ = self.compression_model.encode(
                audio[..., :total_gen_len].to(self.device)
            )
        prompt_tokens = audio_tokens[..., :prompt_token_frames]

        # generate by sampling from the LM
        with self.autocast:
            gen_tokens = self.model.generate(
                prompt_tokens,
                attributes,
                tokens_for_reference=audio_tokens,
                num_samples=None,
                **generation_params,
            )

        # generate audio from tokens
        assert gen_tokens.dim() == 3
        gen_audio = self.compression_model.decode(gen_tokens, None)

        bench_end = time.time()
        gen_outputs = {
            "rtf": (bench_end - bench_start) / gen_duration,
            "ref_audio": audio,
            "gen_audio": gen_audio,
            "gen_tokens": gen_tokens,
            "prompt_audio": audio[
                ..., : int(prompt_duration * self.compression_model.sample_rate)
            ],
            "prompt_tokens": prompt_tokens,
        }
        return gen_outputs
