import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import model_utils.Transformer_TabGenDDPM

def swish(x):
    return x * torch.sigmoid(x)

def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
    """
    Embed a diffusion step $t$ into a higher dimensional space
    E.g. the embedding vector in the 128-dimensional space is
    [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

    Parameters:
    diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
                                diffusion steps for batch data
    diffusion_step_embed_dim_in (int, default=128):  
                                dimensionality of the embedding space for discrete diffusion steps
    
    Returns:
    the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
    """

    assert diffusion_step_embed_dim_in % 2 == 0

    half_dim = diffusion_step_embed_dim_in // 2
    _embed = np.log(10000) / (half_dim - 1)
    _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
    _embed = diffusion_steps * _embed
    diffusion_step_embed = torch.cat((torch.sin(_embed),
                                      torch.cos(_embed)), 1)

    return diffusion_step_embed

# Mit Nullen Auffüllen -> Zero Padding; bei 4 Heads dann bis zum nächsten durch 4 teilbar auffüllen 
#
class TransformerBackbone2(nn.Module):
    def __init__(self, device, dim: int, num_heads: int, num_layers: int,dim_t = 2048, diffusion_step_embed_dim_in=128, diffusion_step_embed_dim_mid=512, mlp_ratio: float = 4.0, dropout: float = 0.1):
        """
        Standard Transformer als Backbone für ein Diffusionsmodell.
        :param dim: Eingangs- und Ausgangsdimensionen des Modells
        :param num_heads: Anzahl der Attention-Köpfe
        :param num_layers: Anzahl der Transformer-Schichten
        :param mlp_ratio: Verhältnis der MLP-Hidden-Dimension zur Eingabedimension
        :param dropout: Dropout-Wert für Regularisierung
        """
        super().__init__()
        self.device = device
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
        self.proj = nn.Linear(dim, dim_t) #dim_t
        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, dim_t)
        self.proj_back = nn.Linear(dim_t, dim) #dim_t
        self.hidden_size = 128
        self.embedding = model_utils.Transformer_TabGenDDPM.ColumnarEmbeddingForGraph(con_features_num=dim,
                                                            cat_features_num=0,
                                                            cat_features_degrees=0,
                                                            latent_space_size=self.hidden_size,
                                                            null_in_categorical_embedding=False,
                                                            global_cls_num=0)
        self.transformer_model = torch.nn.Transformer(d_model=self.hidden_size,
                                                      nhead=num_heads,
                                                      num_encoder_layers=3,
                                                      num_decoder_layers=3,
                                                      dim_feedforward=256,
                                                      batch_first=True)
        self.output = torch.nn.Sequential(torch.nn.Linear(self.hidden_size, self.hidden_size),
                                                    torch.nn.LayerNorm(self.hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                                                    torch.nn.LayerNorm(self.hidden_size),
                                                    torch.nn.ReLU(),
                                                    torch.nn.Linear(self.hidden_size, 1))
        
        self.norm_layer = torch.nn.LayerNorm([num_cont + len(num_classes), hidden_size])
        
        self.norm = nn.LayerNorm(dim_t)
    def forward(self,
                x: torch.tensor,
                t: torch.tensor,
                mask: torch.tensor = None) -> torch.tensor:
        x = self.columnar_embedding(x[:, :self.num_cont],
                                    x[:, self.num_cont:].to(torch.long))
        x_cont = None
        if self.num_cont > 0:
            x_cont = self.transformer_model(src=x,
                                                 tgt=x,
                                                 src_key_padding_mask=(mask < 1).to(torch.bool),
                                                 tgt_key_padding_mask=(mask > 0).to(torch.bool))
        x_cat = None
        if len(self.num_classes):
            x_cat = self.transformer_model_cat(src=x,
                                       tgt=x,
                                       src_key_padding_mask=(mask < 1).to(torch.bool),
                                       tgt_key_padding_mask=(mask > 0).to(torch.bool)
                                       )
        if (x_cont is not None) and (x_cat is not None):
            x = torch.cat([x[:, :self.num_cont, :], x[:, self.num_cont:, :]], dim=1)
        elif x_cat is None:
            x = x_cont
        else:
            x = x_cat
        x = x + self.time_embedding(t).unsqueeze(dim=1)
        x = self.norm_layer(x)
        tmp = []
        for i, l in enumerate(self.outputs):
            tmp.append(l(x[:, i, :]))
        return torch.cat(tmp, dim=1)
    

    def old_forward(self, x, noise_labels):
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        #print("Noise_label: ", noise_labels.shape)
        noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.diffusion_step_embed_dim_in)
        #print("test0", noise_labels_embed.shape)
        #print(noise_labels.shape)
        noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
        noise_labels_embed = swish(self.fc_t2(noise_labels_embed))

        #print("test 1: ",  noise_labels_embed.shape)
        noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("noise_labels_embed", noise_labels_embed.shape)

        x = self.proj(x) + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        #print("Shape vor transfomer", x.shape)
        if len(x.shape) == 4:
            x = x[0]
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.proj_back(x)
