import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import torch.jit
from utils.diffusion_utils import edm_sampler, trigflow_sampler
from typing import Optional


#----------------------------------------------------------------------------
# Magnitude-preserving SiLU (Equation 81).
@torch.jit.script
def mp_silu(x: torch.Tensor) -> torch.Tensor:
    """
    Magnitude-preserving SiLU activation (Equation 81).
    Uses torch.nn.functional.silu and scales the output by 1/0.596.
    """
    return F.silu(x) / 0.596

class MPSiLU(nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.silu(x) / 0.596

class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype=t.dtype)
        t_emb = self.mlp(t_freq)
        return t_emb



#----------------------------------------------------------------------------
# Magnitude-preserving sum (Equation 88).

@torch.jit.script
def mp_sum(a: torch.Tensor, b: torch.Tensor, t: float = 0.5) -> torch.Tensor:
    """
    Magnitude-preserving sum of two tensors (Equation 88).
    Performs a linear interpolation (lerp) between a and b with factor t,
    and then divides by sqrt((1-t)^2 + t^2).
    """
    denom = math.sqrt((1 - t) ** 2 + t ** 2)
    return a.lerp(b, t) / denom

class ECTLoss(nn.Module):
    """ECT Loss for MAR"""
    def __init__(
        self, 
        target_channels: int, 
        z_channels: int, 
        num_res_blocks: int, 
        model_channels: int, 
        seq_len_per_frame: int = 1, 
        P_mean: float = -0.4, 
        P_std: float = 1.0,
        sigma_data: float = 0.5, 
        label_drop_prob: float = 0.0,
        initializer_range: float = 0.02,
        label_balance: float = 0.5,
        naive_mar_mlp: bool = False,
        condition_merge: bool = False,
        head_dropout_p: float = 0.1,
        adj_map_func: str = 'sigmoid',
        ect_q: float = 4.0, 
        ect_c: float = 0.06, 
        ect_k: float = 8.0, 
        ect_b: float = 1.0, 
    ):
        super(ECTLoss, self).__init__()
        self.P_mean = P_mean
        self.P_std = P_std
        self.sigma_data = sigma_data
        self.in_channels = target_channels
        
        if adj_map_func == 'const':
            self.t_to_r = self.t_to_r_const
        elif adj_map_func == 'sigmoid':
            self.t_to_r = self.t_to_r_sigmoid
        else:
            raise ValueError(f'Unknow schedule type {adj_map_func}!')

        self.ect_q = ect_q
        self.stage: float = 0.0 
        self.ratio = 0.
        
        self.ect_k = ect_k
        self.ect_b = ect_b

        self.ect_c = ect_c
        
        self.diffusion_head = DiffusionHead_v2(
            in_channels=target_channels,
            model_channels=model_channels,
            out_channels=target_channels,
            z_channels=z_channels,
            seq_len_per_frame=seq_len_per_frame,
            num_res_blocks=num_res_blocks,
            sigma_data=sigma_data,
            label_drop_prob=label_drop_prob,
            initializer_range=initializer_range,
            label_balance=label_balance,
            naive_mar_mlp=naive_mar_mlp,
            condition_merge=condition_merge,
            head_dropout_p=head_dropout_p,
        )

        # self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
        # self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
    def t_to_r_const(self, t):
        # decay = 1 / self.ect_q ** (self.stage+1)
        stage_tensor = torch.tensor(self.stage, device=t.device)
        decay = 1 / self.ect_q ** torch.ceil(stage_tensor)
        ratio = 1 - decay
        r = t * ratio
        return torch.clamp(r, min=0)

    def t_to_r_sigmoid(self, t):
        adj = 1 + self.ect_k * torch.sigmoid(-self.ect_b * t)
        stage_tensor = torch.tensor(self.stage, device=t.device)
        # decay = 1 / self.ect_q ** (self.stage+1)
        decay = 1 / (self.ect_q ** torch.ceil(stage_tensor))
        ratio = 1 - decay * adj
        r = t * ratio
        # print(r)
        return torch.clamp(r, min=0)
    
    def update_schedule(self, stage):
        self.stage = stage
        # self.ratio = 1 - 1 / self.ect_q ** (stage+1)
    
    def forward(self, target: torch.Tensor, z: torch.Tensor):
        # EDM2
        rnd_normal = torch.randn([target.shape[0], 1], device=target.device, dtype=target.dtype)
        cond = z
        
        sigma = (rnd_normal * self.P_std + self.P_mean).exp() # t
        r = self.t_to_r(sigma)
        
        noise = torch.randn_like(target)
        noise_t = noise * sigma
        noise_r = noise * r
        
        weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2
        
        rng_state = torch.cuda.get_rng_state()
        ### EDM 2 loss
        Dx_t = self.diffusion_head(target + noise_t, sigma, cond, return_logvar=False)
        
        if r.max() > 0:
            torch.cuda.set_rng_state(rng_state)
            with torch.no_grad():
                Dx_r = self.diffusion_head(target + noise_r, r, cond, return_logvar=False)
            
            mask = r > 0
            Dx_r = torch.nan_to_num(Dx_r)
            Dx_r = mask * Dx_r + (~mask) * target
        else:
            Dx_r = target
        
        loss = (Dx_t - Dx_r) ** 2
        loss = torch.sum(loss.reshape(loss.shape[0], -1), dim=-1)
        
        # Producing Adaptive Weighting (p=0.5) through Huber Loss
        if self.ect_c > 0:
            loss = torch.sqrt(loss + self.ect_c ** 2) - self.ect_c
        else:
            loss = torch.sqrt(loss)

        # Weighting fn
        return (loss * weight.flatten()).mean(), loss.mean()

        
    def sample(
        self, 
        z: torch.Tensor,
        temperature: float = 1.0,
        cfg: float = 1.0, 
        device: torch.device = 'cuda',
        net_autoguidance: nn.Module = None,
        mid_t: Optional[torch.Tensor] = None,
        num_steps: int = 20, 
        sigma_min: float = 0.002, 
        sigma_max: float = 80, 
        rho: float = 7, 
        S_churn: int = 0,
        S_min: int = 0, 
        S_max: float = float('inf'), 
        S_noise: int = 1, 
    ) -> torch.Tensor:
        # diffusion loss sampling
        # if not cfg == 1.0:
        #     noise = torch.randn(z.shape[0], self.in_channels).to(device)
        #     noise = torch.cat([noise, noise], dim=0)
        # else:
        #     noise = torch.randn(z.shape[0], self.in_channels).to(device)
        dtype = z.dtype
        noise = torch.randn(z.shape[0], self.in_channels, dtype=dtype).to(device)

        mid_t = [] if mid_t is None else mid_t
        # mid_t = [2.5]   # ect 2 step
        # mid_t = [5.0, 1.1, 0.08] # ect 4 step
    
        t_steps = torch.tensor([sigma_max]+list(mid_t), dtype=dtype, device=device)
        # t_steps = torch.tensor([sigma_max], dtype=dtype, device=device)
        
        # t_0 = T, t_N = 0
        t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]).to(dtype)
        
        # Sampling steps 
        x_next = noise.to(dtype) * t_steps[0]
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            x_next = self.diffusion_head.inference(x_next, t_cur, z)
            if t_next > 0:
                x_next = x_next + t_next * torch.randn_like(x_next)
        return x_next
    



def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class ResBlockv2(nn.Module):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    """

    def __init__(
        self,
        channels: int,
        head_dropout_p: float = 0.0,
        adaDN: bool = False
    ):
        super().__init__()
        self.channels = channels

        self.in_rms = nn.RMSNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Dropout(head_dropout_p),
            nn.Linear(channels, channels, bias=True),
        )
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(channels, 3 * channels, bias=True)
        )
        self.pixel_norm = PixelNorm() if adaDN else None

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        if self.pixel_norm is not None:
            shift_mlp = self.pixel_norm(shift_mlp)
            scale_mlp = self.pixel_norm(scale_mlp)
        h = modulate(self.in_rms(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h

class PixelNorm(nn.Module):
    def __init__(self, eps: float = 1e-8):
        """Normalize each feature vector (pixel) by its L2 norm across channels."""
        super().__init__()
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # normalize over channels
        mean_sq = x.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x * torch.rsqrt(mean_sq + self.eps)
        return x_norm

    
class FinalLayerv2(nn.Module):
    """
    Final projection layer with optional adaptive double nomalization (from TrigFlow).
    - Applies non-affine LayerNorm
    - Computes scale & shift via a small MLP (SiLU + Linear)
    - Optionally PixelNorms the modulation vectors for stability
    - Applies modulation and a final linear map
    """
    def __init__(self, model_channels: int, out_channels: int, adaDN: bool = False, head_dropout_p: float = 0.0):
        super().__init__()
        self.norm_final = nn.RMSNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(model_channels, 2 * model_channels, bias=True)
        )
        self.pixel_norm = PixelNorm() if adaDN else None
        self.dropout = nn.Dropout(head_dropout_p) 
    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        if self.pixel_norm is not None:
            shift = self.pixel_norm(shift)
            scale = self.pixel_norm(scale)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.dropout(x)
        return self.linear(x)
    
#----------------------------------------------------------------------------
# Normalize given tensor to unit magnitude with respect to the given
# dimensions. Default = all dimensions except the first.

def normalize(x, dim=None, eps=1e-4):
    if dim is None:
        dim = list(range(1, x.ndim))
    norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
    norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
    return x / norm.to(x.dtype)

class MPFourier(nn.Module):
    def __init__(self, num_channels: int, bandwidth: float = 1.0):
        super().__init__()
        self.register_buffer('freqs', 2 * np.pi * torch.randn(num_channels, dtype=torch.float32) * bandwidth)
        self.register_buffer('phases', 2 * np.pi * torch.rand(num_channels, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.to(torch.float32)
        y_unsq = y.unsqueeze(-1)
        freqs = self.freqs.view(1, -1)  # [1, num_channels]
        phases = self.phases.view(1, -1)  # [1, num_channels]
        y = y_unsq * freqs 
        y = y + phases         
        y = torch.cos(y) * np.sqrt(2)
        return y.to(x.dtype)

#----------------------------------------------------------------------------
# Magnitude-preserving convolution or fully-connected layer (Equation 47)
# with force weight normalization (Equation 66).
class MPConv(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, kernel: tuple):
        super().__init__()
        self.out_channels = out_channels
        self.kernel = kernel
        self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel, dtype=torch.float32))

    def forward(self, x: torch.Tensor, gain: float = 1.0) -> torch.Tensor:
        w = self.weight
        if self.training:
            with torch.no_grad():
                self.weight.copy_(normalize(w)) # forced weight normalization
        w = normalize(w) # traditional weight normalization
        w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
        w = w.to(x.dtype)
        if w.ndim == 2:
            return x @ w.t()
        elif w.ndim == 3:
            kernel_size = w.shape[-1]
            padding = kernel_size // 2
            out = F.conv1d(x, w, padding=padding)
            return out
        assert w.ndim == 4
        return torch.nn.functional.conv2d(x, w, padding=(w.shape[-1]//2,))

class DiffusionHead_v2(nn.Module):
    """
    The MLP for Diffusion Loss.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param z_channels: channels in the condition.
    :param num_res_blocks: number of residual blocks per downsample.
    """

    def __init__(
        self,
        in_channels: int,
        model_channels: int,
        out_channels: int,
        z_channels: int,
        num_res_blocks: int,
        seq_len_per_frame: int = 1,
        logvar_channels: int = 128,
        sigma_data: float = 0.5,
        label_drop_prob: float = 0.0,
        initializer_range: float = 0.02,
        label_balance: float = 0.5,
        naive_mar_mlp: bool = False,
        condition_merge: bool = False,
        head_dropout_p: float = 0.0,
    ):
        super().__init__()
        assert seq_len_per_frame == 1, "Only seq_len_per_frame=1 is supported at this point."
        
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.initializer_range = initializer_range

        # Fake class embedding for CFG's unconditional generation
        self.label_drop_prob = label_drop_prob
        if label_drop_prob > 0:
            self.fake_latent = nn.Parameter(torch.zeros(1, model_channels))

        self.input_proj = nn.Sequential(
            nn.Linear(in_channels, model_channels),
            nn.Dropout(head_dropout_p)
        )
        res_blocks = []
        for i in range(num_res_blocks):
            res_blocks.append(ResBlockv2(
                model_channels,
                head_dropout_p,
                adaDN=False
            ))
        self.res_blocks = nn.ModuleList(res_blocks)
        self.final_layer = FinalLayerv2(
            model_channels, 
            out_channels, 
            adaDN=False, 
            head_dropout_p=head_dropout_p
        )
        
        self.time_embed = TimestepEmbedder(model_channels)
        if not condition_merge:
            self.cond_embed = nn.Linear(z_channels, model_channels)
            # self.cond_embed = MPConv(z_channels, model_channels, kernel=[])
        else:
            self.cond_embed = nn.Sequential(
                nn.Linear(z_channels, model_channels),
                nn.SiLU(),
                nn.Linear(model_channels, model_channels),
            )
            # self.cond_embed = nn.Sequential(
            #     MPConv(z_channels, model_channels, kernel=[]),
            #     MPSiLU(),
            #     MPConv(model_channels, model_channels, kernel=[]),
            # ) 
        self.sigma_data = sigma_data
        self.register_buffer('sigma_data_tensor', torch.tensor(sigma_data), persistent=False)
        self.logvar_pe = TimestepEmbedder(logvar_channels)
        self.logvar_linear = MPConv(logvar_channels, 1, kernel=[])
        self.label_balance = label_balance
        self.initialize_weights()

    def initialize_weights(self):
        def _init_linear(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

        self.input_proj.apply(_init_linear)
        self.cond_embed.apply(_init_linear)
        self.time_embed.mlp.apply(_init_linear)
        self.logvar_pe.mlp.apply(_init_linear)

        for block in self.res_blocks:
            block.apply(_init_linear)
            nn.init.zeros_(block.adaLN_modulation[-1].weight)
            nn.init.zeros_(block.adaLN_modulation[-1].bias)

        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].weight)
        nn.init.zeros_(self.final_layer.adaLN_modulation[-1].bias)
        self.final_layer.linear.apply(_init_linear)
        # nn.init.zeros_(self.final_layer.linear.weight)
        # nn.init.zeros_(self.final_layer.linear.bias)

    def forward(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, return_logvar: bool=False):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        # EDM precond
        sigma = sigma.reshape(-1, 1) # [(bsz x seq), 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [(bsz x seq), 1]
        c_skip = sigma_data ** 2 / denom  # [(bsz x seq), 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [(bsz x seq), 1]
        c_in   = 1 / denom.sqrt()  # [(bsz x seq), 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [(bsz x seq)]
        x_in = (c_in * x) 
        
        
        t = self.time_embed(c_noise)
        cond = self.cond_embed(cond)
        if self.training and self.label_drop_prob > 0:
            drop_latent_mask = torch.rand(cond.shape[0], dtype=cond.dtype, device=cond.device) < self.label_drop_prob
            drop_latent_mask = drop_latent_mask.unsqueeze(-1)
            cond = drop_latent_mask * self.fake_latent.to(cond.dtype) + (~drop_latent_mask) * cond
        y = mp_sum(t, cond, t=self.label_balance)
            
        ### network part
        x_in = self.input_proj(x_in)  # [(bsz x seq), model_channels]
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x + c_out * F_x
        
        # Estimate uncertainty if requested.
        if return_logvar:
            logvar = self.logvar_linear(self.logvar_pe(c_noise)).reshape(-1, 1)
            return D_x, logvar # u(sigma) in Equation 21     

        return D_x
    
    
    def inference(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :param cond_2: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        cond = cond.view(x.size(0), -1)
        # EDM precond
        sigma = sigma.reshape(-1, 1) # [(bsz x seq), 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [(bsz x seq), 1]
        c_skip = sigma_data ** 2 / denom  # [(bsz x seq), 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [(bsz x seq), 1]
        c_in   = 1 / denom.sqrt()  # [(bsz x seq), 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [(bsz x seq)]
        x_in = (c_in * x) 
        
        
        # conditioning part
        t = self.time_embed(c_noise)
        cond = self.cond_embed(cond)
        y = mp_sum(t, cond, t=self.label_balance)
        
        ### network part
        x_in = self.input_proj(x_in)  # [(bsz x seq), model_channels]
        for block in self.res_blocks:
            x_in = block(x_in, y)
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x + c_out * F_x
        
        return D_x
    
    
    def inference_uncond(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, guidance: float):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        cond = cond.view(x.size(0), -1)
        # for cfg
        sigma = sigma.repeat(cond.shape[0], 1).reshape(-1, 1)
        x_cfg = torch.cat([x, x], dim=0)
        sigma_cfg = torch.cat([sigma, sigma], dim=0)
        # EDM precond
        sigma = sigma_cfg.reshape(-1, 1) # [2*bsz, 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [2*bsz, 1]
        c_skip = sigma_data ** 2 / denom  # [2*bsz, 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [2*bsz, 1]
        c_in   = 1 / denom.sqrt()  # [2*bsz, 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [2*bsz]
        x_in = (c_in * x_cfg) # [2*bsz, latent_dim]
        
        
        
        ### network part
        x_in = self.input_proj(x_in)  # [2*bsz, model_channels]
        
        t = self.emb_noise(self.emb_fourier(c_noise).squeeze()) # [2*bsz, head_dim]
        
        cond = self.emb_label(cond) # [(bsz x seq), head_dim]
        uncond = self.fake_latent.to(cond.dtype).repeat(cond.shape[0], 1) # [bsz, head_dim]
        cond = torch.cat([cond, uncond], dim=0) # [2*bsz, head_dim]

        y = mp_sum(t, cond, t=self.label_balance)
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x_cfg + c_out * F_x
        D_x_cond, D_x_uncond = torch.chunk(D_x, 2, dim=0)
        D_x = guidance * D_x_cond + (1 - guidance) * D_x_uncond
        return D_x

    def inference_uncond_2(self, x: torch.Tensor, sigma: torch.Tensor, cond: torch.Tensor, uncond: torch.Tensor, guidance: float):
        """
        Apply the model to an input batch.
        :param x: [(bsz x seq), latent_dim] Tensor of inputs.
        :param t: a 1-D batch of timesteps.
        :param cond: conditioning from AR transformer. [(bsz x seq), latent_dim]
        :return: [(bsz x seq), latent_dim] Tensor of outputs.
        """
        
        cond = cond.view(x.size(0), -1)
        uncond = uncond.view(x.size(0), -1)
        # for cfg
        sigma = sigma.repeat(cond.shape[0], 1).reshape(-1, 1)
        x_cfg = torch.cat([x, x], dim=0)
        sigma_cfg = torch.cat([sigma, sigma], dim=0)
        # EDM precond
        sigma = sigma_cfg.reshape(-1, 1) # [2*bsz, 1]
        sigma_data = self.sigma_data
        denom = sigma ** 2 + sigma_data ** 2  # [2*bsz, 1]
        c_skip = sigma_data ** 2 / denom  # [2*bsz, 1]
        c_out  = sigma * sigma_data / denom.sqrt()  # [2*bsz, 1]
        c_in   = 1 / denom.sqrt()  # [2*bsz, 1]
        c_noise = (sigma.flatten().log() / 4.0)  # [2*bsz]
        x_in = (c_in * x_cfg) # [2*bsz, latent_dim]
        
        
        
        ### network part
        x_in = self.input_proj(x_in)  # [2*bsz, model_channels]
        
        t = self.time_embed(c_noise)
        cond = self.cond_embed(cond) # [(bsz x seq), head_dim]
        uncond = self.cond_embed(uncond) # [(bsz x seq), head_dim]
        cond = torch.cat([cond, uncond], dim=0) # [2*bsz, head_dim]

        y = mp_sum(t, cond, t=self.label_balance)
        for block in self.res_blocks:
            x_in = block(x_in, y)
            
        F_x = self.final_layer(x_in, y)
        D_x = c_skip * x_cfg + c_out * F_x
        D_x_cond, D_x_uncond = torch.chunk(D_x, 2, dim=0)
        D_x = guidance * D_x_cond + (1 - guidance) * D_x_uncond
        return D_x
    