import torch.nn as nn
import torch
import numpy as np
import model_utils.TTVAE

def swish(x):
    return x * torch.sigmoid(x)

def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
    """
    Embed a diffusion step $t$ into a higher dimensional space
    E.g. the embedding vector in the 128-dimensional space is
    [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

    Parameters:
    diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
                                diffusion steps for batch data
    diffusion_step_embed_dim_in (int, default=128):  
                                dimensionality of the embedding space for discrete diffusion steps
    
    Returns:
    the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
    """

    assert diffusion_step_embed_dim_in % 2 == 0

    half_dim = diffusion_step_embed_dim_in // 2
    _embed = np.log(10000) / (half_dim - 1)
    _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
    _embed = diffusion_steps * _embed
    diffusion_step_embed = torch.cat((torch.sin(_embed),
                                      torch.cos(_embed)), 1)

    return diffusion_step_embed

def reparameterize(mu, log_var):
    std = torch.exp(0.5 * log_var)
    eps = torch.randn_like(std)
    return mu + eps * std


class TTVAE(nn.Module):
    def __init__(self, device, config, train_data, d_in = 6, dim_t = 2048, diffusion_step_embed_dim_in=64, diffusion_step_embed_dim_mid=96, latent_dim =32,# Example latent dimension
        embedding_dim=128,# Transformer embedding dimension
        nhead=8,# Number of attention heads
        dim_feedforward=1028,# Feedforward layer dimension
        dropout=0.1):
        super().__init__()
        self.device = device
        self.latent_dim=latent_dim
        self.embedding_dim = embedding_dim
        self.nhead=nhead
        self.dim_feedforward=dim_feedforward
        self.dropout=dropout
        self.datatransformer = model_utils.TTVAE.DataTransformer()
        self.datatransformer.fit(train_data.squeeze())
        data_dim = self.datatransformer.output_dimensions
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True

        self.encoder = Encoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self.device)
        self.decoder = Decoder_T(data_dim, self.latent_dim, self.embedding_dim, self.nhead, self.dim_feedforward, self.dropout).to(self.device)


        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, data_dim)

    def transform_data(self, train_data):
       train_data = train_data.squeeze()
       return torch.from_numpy(self.datatransformer.transform(train_data).astype('float32')).to(self.device).unsqueeze(dim=1)
    
    def forward(self, x, noise_labels):
        print(self.training)
        B = noise_labels.shape[0]
        noise_labels = noise_labels.view(B, 1)
        noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.diffusion_step_embed_dim_in)
        #print("test0", noise_labels_embed.shape)
        #print(noise_labels.shape)
        noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
        noise_labels_embed = swish(self.fc_t2(noise_labels_embed))

        #print("test 1: ",  x.shape)
        noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("", noise_labels_embed.shape)

        if self.add_noise: # For models where noise is already added ex. CDTD
            F1_x = x + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        else:
            F1_x = x
        #print("x", F1_x.shape)
        #print(x.shape)
        mean, std, logvar, enc_output = self.encoder(F1_x)
        z = reparameterize(mean, logvar)
        #print("z", z.shape)
        recon_x, sigmas = self.decoder(z,enc_output)
        #if self.training:
        return recon_x
        #else:
        #   print(recon_x.shape)
        #   print(torch.Tensor(self.datatransformer.inverse_transform(recon_x.squeeze().cpu().detach().numpy())).unsqueeze(dim=1).to(self.device).shape)
        #   return torch.Tensor(self.datatransformer.inverse_transform(recon_x.squeeze().cpu().detach().numpy())).unsqueeze(dim=1).to(self.device)


class Encoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Encoder_T, self).__init__()
      # Input data to Transformer
      self.linear = nn.Linear(input_dim,embedding_dim)
      # Transformer Encoder
      self.transformerencoder_layer = nn.TransformerEncoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.encoder = nn.TransformerEncoder(self.transformerencoder_layer, num_layers=2)
      # Latent Space Representation
      self.fc_mu = nn.Linear(embedding_dim, latent_dim)
      self.fc_log_var = nn.Linear(embedding_dim, latent_dim)

    def forward(self, x):
      # Encoder
      x = self.linear(x)
      enc_output = self.encoder(x)
      # Latent Space Representation
      mu = self.fc_mu(enc_output)
      logvar = self.fc_log_var(enc_output)
      std = torch.exp(0.5 * logvar)
      return mu, std, logvar, enc_output


class Decoder_T(nn.Module):
    def __init__(self, input_dim, latent_dim, embedding_dim, nhead, dim_feedforward=2048, dropout=0.1):
      super(Decoder_T, self).__init__()
      # Linear layer for mapping latent space to decoder input size
      self.latent_to_decoder_input = nn.Linear(latent_dim, embedding_dim)
      # Transformer Decoder
      self.transformerdecoder_layer = nn.TransformerDecoderLayer(embedding_dim, nhead, dim_feedforward, dropout)
      self.decoder = nn.TransformerDecoder(self.transformerdecoder_layer, num_layers=2)
      # Transformer Embedding to input
      self.linear = nn.Linear(embedding_dim,input_dim)
      self.sigma = nn.Parameter(torch.ones(input_dim) * 0.1)

    def forward(self, z, enc_output):
      # Encoder
      z_decoder_input = self.latent_to_decoder_input(z)
      # Decoder
      # Note: Pass enc_output (memory) to the decoder
      dec_output = self.decoder(z_decoder_input, enc_output)

      return self.linear(dec_output), self.sigma
