import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import rtdl_num_embeddings as rtdl_embed


def swish(x):
    return x * torch.sigmoid(x)

def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
    """
    Embed a diffusion step $t$ into a higher dimensional space
    E.g. the embedding vector in the 128-dimensional space is
    [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

    Parameters:
    diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
                                diffusion steps for batch data
    diffusion_step_embed_dim_in (int, default=128):  
                                dimensionality of the embedding space for discrete diffusion steps
    
    Returns:
    the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
    """

    assert diffusion_step_embed_dim_in % 2 == 0

    half_dim = diffusion_step_embed_dim_in // 2
    _embed = np.log(10000) / (half_dim - 1)
    _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
    _embed = diffusion_steps * _embed
    diffusion_step_embed = torch.cat((torch.sin(_embed),
                                      torch.cos(_embed)), 1)

    return diffusion_step_embed

class ZeroPadding4(nn.Module):
    def __init__(self, dim: int = -1):
        """
        Module to zero-pad a tensor along a specified dimension 
        so its size becomes divisible by 4.

        Args:
            dim (int): The dimension to pad (default: -1, i.e., last dimension).
        """
        super().__init__()
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        size = x.size(self.dim)
        remainder = size % 4

        if remainder == 0:
            return x
        
        pad_amount = 4 - remainder

        # Build padding: (padding_left, padding_right) per dimension in reverse order
        pad = [0] * (2 * x.dim())
        dim_index = self.dim if self.dim >= 0 else x.dim() + self.dim
        pad[2 * dim_index + 1] = pad_amount  # right-side padding

        return F.pad(x, pad, mode='constant', value=0)
    
def get_padded_dim(size: int, multiple: int = 4) -> int:
    """
    Calculate the new dimension size after padding to be divisible by `multiple`.

    Args:
        size (int): Original size of the dimension.
        multiple (int): The number to make the size divisible by (default: 4).

    Returns:
        int: New size after padding.
    """
    remainder = size % multiple
    if remainder == 0:
        return size
    return size + (multiple - remainder)

class TransformerBackbone(nn.Module):
    def __init__(self, device, config, bins, d_embedding, dim: int, num_heads: int, num_layers: int, dim_t = 2048, diffusion_step_embed_dim_in=128, diffusion_step_embed_dim_mid=512, mlp_ratio: float = 4.0, dropout: float = 0.1):
        """
        Standard Transformer als Backbone für ein Diffusionsmodell.
        :param dim: Eingangs- und Ausgangsdimensionen des Modells
        :param num_heads: Anzahl der Attention-Köpfe
        :param num_layers: Anzahl der Transformer-Schichten
        :param mlp_ratio: Verhältnis der MLP-Hidden-Dimension zur Eingabedimension
        :param dropout: Dropout-Wert für Regularisierung
        """
        super().__init__()
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True

        self.device = device
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in

        if config["embedding_type"] == "Linear":
            self.proj = nn.Linear(dim, dim_t)
        elif config["embedding_type"] == "ZeroPadding":
            self.proj = ZeroPadding4(-1)
            dim_t = get_padded_dim(dim, 4)
        elif config["embedding_type"] == "PiecewiseLinearEmbeddings":
            dim_t = dim * d_embedding
            self.proj = nn.Sequential(rtdl_embed.PiecewiseLinearEmbeddings(bins, d_embedding, activation=False, version="B"),
                                      nn.Linear(dim_t, dim_t))
        elif config["embedding_type"] == "PeriodicEmbeddings":
            self.proj = rtdl_embed.PeriodicEmbeddings(dim, d_embedding, lite=False)
            dim_t = dim * d_embedding

        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, dim_t)
        self.proj_back = nn.Linear(dim_t, dim) #dim_t

        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=dim_t, 
                nhead=num_heads, 
                dim_feedforward=int(dim_t * mlp_ratio), 
                dropout=dropout, 
                activation="gelu",
                batch_first=True
            )
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(dim_t)
    
    def forward(self, x, noise_labels):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        #print("Noise_label: ", noise_labels.shape)
        noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.diffusion_step_embed_dim_in)
        #print("test0", noise_labels_embed.shape)
        #print(noise_labels.shape)
        noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
        noise_labels_embed = swish(self.fc_t2(noise_labels_embed))

        #print("test 1: ",  noise_labels_embed.shape)
        noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("noise_labels_embed", noise_labels_embed.shape)

        if self.add_noise:
            x = self.proj(x) + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        else:
            x = self.proj(x)
        #print("Shape vor transfomer", x.shape)
        if len(x.shape) == 4:
            x = x[0]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.proj_back(x)
