from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch import Tensor
from ..utils import GaussianFourierProjection

ModuleType = Union[str, Callable[..., nn.Module]]

class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)
    
    

    


def reglu(x: Tensor) -> Tensor:
    """The ReGLU activation function from [1].
    References:
        [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """
    assert x.shape[-1] % 2 == 0
    a, b = x.chunk(2, dim=-1)
    return a * F.relu(b)


def geglu(x: Tensor) -> Tensor:
    """The GEGLU activation function from [1].
    References:
        [1] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """
    assert x.shape[-1] % 2 == 0
    a, b = x.chunk(2, dim=-1)
    return a * F.gelu(b)

class ReGLU(nn.Module):
    """The ReGLU activation function from [shazeer2020glu].

    Examples:
        .. testcode::

            module = ReGLU()
            x = torch.randn(3, 4)
            assert module(x).shape == (3, 2)

    References:
        * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """

    def forward(self, x: Tensor) -> Tensor:
        return reglu(x)


class GEGLU(nn.Module):
    """The GEGLU activation function from [shazeer2020glu].

    Examples:
        .. testcode::

            module = GEGLU()
            x = torch.randn(3, 4)
            assert module(x).shape == (3, 2)

    References:
        * [shazeer2020glu] Noam Shazeer, "GLU Variants Improve Transformer", 2020
    """

    def forward(self, x: Tensor) -> Tensor:
        return geglu(x)
    

class MLP2048(nn.Module):
    def __init__(self, d_in = 6, dim_t = 2048):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)

        x = self.proj(x) + emb
        return self.mlp(x)


class MLPDiffusion(nn.Module):
    def __init__(self, d_in = 6, dim_t = 256):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)

class MLPDiffusionTabSyn(nn.Module):
    def __init__(self, d_in = 6, dim_t = 512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)

class MLPDiffusionTabSyn1024(nn.Module):
    def __init__(self, d_in = 6, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)

        x = self.proj(x) + emb
        return self.mlp(x)
    

class MLPDiffusionTabSynBig1024(nn.Module):
    def __init__(self, d_in = 6, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 4), #1024 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionICL(nn.Module):
    def __init__(self, d_in = 6, dim_t = 200):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(200)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)

class MLPDiffusionVAE(nn.Module):
    def __init__(self, d_in = 6, dim_t = 512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(4 * d_in, 2 * d_in),
            nn.SiLU(),
            nn.Linear(2 * d_in, d_in),
            nn.SiLU(),
            nn.Linear(d_in, d_in * 2),
            nn.SiLU(),
            nn.Linear(2 * d_in, 4 * d_in),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionVAE1024(nn.Module):
    def __init__(self, d_in = 6, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, 4 * d_in),
            nn.SiLU(),
            nn.Linear(4 * d_in, 2 * d_in),
            nn.SiLU(),
            nn.Linear(2 * d_in, d_in),
            nn.SiLU(),
            nn.Linear(d_in, d_in * 2),
            nn.SiLU(),
            nn.Linear(2 * d_in, 4 * d_in),
            nn.SiLU(),
            nn.Linear(4 * d_in, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionPara(nn.Module):
    def __init__(self, d_in = 6, dim_t = 2048):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionFlow(nn.Module):
    def __init__(self, d_in = 6, dim_t = 784):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionBig(nn.Module):
    def __init__(self, d_in = 6, dim_t = 256):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 4), #1024 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionBig512(nn.Module):
    def __init__(self, d_in = 6, dim_t = 512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 4), #1024 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)

class MLPDiffusionBig1024(nn.Module):
    def __init__(self, d_in = 6, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 4), #1024 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)

class MLPDiffusionBiger(nn.Module):
    def __init__(self, d_in = 6, dim_t = 256):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 8), #1024 --> 2048
            nn.SiLU(),
            nn.Linear(dim_t * 8, dim_t * 4), #2048 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionBiger512(nn.Module):
    def __init__(self, d_in = 6, dim_t = 512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 8), #1024 --> 2048
            nn.SiLU(),
            nn.Linear(dim_t * 8, dim_t * 4), #2048 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
    
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionBiger1024(nn.Module):
    def __init__(self, d_in = 6, dim_t = 1024):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 8), #1024 --> 2048
            nn.SiLU(),
            nn.Linear(dim_t * 8, dim_t * 4), #2048 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, dim_t)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
        
        x = self.proj(x) + emb
        return self.mlp(x)
    
class MLPDiffusionBigger(nn.Module):
    def __init__(self, d_in = 6, dim_t = 256):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t//2)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t//2, dim_t), #128 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, dim_t * 2), #256 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 4), #512 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 8), #1024 --> 2048
            nn.SiLU(),
            nn.Linear(dim_t * 8, dim_t * 8), #2048 --> 2048
            nn.SiLU(),
            nn.Linear(dim_t * 8, dim_t * 4), #2048 --> 1024
            nn.SiLU(),
            nn.Linear(dim_t * 4, dim_t * 2), #1024 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2), #512 --> 512
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t), #512 --> 256
            nn.SiLU(),
            nn.Linear(dim_t, dim_t//2), #256 --> 128
            nn.SiLU(),
            nn.Linear(dim_t//2, d_in),
        )

        self.map_noise = GaussianFourierProjection(dim_t//2)
        self.time_embed = nn.Sequential(
            nn.Linear(dim_t//2, dim_t//2),
            nn.SiLU(),
            nn.Linear(dim_t//2, dim_t//2)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
        
        x = self.proj(x) + emb
        return self.mlp(x)
    

class MLPDiffusionFlex(nn.Module):
    def __init__(self, d_in = 6, dim_t = 256):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, d_in * 2)

        self.mlp = nn.Sequential(
            nn.Linear(d_in * 2, d_in * 4), #256 --> 512
            nn.SiLU(),
            nn.Linear(d_in * 4, d_in * 8), #512 --> 512
            nn.SiLU(),
            nn.Linear(d_in * 8, d_in * 16), #512 --> 1024
            nn.SiLU(),
            nn.Linear(d_in * 16, d_in * 16), #1024 --> 1024
            nn.SiLU(),
            nn.Linear(d_in * 16, d_in * 8), #1024 --> 512
            nn.SiLU(),
            nn.Linear(d_in * 8, d_in * 4), #512 --> 512
            nn.SiLU(),
            nn.Linear(d_in * 4, d_in * 2), #512 --> 256
            nn.SiLU(),
            nn.Linear(d_in * 2, d_in),
        )

        self.map_noise = GaussianFourierProjection(256)
        self.time_embed = nn.Sequential(
            nn.Linear(256, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in*2)
        )
    
    def forward(self, x, noise_labels, class_labels=None):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        emb = self.map_noise(noise_labels)
        emb = emb.reshape(emb.shape[0], 2, -1).flip(1).reshape(*emb.shape) # swap sin/cos
        emb = self.time_embed(emb)
        
        x = self.proj(x) + emb
        return self.mlp(x)
    
    
class MLPDiffusionOnly(nn.Module):
    def __init__(self, d_in = 512, dim_t = 512):
        super().__init__()
        self.dim_t = dim_t

        self.proj = nn.Linear(d_in, dim_t)

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )
    
    def forward(self, x, class_labels=None):
    
        x = self.proj(x) # + emb
        return self.mlp(x)
    


