import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import torch.jit
from utils.diffusion_utils import edm_sampler, trigflow_sampler, edm_sampler_2
from typing import Optional


#----------------------------------------------------------------------------
# Magnitude-preserving SiLU (Equation 81).
@torch.jit.script
def mp_silu(x: torch.Tensor) -> torch.Tensor:
    """
    Magnitude-preserving SiLU activation (Equation 81).
    Uses torch.nn.functional.silu and scales the output by 1/0.596.
    """
    return F.silu(x) / 0.596

class MPSiLU(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.silu(x) / 0.596

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256, max_period=10000):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size
        half = frequency_embedding_size // 2
        freqs = torch.exp(-math.log(max_period) * torch.arange(half, dtype=torch.float32) / half)
        self.register_buffer('freqs', freqs)
        if frequency_embedding_size % 2:
            zero_pad = torch.zeros(1, dtype=torch.float32)
            self.register_buffer('zero_pad', zero_pad)
            
    def forward(self, t):
        # t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=t.dtype)
        
        args = t[:, None].float() * self.freqs[None]
        t_freq = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if self.frequency_embedding_size % 2:
            t_freq = torch.cat([t_freq, torch.zeros_like(t_freq[:, :1])], dim=-1)
        t_emb = self.mlp(t_freq.to(dtype=t.dtype))
        return t_emb



#----------------------------------------------------------------------------
# Magnitude-preserving sum (Equation 88).

@torch.jit.script
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
    """
    Magnitude-preserving sum of two tensors (Equation 88).
    Performs a linear interpolation (lerp) between a and b with factor t,
    and then divides by sqrt((1-t)^2 + t^2).
    """
    denom = math.sqrt((1 - t) ** 2 + t ** 2)
    return a.lerp(b, t) / denom

def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class ResBlock(nn.Module):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    """

    def __init__(
        self,
        channels: int,
    ):
        super().__init__()
        self.channels = channels

        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h


class ResBlockv2(nn.Module):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    """

    def __init__(
        self,
        channels: int,
        head_dropout_p: float = 0.0,
        adaDN: bool = False
    ):
        super().__init__()
        self.channels = channels

        self.in_rms = nn.RMSNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Dropout(head_dropout_p),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(channels, 3 * channels, bias=True)
        )
        self.pixel_norm = PixelNorm() if adaDN else None

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        if self.pixel_norm is not None:
            shift_mlp = self.pixel_norm(shift_mlp)
            scale_mlp = self.pixel_norm(scale_mlp)
        h = modulate(self.in_rms(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h

class PixelNorm(nn.Module):
    def __init__(self, eps: float = 1e-8):
        """Normalize each feature vector (pixel) by its L2 norm across channels."""
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # normalize over channels
        mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(mean_sq + self.eps)
        return x_norm

class FinalLayer(nn.Module):
    """
    Final projection layer with optional adaptive double nomalization (from TrigFlow).
    - Applies non-affine LayerNorm
    - Computes scale & shift via a small MLP (SiLU + Linear)
    - Applies modulation and a final linear map
    """
    def __init__(self, model_channels: int, out_channels: int):
        super().__init__()
        self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(model_channels, 2 * model_channels, bias=True)
        )
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        return self.linear(x)
    
class FinalLayerv2(nn.Module):
    """
    Final projection layer with optional adaptive double nomalization (from TrigFlow).
    - Applies non-affine LayerNorm
    - Computes scale & shift via a small MLP (SiLU + Linear)
    - Optionally PixelNorms the modulation vectors for stability
    - Applies modulation and a final linear map
    """
    def __init__(self, model_channels: int, out_channels: int, adaDN: bool = False, head_dropout_p: float = 0.0):
        super().__init__()
        self.norm_final = nn.RMSNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(model_channels, 2 * model_channels, bias=True)
        )
        self.pixel_norm = PixelNorm() if adaDN else None
        self.dropout = nn.Dropout(head_dropout_p) 
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        if self.pixel_norm is not None:
            shift = self.pixel_norm(shift)
            scale = self.pixel_norm(scale)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.dropout(x)
        return self.linear(x)
#----------------------------------------------------------------------------
# Normalize given tensor to unit magnitude with respect to the given
# dimensions. Default = all dimensions except the first.

# def normalize(x, dim=None, eps=1e-4):
#     if dim is None:
#         dim = list(range(1, x.ndim))
#     norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
#     norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
#     return x / norm.to(x.dtype)

def normalize(x: torch.Tensor, dim=None, eps=1e-4) -> torch.Tensor:
    """
    L2-normalize tensor along given dims without explicit dtype casts.
    """
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.norm(x, dim=dim, keepdim=True)
    return x / (norm + eps)

class MPFourier(nn.Module):
    def __init__(self, num_channels: int, bandwidth: float = 1.0):
        super().__init__()
        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels, dtype=torch.float32) * bandwidth)
        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels, dtype=torch.float32))
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.to(torch.float32)
        y_unsq = y.unsqueeze(-1)
        freqs = self.freqs.view(1, -1)  # [1, num_channels]
        phases = self.phases.view(1, -1)  # [1, num_channels]
        y = y_unsq * freqs 
        y = y + phases         
        y = torch.cos(y) * np.sqrt(2)
        return y.to(x.dtype)
    # def forward(self, x: torch.Tensor) -> torch.Tensor:
    #     # x: (..., features)
    #     y = x.unsqueeze(-1) * self.freqs + self.phases
    #     return torch.cos(y) * math.sqrt(2)




#----------------------------------------------------------------------------
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
# with force weight normalization (Equation 66).
class MPConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel: tuple):
        super().__init__()
        self.out_channels = out_channels
        self.kernel = kernel
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
        # self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel, dtype=torch.float32))

    def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
        w = self.weight
        if self.training:
            with torch.no_grad():
                self.weight.copy_(normalize(w)) # forced weight normalization
        w = normalize(w) # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        elif w.ndim == 3:
            kernel_size = w.shape[-1]
            padding = kernel_size // 2
            out = F.conv1d(x, w, padding=padding)
            return out
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))

class DiffLoss_v2(nn.Module):
    """Diffusion Loss"""
    def __init__(
        self, 
        target_channels: int, 
        z_channels: int, 
        num_res_blocks: int, 
        model_channels: int, 
        seq_len_per_frame: int = 1, 
        P_mean: float = -0.4, 
        P_std: float = 1.0,
        sigma_data: float = 0.5, 
        label_drop_prob: float = 0.0,
        initializer_range: float = 0.02,
        label_balance: float = 0.5,
        naive_mar_mlp: bool = False,
        condition_merge: bool = False,
        head_dropout_p: float = 0.0,
    ):
        super(DiffLoss_v2, self).__init__()
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.in_channels = target_channels
        
        self.diffusion_head = DiffusionHead_v2(
            in_channels=target_channels,
            model_channels=model_channels,
            out_channels=target_channels,
            z_channels=z_channels,
            seq_len_per_frame=seq_len_per_frame,
            num_res_blocks=num_res_blocks,
            sigma_data=sigma_data,
            label_drop_prob=label_drop_prob,
            initializer_range=initializer_range,
            label_balance=label_balance,
            naive_mar_mlp=naive_mar_mlp,
            condition_merge=condition_merge,
            head_dropout_p=head_dropout_p,
        )

    def forward(self, target: torch.Tensor, z: torch.Tensor, mask: Optional[torch.Tensor] = None):
        # EDM2
        rnd_normal = torch.randn([target.shape[0], 1], device=target.device, dtype=target.dtype)
        # rnd_normal = torch.randn([target.shape[0], 1])
        cond = z
        
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        noise = torch.randn_like(target) * sigma
        
        ### EDM 2 loss
        denoised, logvar = self.diffusion_head(target + noise, sigma, cond, return_logvar=True)
        l2_loss = (denoised - target) ** 2
        denom = logvar.exp()
        weighted_l2_loss = weight * l2_loss 
        loss = weighted_l2_loss / denom + logvar
        if mask is not None:
            loss = (loss * mask).sum() / mask.sum()
        return loss.mean(), l2_loss.mean(), logvar.mean()
        
        
    def sample(
        self, 
        z: torch.Tensor,
        temperature: float = 1.0,
        cfg: float = 1.0, 
        device: torch.device = 'cuda',
        net_autoguidance: nn.Module = None,
        num_steps: int = 20, 
        sigma_min: float = 0.002, 
        sigma_max: float = 80, 
        rho: float = 7, 
        S_churn: int = 0,
        S_min: int = 0, 
        S_max: float = float('inf'), 
        S_noise: int = 1, 
    ) -> torch.Tensor:
        noise = torch.randn(z.shape[0], self.in_channels).to(device)
        
        sampled_token_latent = edm_sampler(
            net = self.diffusion_head, 
            noise = noise, 
            labels = z,
            num_steps = num_steps, 
            sigma_min = sigma_min, 
            sigma_max = sigma_max, 
            rho = rho,
            temperature = temperature,
            guidance = cfg,
            S_churn = S_churn,
            S_min = S_min,
            S_max = S_max,
            S_noise = S_noise,
            dtype = noise.dtype,   
            net_autoguidance = net_autoguidance,
        )
        return sampled_token_latent
    def sample_2(
            self, 
            z_cond: torch.Tensor,
            z_uncond: torch.Tensor,
            temperature: float = 1.0,
            cfg: float = 1.0, 
            device: torch.device = 'cuda',
            net_autoguidance: nn.Module = None,
            num_steps: int = 20, 
            sigma_min: float = 0.002, 
            sigma_max: float = 80, 
            rho: float = 7, 
            S_churn: int = 0,
            S_min: int = 0, 
            S_max: float = float('inf'), 
            S_noise: int = 1, 
        ) -> torch.Tensor:
            noise = torch.randn(z_cond.shape[0], self.in_channels).to(device)
            
            sampled_token_latent = edm_sampler_2(
                net = self.diffusion_head, 
                noise = noise, 
                labels = z_cond,
                labels_uncond = z_uncond,
                num_steps = num_steps, 
                sigma_min = sigma_min, 
                sigma_max = sigma_max, 
                rho = rho,
                temperature = temperature,
                guidance = cfg,
                S_churn = S_churn,
                S_min = S_min,
                S_max = S_max,
                S_noise = S_noise,
                dtype = noise.dtype,   
                net_autoguidance = net_autoguidance,
            )
            return sampled_token_latent


class DiffusionHead_v2(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """

    def __init__(
        self,
        in_channels: int,
        model_channels: int,
        out_channels: int,
        z_channels: int,
        num_res_blocks: int,
        seq_len_per_frame: int = 1,
        logvar_channels: int = 128,
        sigma_data: float = 0.5,
        label_drop_prob: float = 0.0,
        initializer_range: float = 0.02,
        label_balance: float = 0.5,
        naive_mar_mlp: bool = False,
        condition_merge: bool = False,
        head_dropout_p: float = 0.0,
    ):
        super().__init__()
        assert seq_len_per_frame == 1, "Only seq_len_per_frame=1 is supported at this point."
        
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.initializer_range = initializer_range

        # Fake class embedding for CFG's unconditional generation
        self.label_drop_prob = label_drop_prob
        if label_drop_prob > 0:
            self.fake_latent = nn.Parameter(torch.zeros(1, model_channels))

        self.input_proj = nn.Sequential(
            nn.Linear(in_channels, model_channels),
            nn.Dropout(head_dropout_p)
        )
        res_blocks = []
        for i in range(num_res_blocks):
            res_blocks.append(ResBlockv2(
                model_channels,
                head_dropout_p,
                adaDN=False
            ))
        self.res_blocks = nn.ModuleList(res_blocks)
        self.final_layer = FinalLayerv2(
            model_channels, 
            out_channels, 
            adaDN=False, 
            head_dropout_p=head_dropout_p
        )
        
        self.time_embed = TimestepEmbedder(model_channels)
        if not condition_merge:
            self.cond_embed = nn.Linear(z_channels, model_channels)
            # self.cond_embed = MPConv(z_channels, model_channels, kernel=[])
        else:
            self.cond_embed = nn.Sequential(
                nn.Linear(z_channels, model_channels),
                nn.SiLU(),
                nn.Linear(model_channels, model_channels),
            )
            # self.cond_embed = nn.Sequential(
            #     MPConv(z_channels, model_channels, kernel=[]),
            #     MPSiLU(),
            #     MPConv(model_channels, model_channels, kernel=[]),
            # ) 
        self.sigma_data = sigma_data
        self.register_buffer('sigma_data_tensor', torch.tensor(sigma_data), persistent=False)
        self.logvar_pe = TimestepEmbedder(logvar_channels)
        self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
        self.label_balance = label_balance
        self.initialize_weights()

    def initialize_weights(self):
        def _init_linear(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        self.input_proj.apply(_init_linear)
        self.cond_embed.apply(_init_linear)
        self.time_embed.mlp.apply(_init_linear)
        self.logvar_pe.mlp.apply(_init_linear)

        for block in self.res_blocks:
            block.apply(_init_linear)
            nn.init.zeros_(block.adaLN_modulation[-1].weight)
            nn.init.zeros_(block.adaLN_modulation[-1].bias)

        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].bias)
        self.final_layer.linear.apply(_init_linear)
        # nn.init.zeros_(self.final_layer.linear.weight)
        # nn.init.zeros_(self.final_layer.linear.bias)

    def forward(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, return_logvar: bool=False):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        # EDM precond
        sigma = sigma.reshape(-1, 1) # [(bsz x seq), 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [(bsz x seq), 1]
        c_skip = sigma_data ** 2 / denom  # [(bsz x seq), 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [(bsz x seq), 1]
        c_in   = 1 / denom.sqrt()  # [(bsz x seq), 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [(bsz x seq)]
        x_in = (c_in * x) 
        
        
        t = self.time_embed(c_noise)
        cond = self.cond_embed(cond)
        if self.training and self.label_drop_prob > 0:
            drop_latent_mask = torch.rand(cond.shape[0], dtype=cond.dtype, device=cond.device) < self.label_drop_prob
            drop_latent_mask = drop_latent_mask.unsqueeze(-1)
            cond = drop_latent_mask * self.fake_latent.to(cond.dtype) + (~drop_latent_mask) * cond
        y = mp_sum(t, cond, t=self.label_balance)
            
        ### network part
        x_in = self.input_proj(x_in)  # [(bsz x seq), model_channels]
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x + c_out * F_x
        
        # Estimate uncertainty if requested.
        if return_logvar:
            logvar = self.logvar_linear(self.logvar_pe(c_noise)).reshape(-1, 1)
            return D_x, logvar # u(sigma) in Equation 21     

        return D_x
    
    
    def inference(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :param cond_2: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        cond = cond.view(x.size(0), -1)
        # EDM precond
        sigma = sigma.reshape(-1, 1) # [(bsz x seq), 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [(bsz x seq), 1]
        c_skip = sigma_data ** 2 / denom  # [(bsz x seq), 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [(bsz x seq), 1]
        c_in   = 1 / denom.sqrt()  # [(bsz x seq), 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [(bsz x seq)]
        x_in = (c_in * x) 
        
        
        # conditioning part
        t = self.time_embed(c_noise)
        cond = self.cond_embed(cond)
        y = mp_sum(t, cond, t=self.label_balance)
        
        ### network part
        x_in = self.input_proj(x_in)  # [(bsz x seq), model_channels]
        for block in self.res_blocks:
            x_in = block(x_in, y)
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x + c_out * F_x
        
        return D_x
    
    
    def inference_uncond(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, guidance: float):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        cond = cond.view(x.size(0), -1)
        # for cfg
        sigma = sigma.repeat(cond.shape[0], 1).reshape(-1, 1)
        x_cfg = torch.cat([x, x], dim=0)
        sigma_cfg = torch.cat([sigma, sigma], dim=0)
        # EDM precond
        sigma = sigma_cfg.reshape(-1, 1) # [2*bsz, 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [2*bsz, 1]
        c_skip = sigma_data ** 2 / denom  # [2*bsz, 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [2*bsz, 1]
        c_in   = 1 / denom.sqrt()  # [2*bsz, 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [2*bsz]
        x_in = (c_in * x_cfg) # [2*bsz, latent_dim]
        
        
        
        ### network part
        x_in = self.input_proj(x_in)  # [2*bsz, model_channels]
        
        t = self.time_embed(c_noise)# [2*bsz, head_dim]
        
        cond = self.cond_embed(cond) # [(bsz x seq), head_dim]
        uncond = self.fake_latent.to(cond.dtype).repeat(cond.shape[0], 1) # [bsz, head_dim]
        cond = torch.cat([cond, uncond], dim=0) # [2*bsz, head_dim]

        y = mp_sum(t, cond, t=self.label_balance)
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x_cfg + c_out * F_x
        D_x_cond, D_x_uncond = torch.chunk(D_x, 2, dim=0)
        D_x = guidance * D_x_cond + (1 - guidance) * D_x_uncond
        return D_x

    def inference_uncond_2(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, uncond: torch.Tensor, guidance: float):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        cond = cond.view(x.size(0), -1)
        uncond = uncond.view(x.size(0), -1)
        # for cfg
        sigma = sigma.repeat(cond.shape[0], 1).reshape(-1, 1)
        x_cfg = torch.cat([x, x], dim=0)
        sigma_cfg = torch.cat([sigma, sigma], dim=0)
        # EDM precond
        sigma = sigma_cfg.reshape(-1, 1) # [2*bsz, 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [2*bsz, 1]
        c_skip = sigma_data ** 2 / denom  # [2*bsz, 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [2*bsz, 1]
        c_in   = 1 / denom.sqrt()  # [2*bsz, 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [2*bsz]
        x_in = (c_in * x_cfg) # [2*bsz, latent_dim]
        
        
        
        ### network part
        x_in = self.input_proj(x_in)  # [2*bsz, model_channels]
        
        t = self.time_embed(c_noise)# [2*bsz, head_dim]
        
        cond = self.cond_embed(cond) # [(bsz x seq), head_dim]
        uncond = self.cond_embed(uncond) # [(bsz x seq), head_dim]
        # uncond = self.fake_latent.to(cond.dtype).repeat(cond.shape[0], 1) # [bsz, head_dim]
        cond = torch.cat([cond, uncond], dim=0) # [2*bsz, head_dim]

        y = mp_sum(t, cond, t=self.label_balance)
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x_cfg + c_out * F_x
        D_x_cond, D_x_uncond = torch.chunk(D_x, 2, dim=0)
        D_x = guidance * D_x_cond + (1 - guidance) * D_x_uncond
        return D_x