#Implementation of TabNet, adjusted to be used in diffusion models
#based on:  https://github.com/google-research/google-research/blob/master/tabnet
#           https://github.com/dreamquark-ai/tabnet/blob/develop/pytorch_tabnet

import torch.nn as nn
import torch
import numpy as np
from model_utils.TabNet import Sparsemax

def glu(act, n_units):
  """Generalized linear unit nonlinear activation."""
  return act[:, :n_units] * torch.sigmoid(act[:, n_units:])

# def swish(x):
#     return x * torch.sigmoid(x)

# def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
#     """
#     Embed a diffusion step $t$ into a higher dimensional space
#     E.g. the embedding vector in the 128-dimensional space is
#     [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

#     Parameters:
#     diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
#                                 diffusion steps for batch data
#     diffusion_step_embed_dim_in (int, default=128):  
#                                 dimensionality of the embedding space for discrete diffusion steps
    
#     Returns:
#     the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
#     """

#     assert diffusion_step_embed_dim_in % 2 == 0

#     half_dim = diffusion_step_embed_dim_in // 2
#     _embed = np.log(10000) / (half_dim - 1)
#     _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
#     _embed = diffusion_steps * _embed
#     diffusion_step_embed = torch.cat((torch.sin(_embed),
#                                       torch.cos(_embed)), 1)
#     return diffusion_step_embed

class TabNet(nn.Module):
    def __init__(self,device, config, d_in = 6, batch_momentum = 0.7, feature_dim=128, output_dim=64, num_decision_steps=40, relaxation_factor=0.5, epsilon=0.00001):
        super().__init__()
        assert feature_dim % 2 == 0
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True

        self.num_features = d_in
        self.feature_dim = feature_dim
        self.output_dim = output_dim
        self.device = device
        self.num_decision_steps = num_decision_steps
        self.relaxation_factor = relaxation_factor
        self.epsilon = epsilon


        self.batch_norm_input = nn.BatchNorm1d(self.num_features, momentum=batch_momentum)
        self.batch_norm1 = nn.BatchNorm1d(self.feature_dim*2, momentum=batch_momentum)
        self.batch_norm2 = nn.BatchNorm1d(self.feature_dim*2, momentum=batch_momentum)
        self.batch_norm3 = nn.BatchNorm1d(self.feature_dim*2, momentum=batch_momentum)
        self.batch_norm4 = nn.BatchNorm1d(self.feature_dim*2, momentum=batch_momentum)

        self.feature_transform_linear1 = nn.Linear(self.num_features, self.feature_dim * 2, bias=False)
        self.feature_transform_linear2 = nn.Linear(self.feature_dim , self.feature_dim * 2, bias=False)
        self.feature_transform_linear3 = nn.Linear(self.feature_dim , self.feature_dim * 2, bias=False)
        self.feature_transform_linear4 = nn.Linear(self.feature_dim , self.feature_dim * 2, bias=False)
        
        self.mask_linear_layer = nn.Linear(self.feature_dim - output_dim, self.num_features, bias=False)
        self.batch_norm_output = nn.BatchNorm1d(self.num_features, momentum = batch_momentum)

        self.output_layer = nn.Linear(output_dim, self.num_features)
        self.sparsemax = Sparsemax(dim=1)
        self.decision_out_activation = nn.ReLU()

        # self.fc_t1 = nn.Linear(self.feature_dim, self.feature_dim)
        #self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, dim_t)


    # def encoder(self, data, noise_labels):
    def encoder(self, data):
        batch_size = data.shape[0]

        features = self.batch_norm_input(data)

        # Initializes decision-step dependent variables.
        output_aggregated = torch.zeros([batch_size, self.output_dim]).to(self.device)
        masked_features = features
        mask_values = torch.zeros([batch_size, self.num_features]).to(self.device)
        aggregated_mask_values = torch.zeros([batch_size, self.num_features]).to(self.device)
        complemantary_aggregated_mask_values = torch.ones([batch_size, self.num_features]).to(self.device)
        total_entropy = 0

        for ni in range(self.num_decision_steps):
            reuse_flag = (ni > 0)

            transform_f1 = self.feature_transform_linear1(masked_features)
            norm_transform_f1 = self.batch_norm1(transform_f1)
            transform_f1 = glu(norm_transform_f1, self.feature_dim)
            # if ni == 0:
            #     noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.feature_dim)
            #     #print("test0", noise_labels_embed.shape)
            #     #print(noise_labels.shape)
            #     noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
            #     if (self.add_noise):
            #         transform_f1 = transform_f1 + noise_labels_embed

            transform_f2 = self.feature_transform_linear2(transform_f1)
            norm_transform_f2 = self.batch_norm2(transform_f2)
            transform_f2 = (glu(norm_transform_f2, self.feature_dim) +transform_f1) * np.sqrt(0.5)

            transform_f3 = self.feature_transform_linear3(transform_f2)
            norm_transform_f3 = self.batch_norm3(transform_f3)
            transform_f3 = (glu(norm_transform_f3, self.feature_dim) + transform_f2) * np.sqrt(0.5)

            transform_f4 = self.feature_transform_linear4(transform_f3)
            norm_transform_f4 = self.batch_norm4(transform_f4)
            transform_f4 = (glu(norm_transform_f4, self.feature_dim) + transform_f3) * np.sqrt(0.5)

            features_for_coef = (transform_f4[:, self.output_dim:])

            if ni > 0:

                decision_out = self.decision_out_activation(transform_f4[:, :self.output_dim])

                # Decision aggregation.
                output_aggregated = torch.add(output_aggregated, decision_out)

                # Aggregated masks are used for visualization of the
                # feature importance attributes.
                scale_agg = torch.sum(decision_out, dim=1, keepdim=True) / (self.num_decision_steps - 1)
                aggregated_mask_values = torch.add(aggregated_mask_values, mask_values * scale_agg)


            if ni < self.num_decision_steps - 1:
                
                # Determines the feature masks via linear and nonlinear
                # transformations, taking into account of aggregated feature use.
                mask_values = self.mask_linear_layer(features_for_coef)
                mask_values = self.batch_norm_output(mask_values)
                mask_values = torch.mul(mask_values, complemantary_aggregated_mask_values)
                mask_values = self.sparsemax(mask_values)
    
                # Relaxation factor controls the amount of reuse of features between
                # different decision blocks and updated with the values of
                # coefficients.
                complemantary_aggregated_mask_values = torch.mul(complemantary_aggregated_mask_values, self.relaxation_factor - mask_values)
    
                # Entropy is used to penalize the amount of sparsity in feature
                # selection.
                total_entropy = torch.add(total_entropy,torch.mean(torch.sum(-mask_values * torch.log(mask_values + self.epsilon),axis=1)) / (self.num_decision_steps - 1))
    
                # Feature selection.
                masked_features = torch.mul(mask_values, features)
        
        return output_aggregated, total_entropy
    
    # def forward(self, x, noise_labels):
    def forward(self, x):
        # B = noise_labels.shape[0]
        # noise_labels = noise_labels.view(B, 1)
        #noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.num_features)
        #print("test0", noise_labels_embed.shape)
        #print(noise_labels.shape)
        #noise_labels_embed = swish(self.fc_t1(noise_labels_embed))

        #print("test 1: ",  x.shape)
        #noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("", noise_labels_embed.shape)

        #x = x + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        x = x.squeeze(dim=1)
        #print(x.shape)
        # encoded, total_entropy = self.encoder(x, noise_labels)
        encoded, total_entropy = self.encoder(x)
        output = self.output_layer(encoded)
        return output.unsqueeze(1)
