# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import time
import typing as tp
import warnings
from pathlib import Path

import flashy
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf

from ..data.audio_dataset import AudioDataset
from ..models.builders import get_lm_model
from ..models.loaders import _delete_param, load_lm_model_ckpt
from ..modules.conditioners import (JointEmbedCondition, SegmentWithAttributes,
                                    WavCondition)
from ..utils import checkpoint
from ..utils.samples.manager import SampleManager
from ..utils.utils import collate, get_dataset_from_loader, is_jsonable
from . import builders
from .musicgen import MusicGenSolver

warnings.simplefilter("ignore", UserWarning)


class ValleARSolver(MusicGenSolver):
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SPEECH

    def __init__(self, cfg: OmegaConf, **kwargs):
        super().__init__(cfg, **kwargs)
        self.generation_params = {
            "use_sampling": self.cfg.generate.lm.use_sampling,
            "temp": self.cfg.generate.lm.temp,
            "top_k": self.cfg.generate.lm.top_k,
            "top_p": self.cfg.generate.lm.top_p,
            "repetition_penalty": self.cfg.generate.lm.repetition_penalty,
            "remove_prompts": self.cfg.generate.lm.remove_prompts,
        }
        self.rng = torch.Generator()
        self.rng.manual_seed(self.cfg.seed)

    def _prepare_tokens_and_attributes(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False,
    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
        """Prepare input batchs for language model training.

        Args:
            batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
                and corresponding metadata as SegmentWithAttributes (with B items).
            check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
        Returns:
            Condition tensors (dict[str, any]): Preprocessed condition attributes.
            Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
                with B the batch size, K the number of codebooks, T_s the token timesteps.
            Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
        """
        if self.model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that."
            )
        audio, infos = batch
        audio = audio.to(self.device)
        assert audio.size(0) == len(infos), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(infos)})",
        )
        # prepare attributes
        attributes = []
        for info in infos:
            attr = info.to_condition_attributes()
            attr.wav = {}  # delete wav condition
            attributes.append(attr)
        # this should be zero, because we need text information
        tokenized = self.model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with torch.no_grad():
            audio_tokens, scale = self.compression_model.encode(audio)
            assert scale is None, "Scaled compression model not supported with LM."
        text_tokens, text_mask = tokenized["text"]

        # create a padding mask to hold valid vs invalid positions
        audio_mask = torch.ones_like(
            audio_tokens, dtype=torch.bool, device=audio_tokens.device
        )
        # add sos, eos, pad
        audio_tokens = F.pad(audio_tokens.clone(), (1, 1), value=0)
        audio_mask = F.pad(audio_mask.clone(), (1, 1), value=1)
        # frame_rate = sample_rate // hop_length (= prod(ratios) = 320)
        token_hop_length = (
            self.compression_model.sample_rate // self.compression_model.frame_rate
        )
        B, _, _ = audio_tokens.shape
        for i in range(B):
            n_samples = infos[i].n_frames
            assert infos[i].sample_rate == self.compression_model.sample_rate
            # take the last token generated from actual audio frames (non-padded audio)
            valid_tokens = math.ceil(n_samples / token_hop_length)
            # add sos token
            audio_tokens[i, :, 0] = self.model.audio_sos_token_id
            # add K eos token, for delay condition
            audio_tokens[i, :, valid_tokens + 1 :] = self.model.audio_eos_token_id
            audio_tokens[i, :, valid_tokens + 2 :] = self.model.audio_pad_token_id
            audio_mask[i, :, valid_tokens + 2 :] = 0

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")
        return (text_tokens, audio_tokens), (text_mask, audio_mask)

    def run_step(
        self,
        idx: int,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        metrics: dict,
    ) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == "cuda"
        (
            tokens,
            masks,
        ) = self._prepare_tokens_and_attributes(batch, check_synchronization_points)

        self.deadlock_detect.update("tokens_and_conditions")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        with self.autocast:
            model_output = self.model.compute_predictions(tokens)  # type: ignore
            _, s_token = tokens
            _, s_mask = masks
            # audio ce
            s_logits = model_output.logits
            s_mask = s_mask & model_output.mask
            audio_ce, audio_ce_per_codebook = self._compute_cross_entropy(
                s_logits, s_token, s_mask
            )
            loss = audio_ce
        self.deadlock_detect.update("loss")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        metrics["audio_ce"] = audio_ce
        metrics["audio_ppl"] = torch.exp(audio_ce)
        metrics["ce"] = audio_ce
        metrics["ppl"] = torch.exp(audio_ce)
        for k, ce_q in enumerate(audio_ce_per_codebook):
            metrics[f"audio_ce_q{k + 1}"] = ce_q
            metrics[f"audio_ppl_q{k + 1}"] = torch.exp(ce_q)

        if self.is_training:
            metrics["lr"] = self.optimizer.param_groups[0]["lr"]

            skip_update = torch.tensor([0], dtype=torch.float, device=self.device)
            for key, value in metrics.items():
                if (isinstance(value, torch.Tensor) is True) and (
                    not value.isfinite().all()
                ):
                    self.logger.warning(
                        f"Value of {key} is not finite. worldsize: {flashy.distrib.world_size()}, rank: {flashy.distrib.rank()}"
                    )
                    skip_update += 1
            self.deadlock_detect.update("update_check")
            flashy.distrib.average_tensors(skip_update)
            if skip_update.item() > 0:
                self.logger.warning(
                    "skip update because of non-finite values in the metrics."
                )
                metrics = {}
                torch.cuda.empty_cache()
            else:
                if self.scaler is not None:
                    loss = self.scaler.scale(loss)
                self.deadlock_detect.update("scale")
                if self.cfg.fsdp.use:
                    loss.backward()
                    flashy.distrib.average_tensors(self.model.buffers())
                elif self.cfg.optim.eager_sync:
                    with flashy.distrib.eager_sync_model(self.model):
                        loss.backward()
                else:
                    # this should always be slower but can be useful
                    # for weird use cases like multiple backwards.
                    loss.backward()
                    flashy.distrib.sync_model(self.model)
                self.deadlock_detect.update("backward")
                if self.scaler is not None:
                    self.scaler.unscale_(self.optimizer)
                if self.cfg.optim.max_norm:
                    if self.cfg.fsdp.use:
                        metrics["grad_norm"] = self.model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                    else:
                        metrics["grad_norm"] = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.cfg.optim.max_norm
                        )
                if self.scaler is None:
                    self.optimizer.step()
                else:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                if self.lr_scheduler:
                    self.lr_scheduler.step()
                self.optimizer.zero_grad()
                self.deadlock_detect.update("optim")
                if self.scaler is not None:
                    scale = self.scaler.get_scale()
                    metrics["grad_scale"] = scale

        return metrics

    def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
        """Common logic for train and valid stages."""
        self.model.train(self.is_training)

        loader = self.dataloaders[dataset_split]
        # get a different order for distributed training, otherwise this will get ignored
        if flashy.distrib.world_size() > 1 and isinstance(
            loader.sampler, torch.utils.data.distributed.DistributedSampler
        ):
            loader.sampler.set_epoch(self.epoch)
        updates_per_epoch = (
            self.train_updates_per_epoch if self.is_training else len(loader)
        )
        if self.cfg.benchmark_no_load:
            self.logger.warning("Fake loading for benchmarking: re-using first batch")
            batch = next(iter(loader))
            loader = [batch] * updates_per_epoch  # type: ignore
        lp = self.log_progress(
            self.current_stage,
            loader,
            total=updates_per_epoch,
            updates=self.log_updates,
        )
        average = flashy.averager()  # epoch wise average
        instant_average = flashy.averager()  # average between two logging
        metrics: dict = {}

        with self.profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
            for idx, batch in enumerate(lp):
                self.deadlock_detect.update("batch")
                if idx >= updates_per_epoch:
                    break
                metrics = {}
                metrics = self.run_step(idx, batch, metrics)
                self.deadlock_detect.update("step")
                # run EMA step
                if (
                    self.ema is not None
                    and self.is_training
                    and (idx + 1) % self.cfg.optim.ema.updates == 0
                    and len(metrics) > 0  # if the batch is not skipped
                ):
                    self.logger.debug("EMA model step")
                    self.ema.step()
                self.deadlock_detect.update("ema")
                self.profiler.step()
                instant_metrics = instant_average(metrics)
                if lp.update(**instant_metrics):
                    instant_average = (
                        flashy.averager()
                    )  # reset averager between two logging
                metrics = average(metrics)  # epoch wise average
                self.deadlock_detect.update("end_batch")

        metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
        return metrics

    @staticmethod
    def _postprocess_codes(
        gen_tokens: torch.Tensor,
        sos_token: int,
        eos_token: int,
        pad_token: tp.Union[int, tp.List[int]],
    ) -> tp.Tuple[torch.Tensor, tp.List[int]]:
        B, _, T = gen_tokens.shape
        processed_tokens = torch.zeros_like(gen_tokens)

        if type(pad_token) is int:
            processed_tokens += pad_token
        elif type(pad_token) is list:
            assert len(pad_token) >= processed_tokens.shape[1]
            for k in range(processed_tokens.shape[1]):
                processed_tokens[:, k] += pad_token[k]

        valid_lengths = []
        for b in range(B):
            token = gen_tokens[b, 0]
            sos_idx = int((token == sos_token).nonzero(as_tuple=True)[0][-1])

            if eos_token in token:
                eos_idx = int((token == eos_token).nonzero(as_tuple=True)[0][0])
            else:
                eos_idx = int(T)

            valid_length = eos_idx - sos_idx - 1
            valid_lengths.append(valid_length)
            processed_tokens[b, :, :valid_length] = gen_tokens[
                b, :, sos_idx + 1 : eos_idx
            ]

        return processed_tokens[:, :, : max(valid_lengths)], valid_lengths

    @staticmethod
    def _postprocess_audios(
        gen_audios: torch.Tensor,
        valid_lengths: tp.List[int],
        compression_frame_rate: int,
        compression_sample_rate: int,
        delta: int = 240,
    ) -> torch.Tensor:
        processed_audios = []
        assert gen_audios.shape[1] == 1
        for b, valid_length in enumerate(valid_lengths):
            valid_length = (
                math.floor(
                    valid_length / compression_frame_rate * compression_sample_rate
                )
                + delta
            )
            processed_audios.append(gen_audios[b, 0, :valid_length])

        return collate(processed_audios, dim=0)[0]

    @torch.no_grad()
    def run_generate_step(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        gen_duration: float,
        prompt_duration: tp.Optional[float] = None,
        **generation_params,
    ) -> dict:
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})",
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]

        # prepare audio prompt
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
        prompt_audio = audio[..., :prompt_audio_frames]

        num_samples = None
        prompt_audio = prompt_audio.to(self.device)
        prompt_tokens, scale = self.compression_model.encode(prompt_audio)
        assert (
            scale is None
        ), "Compression model in MusicGen should not require rescaling."

        # generate by sampling from the LM
        with self.autocast:
            total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
            gen_tokens = self.model.generate(
                prompt_tokens,
                attributes,
                max_gen_len=total_gen_len,
                num_samples=num_samples,
                **generation_params,
            )

        # generate audio from tokens
        assert gen_tokens.dim() == 3
        gen_tokens, valid_lengths = SpeechGenSolver._postprocess_codes(
            gen_tokens,
            self.model.audio_sos_token_id,
            self.model.audio_eos_token_id,
            self.compression_model.silent_tokens,
        )
        gen_audio = self.compression_model.decode(gen_tokens, None)
        gen_audio = SpeechGenSolver._postprocess_audios(
            gen_audio,
            valid_lengths,
            self.compression_model.frame_rate,
            self.compression_model.sample_rate,
        )

        bench_end = time.time()
        gen_outputs = {
            "rtf": (bench_end - bench_start) / gen_duration,
            "ref_audio": audio,
            "gen_audio": gen_audio,
            "gen_tokens": gen_tokens,
            "prompt_audio": prompt_audio,
            "prompt_tokens": prompt_tokens,
        }
        return gen_outputs

    def generate_audio(self) -> dict:
        """Audio generation stage."""
        generate_stage_name = f"{self.current_stage}"
        sample_manager = SampleManager(self.xp)
        self.logger.info(f"Generating samples in {sample_manager.base_folder}")
        loader = self.dataloaders["generate"]
        updates = len(loader)
        lp = self.log_progress(
            generate_stage_name, loader, total=updates, updates=self.log_updates
        )

        dataset = get_dataset_from_loader(loader)
        assert isinstance(dataset, AudioDataset)
        target_duration = self.cfg.generate.lm.gen_duration
        prompt_duration = self.cfg.generate.lm.prompt_duration
        if target_duration is None:
            raise ValueError("Target duration must be provided for audio generation.")
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
            hydrated_conditions = []
            for sample in [x.to_condition_attributes() for x in meta]:
                cond_dict = {}
                for cond_type in sample.__annotations__.keys():
                    for cond_key, cond_val in getattr(sample, cond_type).items():
                        if (
                            cond_key
                            not in self.model.condition_provider.conditioners.keys()
                        ):
                            continue
                        if is_jsonable(cond_val):
                            cond_dict[cond_key] = cond_val
                        elif isinstance(cond_val, WavCondition):
                            cond_dict[cond_key] = cond_val.path
                        elif isinstance(cond_val, JointEmbedCondition):
                            cond_dict[cond_key] = (
                                cond_val.text
                            )  # only support text at inference for now
                        else:
                            # if we reached this point, it is not clear how to log the condition
                            # so we just log the type.
                            cond_dict[cond_key] = str(type(cond_val))
                            continue
                hydrated_conditions.append(cond_dict)
            return hydrated_conditions

        metrics: dict = {}
        average = flashy.averager()
        for batch in lp:
            audio, meta = batch
            # metadata for sample manager
            hydrated_conditions = get_hydrated_conditions(meta)
            if self.cfg.generate.lm.unprompted_samples:
                raise NotImplementedError("Unprompted samples are not supported.")

            if self.cfg.generate.lm.prompted_samples:
                gen_outputs = self.run_generate_step(
                    batch,
                    gen_duration=target_duration,
                    prompt_duration=prompt_duration,
                    **self.generation_params,
                )
                gen_audio = gen_outputs["gen_audio"].cpu()
                prompt_audio = gen_outputs["prompt_audio"].cpu()
                sample_manager.add_samples(
                    gen_audio,
                    self.epoch,
                    hydrated_conditions,
                    prompt_wavs=prompt_audio,
                    ground_truth_wavs=audio,
                    generation_args=self.generation_params,
                )
                metrics["rtf"] = gen_outputs["rtf"]
            metrics = average(metrics)

        flashy.distrib.barrier()
        return metrics

    def evaluate_audio_generation(self) -> dict:
        """Evaluate audio generation with off-the-shelf metrics."""
        # TODO: Implement audio generation evaluation
        return {}

    @staticmethod
    def model_from_checkpoint(
        checkpoint_path: tp.Union[Path, str],
        device: tp.Union[torch.device, str] = "cpu",
    ):
        _checkpoint_path = checkpoint.resolve_checkpoint_path(
            str(checkpoint_path), use_fsdp=False
        )
        pkg = load_lm_model_ckpt(_checkpoint_path)
        cfg = OmegaConf.create(pkg["xp.cfg"])
        cfg.device = str(device)
        if cfg.device == "cpu":
            cfg.dtype = "float32"
        else:
            cfg.dtype = "float16"
        _delete_param(cfg, "conditioners.self_wav.chroma_stem.cache_path")
        _delete_param(cfg, "conditioners.args.merge_text_conditions_p")
        _delete_param(cfg, "conditioners.args.drop_desc_p")
        model = get_lm_model(cfg)
        if pkg["fsdp_best_state"]:
            model.load_state_dict(pkg["fsdp_best_state"]["model"])
        else:
            model.load_state_dict(pkg["best_state"]["model"])
        model.eval()
        model.cfg = cfg
        return model
