# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import time
import typing as tp
import warnings

import flashy
import torch
import torch.nn.functional as F

from .. import models
from ..modules.conditioners import SegmentWithAttributes
from . import builders
from .valle_nar import ValleNARSolver
from ..models.encodec import MReQ

warnings.simplefilter("ignore", UserWarning)


class HalleNARSolver(ValleNARSolver):
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SPEECH

    def build_model(self) -> None:
        """Instantiate models and optimizer."""
        # we can potentially not use all quantizers with which the EnCodec model was trained
        # (e.g. we trained the model with quantizers dropout)
        self.compression_model: HierarchicalEncodecModel5 = CompressionHierarchicalSolver5.wrapped_model_from_checkpoint(
            self.cfg, self.cfg.compression_model_checkpoint, device=self.device
        )
        assert self.compression_model.sample_rate == self.cfg.sample_rate, (
            f"Compression model sample rate is {self.compression_model.sample_rate} but "
            f"Solver sample rate is {self.cfg.sample_rate}."
        )
        # ensure we have matching configuration between LM and compression model
        assert self.cfg.transformer_lm.card == self.compression_model.sub_cardinality, (
            "Cardinalities of the LM and compression model don't match: ",
            f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
            f"compression model cardinality is {self.compression_model.sub_cardinality}",
        )
        assert self.cfg.transformer_lm.n_q == self.compression_model.sub_num_codebooks, (
            "Number of quantizers of the LM and compression model don't match: ",
            f"LM has {self.cfg.transformer_lm.n_q} quantizers vs ",
            f"compression model has {self.compression_model.sub_num_codebooks}",
        )
        self.logger.info(
            "Compression model has %d codebooks with %d cardinality, and a framerate of %d",
            self.compression_model.sub_num_codebooks,
            self.compression_model.sub_cardinality,
            self.compression_model.frame_rate,
        )
        # instantiate LM model
        self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(
            self.device
        )
        if self.cfg.init_from_valle is not None:
            _checkpoint_path = checkpoint.resolve_checkpoint_path(
                str(self.cfg.init_from_valle), use_fsdp=False
            )
            self.logger.info(f"Loading VALL-E checkpoint from {_checkpoint_path}.")
            state = load_lm_model_ckpt(_checkpoint_path)["best_state"]
            new_state = self.model.state_dict()
            for k, v in state.items():
                # only use encoder & deocoder
                if k in new_state:
                    new_state[k] = v
                    self.logger.debug(f"Key {k} loaded from the VALL-E checkpoint.")
                else:
                    self.logger.debug(f"Key {k} not found in the model state dict.")
            self.model.load_state_dict(new_state)
            self.logger.info("VALL-E checkpoint loaded.")

        if self.cfg.fsdp.use:
            assert not self.cfg.autocast, "Cannot use autocast with fsdp"
            self.model = self.wrap_with_fsdp(self.model)
        self.register_ema("model")
        # initialize optimization
        self.optimizer = builders.get_optimizer(
            builders.get_optim_parameter_groups(self.model), self.cfg.optim
        )
        self.lr_scheduler = builders.get_lr_scheduler(
            self.optimizer, self.cfg.schedule, self.total_updates
        )
        self.register_stateful("model", "optimizer", "lr_scheduler")
        self.register_best_state("model")
        self.autocast_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[
            self.cfg.autocast_dtype
        ]
        self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
        if self.cfg.fsdp.use:
            need_scaler = self.cfg.fsdp.param_dtype == "float16"
        else:
            need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
        if need_scaler:
            if self.cfg.fsdp.use:
                from torch.distributed.fsdp.sharded_grad_scaler import \
                    ShardedGradScaler

                self.scaler = ShardedGradScaler()  # type: ignore
            else:
                self.scaler = torch.cuda.amp.GradScaler()
            self.register_stateful("scaler")

    def _prepare_tokens_and_attributes(  # type: ignore
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False,
    ) -> tp.Tuple[
        int,
        tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
        tp.List[int],
    ]:
        """Prepare input batchs for language model training.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
                and corresponding metadata as SegmentWithAttributes (with B items).
            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
        Returns:
            Condition tensors (dict[str, any]): Preprocessed condition attributes.
            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
                with B the batch size, K the number of codebooks, T_s the token timesteps.
            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
        """
        if self.model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that."
            )
        audio, infos = batch
        audio = audio.to(self.device)
        assert audio.size(0) == len(infos), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(infos)})",
        )
        # prepare attributes
        attributes = []
        for info in infos:
            attr = info.to_condition_attributes()
            attr.wav = {}  # delete wav condition
            attributes.append(attr)
        tokenized = self.model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        target_stage = int(
            torch.randint(
                low=1,
                high=self.compression_model.sub_num_codebooks,
                size=(1,),
                generator=self.rng,
            ).item()
        )
        n_q_cumsum = np.cumsum(self.compression_model.sub_num_codebooks_list)
        n_q_target = np.where(target_stage < n_q_cumsum)[0][0]

        long_token_hop_length = self.compression_model.hop_lengths[0]
        with torch.no_grad():
            old_audio_length = audio.shape[-1]
            new_audio_length = math.ceil(old_audio_length / long_token_hop_length) * long_token_hop_length
            audio = F.pad(audio, (0, new_audio_length - old_audio_length), value=0)
            _, codes_pre_outputs, codes_post_outputs, scale = self.compression_model.encode(
                audio, main_code_only=False
            )
            pre_tokens = torch.cat(codes_pre_outputs[n_q_target:], dim=1)
            post_tokens = torch.cat(codes_post_outputs[:n_q_target], dim=1)
            assert scale is None, "Scaled compression model not supported with LM."
        text_tokens, text_mask = tokenized["text"]

        # create a padding mask to hold valid vs invalid positions
        pre_mask = torch.ones_like(
            pre_tokens, dtype=torch.bool, device=pre_tokens.device
        )
        post_mask = torch.ones_like(
            post_tokens, dtype=torch.bool, device=post_tokens.device
        )
        token_frame_rate = self.compression_model.frame_rate
        token_hop_length = (
            self.compression_model.sample_rate // self.compression_model.frame_rate
        )
        B, _, _ = pre_tokens.shape
        prompt_lengths = []
        for i in range(B):
            n_samples = infos[i].n_frames
            assert infos[i].sample_rate == self.compression_model.sample_rate
            # take the last token generated from actual audio frames (non-padded audio)
            valid_tokens = math.ceil(n_samples / token_hop_length)
            pre_tokens[i, :, valid_tokens:] = self.model.audio_pad_token_id
            post_tokens[i, :, valid_tokens:] = self.model.audio_pad_token_id
            pre_mask[i, :, valid_tokens:] = 0
            post_mask[i, :, valid_tokens:] = 0

            # calc a mask for prompting
            min_prompt_length = int(valid_tokens * self.cfg.prompt_cond.min_ratio)
            max_prompt_length = int(valid_tokens * self.cfg.prompt_cond.max_ratio)
            prompt_length = torch.randint(
                min_prompt_length, max_prompt_length, (1,), generator=self.rng
            ).item()
            max_prompt_length = math.ceil(
                self.cfg.prompt_cond.max_length * token_frame_rate
            )
            prompt_length = min(prompt_length, max_prompt_length)
            pre_mask[i, :, :prompt_length] = 0
            post_mask[i, :, :prompt_length] = 0
            prompt_lengths.append(prompt_length)

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")
        return (
            target_stage,
            (text_tokens, pre_tokens, post_tokens),
            (text_mask, pre_mask, post_mask),
            prompt_lengths,
        )

    def run_step(
        self,
        idx: int,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        metrics: dict,
    ) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == "cuda"
        target_stage, tokens, masks, prompt_lengths = self._prepare_tokens_and_attributes(
            batch, check_synchronization_points
        )

        self.deadlock_detect.update("tokens_and_conditions")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with self.autocast:
            model_output = self.model.compute_predictions(
                tokens, target_stage, prompt_lengths
            )
            _, s_token_pre, s_token_post = tokens
            _, s_mask_pre, _ = masks
            _, K_post, _ = s_token_post.shape
            # audio ce
            s_logits = model_output.logits
            s_mask = s_mask_pre[:, target_stage-K_post : target_stage-K_post+1, :]
            s_target_token = s_token_pre[:, target_stage-K_post: target_stage-K_post+1, :]
            audio_ce, _ = self._compute_cross_entropy(s_logits, s_target_token, s_mask)
            loss = audio_ce
        self.deadlock_detect.update("loss")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        metrics["audio_ce"] = audio_ce
        metrics["audio_ppl"] = torch.exp(audio_ce)
        metrics["ce"] = audio_ce
        metrics["ppl"] = torch.exp(audio_ce)
        metrics[f"audio_ce_q{target_stage+1}"] = audio_ce
        metrics[f"audio_ppl_q{target_stage + 1}"] = torch.exp(audio_ce)

        if self.is_training:
            metrics["lr"] = self.optimizer.param_groups[0]["lr"]

            skip_update = torch.tensor([0], dtype=torch.float, device=self.device)
            for key, value in metrics.items():
                if (isinstance(value, torch.Tensor) is True) and (
                    not value.isfinite().all()
                ):
                    self.logger.warning(
                        f"Value of {key} is not finite. worldsize: {flashy.distrib.world_size()}, rank: {flashy.distrib.rank()}"
                    )
                    skip_update += 1
            self.deadlock_detect.update("update_check")
            flashy.distrib.average_tensors(skip_update)
            if skip_update.item() > 0:
                self.logger.warning(
                    "skip update because of non-finite values in the metrics."
                )
                metrics = {}
                torch.cuda.empty_cache()
            else:
                if self.scaler is not None:
                    loss = self.scaler.scale(loss)
                self.deadlock_detect.update("scale")
                if self.cfg.fsdp.use:
                    loss.backward()
                    flashy.distrib.average_tensors(self.model.buffers())
                elif self.cfg.optim.eager_sync:
                    with flashy.distrib.eager_sync_model(self.model):
                        loss.backward()
                else:
                    # this should always be slower but can be useful
                    # for weird use cases like multiple backwards.
                    loss.backward()
                    flashy.distrib.sync_model(self.model)
                self.deadlock_detect.update("backward")
                if self.scaler is not None:
                    self.scaler.unscale_(self.optimizer)
                if self.cfg.optim.max_norm:
                    if self.cfg.fsdp.use:
                        metrics["grad_norm"] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                    else:
                        metrics["grad_norm"] = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.cfg.optim.max_norm
                        )
                if self.scaler is None:
                    self.optimizer.step()
                else:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
                self.deadlock_detect.update("optim")
                if self.scaler is not None:
                    scale = self.scaler.get_scale()
                    metrics["grad_scale"] = scale

        return metrics

    @torch.no_grad()
    def run_generate_step(  # type: ignore
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        gen_duration: float,
        prompt_duration: tp.Optional[float] = None,
        **generation_params,
    ) -> dict:
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})",
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]

        # prepare audio prompt
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        prompt_token_frames = math.ceil(
            prompt_duration * self.compression_model.frame_rate
        )
        print(f"prompt_token_frames: {prompt_token_frames}")

        long_token_hop_length = self.compression_model.hop_lengths[0]
        old_audio_length = audio.shape[-1]
        new_audio_length = math.ceil(old_audio_length / long_token_hop_length) * long_token_hop_length
        audio = F.pad(audio, (0, new_audio_length - old_audio_length), value=0)

        total_gen_len = math.ceil(gen_duration * self.compression_model.sample_rate)
        total_gen_len = math.ceil(total_gen_len / long_token_hop_length) * long_token_hop_length
        with torch.no_grad():
            _, codes_pre_outputs, codes_post_outputs, _ = self.compression_model.encode(
                audio[..., :total_gen_len].to(self.device), main_code_only=False
            )
            audio_tokens = torch.cat(codes_post_outputs, dim=1)[:, :1]  # only use the first codebook
            prompt_tokens = torch.cat(codes_pre_outputs, dim=1)[..., :prompt_token_frames]

        # generate by sampling from the LM
        with self.autocast:
            gen_tokens = self.model.generate(
                self.compression_model,
                prompt_tokens,
                attributes,
                tokens_for_reference=audio_tokens,
                num_samples=None,
                **generation_params,
            )

        # generate audio from tokens
        assert gen_tokens.dim() == 3
        gen_audio = self.compression_model.decode_from_nar(
            gen_tokens, None
        )

        bench_end = time.time()
        gen_outputs = {
            "rtf": (bench_end - bench_start) / gen_duration,
            "ref_audio": audio,
            "gen_audio": gen_audio,
            "gen_tokens": gen_tokens,
            "prompt_audio": audio[
                ..., : int(prompt_duration * self.compression_model.sample_rate)
            ],
            "prompt_tokens": prompt_tokens,
        }
        return gen_outputs
