# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import time
import typing as tp
import warnings

import flashy
import torch
import torch.nn.functional as F

from ..modules.conditioners import SegmentWithAttributes
from . import builders
from .halle_nar import HalleNARSolver


warnings.simplefilter("ignore", UserWarning)


class Halle2NARSolver(HalleNARSolver):
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SPEECH

    def _prepare_tokens_and_attributes(  # type: ignore
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False,
    ) -> tp.Tuple[
        int,
        tp.Tuple[torch.Tensor, torch.Tensor],
        tp.Tuple[torch.Tensor, torch.Tensor],
        tp.List[int],
    ]:
        """Prepare input batchs for language model training.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
                and corresponding metadata as SegmentWithAttributes (with B items).
            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
        Returns:
            Condition tensors (dict[str, any]): Preprocessed condition attributes.
            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
                with B the batch size, K the number of codebooks, T_s the token timesteps.
            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
        """
        if self.model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that."
            )
        audio, infos = batch
        audio = audio.to(self.device)
        assert audio.size(0) == len(infos), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(infos)})",
        )
        # prepare attributes
        attributes = []
        for info in infos:
            attr = info.to_condition_attributes()
            attr.wav = {}  # delete wav condition
            attributes.append(attr)
        tokenized = self.model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        target_stage = int(
            torch.randint(
                low=1,
                high=self.compression_model.sub_num_codebooks,
                size=(1,),
                generator=self.rng,
            ).item()
        )
        long_token_hop_length = self.compression_model.hop_lengths[0]
        with torch.no_grad():
            old_audio_length = audio.shape[-1]
            new_audio_length = math.ceil(old_audio_length / long_token_hop_length) * long_token_hop_length
            audio = F.pad(audio, (0, new_audio_length - old_audio_length), value=0)
            _, codes_pre_outputs, codes_post_outputs, scale = self.compression_model.encode(
                audio, main_code_only=False
            )
            if self.cfg.pre_post_mode == "pre":
                codes_pre_outputs[0] = codes_post_outputs[0]
                audio_tokens = torch.cat(codes_pre_outputs, dim=1)
            else:
                audio_tokens = torch.cat(codes_post_outputs, dim=1)
            assert scale is None, "Scaled compression model not supported with LM."
        text_tokens, text_mask = tokenized["text"]

        # create a padding mask to hold valid vs invalid positions
        audio_mask = torch.ones_like(
            audio_tokens, dtype=torch.bool, device=audio_tokens.device
        )
        token_frame_rate = self.compression_model.frame_rate
        token_hop_length = (
            self.compression_model.sample_rate // self.compression_model.frame_rate
        )
        B, _, _ = audio_tokens.shape
        prompt_lengths = []
        for i in range(B):
            n_samples = infos[i].n_frames
            assert infos[i].sample_rate == self.compression_model.sample_rate
            # take the last token generated from actual audio frames (non-padded audio)
            valid_tokens = math.ceil(n_samples / token_hop_length)
            audio_tokens[i, :, valid_tokens:] = self.model.audio_pad_token_id
            audio_mask[i, :, valid_tokens:] = 0

            # calc a mask for prompting
            min_prompt_length = int(valid_tokens * self.cfg.prompt_cond.min_ratio)
            max_prompt_length = int(valid_tokens * self.cfg.prompt_cond.max_ratio)
            prompt_length = torch.randint(
                min_prompt_length, max_prompt_length, (1,), generator=self.rng
            ).item()
            max_prompt_length = math.ceil(
                self.cfg.prompt_cond.max_length * token_frame_rate
            )
            prompt_length = min(prompt_length, max_prompt_length)
            audio_mask[i, :, :prompt_length] = 0
            prompt_lengths.append(prompt_length)

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")
        return (
            target_stage,
            (text_tokens, audio_tokens),
            (text_mask, audio_mask),
            prompt_lengths,
        )

    def run_step(
        self,
        idx: int,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        metrics: dict,
    ) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == "cuda"
        target_stage, tokens, masks, prompt_lengths = self._prepare_tokens_and_attributes(
            batch, check_synchronization_points
        )

        self.deadlock_detect.update("tokens_and_conditions")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with self.autocast:
            model_output = self.model.compute_predictions(
                tokens, target_stage, prompt_lengths
            )
            _, s_token = tokens
            # audio ce
            s_logits = model_output.logits
            s_mask = masks[1][:, target_stage : target_stage + 1, :]
            s_target_token = s_token[:, target_stage : target_stage + 1, :]
            audio_ce, _ = self._compute_cross_entropy(s_logits, s_target_token, s_mask)
            loss = audio_ce
        self.deadlock_detect.update("loss")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        metrics["audio_ce"] = audio_ce
        metrics["audio_ppl"] = torch.exp(audio_ce)
        metrics["ce"] = audio_ce
        metrics["ppl"] = torch.exp(audio_ce)
        metrics[f"audio_ce_q{target_stage+1}"] = audio_ce
        metrics[f"audio_ppl_q{target_stage + 1}"] = torch.exp(audio_ce)

        if self.is_training:
            metrics["lr"] = self.optimizer.param_groups[0]["lr"]

            skip_update = torch.tensor([0], dtype=torch.float, device=self.device)
            for key, value in metrics.items():
                if (isinstance(value, torch.Tensor) is True) and (
                    not value.isfinite().all()
                ):
                    self.logger.warning(
                        f"Value of {key} is not finite. worldsize: {flashy.distrib.world_size()}, rank: {flashy.distrib.rank()}"
                    )
                    skip_update += 1
            self.deadlock_detect.update("update_check")
            flashy.distrib.average_tensors(skip_update)
            if skip_update.item() > 0:
                self.logger.warning(
                    "skip update because of non-finite values in the metrics."
                )
                metrics = {}
                torch.cuda.empty_cache()
            else:
                if self.scaler is not None:
                    loss = self.scaler.scale(loss)
                self.deadlock_detect.update("scale")
                if self.cfg.fsdp.use:
                    loss.backward()
                    flashy.distrib.average_tensors(self.model.buffers())
                elif self.cfg.optim.eager_sync:
                    with flashy.distrib.eager_sync_model(self.model):
                        loss.backward()
                else:
                    # this should always be slower but can be useful
                    # for weird use cases like multiple backwards.
                    loss.backward()
                    flashy.distrib.sync_model(self.model)
                self.deadlock_detect.update("backward")
                if self.scaler is not None:
                    self.scaler.unscale_(self.optimizer)
                if self.cfg.optim.max_norm:
                    if self.cfg.fsdp.use:
                        metrics["grad_norm"] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                    else:
                        metrics["grad_norm"] = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.cfg.optim.max_norm
                        )
                if self.scaler is None:
                    self.optimizer.step()
                else:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
                self.deadlock_detect.update("optim")
                if self.scaler is not None:
                    scale = self.scaler.get_scale()
                    metrics["grad_scale"] = scale

        return metrics

    @torch.no_grad()
    def run_generate_step(  # type: ignore
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        gen_duration: float,
        prompt_duration: tp.Optional[float] = None,
        **generation_params,
    ) -> dict:
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})",
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]

        # prepare audio prompt
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        prompt_token_frames = math.ceil(
            prompt_duration * self.compression_model.frame_rate
        )
        print(f"prompt_token_frames: {prompt_token_frames}")

        long_token_hop_length = self.compression_model.hop_lengths[0]
        old_audio_length = audio.shape[-1]
        new_audio_length = math.ceil(old_audio_length / long_token_hop_length) * long_token_hop_length
        audio = F.pad(audio, (0, new_audio_length - old_audio_length), value=0)

        total_gen_len = math.ceil(gen_duration * self.compression_model.sample_rate)
        total_gen_len = math.ceil(total_gen_len / long_token_hop_length) * long_token_hop_length
        with torch.no_grad():
            _, codes_pre_outputs, codes_post_outputs, _ = self.compression_model.encode(
                audio[..., :total_gen_len].to(self.device), main_code_only=False
            )
            if self.cfg.pre_post_mode == "pre":
                codes_pre_outputs[0] = codes_post_outputs[0] 
                audio_tokens = torch.cat(codes_pre_outputs, dim=1)[:, :1]
                prompt_tokens = torch.cat(codes_pre_outputs, dim=1)[..., :prompt_token_frames]
            else:
                audio_tokens = torch.cat(codes_post_outputs, dim=1)[:, :1]  # only use the first codebook
                prompt_tokens = torch.cat(codes_post_outputs, dim=1)[..., :prompt_token_frames]

        # generate by sampling from the LM
        with self.autocast:
            gen_tokens = self.model.generate(
                prompt_tokens,
                attributes,
                tokens_for_reference=audio_tokens,
                num_samples=None,
                **generation_params,
            )

        # generate audio from tokens
        assert gen_tokens.dim() == 3
        if self.cfg.pre_post_mode == "pre":
            gen_tokens = self.compression_model.pre2post_from_nar(
                gen_tokens
            )
            gen_tokens[:, 0] = audio_tokens[:, 0]
        gen_audio = self.compression_model.decode_from_nar(
            gen_tokens, None  # for post tokens
        )

        bench_end = time.time()
        gen_outputs = {
            "rtf": (bench_end - bench_start) / gen_duration,
            "ref_audio": audio,
            "gen_audio": gen_audio,
            "gen_tokens": gen_tokens,
            "prompt_audio": audio[
                ..., : int(prompt_duration * self.compression_model.sample_rate)
            ],
            "prompt_tokens": prompt_tokens,
        }
        return gen_outputs
