import math

import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from .train_utils import normalize_emb
from timm.models.vision_transformer import Attention, Mlp

import torch.nn.init as nn_init

from torch import Tensor

class Tokenizer(nn.Module):

    def __init__(self, d_numerical, d_token, bias= True):
        super().__init__()
        #if categories is None:
        d_bias = d_numerical
        self.category_offsets = None
        self.category_embeddings = None

        # do not take [CLS] token into account
        self.weight = nn.Parameter(Tensor(d_numerical , d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))


    def forward(self, x_num ):
        x = self.weight[None] * x_num[:, :, None]
        
        if self.bias is not None:
            x = x + self.bias[None]

        return x

class MaskEmbed(nn.Module):
    def __init__(self, num_features, d_token, bias= False):
        super().__init__()
        d_bias = num_features
        self.weight = nn.Parameter(Tensor(num_features , d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    def forward(self, x_num ):
        x = self.weight[None] * x_num[:, :, None]
        if self.bias is not None:
            x = x + self.bias[None]
        return x


class MultiheadAttention(nn.Module):
    def __init__(self, d, n_heads, dropout, initialization = 'kaiming'):

        if n_heads > 1:
            assert d % n_heads == 0
        assert initialization in ['xavier', 'kaiming']

        super().__init__()
        self.W_q = nn.Linear(d, d)
        self.W_k = nn.Linear(d, d)
        self.W_v = nn.Linear(d, d)
        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None

        for m in [self.W_q, self.W_k, self.W_v]:
            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
                # gain is needed since W_qkv is represented with 3 separate layers
                nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
            nn_init.zeros_(m.bias)
        if self.W_out is not None:
            nn_init.zeros_(self.W_out.bias)

    def _reshape(self, x):
        batch_size, n_tokens, d = x.shape
        d_head = d // self.n_heads
        return (
            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
            .transpose(1, 2)
            .reshape(batch_size * self.n_heads, n_tokens, d_head)
        )

    def forward(self, x_q, x_kv, key_compression = None, value_compression = None):
  
        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
        for tensor in [q, k, v]:
            assert tensor.shape[-1] % self.n_heads == 0
        if key_compression is not None:
            assert value_compression is not None
            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
        else:
            assert value_compression is None

        batch_size = len(q)
        d_head_key = k.shape[-1] // self.n_heads
        d_head_value = v.shape[-1] // self.n_heads
        n_q_tokens = q.shape[1]

        q = self._reshape(q)
        k = self._reshape(k)

        a = q @ k.transpose(1, 2)
        b = math.sqrt(d_head_key)
        attention = F.softmax(a/b , dim=-1)

        
        if self.dropout is not None:
            attention = self.dropout(attention)
        x = attention @ self._reshape(v)
        x = (
            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
            .transpose(1, 2)
            .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
        )
        if self.W_out is not None:
            x = self.W_out(x)

        return x
        
class Transformer(nn.Module):

    def __init__(
        self,
        n_layers: int,
        d_token: int,
        n_heads: int,
        d_out: int,
        d_ffn_factor: int,
        attention_dropout = 0.0,
        ffn_dropout = 0.0,
        residual_dropout = 0.0,
        activation = 'relu',
        prenormalization = True,
        initialization = 'kaiming',      
    ):
        super().__init__()

        def make_normalization():
            return nn.LayerNorm(d_token)

        d_hidden = int(d_token * d_ffn_factor)
        self.layers = nn.ModuleList([])
        for layer_idx in range(n_layers):
            layer = nn.ModuleDict(
                {
                    'attention': MultiheadAttention(
                        d_token, n_heads, attention_dropout, initialization
                    ),
                    'linear0': nn.Linear(
                        d_token, d_hidden
                    ),
                    'linear1': nn.Linear(d_hidden, d_token),
                    'norm1': make_normalization(),
                }
            )
            if not prenormalization or layer_idx:
                layer['norm0'] = make_normalization()
   
            self.layers.append(layer)

        self.activation = nn.ReLU()
        self.last_activation = nn.ReLU()
        # self.activation = lib.get_activation_fn(activation)
        # self.last_activation = lib.get_nonglu_activation_fn(activation)
        self.prenormalization = prenormalization
        self.last_normalization = make_normalization() if prenormalization else None
        self.ffn_dropout = ffn_dropout
        self.residual_dropout = residual_dropout
        self.head = nn.Linear(d_token, d_out)


    def _start_residual(self, x, layer, norm_idx):
        x_residual = x
        if self.prenormalization:
            norm_key = f'norm{norm_idx}'
            if norm_key in layer:
                x_residual = layer[norm_key](x_residual)
        return x_residual

    def _end_residual(self, x, x_residual, layer, norm_idx):
        if self.residual_dropout:
            x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
        x = x + x_residual
        if not self.prenormalization:
            x = layer[f'norm{norm_idx}'](x)
        return x

    def forward(self, x):
        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx + 1 == len(self.layers)

            x_residual = self._start_residual(x, layer, 0)
            x_residual = layer['attention'](
                # for the last attention, it is enough to process only [CLS]
                x_residual,
                x_residual,
            )

            x = self._end_residual(x, x_residual, layer, 0)

            x_residual = self._start_residual(x, layer, 1)
            x_residual = layer['linear0'](x_residual)
            x_residual = self.activation(x_residual)
            if self.ffn_dropout:
                x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
            x_residual = layer['linear1'](x_residual)
            x = self._end_residual(x, x_residual, layer, 1)
        return x

def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        #print(x.shape)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x
    
    
class Penultimate(nn.Module):
    """
    The Penultimate layer of DiT.
    """
    def __init__(self, hidden_size):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        #print('c:', c.shape)
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        return x
    
class TimeStepEmbedding(nn.Module):
    """
    Layer that embeds diffusion timesteps.
    
     Args:
        - dim (int): the dimension of the output.
        - max_period (int): controls the minimum frequency of the embeddings.
        - n_layers (int): number of dense layers
        - fourer (bool): whether to use random fourier features as embeddings
    """
    def __init__(
        self,
        dim: int,
        max_period: int = 10000,
        n_layers: int = 2,
        fourier: bool = False,
        scale=16,
    ):
        super().__init__()
        self.dim = dim
        self.max_period = max_period
        self.n_layers = n_layers
        self.fourier = fourier

        if dim % 2 != 0:
            raise ValueError(f"embedding dim must be even, got {dim}")

        if fourier:
            self.register_buffer("freqs", torch.randn(dim // 2) * scale)

        layers = []
        for i in range(n_layers - 1):
            layers.append(nn.Linear(dim, dim))
            layers.append(nn.SiLU())
        self.fc = nn.Sequential(*layers, nn.Linear(dim, dim))

    def forward(self, timesteps):
        if not self.fourier:
            d, T = self.dim, self.max_period
            mid = d // 2
            fs = torch.exp(-math.log(T) / mid * torch.arange(mid, dtype=torch.float32))
            #print(timesteps.device)
            fs = fs.to(timesteps.device)
            args = timesteps[:, None].float() * fs[None]
            emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        else:
            x = timesteps.ger((2 * torch.pi * self.freqs).to(timesteps.dtype))
            emb = torch.cat([x.cos(), x.sin()], dim=1)

        return self.fc(emb)


class FinalLayer_org(nn.Module):
    """
    Final layer that predicts logits for each category for categorical features 
    and scalers for continuous features.
    """

    def __init__(self, dim_in, categories, num_cont_features, bias_init=None):
        super().__init__()
        self.num_cont_features = num_cont_features
        self.num_cat_features = len(categories)
        # patch
        categories_with_unk = [c + 1 for c in categories]
        #dim_out = sum(categories) + self.num_cont_features
        dim_out = sum(categories_with_unk) + self.num_cont_features

        self.linear = nn.Linear(dim_in, dim_out)
        nn.init.zeros_(self.linear.weight)
        if bias_init is None:
            nn.init.zeros_(self.linear.bias)
        else:
            # Ensure bias_init has the correct size
            if bias_init.shape[0] != dim_out:
                print(f'  ⚠️  WARNING: bias_init size ({bias_init.shape[0]}) != dim_out ({dim_out}). '
                      f'Padding or truncating to match.')
                if bias_init.shape[0] < dim_out:
                    # Pad with zeros
                    padding = torch.zeros(dim_out - bias_init.shape[0], dtype=bias_init.dtype, device=bias_init.device)
                    bias_init = torch.cat([bias_init, padding])
                else:
                    # Truncate
                    bias_init = bias_init[:dim_out]
            self.linear.bias = nn.Parameter(bias_init)
        # patch
        #self.split_chunks = [self.num_cont_features, *categories]
        self.split_chunks = [self.num_cont_features, *categories_with_unk]
        #print(90*'=')
        #print('UPDATED')

        self.cat_idx = 0
        if self.num_cont_features > 0:
            self.cat_idx = 1

    def forward(self, x):
        #print('x.shape', x.shape)
        #print('noerror')
        x = self.linear(x)
        #print('x.shape', x.shape)
        out = torch.split(x, self.split_chunks, dim=-1)

        if self.num_cont_features > 0:
            cont_logits = out[0]
        else:
            cont_logits = None
        if self.num_cat_features > 0:
            cat_logits = out[self.cat_idx :]
        else:
            cat_logits = None

        #print('cat_logits.shape:', cat_logits.shape)
        #print('cont_logits.shape:', cont_logits.shape)

        return cat_logits, cont_logits


class FinalLayer(nn.Module):
    """
    Final layer that predicts logits for each category for categorical features 
    and scalers for continuous features.
    """

    def __init__(self, dim_in, categories, num_cont_features, bias_init=None):
        super().__init__()
        self.num_cont_features = num_cont_features
        self.num_cat_features = len(categories)
        dim_out = sum(categories) + self.num_cont_features
        
        #dict_cat = { key:val for key, val in enumerate(categories) }
        #dict_cont = { self.num_cat_features+i:1 for i in range(len(num_cont_features))}
        
        outputs_num = [1 for _ in range(num_cont_features)]
        
        layers_cat = []
        layers_num = []
        cum_sum = 0
        
        for idx, dim_out in enumerate(outputs_num):
            layer =  nn.Linear(dim_in, dim_out)
            nn.init.zeros_(layer.weight)
            if bias_init is None:
                nn.init.zeros_(layer.bias)
            else:
                layer.bias = nn.Parameter(bias_init[cum_sum: cum_sum+dim_out])
            layers_num.append(layer)
            cum_sum += dim_out
            
            
        for idx, dim_out in enumerate(categories):
            layer =  nn.Linear(dim_in, dim_out)
            nn.init.zeros_(layer.weight)
            if bias_init is None:
                nn.init.zeros_(layer.bias)
            else:
                layer.bias = nn.Parameter(bias_init[cum_sum: cum_sum+dim_out])
            layers_cat.append(layer)
            cum_sum += dim_out
        
        
        self.layers_cat = nn.ModuleList(layers_cat)
        self.layers_num = nn.ModuleList(layers_num)
        
        self.split_chunks = [self.num_cat_features, self.num_cont_features]

    def forward(self, x):
        #print('x.shape', x.shape)
        out_cat, out_cont = torch.split(x, self.split_chunks, dim=1)
        
        out_cat = torch.split(out_cat, [1]* self.num_cat_features, dim=1)
        out_cont = torch.split(out_cont, [1]*self.num_cont_features, dim=1)
        
        cat_logits = []
        cont_logits = []
        
        for idx, x_ in enumerate(out_cat):
            #print(self.layers_cat[idx])
            #print(x_.shape)
            cat_logits.append(self.layers_cat[idx](x_).squeeze(1))
            
        #print(cat_logits)
        for idx, x_ in enumerate(out_cont):
            cont_logits.append(self.layers_num[idx](x_))
        
        return cat_logits, torch.concat(cont_logits, 1).squeeze(-1)

class PositionalEmbedder(nn.Module):
    """
    Positional embedding layer for encoding continuous features.
    Adapted from https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py#L61
    """

    def __init__(self, dim, num_features, trainable=False, freq_init_scale=0.01):
        super().__init__()
        assert (dim % 2) == 0
        self.half_dim = dim // 2
        self.weights = nn.Parameter(
            torch.randn(1, num_features, self.half_dim), requires_grad=trainable
        )
        self.sigma = freq_init_scale
        bound = self.sigma * 3
        nn.init.trunc_normal_(self.weights, 0.0, self.sigma, a=-bound, b=bound)

    def forward(self, x):
        x = rearrange(x, "b f -> b f 1")
        freqs = x * self.weights * 2 * torch.pi
        fourier = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
        return fourier


class ContEmbedder(nn.Module):
    """
    Embedding layer for continuous features that utilizes Fourier features.
    """

    def __init__(self, dim, num_features, freq_init_scale=0.01):
        super().__init__()
        assert (dim % 2) == 0
        self.pos_emb = PositionalEmbedder(
            2 * dim, num_features, trainable=True, freq_init_scale=freq_init_scale
        )
        self.nlinear = NLinear(2 * dim, dim, num_features)
        self.act = nn.SiLU()

    def forward(self, x):
        x = self.pos_emb(x)
        x = self.nlinear(x)
        return self.act(x)


class NLinear(nn.Module):
    """N separate linear layers for N separate features
    adapted from https://github.com/yandex-research/rtdl-num-embeddings/blob/main/package/rtdl_num_embeddings.py#L61
    x has typically 3 dimensions: (batch, features, embedding dim)
    """

    def __init__(self, in_dim, out_dim, n):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(n, in_dim, out_dim))
        self.bias = nn.Parameter(torch.empty(n, out_dim))
        d_in_rsqrt = 1 / math.sqrt(in_dim)
        nn.init.uniform_(self.weight, -d_in_rsqrt, d_in_rsqrt)
        nn.init.uniform_(self.bias, -d_in_rsqrt, d_in_rsqrt)

    def forward(self, x):
        x = (x[..., None, :] @ self.weight).squeeze(-2)
        x += self.bias
        return x


class FeatCond(nn.Module):
    """
    Feature-specific conditioning module for more complex feature-sensitive conditioning.
    """

    def __init__(self, num_features, d_in, d_out, init_zero=False, cond_dim=None):
        super().__init__()

        self.num_features = num_features
        self.condition = cond_dim is not None
        self.nlinear = NLinear(d_in, d_out, num_features)
        if init_zero:
            nn.init.zeros_(self.nlinear.weight)
            nn.init.zeros_(self.nlinear.bias)
        self.act = nn.SiLU()

        if self.condition:
            self.cond_proj = nn.Linear(cond_dim, d_in * num_features)

    def forward(self, x, c=None):
        if self.condition:
            cond = F.silu(self.cond_proj(c))
            cond = rearrange(cond, "b (f d) -> b f d", f=self.num_features)
            x += cond
        h = self.nlinear(x)
        return self.act(h)


class CatEmbedding(nn.Module):
    """
    Feature-specific embedding layer for categorical features.
    bias = True adds a learnable bias term to each feature, which is is same across categories.
    """

    def __init__(self, dim, categories, cat_emb_init_sigma=0.001, bias=False):
        super().__init__()

        self.categories = torch.tensor(categories)
        categories_offset = self.categories.cumsum(dim=-1)[:-1]
        categories_offset = torch.cat(
            (torch.zeros((1,), dtype=torch.long), categories_offset)
        )
        print('categories_offset(before):', categories_offset)
        
        # # NEW
        self.categories_with_unk = self.categories + 1

        categories_offset = torch.cat([
            torch.zeros(1, dtype=torch.long),
            self.categories_with_unk.cumsum(dim=0)[:-1]
        ])
        print('categories_offset:', categories_offset)
        # UP
        
        self.register_buffer("categories_offset", categories_offset)
        self.dim = torch.tensor(dim)

        #self.cat_emb = nn.Embedding(sum(categories), dim)
        self.cat_emb = nn.Embedding(self.categories_with_unk.sum().item(), dim)
        
        nn.init.normal_(self.cat_emb.weight, std=cat_emb_init_sigma)

        self.bias = bias
        if self.bias:
            self.cat_bias = nn.Parameter(torch.zeros(len(categories), dim))

    def forward(self, x):
        """
        x: LongTensor of shape [batch, n_cat_features], category indices.
           Each value must be in [0, n_categories] where n_categories = train max + 1 (UNK).
        """
        # Defensive: check for OOB and map to UNK
        x = x.clone()
        for f in range(x.shape[1]):
            unk_index = self.categories[f].item()  # UNK is last index per feature
            max_valid = unk_index - 1
            invalid = (x[:, f] < 0) | (x[:, f] > max_valid)
            x[:, f][invalid] = unk_index  # map all OOB to UNK

        x_with_offset = x + self.categories_offset  # [batch, n_cat_features]
        assert (x_with_offset < self.cat_emb.num_embeddings).all(), (
            f"Index out of bounds: max {x_with_offset.max().item()}, embedding size {self.cat_emb.num_embeddings}"
        )
        x_emb = self.cat_emb(x_with_offset)
        if self.bias:
            x_emb += self.cat_bias
        x_emb = normalize_emb(x_emb, dim=2) * self.dim.sqrt().item()
        return x_emb

    def get_all_feat_emb(self, feat_idx):
        # Return embedding for all categories (including UNK) for feature `feat_idx`
        n_cats = self.categories_with_unk[feat_idx]
        emb_idx = (
            torch.arange(n_cats, device=self.cat_emb.weight.device)
            + self.categories_offset[feat_idx]
        )
        x = self.cat_emb(emb_idx)
        if self.bias:
            x += self.cat_bias[feat_idx]
        x = normalize_emb(x, dim=1) * self.dim.sqrt().item()
        return x


class MLP(nn.Module):
    def __init__(
        self,
        num_cont_features,
        cat_emb_dim,
        categories,
        num_y_classes,
        emb_dim,
        n_layers,
        n_units,
        proportions=None,
        use_fourier_features=False,
        act="relu",
        feat_spec_cond=False,
        time_fourier=False,
    ):
        super().__init__()

        self.num_cont_features = num_cont_features
        self.num_cat_features = len(categories)
        self.num_features = num_cont_features + self.num_cat_features
        self.time_emb = TimeStepEmbedding(emb_dim, fourier=time_fourier)

        self.y_cond = False
        if num_y_classes is not None:
            self.y_emb = nn.Embedding(num_y_classes, emb_dim)
            self.y_cond = True

        in_dims = [emb_dim] + (n_layers - 1) * [n_units]
        out_dims = n_layers * [n_units]

        layers = nn.ModuleList()
        for i in range(len(in_dims)):
            layers.append(nn.Linear(in_dims[i], out_dims[i]))
            layers.append(nn.ReLU() if act == "relu" else nn.SiLU())
        self.fc = nn.Sequential(*layers)

        self.use_fourier_features = use_fourier_features
        self.feat_spec_cond = feat_spec_cond
        if self.feat_spec_cond:
            assert self.use_fourier_features
            self.cond_feat = FeatCond(
                self.num_cat_features + num_cont_features,
                cat_emb_dim,
                cat_emb_dim,
                cond_dim=emb_dim,
                init_zero=True,
            )
            
            
            proj_dim_in = (
                num_cont_features * cat_emb_dim + self.num_cat_features * cat_emb_dim
            )
        elif self.use_fourier_features:
            proj_dim_in = (
                num_cont_features * cat_emb_dim + self.num_cat_features * cat_emb_dim
            )
        else:
            proj_dim_in = num_cont_features + self.num_cat_features * cat_emb_dim
        
        
            
        if self.use_fourier_features:
            self.cont_emb = ContEmbedder(cat_emb_dim, num_cont_features)

        self.proj = nn.Linear(proj_dim_in, emb_dim)

        # init final layer
        if proportions is None:
            bias_init = None
        else:
            cont_bias_init = torch.zeros((num_cont_features,))
            cat_bias_init = torch.cat(proportions).log()
            bias_init = torch.cat((cont_bias_init, cat_bias_init))

        self.final_layer = FinalLayer(
            out_dims[-1], categories, num_cont_features, bias_init=bias_init
        )

    def forward(
        self,
        x_cat_emb_t,
        x_cont_t,
        time,
        c,
    ):
        # construct time embedding
        cond_emb = self.time_emb(time)

        # construct conditioning embedding if using y_cond
        if self.y_cond:
            cond_emb = cond_emb + F.silu(self.y_emb(c))

        # map inputs to dim_emb
        if self.use_fourier_features:
            x_cont_t = self.cont_emb(x_cont_t)
            x_cont_t = rearrange(x_cont_t, "B F D -> B (F D)")

        if self.feat_spec_cond:
            x = torch.concat(
                (rearrange(x_cat_emb_t, "B F D -> B (F D)"), x_cont_t), dim=-1
            )
            x_cond = self.cond_feat(
                torch.concat(
                    (
                        x_cat_emb_t,
                        rearrange(
                            x_cont_t, "B (F D) -> B F D", F=self.num_cont_features
                        ),
                    ),
                    dim=1,
                ),
                cond_emb,
            )  # feat spec encoding path
            x += rearrange(x_cond, "B F D -> B (F D)")  # add back to main path
        else:
            x = torch.concat(
                (rearrange(x_cat_emb_t, "B F D -> B (F D)"), x_cont_t), dim=-1
            )

        emb = self.proj(x) + cond_emb
        h = self.fc(emb)

        return self.final_layer(h)


class TabDDPM_MLP(nn.Module):
    """
    TabDDPM-like architecture for both continuous and categorical features.
    Used for TabDDPM and CDTD.
    """

    def __init__(
        self,
        num_cont_features,
        cat_emb_dim,
        categories,
        num_y_classes,
        emb_dim,
        n_layers,
        n_units,
        proportions=None,
        use_fourier_features=False,
        act="relu",
    ):
        super().__init__()

        num_cat_features = len(categories)
        self.time_emb = TimeStepEmbedding(emb_dim, fourier=False)

        self.y_cond = False
        if num_y_classes is not None:
            self.y_emb = nn.Embedding(num_y_classes, emb_dim)
            self.y_cond = True

        self.use_fourier_features = use_fourier_features
        if self.use_fourier_features:
            self.cont_emb = ContEmbedder(cat_emb_dim, num_cont_features)

        in_dims = [emb_dim] + (n_layers - 1) * [n_units]
        out_dims = n_layers * [n_units]
        layers = nn.ModuleList()
        for i in range(len(in_dims)):
            layers.append(nn.Linear(in_dims[i], out_dims[i]))
            layers.append(nn.ReLU() if act == "relu" else nn.SiLU())
        self.fc = nn.Sequential(*layers)

        if self.use_fourier_features:
            dim_in = (num_cont_features + num_cat_features) * cat_emb_dim
        else:
            dim_in = num_cont_features + num_cat_features * cat_emb_dim
        self.proj = nn.Linear(dim_in, emb_dim)

        # init final layer
        if proportions is None:
            bias_init = None
        else:
            cont_bias_init = torch.zeros((num_cont_features,))
            cat_bias_init = torch.cat(proportions).log()
            
            # UPDATED TO ADD UNK's
            proportions_with_unk = [
                torch.cat([p, torch.tensor([1e-8], dtype=p.dtype, device=p.device)]) for p in proportions
            ]
            cat_bias_init = torch.cat(proportions_with_unk).log()
            # UPDATED TO ADD UNK's

            bias_init = torch.cat((cont_bias_init, cat_bias_init))

        self.final_layer = FinalLayer_org(
            out_dims[-1], categories, num_cont_features, bias_init=bias_init
        )

    
    def forward(self, x_cat_emb_t, x_cont_t, time, c, m_cat = None, m_cont = None, sample = False, obs_flag = False):

        # construct time embedding
        cond_emb = self.time_emb(time)

        # construct conditioning embedding if using y_cond
        if self.y_cond:
            cond_emb = cond_emb + F.silu(self.y_emb(c))

        if self.use_fourier_features:
            x_cont_t = self.cont_emb(x_cont_t)
            x_cont_t = rearrange(x_cont_t, "B F D -> B (F D)")

        
        if obs_flag == True: 
            if m_cat is not None:
                x_cat_emb_t = m_cat.unsqueeze(2) * x_cat_emb_t
            if m_cont is not None: 
                x_cont_t = m_cont * x_cont_t


        
        x = torch.concat((rearrange(x_cat_emb_t, "B F D -> B (F D)"), x_cont_t), dim=-1)

        x = self.proj(x) + cond_emb
        x = self.fc(x)

        return self.final_layer(x)


class TabDDPM_MLP_Cont(nn.Module):
    """
    TabDDPM-like architecture for continuous features only.
    This is used for TabSyn as a score model for learned latents.
    """

    def __init__(self, num_features, emb_dim, n_layers, n_units, act="relu"):
        super().__init__()

        self.time_emb = TimeStepEmbedding(emb_dim, fourier=False)
        in_dims = [emb_dim] + (n_layers - 1) * [n_units]
        out_dims = n_layers * [n_units]
        layers = nn.ModuleList()
        for i in range(len(in_dims)):
            layers.append(nn.Linear(in_dims[i], out_dims[i]))
            layers.append(nn.ReLU() if act == "relu" else nn.SiLU())
        # add final layer
        layers.append(nn.Linear(out_dims[-1], num_features))
        self.fc = nn.Sequential(*layers)
        self.proj = nn.Linear(num_features, emb_dim)

    def forward(self, x, time):
        cond_emb = self.time_emb(time)
        x = self.proj(x) + cond_emb
        return self.fc(x)
