from typing import Sequence
import random
from typing import Any

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import diffusers.schedulers as noise_schedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils.torch_utils import randn_tensor

from models.autoencoder.autoencoder_base import AutoEncoderBase
from models.content_encoder.content_encoder import ContentEncoder
from models.content_adapter import ContentAdapterBase, ContentEncoderAdapterMixin
from models.common import (
    LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
    DurationAdapterMixin
)
from utils.torch_utilities import (
    create_alignment_path, create_mask_from_length, loss_with_mask,
    trim_or_pad_length
)


class DiffusionMixin:
    def __init__(
        self,
        noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
        snr_gamma: float = None,
        cfg_drop_ratio: float = 0.2
    ) -> None:
        self.noise_scheduler_name = noise_scheduler_name
        self.snr_gamma = snr_gamma
        self.classifier_free_guidance = cfg_drop_ratio > 0.0
        self.cfg_drop_ratio = cfg_drop_ratio
        self.noise_scheduler = noise_schedulers.DDPMScheduler.from_pretrained(
            self.noise_scheduler_name, subfolder="scheduler"
        )

    def compute_snr(self, timesteps) -> torch.Tensor:
        """
        Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
        """
        alphas_cumprod = self.noise_scheduler.alphas_cumprod
        sqrt_alphas_cumprod = alphas_cumprod**0.5
        sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod)**0.5

        # Expand the tensors.
        # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
        sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device
                                                    )[timesteps].float()
        while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
            sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
        alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
            device=timesteps.device
        )[timesteps].float()
        while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
            sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[...,
                                                                          None]
        sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

        # Compute SNR.
        snr = (alpha / sigma)**2
        return snr

    def get_timesteps(
        self,
        batch_size: int,
        device: torch.device,
        training: bool = True
    ) -> torch.Tensor:
        if training:
            timesteps = torch.randint(
                0,
                self.noise_scheduler.config.num_train_timesteps,
                (batch_size, ),
                device=device
            )
        else:
            # validation on half of the total timesteps
            timesteps = (self.noise_scheduler.config.num_train_timesteps //
                         2) * torch.ones((batch_size, ),
                                         dtype=torch.int64,
                                         device=device)

        timesteps = timesteps.long()
        return timesteps

    def get_input_target_and_timesteps(
        self,
        latent: torch.Tensor,
        training: bool,
    ):
        batch_size = latent.shape[0]
        device = latent.device
        num_train_timesteps = self.noise_scheduler.config.num_train_timesteps
        self.noise_scheduler.set_timesteps(num_train_timesteps, device=device)
        timesteps = self.get_timesteps(batch_size, device, training=training)
        noise = torch.randn_like(latent)
        noisy_latent = self.noise_scheduler.add_noise(latent, noise, timesteps)
        target = self.get_target(latent, noise, timesteps)
        return noisy_latent, target, timesteps

    def get_target(
        self, latent: torch.Tensor, noise: torch.Tensor,
        timesteps: torch.Tensor
    ) -> torch.Tensor:
        """
        Get the target for loss depending on the prediction type
        """
        if self.noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        elif self.noise_scheduler.config.prediction_type == "v_prediction":
            target = self.noise_scheduler.get_velocity(
                latent, noise, timesteps
            )
        else:
            raise ValueError(
                f"Unknown prediction type {self.noise_scheduler.config.prediction_type}"
            )
        return target

    def loss_with_snr(
        self,
        pred: torch.Tensor,
        target: torch.Tensor,
        timesteps: torch.Tensor,
        mask: torch.Tensor,
        reduce: bool = True
    ) -> torch.Tensor:
        if self.snr_gamma is None:
            loss = F.mse_loss(pred.float(), target.float(), reduction="none")
            loss = loss_with_mask(loss, mask, reduce=reduce)
        else:
            # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
            # Adapted from https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py#L1006
            snr = self.compute_snr(timesteps)
            mse_loss_weights = torch.stack(
                [
                    snr,
                    self.snr_gamma * torch.ones_like(timesteps),
                ],
                dim=1,
            ).min(dim=1)[0]
            # division by (snr + 1) does not work well, not clear about the reason
            mse_loss_weights = mse_loss_weights / snr
            loss = F.mse_loss(pred.float(), target.float(), reduction="none")
            loss = loss_with_mask(loss, mask, reduce=False) * mse_loss_weights
            if reduce:
                loss = loss.mean()
        return loss

    def rescale_cfg(
        self, pred_cond: torch.Tensor, pred_cfg: torch.Tensor,
        guidance_rescale: float
    ):
        """
        Rescale `pred_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
        Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
        """
        std_cond = pred_cond.std(
            dim=list(range(1, pred_cond.ndim)), keepdim=True
        )
        std_cfg = pred_cfg.std(dim=list(range(1, pred_cfg.ndim)), keepdim=True)

        pred_rescaled = pred_cfg * (std_cond / std_cfg)
        pred_cfg = guidance_rescale * pred_rescaled + (
            1 - guidance_rescale
        ) * pred_cfg
        return pred_cfg


class SingleTaskCrossAttentionAudioDiffusion(
    LoadPretrainedBase, CountParamsBase, SaveTrainableParamsBase,
    DiffusionMixin, ContentEncoderAdapterMixin
):
    def __init__(
        self,
        autoencoder: AutoEncoderBase,
        content_encoder: ContentEncoder,
        backbone: nn.Module,
        noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
        snr_gamma: float = None,
        cfg_drop_ratio: float = 0.2,
    ):
        nn.Module.__init__(self)
        DiffusionMixin.__init__(
            self, noise_scheduler_name, snr_gamma, cfg_drop_ratio
        )
        ContentEncoderAdapterMixin.__init__(
            self, content_encoder=content_encoder
        )

        self.autoencoder = autoencoder
        for param in self.autoencoder.parameters():
            param.requires_grad = False

        if hasattr(self.content_encoder, "audio_encoder"):
            self.content_encoder.audio_encoder.model = self.autoencoder

        self.backbone = backbone
        self.dummy_param = nn.Parameter(torch.empty(0))

    def forward(
        self, content: list[Any], condition: list[Any], task: list[str],
        waveform: torch.Tensor, waveform_lengths: torch.Tensor, **kwargs
    ):
        device = self.dummy_param.device

        self.autoencoder.eval()
        with torch.no_grad():
            latent, latent_mask = self.autoencoder.encode(
                waveform.unsqueeze(1), waveform_lengths
            )

        content_dict = self.encode_content(content, task, device)
        content, content_mask = content_dict["content"], content_dict[
            "content_mask"]

        if self.training and self.classifier_free_guidance:
            mask_indices = [
                k for k in range(len(waveform))
                if random.random() < self.cfg_drop_ratio
            ]
            if len(mask_indices) > 0:
                content[mask_indices] = 0

        noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
            latent, self.training
        )

        pred: torch.Tensor = self.backbone(
            x=noisy_latent,
            timesteps=timesteps,
            context=content,
            x_mask=latent_mask,
            context_mask=content_mask
        )

        pred = pred.transpose(1, self.autoencoder.time_dim)
        target = target.transpose(1, self.autoencoder.time_dim)
        loss = self.loss_with_snr(pred, target, timesteps, latent_mask)

        return loss

    def prepare_latent(
        self, batch_size: int, scheduler: SchedulerMixin,
        latent_shape: Sequence[int], dtype: torch.dtype, device: str
    ):
        shape = (batch_size, *latent_shape)
        latent = randn_tensor(
            shape, generator=None, device=device, dtype=dtype
        )
        # scale the initial noise by the standard deviation required by the scheduler
        latent = latent * scheduler.init_noise_sigma
        return latent

    def iterative_denoise(
        self,
        latent: torch.Tensor,
        scheduler: SchedulerMixin,
        verbose: bool,
        cfg: bool,
        cfg_scale: float,
        cfg_rescale: float,
        backbone_input: dict,
    ):
        timesteps = scheduler.timesteps
        num_steps = len(timesteps)
        num_warmup_steps = len(timesteps) - num_steps * scheduler.order
        progress_bar = tqdm(range(num_steps), disable=not verbose)

        for i, timestep in enumerate(timesteps):
            # expand the latent if we are doing classifier free guidance
            if cfg:
                latent_input = torch.cat([latent, latent])
            else:
                latent_input = latent
            latent_input = scheduler.scale_model_input(latent_input, timestep)

            noise_pred = self.backbone(
                x=latent_input, timesteps=timestep, **backbone_input
            )

            # perform guidance
            if cfg:
                noise_pred_uncond, noise_pred_content = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + cfg_scale * (
                    noise_pred_content - noise_pred_uncond
                )
                if cfg_rescale != 0.0:
                    noise_pred = self.rescale_cfg(
                        noise_pred_content, noise_pred, cfg_rescale
                    )

            # compute the previous noisy sample x_t -> x_t-1
            latent = scheduler.step(noise_pred, timestep, latent).prev_sample

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and
                                           (i + 1) % scheduler.order == 0):
                progress_bar.update(1)

        progress_bar.close()

        return latent

    @torch.no_grad()
    def inference(
        self,
        content: list[Any],
        condition: list[Any],
        task: list[str],
        latent_shape: Sequence[int],
        scheduler: SchedulerMixin,
        num_steps: int = 50,
        guidance_scale: float = 3.0,
        guidance_rescale: float = 0.0,
        disable_progress: bool = True,
        **kwargs
    ):
        device = self.dummy_param.device
        classifier_free_guidance = guidance_scale > 1.0
        batch_size = len(content)

        content_output: dict[str, torch.Tensor] = self.encode_content(
            content, task, device
        )
        content, content_mask = content_output["content"], content_output[
            "content_mask"]

        if classifier_free_guidance:
            uncond_content = torch.zeros_like(content)
            uncond_content_mask = content_mask.detach().clone()
            content = torch.cat([uncond_content, content])
            content_mask = torch.cat([uncond_content_mask, content_mask])

        scheduler.set_timesteps(num_steps, device=device)

        latent = self.prepare_latent(
            batch_size, scheduler, latent_shape, content.dtype, device
        )
        latent = self.iterative_denoise(
            latent=latent,
            scheduler=scheduler,
            verbose=not disable_progress,
            cfg=classifier_free_guidance,
            cfg_scale=guidance_scale,
            cfg_rescale=guidance_rescale,
            backbone_input={
                "context": content,
                "context_mask": content_mask
            },
        )

        waveform = self.autoencoder.decode(latent)

        return waveform


class CrossAttentionAudioDiffusion(
    SingleTaskCrossAttentionAudioDiffusion, DurationAdapterMixin
):
    def __init__(
        self,
        autoencoder: AutoEncoderBase,
        content_encoder: ContentEncoder,
        content_adapter: ContentAdapterBase,
        backbone: nn.Module,
        content_dim: int = None,
        frame_resolution: float = None,
        duration_offset: float = 1.0,
        noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
        snr_gamma: float = None,
        cfg_drop_ratio: float = 0.2,
    ):
        super().__init__(
            autoencoder=autoencoder,
            content_encoder=content_encoder,
            backbone=backbone,
            noise_scheduler_name=noise_scheduler_name,
            snr_gamma=snr_gamma,
            cfg_drop_ratio=cfg_drop_ratio
        )
        ContentEncoderAdapterMixin.__init__(
            self,
            content_encoder=content_encoder,
            content_adapter=content_adapter,
        )
        DurationAdapterMixin.__init__(
            self,
            latent_token_rate=autoencoder.latent_token_rate,
            offset=duration_offset,
        )

    def encode_content_with_instruction(
        self,
        content: list[Any],
        task: list[str],
        device: str | torch.device,
        instruction: torch.Tensor,
        instruction_lengths: torch.Tensor,
    ):
        content_dict = self.encode_content(
            content, task, device, instruction, instruction_lengths
        )
        return (
            content_dict["content"],
            content_dict["content_mask"],
            content_dict["global_duration_pred"],
            content_dict["local_duration_pred"],
            content_dict["length_aligned_content"],
        )

    def forward(
        self,
        content: list[Any],
        task: list[str],
        waveform: torch.Tensor,
        waveform_lengths: torch.Tensor,
        instruction: torch.Tensor,
        instruction_lengths: Sequence[int],
        loss_reduce: bool = True,
        **kwargs
    ):
        device = self.dummy_param.device
        loss_reduce = self.training or (loss_reduce and not self.training)

        self.autoencoder.eval()
        with torch.no_grad():
            latent, latent_mask = self.autoencoder.encode(
                waveform.unsqueeze(1), waveform_lengths
            )

        content, content_mask, global_duration_pred, _, _ = \
            self.encode_content_with_instruction(
                content, task, device, instruction, instruction_lengths
            )
        global_duration_loss = self.get_global_duration_loss(
            global_duration_pred, latent_mask, reduce=loss_reduce
        )

        if self.training and self.classifier_free_guidance:
            mask_indices = [
                k for k in range(len(waveform))
                if random.random() < self.cfg_drop_ratio
            ]
            if len(mask_indices) > 0:
                content[mask_indices] = 0

        noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
            latent, training=self.training
        )

        pred: torch.Tensor = self.backbone(
            x=noisy_latent,
            timesteps=timesteps,
            context=content,
            x_mask=latent_mask,
            context_mask=content_mask
        )

        pred = pred.transpose(1, self.autoencoder.time_dim)
        target = target.transpose(1, self.autoencoder.time_dim)
        diff_loss = self.loss_with_snr(pred, target, timesteps, latent_mask)

        return {
            "diff_loss": diff_loss,
            "global_duration_loss": global_duration_loss,
        }

    @torch.no_grad()
    def inference(
        self,
        content: list[Any],
        condition: list[Any],
        task: list[str],
        is_time_aligned: Sequence[bool],
        instruction: torch.Tensor,
        instruction_lengths: Sequence[int],
        scheduler: SchedulerMixin,
        num_steps: int = 50,
        guidance_scale: float = 3.0,
        guidance_rescale: float = 0.0,
        disable_progress: bool = True,
        use_gt_duration: bool = False,
        **kwargs
    ):
        device = self.dummy_param.device
        classifier_free_guidance = guidance_scale > 1.0

        (
            content,
            content_mask,
            global_duration_pred,
            local_duration_pred,
            _,
        ) = self.encode_content_with_instruction(
            content, task, device, instruction, instruction_lengths
        )

        if use_gt_duration:
            raise NotImplementedError(
                "Using ground truth global duration only is not implemented yet"
            )

        # prepare global duration
        global_duration = self.prepare_global_duration(
            global_duration_pred,
            local_duration_pred,
            is_time_aligned,
            use_local=False
        )
        latent_length = torch.round(global_duration * self.latent_token_rate)
        latent_mask = create_mask_from_length(latent_length).to(device)
        max_latent_length = latent_mask.sum(1).max().item()

        # prepare latent and noise
        if classifier_free_guidance:
            uncond_content = torch.zeros_like(content)
            uncond_content_mask = content_mask.detach().clone()
            context = torch.cat([uncond_content, content])
            context_mask = torch.cat([uncond_content_mask, content_mask])
        else:
            context = content
            context_mask = content_mask

        batch_size = content.size(0)
        latent_shape = tuple(
            max_latent_length if dim is None else dim
            for dim in self.autoencoder.latent_shape
        )
        latent = self.prepare_latent(
            batch_size, scheduler, latent_shape, content.dtype, device
        )

        scheduler.set_timesteps(num_steps, device=device)
        latent = self.iterative_denoise(
            latent=latent,
            scheduler=scheduler,
            verbose=not disable_progress,
            cfg=classifier_free_guidance,
            cfg_scale=guidance_scale,
            cfg_rescale=guidance_rescale,
            backbone_input={
                "x_mask": latent_mask,
                "context": context,
                "context_mask": context_mask,
            }
        )

        waveform = self.autoencoder.decode(latent)

        return waveform


class DummyContentAudioDiffusion(CrossAttentionAudioDiffusion):
    def __init__(
        self,
        autoencoder: AutoEncoderBase,
        content_encoder: ContentEncoder,
        content_adapter: ContentAdapterBase,
        backbone: nn.Module,
        content_dim: int,
        frame_resolution: float,
        duration_offset: float = 1.0,
        noise_scheduler_name: str = "stabilityai/stable-diffusion-2-1",
        snr_gamma: float = None,
        cfg_drop_ratio: float = 0.2,
    ):
        """
        Args:
            autoencoder:
                Pretrained audio autoencoder that encodes raw waveforms into latent
                space and decodes latents back to waveforms.
            content_encoder:
                Module that produces content embeddings (e.g., from text, MIDI, or
                other modalities) used to guide the diffusion.
            content_adapter (ContentAdapterBase):
                Adapter module that fuses task instruction embeddings and content embeddings,
                and performs duration prediction for time-aligned tasks.
            backbone:
                U‑Net or Transformer backbone that performs the core denoising
                operations in latent space.
            content_dim:
                Dimension of the content embeddings produced by the `content_encoder` 
                and `content_adapter`.
            frame_resolution:
                Time resolution, in seconds, of each content frame when predicting
                duration alignment. Used when calculating duration loss.
            duration_offset:
                A small positive offset (frame number) added to predicted durations
                to ensure numerical stability of log-scaled duration prediction. 
            noise_scheduler_name:
                Identifier of the pretrained noise scheduler to use. 
            snr_gamma:
                Clipping value in min-SNR diffusion loss weighting strategy.
            cfg_drop_ratio:
                Probability of dropping the content conditioning during training
                to support CFG.
        """
        super().__init__(
            autoencoder=autoencoder,
            content_encoder=content_encoder,
            content_adapter=content_adapter,
            backbone=backbone,
            duration_offset=duration_offset,
            noise_scheduler_name=noise_scheduler_name,
            snr_gamma=snr_gamma,
            cfg_drop_ratio=cfg_drop_ratio,
        )
        self.frame_resolution = frame_resolution
        self.dummy_nta_embed = nn.Parameter(torch.zeros(content_dim))
        self.dummy_ta_embed = nn.Parameter(torch.zeros(content_dim))

    def get_backbone_input(
        self,
        target_length: int,
        content: torch.Tensor,
        content_mask: torch.Tensor,
        time_aligned_content: torch.Tensor,
        length_aligned_content: torch.Tensor,
        is_time_aligned: torch.Tensor,
    ):
        # TODO compatility for 2D spectrogram VAE
        time_aligned_content = trim_or_pad_length(
            time_aligned_content, target_length, 1
        )
        length_aligned_content = trim_or_pad_length(
            length_aligned_content, target_length, 1
        )
        # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
        # length_aligned_content: from aligned input (f0/energy)
        time_aligned_content = time_aligned_content + length_aligned_content
        time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
            time_aligned_content.dtype
        )

        context = content
        context[is_time_aligned] = self.dummy_nta_embed.to(context.dtype)
        # only use the first dummy non time aligned embedding
        context_mask = content_mask.detach().clone()
        context_mask[is_time_aligned, 1:] = False

        # truncate dummy non time aligned context
        if is_time_aligned.sum().item() < content.size(0):
            trunc_nta_length = content_mask[~is_time_aligned].sum(1).max()
        else:
            trunc_nta_length = content.size(1)
        context = context[:, :trunc_nta_length]
        context_mask = context_mask[:, :trunc_nta_length]

        return context, context_mask, time_aligned_content

    def forward(
        self,
        content: list[Any],
        task: list[str],
        is_time_aligned: Sequence[bool],
        duration: Sequence[float],
        waveform: torch.Tensor,
        waveform_lengths: torch.Tensor,
        instruction: torch.Tensor,
        instruction_lengths: Sequence[int],
        loss_reduce: bool = True,
        **kwargs
    ):
        device = self.dummy_param.device
        loss_reduce = self.training or (loss_reduce and not self.training)

        self.autoencoder.eval()
        with torch.no_grad():
            latent, latent_mask = self.autoencoder.encode(
                waveform.unsqueeze(1), waveform_lengths
            )

        (
            content, content_mask, global_duration_pred, local_duration_pred,
            length_aligned_content
        ) = self.encode_content_with_instruction(
            content, task, device, instruction, instruction_lengths
        )

        # truncate unused non time aligned duration prediction
        if is_time_aligned.sum() > 0:
            trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
        else:
            trunc_ta_length = content.size(1)

        # duration loss
        local_duration_pred = local_duration_pred[:, :trunc_ta_length]
        ta_content_mask = content_mask[:, :trunc_ta_length]
        local_duration_loss = self.get_local_duration_loss(
            duration,
            local_duration_pred,
            ta_content_mask,
            is_time_aligned,
            reduce=loss_reduce
        )
        global_duration_loss = self.get_global_duration_loss(
            global_duration_pred, latent_mask, reduce=loss_reduce
        )

        # --------------------------------------------------------------------
        # prepare latent and diffusion-related noise
        # --------------------------------------------------------------------
        noisy_latent, target, timesteps = self.get_input_target_and_timesteps(
            latent, training=self.training
        )

        # --------------------------------------------------------------------
        # duration adapter
        # --------------------------------------------------------------------
        if is_time_aligned.sum() == 0 and \
            duration.size(1) < content_mask.size(1):
            # for non time-aligned tasks like TTA, `duration` is dummy one
            duration = F.pad(
                duration, (0, content_mask.size(1) - duration.size(1))
            )
        time_aligned_content, _ = self.expand_by_duration(
            x=content[:, :trunc_ta_length],
            content_mask=ta_content_mask,
            local_duration=duration,
        )

        # --------------------------------------------------------------------
        # prepare input to the backbone
        # --------------------------------------------------------------------
        # TODO compatility for 2D spectrogram VAE
        latent_length = noisy_latent.size(self.autoencoder.time_dim)
        context, context_mask, time_aligned_content = self.get_backbone_input(
            latent_length, content, content_mask, time_aligned_content,
            length_aligned_content, is_time_aligned
        )

        # --------------------------------------------------------------------
        # classifier free guidance
        # --------------------------------------------------------------------
        if self.training and self.classifier_free_guidance:
            mask_indices = [
                k for k in range(len(waveform))
                if random.random() < self.cfg_drop_ratio
            ]
            if len(mask_indices) > 0:
                context[mask_indices] = 0
                time_aligned_content[mask_indices] = 0

        pred: torch.Tensor = self.backbone(
            x=noisy_latent,
            x_mask=latent_mask,
            timesteps=timesteps,
            context=context,
            context_mask=context_mask,
            time_aligned_context=time_aligned_content,
        )
        pred = pred.transpose(1, self.autoencoder.time_dim)
        target = target.transpose(1, self.autoencoder.time_dim)
        diff_loss = self.loss_with_snr(
            pred, target, timesteps, latent_mask, reduce=loss_reduce
        )
        return {
            "diff_loss": diff_loss,
            "local_duration_loss": local_duration_loss,
            "global_duration_loss": global_duration_loss
        }

    @torch.no_grad()
    def inference(
        self,
        content: list[Any],
        condition: list[Any],
        task: list[str],
        is_time_aligned: list[bool],
        instruction: torch.Tensor,
        instruction_lengths: Sequence[int],
        scheduler: SchedulerMixin,
        num_steps: int = 20,
        guidance_scale: float = 3.0,
        guidance_rescale: float = 0.0,
        disable_progress: bool = True,
        use_gt_duration: bool = False,
        **kwargs
    ):
        device = self.dummy_param.device
        classifier_free_guidance = guidance_scale > 1.0

        (
            content, content_mask, global_duration_pred, local_duration_pred,
            length_aligned_content
        ) = self.encode_content_with_instruction(
            content, task, device, instruction, instruction_lengths
        )

        batch_size = content.size(0)

        # truncate dummy time aligned duration prediction
        is_time_aligned = torch.as_tensor(is_time_aligned)
        if is_time_aligned.sum() > 0:
            trunc_ta_length = content_mask[is_time_aligned].sum(1).max()
        else:
            trunc_ta_length = content.size(1)

        # prepare local duration
        local_duration = self.prepare_local_duration(
            local_duration_pred, content_mask
        )
        local_duration = local_duration[:, :trunc_ta_length]
        # use ground truth duration
        if use_gt_duration and "duration" in kwargs:
            local_duration = torch.as_tensor(kwargs["duration"]).to(device)

        # prepare global duration
        global_duration = self.prepare_global_duration(
            global_duration_pred, local_duration, is_time_aligned
        )

        # --------------------------------------------------------------------
        # duration adapter
        # --------------------------------------------------------------------
        time_aligned_content, latent_mask = self.expand_by_duration(
            x=content[:, :trunc_ta_length],
            content_mask=content_mask[:, :trunc_ta_length],
            local_duration=local_duration,
            global_duration=global_duration,
        )

        context, context_mask, time_aligned_content = self.get_backbone_input(
            target_length=time_aligned_content.size(1),
            content=content,
            content_mask=content_mask,
            time_aligned_content=time_aligned_content,
            length_aligned_content=length_aligned_content,
            is_time_aligned=is_time_aligned
        )

        # --------------------------------------------------------------------
        # prepare unconditional input
        # --------------------------------------------------------------------
        if classifier_free_guidance:
            uncond_time_aligned_content = torch.zeros_like(
                time_aligned_content
            )
            uncond_context = torch.zeros_like(context)
            uncond_context_mask = context_mask.detach().clone()
            time_aligned_content = torch.cat([
                uncond_time_aligned_content, time_aligned_content
            ])
            context = torch.cat([uncond_context, context])
            context_mask = torch.cat([uncond_context_mask, context_mask])
            latent_mask = torch.cat([
                latent_mask, latent_mask.detach().clone()
            ])

        # --------------------------------------------------------------------
        # prepare input to the backbone
        # --------------------------------------------------------------------
        latent_length = latent_mask.sum(1).max().item()
        latent_shape = tuple(
            latent_length if dim is None else dim
            for dim in self.autoencoder.latent_shape
        )
        latent = self.prepare_latent(
            batch_size, scheduler, latent_shape, content.dtype, device
        )

        scheduler.set_timesteps(num_steps, device=device)
        latent = self.iterative_denoise(
            latent=latent,
            scheduler=scheduler,
            verbose=not disable_progress,
            cfg=classifier_free_guidance,
            cfg_scale=guidance_scale,
            cfg_rescale=guidance_rescale,
            backbone_input={
                "x_mask": latent_mask,
                "context": context,
                "context_mask": context_mask,
                "time_aligned_context": time_aligned_content,
            }
        )
        # TODO variable length decoding, using `latent_mask`
        waveform = self.autoencoder.decode(latent)
        return waveform


class DoubleContentAudioDiffusion(DummyContentAudioDiffusion):
    def get_backbone_input(
        self,
        target_length: int,
        content: torch.Tensor,
        content_mask: torch.Tensor,
        time_aligned_content: torch.Tensor,
        length_aligned_content: torch.Tensor,
        is_time_aligned: torch.Tensor,
    ):
        time_aligned_content = trim_or_pad_length(
            time_aligned_content, target_length, 1
        )
        context_length = min(content.size(1), time_aligned_content.size(1))
        time_aligned_content[~is_time_aligned, :context_length] = content[
            ~is_time_aligned, :context_length]
        length_aligned_content = trim_or_pad_length(
            length_aligned_content, target_length, 1
        )
        time_aligned_content = time_aligned_content + length_aligned_content

        context = content
        context_mask = content_mask.detach().clone()

        return context, context_mask, time_aligned_content


class HybridContentAudioDiffusion(DummyContentAudioDiffusion):
    def get_backbone_input(
        self,
        target_length: int,
        content: torch.Tensor,
        content_mask: torch.Tensor,
        time_aligned_content: torch.Tensor,
        length_aligned_content: torch.Tensor,
        is_time_aligned: torch.Tensor,
    ):
        # TODO compatility for 2D spectrogram VAE
        time_aligned_content = trim_or_pad_length(
            time_aligned_content, target_length, 1
        )
        length_aligned_content = trim_or_pad_length(
            length_aligned_content, target_length, 1
        )
        # time_aligned_content: from monotonic aligned input, without frame expansion (phoneme)
        # length_aligned_content: from aligned input (f0/energy)
        time_aligned_content = time_aligned_content + length_aligned_content
        time_aligned_content[~is_time_aligned] = self.dummy_ta_embed.to(
            time_aligned_content.dtype
        )

        context = content
        context_mask = content_mask.detach().clone()

        return context, context_mask, time_aligned_content
