
import torch
import torch.nn as nn
import math
import numpy as np

class SiLU(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

# ----- Zeit-Embedding -----
# class SinusoidalTimeEmbedding(nn.Module):
#     def __init__(self, dim):
#         super().__init__()
#         self.dim = dim

#     def forward(self, t):
#         half_dim = self.dim // 2
#         emb_scale = math.log(10000) / (half_dim - 1)
#         emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb_scale)
#         emb = t[:, None] * emb[None, :]
#         emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
#         return emb

# class GaussianFourierProjection(nn.Module):
#     def __init__(self, embed_dim, scale=30.):
#         super().__init__()
#         self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
#     def forward(self, x):
#         x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
#         return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


# ----- Low-Rank Attention -----
class LowRankAttention(nn.Module):
    """
    Low-Rank Self-Attention pro Batch eher klein dim_input,
    aber für numerische Features (Tabellen) optimiert.
    """
    def __init__(self, dim, rank=32):
        super().__init__()
        self.rank = rank
        self.q_proj = nn.Linear(dim, rank, bias=False)
        self.k_proj = nn.Linear(dim, rank, bias=False)
        self.v_proj = nn.Linear(dim, dim, bias=False)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        # x shape: (batch, features, dim)
        Q = self.q_proj(x)    # (B, F, r)
        K = self.k_proj(x)    # (B, F, r)
        V = self.v_proj(x)    # (B, F, d)

        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.rank)
        attn_weights = attn_scores.softmax(dim=-1)  # (B, F, F)

        out = torch.matmul(attn_weights, V)  # (B, F, d)
        return self.out_proj(out)


# ----- ResMLP + Low-Rank Attention Block -----
class HybridBlock(nn.Module):
    def __init__(self, dim, mlp_hidden_dim, rank=32):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = LowRankAttention(dim, rank)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, dim)
        )

    def forward(self, x):
        # x: (B, F, dim)
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


# ----- Hybrid-Denoiser-Backbone -----
class HybridResMLP_Denoiser(nn.Module):
    def __init__(self, data_dim, hidden_dim=2048, depth=4, mlp_hidden_dim=2048, rank=32, time_embed_dim=2048):
        super().__init__()
        self.input_proj = nn.Linear(data_dim, hidden_dim)
        
        # self.time_embed = nn.Sequential(
        #     SinusoidalTimeEmbedding(time_embed_dim),
        #     # GaussianFourierProjection(time_embed_dim),
        #     nn.Linear(time_embed_dim, time_embed_dim),
        #     nn.GELU(),
        #     nn.Linear(time_embed_dim, hidden_dim)
        # )
        
        self.blocks = nn.ModuleList([
            HybridBlock(hidden_dim, mlp_hidden_dim, rank)
            for _ in range(depth)
        ])
        
        self.norm = nn.LayerNorm(hidden_dim)
        self.output_proj = nn.Linear(hidden_dim, data_dim)

    def forward(self, x):
        """
        x: (B, data_dim)
        t: (B,)  Zeitindizes
        """
        # print("0: ", x.shape)
        # (B, data_dim) -> (B, 1, hidden_dim) für Konsistenz
        x = self.input_proj(x)#.unsqueeze(1)  # features als seq_len=1? -> s.u.
        # print("1: ", x.shape)
        # t_emb = self.time_embed(t).unsqueeze(1)  # (B, 1, hidden_dim)
        # print("2: ", t_emb.shape)
        h = x #+ t_emb  # Add time conditioning
        # print("3: ", h.shape)
        # Hier interpretieren wir jede Feature-Komponente als "Token"
        # → Wenn echte Tabellendaten: seq_len = n_features
        # Falls nötig reshape: (B, n_features, hidden_dim)
        
        for block in self.blocks:
            h = block(h)
            # print("4: ", h.shape)
        
        h = self.norm(h)
        # Zurück in shape (B, data_dim)
        # h = h.mean(dim=1) #NCSBAD
        # print("5: ", h.shape)
        h = self.output_proj(h)
        # print("6: ", h.shape)
        return h