import math
import torch
from torch import nn
from torch.nn import functional as F



class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)


class TimeEmbedding(nn.Module):
    def __init__(self, T, d_model, dim):
        assert d_model % 2 == 0
        super().__init__()
        emb = torch.arange(0, d_model, step=2) / d_model * math.log(10000)
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:, None] * emb[None, :]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1).view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb, freeze=False),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.timembedding(t)
        return emb


class ConditionalEmbedding(nn.Module):
    def __init__(self, num_labels, d_model, dim):
        super().__init__()
        self.condEmbedding = nn.Sequential(
            # nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model, padding_idx=0)
            nn.Embedding(num_embeddings=num_labels + 1, embedding_dim=d_model),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),
        )

    def forward(self, t):
        emb = self.condEmbedding(t)
        return emb



class FiLM_MLP_Block(nn.Module):
    def __init__(self, width: int, cond_dim: int, dropout: float = 0.0):
        super().__init__()
        self.lin1 = nn.Linear(width, width)
        self.lin2 = nn.Linear(width, width)
        self.norm = nn.LayerNorm(width)
        self.drop = nn.Dropout(dropout)
        self.film = nn.Linear(cond_dim, 2 * width)  # -> gamma, beta

    def forward(self, x, cond):
        # x: [B, W], cond: [B, C]
        h = self.norm(x)
        h = F.silu(self.lin1(h))
        gamma, beta = self.film(cond).chunk(2, dim=-1)
        h = h * (1 + gamma) + beta          # FiLM
        h = self.drop(F.silu(self.lin2(h)))
        return x + h



class EpsNet_FiLMMLP(nn.Module):
    def __init__(self, T, num_labels, ch=128, width=512, n_blocks=8, dropout=0.0, data_dim=256):
        super().__init__()
        tdim = ch * 4
        self.data_dim = data_dim
        self.time_embedding = TimeEmbedding(T, ch, tdim)
        self.cond_embedding = ConditionalEmbedding(num_labels, ch, tdim)
        self.proj_in  = nn.Linear(data_dim, width)
        self.blocks   = nn.ModuleList([FiLM_MLP_Block(width, tdim, dropout) for _ in range(n_blocks)])
        self.proj_out = nn.Linear(width, data_dim)
        nn.init.zeros_(self.proj_out.weight)
        nn.init.zeros_(self.proj_out.bias)

    def forward(self, x, t, labels):
        # x: [B,1,256]
        B, D = x.size(0), x.size(-1)
        assert D == self.data_dim, f"expected last dim={self.data_dim}, got {D}"
        temb = self.time_embedding(t)              # [B, tdim]
        cemb = self.cond_embedding(labels)         # [B, tdim]
        cond = temb + cemb

        x = x.view(B, -1)                          # [B,256]
        x = F.silu(self.proj_in(x))                # [B,W]
        for blk in self.blocks:
            x = blk(x, cond)
        x = self.proj_out(x)                       # [B,256]
        return x.view(B, 1, self.data_dim)





