# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import math
import time
import typing as tp
import warnings
from pathlib import Path

import flashy
import torch
import torch.nn.functional as F
from omegaconf import OmegaConf
from torch import nn

from .. import models, optim
from ..data.audio_dataset import AudioDataset
from ..models.builders import get_lm_model
from ..models.loaders import _delete_param, load_lm_model_ckpt
from ..modules.conditioners import (JointEmbedCondition, SegmentWithAttributes,
                                    WavCondition)
from ..optim import fsdp
from ..utils import checkpoint
from ..utils.best_state import BestStateDictManager
from ..utils.deadlock import DeadlockDetect
from ..utils.profiler import Profiler
from ..utils.samples.manager import SampleManager
from ..utils.utils import ( get_dataset_from_loader, is_jsonable,
                           model_hash)
from . import builders
from .musicgen import MusicGenSolver
from .valle_ar import SpeechGenSolver

warnings.simplefilter("ignore", UserWarning)


class HalleARSolver(MusicGenSolver):
    DATASET_TYPE: builders.DatasetType = builders.DatasetType.SPEECH

    def __init__(self, cfg: OmegaConf, **kwargs):
        flashy.BaseSolver.__init__(self)
        self.logger.info(
            f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}"
        )
        self.logger.info(f"All XP logs are stored in {self.xp.folder}")
        self.cfg = cfg
        self.device = cfg.device
        self.long_model: nn.Module
        self._continue_best_source_keys = ["best_state", "fsdp_best_state"]
        self._fsdp_modules: tp.List[fsdp.FSDP] = []
        self._ema_sources: nn.ModuleDict = nn.ModuleDict()
        self.ema: tp.Optional[optim.ModuleDictEMA] = None
        self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
        self._log_updates = self.cfg.logging.get("log_updates", 10)
        if self.cfg.logging.log_tensorboard:
            self.init_tensorboard(**self.cfg.get("tensorboard"))
        if self.cfg.logging.log_wandb and self:
            self.init_wandb(**self.cfg.get("wandb"))
        # keep a copy of the best performing state for stateful objects
        # used for evaluation and generation stages
        dtype_best: tp.Optional[torch.dtype] = None
        if self.cfg.fsdp.use:
            dtype_best = getattr(torch, self.cfg.fsdp.param_dtype)  # type: ignore
            assert isinstance(dtype_best, torch.dtype)
        elif self.cfg.autocast:
            dtype_best = getattr(torch, self.cfg.autocast_dtype)  # type: ignore
            assert isinstance(dtype_best, torch.dtype)
        self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
        # Hacky support for keeping a copy of the full best state in rank0.
        self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
        self.register_stateful(
            "best_state", "fsdp_best_state"
        )  # register best_state object to keep it in state_dict
        self._new_best_state: bool = False  # should save a new checkpoint
        # instantiate datasets and appropriate number of updates per epoch
        self.build_dataloaders()
        if self.cfg.execute_only is None:
            assert (
                "train" in self.dataloaders
            ), "The train dataset split must be provided."
            assert (
                "valid" in self.dataloaders
            ), "The valid dataset split must be provided."
        self.train_updates_per_epoch = (
            len(self.dataloaders["train"]) if "train" in self.dataloaders else 0
        )
        if self.cfg.optim.updates_per_epoch:
            self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
        self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
        # instantiate model & exponential moving average on the model
        self.build_model()
        self.logger.info("Long Model hash: %s", model_hash(self.long_model))
        assert (
            "long_model" in self.stateful.sources
        ), "Please register the model to stateful with self.register_stateful('long_model') in build_model."
        self.long_profiler = Profiler(self.long_model, **self.cfg.profiler)
        self.initialize_ema()
        self.register_stateful("ema")
        assert (
            self.ema is None or "ema" in self.stateful.sources
        ), "Please register the ema to stateful with self.register_stateful('ema') in build_model."
        self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
        # basic statistics on the trained model
        long_model_size = (
            sum(p.numel() for p in self.long_model.parameters() if p.requires_grad)
            / 1e6
        )
        # one copy of grad, one copy of momentum, one copy of denominator and model weights.
        # and 4 bytes for each float!
        mem_usage = long_model_size * 4 * 4 / 1000
        self.logger.info("Long Model size: %.2f M params", long_model_size)
        self.logger.info(
            "Base memory usage, with model, grad and optim: %.2f GB", mem_usage
        )
        self.generation_params = {
            "use_sampling": self.cfg.generate.lm.use_sampling,
            "temp": self.cfg.generate.lm.temp,
            "top_k": self.cfg.generate.lm.top_k,
            "top_p": self.cfg.generate.lm.top_p,
            "repetition_penalty": self.cfg.generate.lm.repetition_penalty,
            "remove_prompts": self.cfg.generate.lm.remove_prompts,
        }
        self.rng = torch.Generator()
        self.rng.manual_seed(self.cfg.seed)
        self._best_metric_name: tp.Optional[str] = "ce"

        self._cached_batch_writer = None
        self._cached_batch_loader = None

    def build_model(self) -> None:
        """Instantiate models and optimizer."""
        # we can potentially not use all quantizers with which the EnCodec model was trained
        # (e.g. we trained the model with quantizers dropout)
        self.compression_model: HierarchicalEncodecModel4 = (
            CompressionHierarchicalSolver4.wrapped_model_from_checkpoint(
                self.cfg, self.cfg.compression_model_checkpoint, device=self.device
            )
        )
        assert self.compression_model.sample_rate == self.cfg.sample_rate, (
            f"Compression model sample rate is {self.compression_model.sample_rate} but "
            f"Solver sample rate is {self.cfg.sample_rate}."
        )
        # ensure we have matching configuration between LM and compression model
        assert (
            self.cfg.transformer_lm_hier_long.card
            == self.compression_model.cardinality[0]
        ), (
            "Long cardinalities of the LM and compression model don't match: ",
            f"LM cardinality is {self.cfg.transformer_lm_hier_long.card} vs ",
            f"compression model cardinality is {self.compression_model.cardinality[0]}",
        )
        if (
            self.cfg.transformer_lm_hier_long.n_q
            != self.compression_model.num_codebooks[0]
        ):
            self.logger.info(
                f"Numbers of long codebooks of the LM and compression models don't match: LM number of codebooks is {self.cfg.transformer_lm_hier_long.n_q} vs compression model numer of codebooks is {self.compression_model.num_codebooks[0]}",
            )
            self.logger.info(
                "Changing compression model's number of codebooks to match LM model."
            )
        long_n_q = self.cfg.transformer_lm_hier_long.n_q
        default_num_codebooks = self.compression_model.num_codebooks
        default_num_codebooks[0] = long_n_q
        self.compression_model.set_num_codebooks(default_num_codebooks)
        self.logger.info(
            f"Compression model has {self.compression_model.num_codebooks} codebooks with {self.compression_model.cardinality} cardinality, and a framerate of {self.compression_model.frame_rates}",
        )
        # instantiate LM model
        self.long_model = models.builders.get_lm_model(self.cfg)
        self.long_model.to(self.device)

        if self.cfg.init_from_valle is not None:
            _checkpoint_path = checkpoint.resolve_checkpoint_path(
                str(self.cfg.init_from_valle), use_fsdp=False
            )
            self.logger.info(f"Loading VALL-E checkpoint from {_checkpoint_path}.")
            state = load_lm_model_ckpt(_checkpoint_path)["best_state"]
            new_state = self.long_model.state_dict()
            for k, v in state.items():
                # only use encoder & deocoder
                if k in new_state:
                    new_state[k] = v
                    self.logger.debug(f"Success: Key {k} loaded from the VALL-E checkpoint.")
                else:
                    self.logger.debug(f"Fail: Key {k} not found in the model state dict.")
            self.long_model.load_state_dict(new_state)
            self.logger.info("VALL-E checkpoint loaded.")

        if self.cfg.fsdp.use:
            assert not self.cfg.autocast, "Cannot use autocast with fsdp"
            self.long_model = self.wrap_with_fsdp(self.long_model)
        self.register_ema("long_model")
        # initialize optimization
        self.long_optimizer = builders.get_optimizer(
            builders.get_optim_parameter_groups(self.long_model), self.cfg.optim
        )
        self.long_lr_scheduler = builders.get_lr_scheduler(
            self.long_optimizer, self.cfg.schedule, self.total_updates
        )
        self.register_stateful(
            "long_model",
            "long_optimizer",
            "long_lr_scheduler",
        )
        self.register_best_state("long_model")
        self.autocast_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[
            self.cfg.autocast_dtype
        ]
        self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
        if self.cfg.fsdp.use:
            need_scaler = self.cfg.fsdp.param_dtype == "float16"
        else:
            need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
        if need_scaler:
            if self.cfg.fsdp.use:
                from torch.distributed.fsdp.sharded_grad_scaler import \
                    ShardedGradScaler

                self.scaler = ShardedGradScaler()  # type: ignore
            else:
                self.scaler = torch.cuda.amp.GradScaler()
            self.register_stateful("scaler")

    def load_state_dict(self, state: dict) -> None:
        if "condition_provider_long" in state:
            model_state = state["long_model"]
            condition_provider_state = state.pop("condition_provider_long")
            prefix = "condition_provider_long."
            for key, value in condition_provider_state.items():
                key = prefix + key
                assert key not in model_state
                model_state[key] = value
        if "compression_model" in state:
            # We used to store the `compression_model` state in the checkpoint, however
            # this is in general not needed, as the compression model should always be readable
            # from the original `cfg.compression_model_checkpoint` location.
            compression_model_state = state.pop("compression_model")
            before_hash = model_hash(self.compression_model)
            self.compression_model.load_state_dict(compression_model_state)
            after_hash = model_hash(self.compression_model)
            if before_hash != after_hash:
                raise RuntimeError(
                    "The compression model state inside the checkpoint is different"
                    " from the one obtained from compression_model_checkpoint..."
                    "We do not support altering the compression model inside the LM "
                    "checkpoint as parts of the code, in particular for running eval post-training "
                    "will use the compression_model_checkpoint as the source of truth."
                )

        super().load_state_dict(state)

    def save_checkpoints(self):
        """Save checkpoint, optionally keeping a copy for a given epoch."""
        is_sharded = self.cfg.fsdp.use
        if not flashy.distrib.is_rank_zero() and not is_sharded:
            return
        self.logger.info("Long Model hash: %s", model_hash(self.long_model))
        state = self.state_dict()
        epoch = (
            self.epoch - 1
        )  # pushing metrics will increase the epoch in Flashy, so we do -1 here

        # save minimal state_dict as new checkpoint every X epoch
        if self.cfg.checkpoint.save_every:
            if epoch % self.cfg.checkpoint.save_every == 0:
                minimal_state = state
                if (
                    self.cfg.checkpoint.keep_every_states is not None
                    and len(self.cfg.checkpoint.keep_every_states) > 0
                ):
                    minimal_state = {
                        name: source
                        for name, source in state.items()
                        if name in self.cfg.checkpoint.keep_every_states
                    }
                epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
                checkpoint.save_checkpoint(
                    minimal_state, epoch_checkpoint_path, is_sharded
                )

        # save checkpoint as latest checkpoint
        if self.cfg.checkpoint.save_last:
            last_checkpoint_path = self.checkpoint_path()
            checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)

        # flush any stale checkpoint to reduce disk footprint
        checkpoint.flush_stale_checkpoints(self.checkpoint_path())

    def restore(
        self,
        load_best: bool = False,
        replay_metrics: bool = False,
        ignore_state_keys: tp.List[str] = [],
    ) -> bool:
        """Restore the status of a solver for a given xp.

        Args:
            load_best (bool): if `True`, load the best state from the checkpoint.
            replay_metrics (bool): if `True`, logs all the metrics from past epochs.
            ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
        """
        self.logger.info("Restoring weights and history.")
        restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)

        self.logger.info("Long Model hash: %s", model_hash(self.long_model))

        if replay_metrics and len(self.history) > 0:
            self.logger.info("Replaying past metrics...")
            for epoch, stages in enumerate(self.history):
                for stage_name, metrics in stages.items():
                    # We manually log the metrics summary to the result logger
                    # as we don't want to add them to the pending metrics
                    self.result_logger._log_summary(
                        stage_name,
                        metrics,
                        step=epoch + 1,
                        step_name="epoch",
                        formatter=self.get_formatter(stage_name),
                    )
        return restored_checkpoints is not None

    def _prepare_tokens_and_attributes(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        check_synchronization_points: bool = False,
    ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
        if self.long_model.training:
            warnings.warn(
                "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
                "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
                "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
                "Really sorry about that."
            )
        audio, infos = batch
        audio = audio.to(self.device)
        assert audio.size(0) == len(infos), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(infos)})",
        )
        # prepare attributes
        attributes = []
        for info in infos:
            attr = info.to_condition_attributes()
            attr.wav = {}  # delete wav condition
            attributes.append(attr)
        # this should be zero, because we need text information
        # for long
        long_tokenized = self.long_model.condition_provider.tokenize(attributes)

        # Now we should be synchronization free.
        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        long_token_hop_length = self.compression_model.hop_lengths[0]
        with torch.no_grad():
            old_audio_length = audio.shape[-1]
            new_audio_length = old_audio_length // long_token_hop_length * long_token_hop_length + int(
                old_audio_length % long_token_hop_length != 0
            ) * long_token_hop_length
            audio = F.pad(audio, (0, new_audio_length - old_audio_length), value=0)
            tokens, scale = self.compression_model.encode(audio, stage=0)
            long_tokens = tokens[0]
            assert scale is None, "Scaled compression model not supported with LM."

        text_tokens, text_mask = long_tokenized["text"]

        # create a padding mask to hold valid vs invalid positions
        long_token_mask = torch.ones_like(
            long_tokens, dtype=torch.bool, device=long_tokens.device
        )
        # add sos, eos, pad
        long_tokens = F.pad(long_tokens.clone(), (1, 1), value=0)
        long_token_mask = F.pad(long_token_mask.clone(), (1, 1), value=1)
        # frame_rate = sample_rate // hop_length (= prod(ratios) = 320)
        B, _, _ = long_tokens.shape
        for i in range(B):
            n_samples = infos[i].n_frames
            # take the last token generated from actual audio frames (non-padded audio)
            valid_tokens = math.ceil(n_samples / long_token_hop_length)
            # add sos token
            long_tokens[i, :, 0] = self.long_model.audio_sos_token_id
            # add K eos token, for delay condition
            long_tokens[i, :, valid_tokens + 1 :] = self.long_model.audio_eos_token_id
            long_tokens[i, :, valid_tokens + 2 :] = self.long_model.audio_pad_token_id
            long_token_mask[i, :, valid_tokens + 2 :] = 0

        if self.device == "cuda" and check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")
        return (
            (text_tokens, long_tokens),
            (text_mask, long_token_mask),
        )

    def run_step(
        self,
        idx: int,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        metrics: dict,
    ) -> dict:
        """Perform one training or valid step on a given batch."""
        check_synchronization_points = idx == 1 and self.device == "cuda"
        (
            long_lm_tokens,
            long_lm_masks,
        ) = self._prepare_tokens_and_attributes(batch, check_synchronization_points)

        self.deadlock_detect.update("tokens_and_conditions")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("warn")

        # for long
        _, l_token = long_lm_tokens
        _, l_mask = long_lm_masks

        long_metrics = {}
        with self.autocast:
            l_output = self.long_model.compute_predictions(
                long_lm_tokens,
            )  # type: ignore
            # long ce
            l_logits = l_output.logits
            l_mask = l_mask & l_output.mask
            long_ce, long_ce_per_codebook = self._compute_cross_entropy(
                l_logits,
                l_token,
                l_mask,
            )
            loss = long_ce

        self.deadlock_detect.update("loss")

        if check_synchronization_points:
            torch.cuda.set_sync_debug_mode("default")

        long_metrics["long_ce"] = long_ce
        long_metrics["long_ppl"] = torch.exp(long_ce)
        for k, ce_q in enumerate(long_ce_per_codebook):
            long_metrics[f"long_ce_q{k + 1}"] = ce_q
            long_metrics[f"long_ppl_q{k + 1}"] = torch.exp(ce_q)

        if self.is_training:
            long_metrics["long_lr"] = self.long_optimizer.param_groups[0]["lr"]

            skip_update = torch.tensor([0], dtype=torch.float, device=self.device)
            for key, value in long_metrics.items():
                if (isinstance(value, torch.Tensor) is True) and (
                    not value.isfinite().all()
                ):
                    self.logger.warning(
                        f"Value of {key} is not finite. worldsize: {flashy.distrib.world_size()}, rank: {flashy.distrib.rank()}"
                    )
                    skip_update += 1
            self.deadlock_detect.update("update_check")
            flashy.distrib.average_tensors(skip_update)
            if skip_update.item() > 0:
                self.logger.warning(
                    "long: skip update because of non-finite values in the metrics."
                )
                long_metrics = {}
                torch.cuda.empty_cache()
            else:
                if self.scaler is not None:
                    loss = self.scaler.scale(loss)
                self.deadlock_detect.update("scale")
                if self.cfg.fsdp.use:
                    loss.backward()
                    flashy.distrib.average_tensors(self.long_model.buffers())
                elif self.cfg.optim.eager_sync:
                    with flashy.distrib.eager_sync_model(self.long_model):
                        loss.backward()
                else:
                    # this should always be slower but can be useful
                    # for weird use cases like multiple backwards.
                    loss.backward()
                    flashy.distrib.sync_model(self.long_model)
                self.deadlock_detect.update("backward")
                if self.scaler is not None:
                    self.scaler.unscale_(self.long_optimizer)
                if self.cfg.optim.max_norm:
                    if self.cfg.fsdp.use:
                        long_metrics["long_grad_norm"] = self.long_model.clip_grad_norm_(self.cfg.optim.max_norm)  # type: ignore
                    else:
                        long_metrics["long_grad_norm"] = torch.nn.utils.clip_grad_norm_(
                            self.long_model.parameters(), self.cfg.optim.max_norm
                        )
                if self.scaler is None:
                    self.long_optimizer.step()
                else:
                    self.scaler.step(self.long_optimizer)
                    self.scaler.update()
                if self.long_lr_scheduler:
                    self.long_lr_scheduler.step()
                self.long_optimizer.zero_grad()
                self.deadlock_detect.update("optim")
                if self.scaler is not None:
                    scale = self.scaler.get_scale()
                    long_metrics["long_grad_scale"] = scale

        if len(long_metrics) > 0:
            metrics["ce"] = long_ce
            metrics["ppl"] = torch.exp(long_ce)
        metrics.update(long_metrics)
        return metrics

    def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
        """Common logic for train and valid stages."""
        self.long_model.train(self.is_training)

        loader = self.dataloaders[dataset_split]
        # get a different order for distributed training, otherwise this will get ignored
        if flashy.distrib.world_size() > 1 and isinstance(
            loader.sampler, torch.utils.data.distributed.DistributedSampler
        ):
            loader.sampler.set_epoch(self.epoch)
        updates_per_epoch = (
            self.train_updates_per_epoch if self.is_training else len(loader)
        )
        if self.cfg.benchmark_no_load:
            self.logger.warning("Fake loading for benchmarking: re-using first batch")
            batch = next(iter(loader))
            loader = [batch] * updates_per_epoch  # type: ignore
        lp = self.log_progress(
            self.current_stage,
            loader,
            total=updates_per_epoch,
            updates=self.log_updates,
        )
        average = flashy.averager()  # epoch wise average
        instant_average = flashy.averager()  # average between two logging
        metrics: dict = {}

        with self.long_profiler, self.deadlock_detect:  # profiler will only run for the first 20 updates.
            for idx, batch in enumerate(lp):
                self.deadlock_detect.update("batch")
                if idx >= updates_per_epoch:
                    break
                metrics = {}
                metrics = self.run_step(idx, batch, metrics)
                self.deadlock_detect.update("step")
                # run EMA step
                if (
                    self.ema is not None
                    and self.is_training
                    and (idx + 1) % self.cfg.optim.ema.updates == 0
                    and len(metrics) > 0  # if the batch is not skipped
                ):
                    self.logger.debug("EMA model step")
                    self.ema.step()
                self.deadlock_detect.update("ema")
                self.long_profiler.step()
                instant_metrics = instant_average(metrics)
                if lp.update(**instant_metrics):
                    instant_average = (
                        flashy.averager()
                    )  # reset averager between two logging
                metrics = average(metrics)  # epoch wise average
                self.deadlock_detect.update("end_batch")

        metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
        return metrics

    @torch.no_grad()
    def run_generate_step(
        self,
        batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
        gen_duration: float,
        prompt_duration: tp.Optional[float] = None,
        **generation_params,
    ) -> dict:
        bench_start = time.time()
        audio, meta = batch
        assert audio.size(0) == len(meta), (
            f"Mismatch between number of items in audio batch ({audio.size(0)})",
            f" and in metadata ({len(meta)})",
        )
        # prepare attributes
        attributes = [x.to_condition_attributes() for x in meta]

        # prepare audio prompt
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        # get audio tokens from compression model
        prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
        prompt_audio = audio[..., :prompt_audio_frames]

        num_samples = None
        prompt_audio = prompt_audio.to(self.device)
        prompt_tokens, scale = self.compression_model.encode(
            prompt_audio, stage=0
        )
        prompt_long_tokens = prompt_tokens[0]
        assert (
            scale is None
        ), "Compression model in MusicGen should not require rescaling."

        # generate by sampling from the LM
        with self.autocast:
            total_gen_len = math.ceil(
                gen_duration * self.compression_model.frame_rates[0]
            )
            # gen_long_tokens: [B, K, T_s]
            gen_long_tokens = self.long_model.generate(
                prompt_long_tokens,
                attributes,
                max_gen_len=total_gen_len,
                num_samples=num_samples,
                **generation_params,
            )
        assert gen_long_tokens.dim() == 3
        gen_long_tokens, valid_lengths = SpeechGenSolver._postprocess_codes(
            gen_long_tokens,
            self.long_model.audio_sos_token_id,
            self.long_model.audio_eos_token_id,
            self.compression_model.silent_tokens[0],
        )

        # generate audio from tokens
        gen_audio = self.compression_model.decode(
            [gen_long_tokens],
            None
        )
        gen_audio = SpeechGenSolver._postprocess_audios(
            gen_audio,
            valid_lengths,
            self.compression_model.frame_rates[0],
            self.compression_model.sample_rate,
        )

        bench_end = time.time()
        gen_outputs = {
            "rtf": (bench_end - bench_start) / gen_duration,
            "ref_audio": audio,
            "gen_audio": gen_audio,
            "gen_long_tokens": gen_long_tokens,
            "prompt_audio": prompt_audio,
            "prompt_long_tokens": prompt_long_tokens,
        }
        return gen_outputs

    def generate_audio(self) -> dict:
        """Audio generation stage."""
        generate_stage_name = f"{self.current_stage}"
        sample_manager = SampleManager(self.xp)
        self.logger.info(f"Generating samples in {sample_manager.base_folder}")
        loader = self.dataloaders["generate"]
        updates = len(loader)
        lp = self.log_progress(
            generate_stage_name, loader, total=updates, updates=self.log_updates
        )

        dataset = get_dataset_from_loader(loader)
        assert isinstance(dataset, AudioDataset)
        target_duration = self.cfg.generate.lm.gen_duration
        prompt_duration = self.cfg.generate.lm.prompt_duration
        if target_duration is None:
            raise ValueError("Target duration must be provided for audio generation.")
        if prompt_duration is None:
            raise ValueError("Prompt duration must be provided for audio generation.")

        def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
            hydrated_conditions = []
            for sample in [x.to_condition_attributes() for x in meta]:
                cond_dict = {}
                for cond_type in sample.__annotations__.keys():
                    for cond_key, cond_val in getattr(sample, cond_type).items():
                        if (
                            cond_key
                            not in self.long_model.condition_provider.conditioners.keys()
                        ):
                            continue
                        if is_jsonable(cond_val):
                            cond_dict[cond_key] = cond_val
                        elif isinstance(cond_val, WavCondition):
                            cond_dict[cond_key] = cond_val.path
                        elif isinstance(cond_val, JointEmbedCondition):
                            cond_dict[cond_key] = (
                                cond_val.text
                            )  # only support text at inference for now
                        else:
                            # if we reached this point, it is not clear how to log the condition
                            # so we just log the type.
                            cond_dict[cond_key] = str(type(cond_val))
                            continue
                hydrated_conditions.append(cond_dict)
            return hydrated_conditions

        metrics: dict = {}
        average = flashy.averager()
        for batch in lp:
            audio, meta = batch
            # metadata for sample manager
            hydrated_conditions = get_hydrated_conditions(meta)
            if self.cfg.generate.lm.unprompted_samples:
                raise NotImplementedError("Unprompted samples are not supported.")

            if self.cfg.generate.lm.prompted_samples:
                gen_outputs = self.run_generate_step(
                    batch,
                    gen_duration=target_duration,
                    prompt_duration=prompt_duration,
                    **self.generation_params,
                )
                gen_audio = gen_outputs["gen_audio"].cpu()
                prompt_audio = gen_outputs["prompt_audio"].cpu()
                sample_manager.add_samples(
                    gen_audio,
                    self.epoch,
                    hydrated_conditions,
                    prompt_wavs=prompt_audio,
                    ground_truth_wavs=audio,
                    generation_args=self.generation_params,
                )
                metrics["rtf"] = gen_outputs["rtf"]

            metrics = average(metrics)

        flashy.distrib.barrier()
        return metrics

    def evaluate_audio_generation(self) -> dict:
        """Evaluate audio generation with off-the-shelf metrics."""
        # TODO: Implement audio generation evaluation
        return {}

    @staticmethod
    def model_from_checkpoint(
        checkpoint_path: tp.Union[Path, str],
        device: tp.Union[torch.device, str] = "cpu",
    ) -> models.LMModel:
        _checkpoint_path = checkpoint.resolve_checkpoint_path(
            str(checkpoint_path), use_fsdp=False
        )
        pkg = load_lm_model_ckpt(_checkpoint_path)
        cfg = OmegaConf.create(pkg["xp.cfg"])
        cfg.device = str(device)
        if cfg.device == "cpu":
            cfg.dtype = "float32"
        else:
            cfg.dtype = "float16"
        _delete_param(cfg, "conditioners.self_wav.chroma_stem.cache_path")
        _delete_param(cfg, "conditioners.args.merge_text_conditions_p")
        _delete_param(cfg, "conditioners.args.drop_desc_p")
        long_model = get_lm_model(cfg)
        if pkg["fsdp_best_state"]:
            long_model.load_state_dict(pkg["fsdp_best_state"]["long_model"])
        else:
            long_model.load_state_dict(pkg["best_state"]["long_model"])
        long_model.eval()
        long_model.cfg = cfg
        return long_model

    def evaluate(self) -> dict:
        """Evaluate stage."""
        self.long_model.eval()
        with torch.no_grad():
            metrics: dict = {}
            if self.cfg.evaluate.metrics.base:
                metrics.update(self.common_train_valid("evaluate"))
            gen_metrics = self.evaluate_audio_generation()
            return {**metrics, **gen_metrics}

    def generate(self) -> dict:
        """Generate stage."""
        self.long_model.eval()
        with torch.no_grad():
            return self.generate_audio()
