#---------------------------------------------------
# wo t
# import torch
# import torch.nn as nn
# import math


# class ResidualBlock(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.fc1 = nn.Linear(dim, dim)
#         self.fc2 = nn.Linear(dim, dim)
#         self.act = nn.GELU()
#         self.norm1 = nn.LayerNorm(dim)
#         self.norm2 = nn.LayerNorm(dim)

#     def forward(self, x):
#         h = self.fc1(self.norm1(x))
#         h = self.act(h)
#         h = self.fc2(self.norm2(h))
#         return x + h  # Residual
        

# # ---------- ResMLP-Denoiser ----------
# class ResMLPDenoiser(nn.Module):
#     def __init__(self, device, config, data_dim, width=2048, depth=3, time_embed_dim=1024):
#         super().__init__()
#         self.input_proj = nn.Linear(data_dim, width)
#         self.blocks = nn.ModuleList([
#             ResidualBlock(width) for _ in range(depth)
#         ])
#         self.output_proj = nn.Linear(width, data_dim)

#     def forward(self, x):
#         h = self.input_proj(x)

#         for block in self.blocks:
#             h = block(h)

#         h = self.output_proj(h)

#         return h
    



#------------------------------------------------------------
## t vorn


import torch
import torch.nn as nn
import math

# ---------- Time-Embedding ----------
class SinusoidalTimeEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        # t: (batch,) – integer oder float timesteps
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        return emb  # shape: (batch, dim)


# ---------- Residual-MLP-Block ----------
class ResidualBlock(nn.Module):
    # def __init__(self, dim, time_embed_dim):
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.act = nn.GELU()
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

    # def forward(self, x, t_emb):
    def forward(self, x):
        # t_emb wird projiziert und addiert
        h = self.fc1(self.norm1(x))
        h = self.act(h)
        h = self.fc2(self.norm2(h))
        return x + h  # Residual
        

# ---------- ResMLP-Denoiser ----------
class ResMLPDenoiser(nn.Module):
    def __init__(self, device, config, data_dim, width=2048, depth=3, time_embed_dim=2048):
        super().__init__()
        self.time_embed = nn.Sequential(
            SinusoidalTimeEmbedding(time_embed_dim),
            nn.Linear(time_embed_dim, time_embed_dim),
            nn.GELU(),
            nn.Linear(time_embed_dim, time_embed_dim),
        )
        self.input_proj = nn.Linear(data_dim, width)
        self.blocks = nn.ModuleList([
            ResidualBlock(width) for _ in range(depth)
        ])
        self.output_proj = nn.Linear(width, data_dim)

    def forward(self, x, t):
        """
        x: (batch, data_dim)
        t: (batch,) Timesteps (0..T)
        """

        t_emb = self.time_embed(t)
  
        h = self.input_proj(x) + t_emb.unsqueeze(1)
   
        for block in self.blocks:
            h = block(h)

        h = self.output_proj(h)

        return h


#--------------------------------------------------------------------
# t mitte

# import torch
# import torch.nn as nn
# import math

# # ---------- Time-Embedding ----------
# class SinusoidalTimeEmbedding(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.dim = dim

#     def forward(self, t):
#         # t: (batch,) – integer oder float timesteps
#         half_dim = self.dim // 2
#         emb = math.log(10000) / (half_dim - 1)
#         emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
#         emb = t[:, None] * emb[None, :]
#         emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
#         return emb  # shape: (batch, dim)


# # ---------- Residual-MLP-Block ----------
# class ResidualBlock(nn.Module):
#     def __init__(self, dim, time_embed_dim):
#         super().__init__()
#         self.fc1 = nn.Linear(dim, dim)
#         self.fc2 = nn.Linear(dim, dim)
#         self.act = nn.GELU()
#         self.norm1 = nn.LayerNorm(dim)
#         self.norm2 = nn.LayerNorm(dim)
#         self.time_proj = nn.Linear(time_embed_dim, dim)

#     def forward(self, x, t_emb):
#         # t_emb wird projiziert und addiert
#         h = self.fc1(self.norm1(x))
#         h = self.act(h + self.time_proj(t_emb).unsqueeze(1))
#         h = self.act(h)
#         h = self.fc2(self.norm2(h))
#         return x + h  # Residual
        

# # ---------- ResMLP-Denoiser ----------
# class ResMLPDenoiser(nn.Module):
#     def __init__(self, device, config, data_dim, width=2048, depth=3, time_embed_dim=1024):
#         super().__init__()
#         self.time_embed = nn.Sequential(
#             SinusoidalTimeEmbedding(time_embed_dim),
#             nn.Linear(time_embed_dim, time_embed_dim),
#             nn.GELU(),
#             nn.Linear(time_embed_dim, time_embed_dim),
#         )
#         self.input_proj = nn.Linear(data_dim, width)
#         self.blocks = nn.ModuleList([
#             ResidualBlock(width, time_embed_dim) for _ in range(depth)
#         ])
#         self.output_proj = nn.Linear(width, data_dim)

#     def forward(self, x, t):
#         """
#         x: (batch, data_dim)
#         t: (batch,) Timesteps (0..T)
#         """

#         t_emb = self.time_embed(t)
#         h = self.input_proj(x)

#         for block in self.blocks:
#             h = block(h, t_emb)

#         h = self.output_proj(h)

#         return h
