import math

import numpy as np
import torch
import torch.nn as nn
from torchdiffeq import odeint
import torch.nn.functional as F
import math
from typing import Tuple, Union, Optional, List


def modulate(x, shift, scale):
    return x * (1 + scale) + shift


class TimestepEmbedder(nn.Module):
    """
    Embeds scalar timesteps into vector representations.
    """

    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Create sinusoidal timestep embeddings.
        :param t: a 1-D Tensor of N indices, one per batch element.
                          These may be fractional.
        :param dim: the dimension of the output.
        :param max_period: controls the minimum frequency of the embeddings.
        :return: an (N, D) Tensor of positional embeddings.
        """
        # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        if t.dim() == 1:
            t = t.unsqueeze(-1)
        args = t.float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb


class ODEfunc(nn.Module):
    """
    A residual block that can optionally change the number of channels.
    :param channels: the number of input channels.
    """

    def __init__(
            self,
            channels,
            time_channels,
            condition_channels
    ):
        super().__init__()
        self.channels = channels
        self.time_embed = TimestepEmbedder(time_channels)
        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True)
        )

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_channels + condition_channels, 3 * channels, bias=True)
        )
        self.nfe = 0

    def forward(self, t, x, y):
        self.nfe += 1
        if t.dim() == 0:
            t = t[None]
        t = self.time_embed(t)
        if len(y) != len(t):
            t = t.tile([len(y), 1])
        y = torch.cat([y, t], dim=-1)
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        res = gate_mlp * h
        return res


class ODEBlock(nn.Module):

    def __init__(self, channels, time_channels, condition_channels, step_size=1, method='heun3'):
        super().__init__()
        self.odefunc = ODEfunc(channels, time_channels, condition_channels)
        self.step_size = step_size
        self.method = method
        self.register_buffer('integration_time', torch.Tensor([0, 1]).float())

    def give_condition(self, y):
        def func(t, x):
            return self.odefunc(t, x, y)

        return func

    def forward(self, x, y):
        out = odeint(self.give_condition(y), x, self.integration_time, method=self.method)
        return out[1]

    @property
    def nfe(self):
        return self.odefunc.nfe

    @nfe.setter
    def nfe(self, value):
        self.odefunc.nfe = value


class ResBlock(nn.Module):

    def __init__(
            self,
            channels,
            condition_channels
    ):
        super().__init__()
        self.channels = channels

        self.in_ln = nn.LayerNorm(channels, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(channels, channels, bias=True),
            nn.SiLU(),
            nn.Linear(channels, channels, bias=True),
        )

        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(condition_channels, 3 * channels, bias=True)
        )

    def forward(self, x, y):
        shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
        h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
        h = self.mlp(h)
        return x + gate_mlp * h


class FinalLayer(nn.Module):
    """
    The final layer of DiT.
    """

    def __init__(self, model_channels, condition_channels, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(model_channels, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(condition_channels, 2 * model_channels, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x


class ConditionAdaLNODE(nn.Module):

    def __init__(
            self,
            in_channels,
            ode_channels,
            time_channels,
            ode_time_channels,
            condition_in_channels,
            condition_channels,
    ):
        super().__init__()
        self.time_embed = TimestepEmbedder(time_channels)
        if condition_channels > 0:
            self.cond_embed = nn.Linear(condition_in_channels, condition_channels)
        else:
            self.cond_embed = None

        self.input_proj = nn.Linear(in_channels, ode_channels)
        self.res_blocks = nn.ModuleList([ODEBlock(ode_channels, ode_time_channels, time_channels + condition_channels)])
        self.final_layer = FinalLayer(ode_channels, time_channels + condition_channels, in_channels)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.odefunc.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.odefunc.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _forward(self, x, t, condition=None):
        self.nfe = 0
        x = self.input_proj(x)
        t = self.time_embed(t)
        if self.cond_embed is not None:
            c = self.cond_embed(condition)
            if len(t) != len(c):
                t = t.tile([len(c)] + [1] * (t.dim() - 1))
            y = torch.cat([t, c], dim=-1)
        else:
            y = t

        for block in self.res_blocks:
            x = block(x, y)

        x = self.final_layer(x, y)
        return x

    def forward(self, x, t,
                condition=None,
                condition_u=None,
                cfg=1.,
                **kwargs):
        if cfg == 1:
            return self._forward(x, t, condition)
        else:
            x_c = self._forward(x, t, condition)
            x_uc = self._forward(x, t, condition_u)
            return x_c + (x_c - x_uc) * (cfg - 1.)

    @property
    def nfe(self):
        return np.mean([e.nfe for e in self.res_blocks])

    @nfe.setter
    def nfe(self, value):
        for e in self.res_blocks:
            e.nfe = value


class ConditionAdaLN(nn.Module):

    def __init__(
            self,
            in_channels,
            ode_channels,
            time_channels,
            condition_in_channels,
            condition_channels
    ):
        super().__init__()
        self.time_embed = TimestepEmbedder(time_channels)
        if condition_channels > 0:
            self.cond_embed = nn.Linear(condition_in_channels, condition_channels)
        else:
            self.cond_embed = None

        self.input_proj = nn.Linear(in_channels, ode_channels)
        self.res_blocks = nn.ModuleList([ResBlock(ode_channels, time_channels + condition_channels) for _ in range(3)])
        self.final_layer = FinalLayer(ode_channels, time_channels + condition_channels, in_channels)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _forward(self, x, t, condition=None):
        self.nfe = 0
        x = self.input_proj(x)
        t = self.time_embed(t)
        if self.cond_embed is not None:
            c = self.cond_embed(condition)
            if len(t) != len(c):
                t = t.tile([len(c)] + [1] * (t.dim() - 1))
            y = torch.cat([t, c], dim=-1)
        else:
            y = t

        for block in self.res_blocks:
            x = block(x, y)

        x = self.final_layer(x, y)
        return x

    def forward(self, x, t,
                condition=None,
                condition_u=None,
                cfg=1.):
        if cfg == 1:
            return self._forward(x, t, condition)
        else:
            x_c = self._forward(x, t, condition)
            x_uc = self._forward(x, t, condition_u)
        return x_c + (x_c - x_uc) * (cfg - 1.)


class ConditionAdaLNv2(nn.Module):

    def __init__(
            self,
            in_channels: int,
            ode_channels: int,
            time_channels: int,
            condition_in_channels: int,
            condition_channels: int,
            *,
            numerical_dim: int,
            n_categories_per_columns: List[int],
            hidden_per_col: int = 4
    ):
        super().__init__()
        # schema
        self.numerical_dim = int(numerical_dim)
        self.n_categories_per_columns = list(n_categories_per_columns)
        self.num_categorical = len(self.n_categories_per_columns)
        self.n_columns = self.numerical_dim + self.num_categorical
        self.hidden_per_col = int(hidden_per_col)

        # time / condition embeddings
        self.time_embed = TimestepEmbedder(time_channels)
        self.cond_embed = nn.Linear(condition_in_channels, condition_channels) if condition_channels > 0 else None

        # backbone
        self.input_proj = nn.Linear(in_channels, ode_channels)
        self.res_blocks = nn.ModuleList([
            ResBlock(ode_channels, time_channels + condition_channels) for _ in range(3)
        ])
        # project to per-column hidden states
        self.final_layer = FinalLayer(ode_channels, time_channels + condition_channels,
                                      self.n_columns * self.hidden_per_col)

        # column-wise heads
        # continuous: PositionalVectorEncoding over columns (hidden_per_col -> 1)
        self.num_head = PositionalVectorEncoding(self.numerical_dim, self.hidden_per_col, 1)
        # categorical: per-column Linear(hidden_per_col -> n_cls)
        self.cat_heads = nn.ModuleList([
            nn.Linear(self.hidden_per_col, n_cls) for n_cls in self.n_categories_per_columns
        ])

        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero-out adaLN modulation layers for stability
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

        # Heads: keep default init except ensure Linear biases are zero (already in _basic_init)

    def _forward(self, x: torch.Tensor, t: torch.Tensor, condition: Optional[torch.Tensor] = None) -> torch.Tensor:
        # backbone forward
        x = self.input_proj(x)
        t = self.time_embed(t)
        if self.cond_embed is not None and condition is not None:
            c = self.cond_embed(condition)
            if len(t) != len(c):
                t = t.tile([len(c)] + [1] * (t.dim() - 1))
            y = torch.cat([t, c], dim=-1)
        else:
            y = t

        for block in self.res_blocks:
            x = block(x, y)

        # per-column hidden states
        h = self.final_layer(x, y)  # (B, n_columns * hidden_per_col)
        h = h.view(len(h), self.n_columns, self.hidden_per_col)

        # split into numerical and categorical columns
        h_num = h[:, :self.numerical_dim] if self.numerical_dim > 0 else None
        h_cat = h[:, self.numerical_dim:] if self.num_categorical > 0 else None

        # numerical projection via PositionalVectorEncoding -> (B, n_num)
        if h_num is not None and self.numerical_dim > 0:
            num = self.num_head(h_num).squeeze(-1)
        else:
            num = torch.zeros(len(h), 0, device=h.device, dtype=h.dtype)

        # categorical projection via per-column Linear -> concat logits
        if h_cat is not None and self.num_categorical > 0:
            cat_logits = []
            for i, head in enumerate(self.cat_heads):
                cat_logits.append(head(h_cat[:, i]))  # (B, n_cls_i)
            cat = torch.cat(cat_logits, dim=1) if len(cat_logits) else torch.zeros(len(h), 0, device=h.device,
                                                                                   dtype=h.dtype)
        else:
            cat = torch.zeros(len(h), 0, device=h.device, dtype=h.dtype)

        # flatten back to (B, in_channels)
        out = torch.cat([num, cat], dim=1)
        return out

    def forward(self, x: torch.Tensor, t: torch.Tensor,
                condition: Optional[torch.Tensor] = None,
                condition_u: Optional[torch.Tensor] = None,
                cfg: float = 1.) -> torch.Tensor:
        if cfg == 1:
            return self._forward(x, t, condition)
        else:
            x_c = self._forward(x, t, condition)
            x_uc = self._forward(x, t, condition_u)
            return x_c + (x_c - x_uc) * (cfg - 1.)


class ConditionAdaLNTransformer(nn.Module):

    def __init__(
            self,
            in_channels,
            d_model,
            time_channels,
            condition_channels,
            *,
            # schema for Transformer-based conditional encoding (required)
            numerical_dim: int,
            n_categories_per_columns: List[int],
            transformer_layers: int = 2,
            transformer_nheads: Optional[int] = None,
            # options
            use_mask_token: bool = True,
            cond_pool: str = 'mean'  # 'mean' | 'mean_observed' | 'mean_missing' | 'cls'
    ):
        super().__init__()
        self.time_embed = TimestepEmbedder(time_channels)

        # Build Transformer-based conditional encoder (schema is required)
        self.num_numerical = int(numerical_dim)
        self.n_categories_per_columns = list(n_categories_per_columns)
        if condition_channels <= 0:
            raise ValueError("condition_channels must be > 0 for Transformer-based encoder")
        self.cond_pool = cond_pool
        self.condition_channels = condition_channels

        d_token = condition_channels
        # CLS token
        self.cls_tkn = nn.Embedding(1, d_token)

        # Categorical per-column embeddings (use vocab size exactly; missing handled via mask tokens or zeroing)
        self.categorical_embeddings = nn.ModuleList([
            nn.Embedding(n_vocab, d_token) for n_vocab in self.n_categories_per_columns
        ])
        # Numerical projection as in ImpuGenVAEModel (PositionalVectorEncoding)
        self.numerical_embeddings = PositionalVectorEncoding(self.num_numerical, 1, d_token)
        # Precompute categorical one-hot slicing offsets for faster slicing
        starts = [0]
        for n in self.n_categories_per_columns[:-1]:
            starts.append(starts[-1] + n)
        self._cat_starts: List[int] = starts
        self._cat_ends: List[int] = [s + n for s, n in zip(self._cat_starts, self.n_categories_per_columns)]
        # Precompute per-column id rows for fast one-hot -> id mapping
        for j, n in enumerate(self.n_categories_per_columns):
            self.register_buffer(f"_cat_id_row_{j}", torch.arange(n, dtype=torch.float32))
        # Simple caches by batch size
        self._cached_B = -1
        self._mask_tok_num_cache = None
        self._mask_tok_row_cache = None
        self._pad_prefix_cache = None

        # Transformer encoder (encoder_mu style)
        if transformer_nheads is None:
            transformer_nheads = max(1, min(8, d_token))
            while d_token % transformer_nheads != 0 and transformer_nheads > 1:
                transformer_nheads -= 1
        self.cond_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_token, transformer_nheads, int(d_token * 4),
                activation=nn.GELU(), batch_first=True, norm_first=True, dropout=0.0
            ),
            num_layers=transformer_layers
        )

        # Options
        self.use_mask_token = use_mask_token
        if self.use_mask_token:
            # Shared mask token for missing numerics and categoricals
            self.mask_token = nn.Parameter(torch.zeros(1, 1, d_token))
        else:
            self.register_parameter('mask_token', None)
        # Precompute shapes
        self.num_categorical = len(self.n_categories_per_columns)
        self.n_columns = self.num_numerical + self.num_categorical
        self.categorical_dim_onehot = sum(self.n_categories_per_columns)
        # values dim = numerical + one-hot categorical
        self.condition_values_dim = self.num_numerical + self.categorical_dim_onehot

        self.input_proj = nn.Linear(in_channels, d_model)
        self.res_blocks = nn.ModuleList([ResBlock(d_model, time_channels + condition_channels) for _ in range(3)])
        self.final_layer = FinalLayer(d_model, time_channels + condition_channels, in_channels)
        self.initialize_weights()

    def initialize_weights(self):
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Init Transformer condition modules
        nn.init.kaiming_uniform_(self.cls_tkn.weight, a=math.sqrt(5))
        for embedding in self.categorical_embeddings:
            nn.init.kaiming_uniform_(embedding.weight, a=math.sqrt(5))
        if getattr(self, 'mask_token', None) is not None:
            nn.init.normal_(self.mask_token, mean=0.0, std=0.02)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

        # Zero-out output layers
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _build_condition_tokens(self, values: torch.Tensor, observed_column: torch.Tensor) -> Tuple[
        torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
        """
        Returns
        -------
        tokens : (B, 1+n_columns, d)
        padding_mask : (B, 1+n_columns) bool or None
        obs_mask : (B, n_columns) float in {0,1}
        """
        B = len(values)
        device = values.device

        obs_mask = observed_column

        # Numerical tokens
        offs = 0
        num_tokens = None
        if self.num_numerical > 0:
            num_vals = values[:, offs: offs + self.num_numerical]
            offs += self.num_numerical
            num_tokens = self.numerical_embeddings(num_vals)  # (B, n_num, d)
            if self.use_mask_token and self.mask_token is not None:
                # cache per-batch expansions
                if getattr(self, '_cached_B', -1) != B:
                    self._mask_tok_num_cache = self.mask_token.expand(B, self.num_numerical, -1)
                    self._mask_tok_row_cache = self.mask_token.expand(B, -1, -1).squeeze(1)
                    self._pad_prefix_cache = torch.zeros(B, 1, device=device, dtype=torch.bool)
                    self._cached_B = B
                obs_num = obs_mask[:, : self.num_numerical] > 0.5
                num_tokens = torch.where(obs_num.unsqueeze(-1), num_tokens, self._mask_tok_num_cache)

        # Categorical tokens
        cat_tokens = None
        if self.num_categorical > 0:
            cat_onehot = values[:, offs: offs + self.categorical_dim_onehot]
            embeds = []
            if self.use_mask_token and self.mask_token is not None:
                if getattr(self, '_cached_B', -1) != B:
                    self._mask_tok_num_cache = self.mask_token.expand(B, self.num_numerical, -1)
                    self._mask_tok_row_cache = self.mask_token.expand(B, -1, -1).squeeze(1)
                    self._pad_prefix_cache = torch.zeros(B, 1, device=device, dtype=torch.bool)
                    self._cached_B = B
                mask_tok_row = self._mask_tok_row_cache
            else:
                mask_tok_row = None
            for j, (s, e) in enumerate(zip(self._cat_starts, self._cat_ends)):
                sl = cat_onehot[:, s:e]
                id_row = getattr(self, f"_cat_id_row_{j}").to(sl)
                ids = (sl * id_row).sum(dim=-1).round().long()
                emb_j = self.categorical_embeddings[j](ids)
                if mask_tok_row is not None:
                    observed_j = (obs_mask[:, self.num_numerical + j] > 0.5)
                    emb_j = torch.where(observed_j.unsqueeze(-1), emb_j, mask_tok_row)
                embeds.append(emb_j)
            cat_tokens = torch.stack(embeds, dim=1)  # (B, n_cat, d)

        tokens_list = [self.cls_tkn(torch.zeros(B, 1, device=device, dtype=torch.int64))]
        if num_tokens is not None:
            tokens_list.append(num_tokens)
        if cat_tokens is not None:
            tokens_list.append(cat_tokens)
        tokens = torch.cat(tokens_list, dim=1)  # (B, 1+n_columns, d)

        # Build padding mask if not using mask tokens (mask out unobserved positions)
        if not self.use_mask_token:
            if getattr(self, '_cached_B', -1) != B:
                self._pad_prefix_cache = torch.zeros(B, 1, device=device, dtype=torch.bool)
                self._cached_B = B
            pad_vals = ~(obs_mask > 0.5)
            padding_mask = torch.cat([self._pad_prefix_cache, pad_vals], dim=1)
        else:
            padding_mask = None

        return tokens, padding_mask, obs_mask

    def _encode_condition(self, values: torch.Tensor, observed_column: torch.Tensor) -> torch.Tensor:
        tokens, padding_mask, obs_mask = self._build_condition_tokens(values, observed_column)
        encoded = self.cond_encoder(tokens, src_key_padding_mask=padding_mask)

        # Pooling
        if self.cond_pool == 'cls':
            c = encoded[:, 0]
        else:
            seq = encoded[:, 1:]
            if self.cond_pool == 'mean_observed':
                sel = (obs_mask > 0.5)
            elif self.cond_pool == 'mean_missing':
                sel = ~(obs_mask > 0.5)
            else:
                sel = torch.ones_like(obs_mask, dtype=torch.bool)
            denom = sel.sum(dim=1).clamp(min=1).unsqueeze(-1)
            c = (seq * sel.unsqueeze(-1)).sum(dim=1) / denom
        return c

    def _forward(self, x, t, condition=None, observed_column=None):
        self.nfe = 0
        x = self.input_proj(x)
        t = self.time_embed(t)
        if condition is None or observed_column is None:
            y = t
        else:
            c = self._encode_condition(condition, observed_column)
            if len(t) != len(c):
                t = t.tile([len(c)] + [1] * (t.dim() - 1))
            y = torch.cat([t, c], dim=-1)

        for block in self.res_blocks:
            x = block(x, y)

        x = self.final_layer(x, y)
        return x

    def forward(self, x, t,
                condition: Optional[torch.Tensor] = None,
                observed_column: Optional[torch.Tensor] = None,
                condition_u: Optional[torch.Tensor] = None,
                observed_column_u: Optional[torch.Tensor] = None,
                cfg: float = 1.):
        if cfg == 1:
            return self._forward(x, t, condition, observed_column)
        else:
            x_c = self._forward(x, t, condition, observed_column)
            x_uc = self._forward(x, t, condition_u, observed_column_u)
            return x_c + (x_c - x_uc) * (cfg - 1.)


class PositionalVectorEncoding(nn.Module):

    def __init__(self, n_inputs, in_features, out_features, weight=True, bias=True, device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.n_inputs = n_inputs
        self.in_features = in_features
        self.out_features = out_features
        if weight:
            self.weight = nn.Parameter(torch.empty((n_inputs, in_features, out_features), **factory_kwargs))
        else:
            self.weight = None
        if bias:
            self.bias = nn.Parameter(torch.empty((n_inputs, out_features), **factory_kwargs))
        else:
            self.bias = None
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        if self.weight is not None:
            fan = self.in_features
            gain = math.sqrt(1 / 3)
            std = gain / math.sqrt(fan)
            bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
            nn.init.uniform_(self.weight, -bound, bound)
        if self.bias is not None:
            fan_in = self.in_features
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        if self.weight is None:
            if self.bias is None:
                return input
            return input + self.bias[None, :input.shape[1]]
        else:
            input = input.reshape(len(input), -1, self.in_features)
        return torch.einsum('bnc, nco -> bno', [input, self.weight[:input.shape[1]]]) + self.bias[None, :input.shape[1]]

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )


class MLP(nn.Module):
    def __init__(
            self,
            input_dim: int,
            hidden_dim: int,
            output_dim: int,
            num_layers: int,
            dropout: float = 0.,
            sigmoid_output: bool = False,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.dropout = nn.Dropout(dropout)
        self.sigmoid_output = sigmoid_output

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.silu(self.dropout(layer(x))) if i < self.num_layers - 1 else layer(x)
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x


class PositionalMLP(nn.Module):
    def __init__(
            self,
            n_inputs: int,
            input_dim: int,
            hidden_dim: int,
            output_dim: int,
            num_layers: int,
            dropout: float = 0.,
            sigmoid_output: bool = False,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(
            PositionalVectorEncoding(n_inputs, n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.dropout = nn.Dropout(dropout)
        self.sigmoid_output = sigmoid_output

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.silu(self.dropout(layer(x))) if i < self.num_layers - 1 else layer(x)
        if self.sigmoid_output:
            x = F.sigmoid(x)
        return x


class ConditionAdaLNv3(nn.Module):
    """ConditionAdaLNv3: condition을 관측 마스크로 modulation(FiLM)하여 사용하는 AdaLN 백본.

    - condition 값 벡터를 `condition_channels`로 투영
    - observed_column 마스크를 작은 MLP로 통과시켜 (gamma, beta) 생성
    - c = c_val * (1 + gamma) + beta 로 변조 후 time 임베딩과 결합하여 ResBlock에 주입
    """

    def __init__(
            self,
            in_channels: int,
            ode_channels: int,
            time_channels: int,
            condition_val_in_channels: int,
            condition_channels: int,
            mask_in_channels: int,
            mask_mod_with_time: bool = True,
    ) -> None:
        super().__init__()
        self.time_embed = TimestepEmbedder(time_channels)
        self.mask_mod_with_time = mask_mod_with_time

        # condition value -> embedding
        if condition_channels > 0:
            self.cond_val_embed = nn.Linear(condition_val_in_channels, condition_channels)
        else:
            self.cond_val_embed = None

        # mask -> (gamma, beta)
        mask_in = mask_in_channels + (self.time_embed.mlp[-1].out_features if self.mask_mod_with_time else 0)
        hidden = max(128, 2 * condition_channels)
        self.mask_mlp = nn.Sequential(
            nn.Linear(mask_in, hidden, bias=True),
            nn.SiLU(),
            nn.Linear(hidden, 2 * condition_channels, bias=True),
        )

        # backbone
        self.input_proj = nn.Linear(in_channels, ode_channels)
        self.res_blocks = nn.ModuleList([ResBlock(ode_channels, time_channels + condition_channels) for _ in range(3)])
        self.final_layer = FinalLayer(ode_channels, time_channels + condition_channels, in_channels)

        self.initialize_weights()

    def initialize_weights(self) -> None:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero init mask MLP last layer → 초기엔 무변조
        last = self.mask_mlp[-1]
        if isinstance(last, nn.Linear):
            nn.init.constant_(last.weight, 0)
            nn.init.constant_(last.bias, 0)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _encode_condition(self, condition: Optional[torch.Tensor], observed_column: Optional[torch.Tensor], t_emb: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if self.cond_val_embed is None or condition is None or observed_column is None:
            return None
        c_val = self.cond_val_embed(condition)
        if self.mask_mod_with_time and t_emb is not None:
            mask_in = torch.cat([observed_column, t_emb], dim=-1)
        else:
            mask_in = observed_column
        gamma_beta = self.mask_mlp(mask_in)
        gamma, beta = torch.chunk(gamma_beta, 2, dim=-1)
        c = modulate(c_val, beta, gamma)
        return c

    def _forward(self, x: torch.Tensor, t: torch.Tensor,
                 condition: Optional[torch.Tensor] = None,
                 observed_column: Optional[torch.Tensor] = None) -> torch.Tensor:
        # reset nfe
        self.nfe = 0
        # project input and time
        x = self.input_proj(x)
        t_emb = self.time_embed(t)
        # ensure batch match with x
        if len(t_emb) != len(x):
            t_emb = t_emb.tile([len(x), 1])

        c = self._encode_condition(condition, observed_column, t_emb)
        if c is not None:
            if len(t_emb) != len(c):
                t_emb = t_emb.tile([len(c)] + [1] * (t_emb.dim() - 1))
            y = torch.cat([t_emb, c], dim=-1)
        else:
            y = t_emb

        for block in self.res_blocks:
            x = block(x, y)

        x = self.final_layer(x, y)
        return x

    def forward(self, x: torch.Tensor, t: torch.Tensor,
                condition: Optional[torch.Tensor] = None,
                observed_column: Optional[torch.Tensor] = None,
                condition_u: Optional[torch.Tensor] = None,
                observed_column_u: Optional[torch.Tensor] = None,
                cfg: float = 1.) -> torch.Tensor:
        if cfg == 1:
            return self._forward(x, t, condition, observed_column)
        else:
            x_c = self._forward(x, t, condition, observed_column)
            x_uc = self._forward(x, t, condition_u, observed_column_u)
            return x_c + (x_c - x_uc) * (cfg - 1.)


class ConditionAdaLNv4(nn.Module):
    """ConditionAdaLNv4: observe_mask(+time)로 condition과 time 임베딩을 모두 modulation.

    - condition 값을 d_cond로 투영 후 (gamma_c, beta_c)로 FiLM
    - time 임베딩도 (gamma_t, beta_t)로 FiLM하여 마스크/단계 적응 강화
    - 마지막 층 0 초기화로 초기 무변조 상태에서 안정 학습 시작
    """

    def __init__(
            self,
            in_channels: int,
            ode_channels: int,
            time_channels: int,
            condition_val_in_channels: int,
            condition_channels: int,
            mask_in_channels: int,
            mask_mod_with_time: bool = True,
            mod_time: bool = True,
    ) -> None:
        super().__init__()
        self.time_embed = TimestepEmbedder(time_channels)
        self.mask_mod_with_time = mask_mod_with_time
        self.mod_time = mod_time

        # condition value -> embedding
        if condition_channels > 0:
            self.cond_val_embed = nn.Linear(condition_val_in_channels, condition_channels)
        else:
            self.cond_val_embed = None

        # mask(+time) -> (gamma_c, beta_c[, gamma_t, beta_t])
        extra = self.time_embed.mlp[-1].out_features if self.mask_mod_with_time else 0
        mask_in = mask_in_channels + extra
        out_dim = 2 * condition_channels + (2 * time_channels if mod_time else 0)
        hidden = max(128, out_dim)
        self.mask_mlp = nn.Sequential(
            nn.Linear(mask_in, hidden, bias=True),
            nn.SiLU(),
            nn.Linear(hidden, out_dim, bias=True),
        )

        # backbone
        self.input_proj = nn.Linear(in_channels, ode_channels)
        self.res_blocks = nn.ModuleList([ResBlock(ode_channels, time_channels + condition_channels) for _ in range(3)])
        self.final_layer = FinalLayer(ode_channels, time_channels + condition_channels, in_channels)

        self.initialize_weights()

    def initialize_weights(self) -> None:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)

        # Initialize timestep embedding MLP
        nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
        nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)

        # Zero init mask MLP last layer → 초기엔 무변조
        last = self.mask_mlp[-1]
        if isinstance(last, nn.Linear):
            nn.init.constant_(last.weight, 0)
            nn.init.constant_(last.bias, 0)

        # Zero-out adaLN modulation layers
        for block in self.res_blocks:
            nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
            nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
        nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    def _encode(self, condition: Optional[torch.Tensor], observed_column: Optional[torch.Tensor], t_emb: Optional[torch.Tensor]):
        if self.cond_val_embed is None or condition is None or observed_column is None:
            return None, None
        c_val = self.cond_val_embed(condition)
        if self.mask_mod_with_time and t_emb is not None:
            # ensure batch dims match
            if len(t_emb) != len(observed_column):
                t_in = t_emb.tile([len(observed_column), 1])
            else:
                t_in = t_emb
            mask_in = torch.cat([observed_column, t_in], dim=-1)
        else:
            mask_in = observed_column
        vec = self.mask_mlp(mask_in)
        if self.mod_time:
            gc, bc, gt, bt = torch.split(vec, [c_val.size(-1), c_val.size(-1), t_emb.size(-1), t_emb.size(-1)], dim=-1)
        else:
            gc, bc = torch.chunk(vec, 2, dim=-1)
            gt = bt = None
        c = modulate(c_val, bc, gc)
        return c, (gt, bt)

    def _forward(self, x: torch.Tensor, t: torch.Tensor,
                 condition: Optional[torch.Tensor] = None,
                 observed_column: Optional[torch.Tensor] = None) -> torch.Tensor:
        self.nfe = 0
        x = self.input_proj(x)
        t_emb = self.time_embed(t)

        c, tb = self._encode(condition, observed_column, t_emb)
        if self.mod_time and tb is not None and tb[0] is not None:
            gt, bt = tb
            t_mod = modulate(t_emb, bt, gt)
        else:
            t_mod = t_emb

        if c is not None:
            if len(t_mod) != len(c):
                t_mod = t_mod.tile([len(c)] + [1] * (t_mod.dim() - 1))
            y = torch.cat([t_mod, c], dim=-1)
        else:
            y = t_mod

        for block in self.res_blocks:
            x = block(x, y)

        x = self.final_layer(x, y)
        return x

    def forward(self, x: torch.Tensor, t: torch.Tensor,
                condition: Optional[torch.Tensor] = None,
                observed_column: Optional[torch.Tensor] = None,
                condition_u: Optional[torch.Tensor] = None,
                observed_column_u: Optional[torch.Tensor] = None,
                cfg: float = 1.) -> torch.Tensor:
        if cfg == 1:
            return self._forward(x, t, condition, observed_column)
        else:
            x_c = self._forward(x, t, condition, observed_column)
            x_uc = self._forward(x, t, condition_u, observed_column_u)
            return x_c + (x_c - x_uc) * (cfg - 1.)


class ImpuGenVAEModel(nn.Module):

    def __init__(self, d_model, d_transformer, nhead, feedforward_ratio, dropout, n_layers,
                 n_categories_per_columns, n_cat_columns, n_num_columns, **kwargs):
        super().__init__()
        self.d_model = d_model
        self.d_transformer = d_transformer
        self.n_cat_columns = n_cat_columns
        self.n_num_columns = n_num_columns
        self.seq_len = n_cat_columns + n_num_columns

        self.cls_tkn = nn.Embedding(1, d_transformer)
        self.categorical_embeddings = nn.ModuleList([
            nn.Embedding(n_vocab, d_transformer) for n_vocab in n_categories_per_columns
        ])
        self.numerical_embeddings = PositionalVectorEncoding(n_num_columns, 1, d_transformer)
        nn.init.kaiming_uniform_(self.cls_tkn.weight, a=math.sqrt(5))
        for embedding in self.categorical_embeddings:
            nn.init.kaiming_uniform_(embedding.weight, a=math.sqrt(5))

        # encoders
        self.encoder_mu = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_transformer, nhead, int(d_transformer * feedforward_ratio),
                                       activation=nn.GELU(), batch_first=True, norm_first=True,
                                       dropout=dropout), num_layers=n_layers
        )
        self.encoder_logvar = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_transformer, nhead, int(d_transformer * feedforward_ratio),
                                       activation=nn.GELU(), batch_first=True, norm_first=True,
                                       dropout=dropout), num_layers=n_layers
        )
        if d_transformer == d_model:
            self.encoder_mu_head = nn.Identity()
            self.encoder_logvar_head = nn.Identity()
        else:
            self.encoder_mu_head = nn.Linear(d_transformer, d_model)
            self.encoder_logvar_head = nn.Linear(d_transformer, d_model)

        # decoder
        self.decoder_proj = nn.Linear(d_model, d_transformer) if d_transformer != d_model else nn.Identity()
        self.decoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_transformer, nhead, int(d_transformer * feedforward_ratio),
                                       activation=nn.GELU(), batch_first=True, norm_first=True,
                                       dropout=dropout), num_layers=n_layers
        )
        # prediction heads
        self.categorical_heads = nn.ModuleList([
            MLP(d_transformer, d_transformer * feedforward_ratio, n_vocab, num_layers=1) for n_vocab in
            n_categories_per_columns
        ])
        self.numerical_heads = PositionalMLP(n_num_columns, d_transformer, d_transformer * feedforward_ratio, 1,
                                             num_layers=1)

    @staticmethod
    def reparameterize(mu: torch.Tensor, logvar: torch.Tensor, noise=1., norm_noise=False) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        if norm_noise:
            eps = torch.nn.functional.normalize(eps.view(len(eps), -1), dim=-1, eps=1e-5).view_as(eps)
        return mu + eps * std * noise

    def projection(self, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]):
        if isinstance(x, tuple):
            x = torch.cat(x, dim=1)
        nan_mask = x.isnan()
        x = x.nan_to_num()
        cat = torch.cat([torch.zeros(len(x), 0, self.d_transformer, device=self.device, dtype=self.dtype)] +
                        [embedding(col) for embedding, col in zip(
                            self.categorical_embeddings, x[:, :self.n_cat_columns].to(torch.int32).split(1, dim=1))],
                        dim=1)
        x = torch.cat([cat, self.numerical_embeddings(x[:, self.n_cat_columns:])], dim=1)
        x = x * ~nan_mask.unsqueeze(-1) + torch.zeros_like(x) * nan_mask.unsqueeze(-1)  # zero padding
        cls_tkn = self.cls_tkn(torch.zeros(len(x), 1, device=self.device, dtype=torch.int32))
        x = torch.cat([cls_tkn, x], dim=1)
        bsz, sz, _ = x.shape
        mask = torch.zeros(bsz, sz, device=self.device, dtype=torch.bool)
        mask[:, 1:] = nan_mask
        return x, mask

    def head(self, x: torch.Tensor) -> Tuple[List[torch.Tensor], torch.Tensor]:
        cat = []
        for head, col in zip(self.categorical_heads, x[:, :self.n_cat_columns].nan_to_num().split(1, dim=1)):
            pred = head(col.squeeze(1))
            cat.append(pred)
        num = self.numerical_heads(x[:, self.n_cat_columns:]).squeeze(-1)
        return cat, num

    def forward(self, x: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, *,
                z: Optional[torch.Tensor] = None, encode_only: bool = False, logvar_only: bool = False,
                mu_and_logvar=False,
                noise=1., norm_noise=False):
        assert x is not None or z is not None
        mask = None
        if x is not None:
            x, mask = self.projection(x)
            mu_z = self.encoder_mu_head(self.encoder_mu(x, src_key_padding_mask=mask))
            std_z = self.encoder_logvar_head(self.encoder_logvar(x, src_key_padding_mask=mask))
            z = self.reparameterize(mu_z, std_z, noise=noise, norm_noise=norm_noise)
            z = z[:, 1:]
            mask = mask[:, 1:]

        if mu_and_logvar:
            return torch.cat([mu_z[:, 1:], std_z[:, 1:]], dim=-1)
        if encode_only:
            return z
        if logvar_only:
            return std_z[:, 1:]

        h = self.decoder(self.decoder_proj(z), src_key_padding_mask=mask)
        cat, num = self.head(h)
        if x is not None:
            return cat, num, h, mu_z, std_z
        return cat, num
