import torch.nn as nn
import torch
import numpy as np
from mamba_ssm import Mamba
# from ..mamba.mamba_ssm.modules import mamba_simple as Mamba
# from ..mamba.mamba_ssm.modules.mamba_simple import Mamba
# from mamba.mamba_ssm.modules.mamba_simple import Mamba

# def swish(x):
#     return x * torch.sigmoid(x)

# def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
#     """
#     Embed a diffusion step $t$ into a higher dimensional space
#     E.g. the embedding vector in the 128-dimensional space is
#     [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))]

#     Parameters:
#     diffusion_steps (torch.long tensor, shape=(batchsize, 1)):     
#                                 diffusion steps for batch data
#     diffusion_step_embed_dim_in (int, default=128):  
#                                 dimensionality of the embedding space for discrete diffusion steps
    
#     Returns:
#     the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)):
#     """

#     assert diffusion_step_embed_dim_in % 2 == 0

#     half_dim = diffusion_step_embed_dim_in // 2
#     _embed = np.log(10000) / (half_dim - 1)
#     _embed = torch.exp(torch.arange(half_dim) * -_embed).to(device)
#     _embed = diffusion_steps * _embed
#     diffusion_step_embed = torch.cat((torch.sin(_embed),
#                                       torch.cos(_embed)), 1)

#     return diffusion_step_embed

class MambaTab(torch.nn.Module):
    """
    This class defines the MambaTab model
    """
    def __init__(self, device, config, d_in, intermediate_representation=2048):
        super(MambaTab, self).__init__()
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True
        
        self.device = device
        # self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
        self.linear_layer=torch.nn.Linear(d_in,intermediate_representation)
        self.relu=torch.nn.ReLU()
        self.layer_norm=torch.nn.LayerNorm(intermediate_representation)

        self.mamba=Mamba(d_model=intermediate_representation,d_state=32,d_conv=4,expand=2,device=device) # Please use different parameters settings for different configurations
        self.output_layer=torch.nn.Linear(intermediate_representation,d_in)

        # self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        # self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, intermediate_representation)
    
    def forward(self, x):
        # B = noise_labels.shape[0]
        # noise_labels = noise_labels.view(B, 1)
        # noise_labels_embed = calc_diffusion_step_embedding(self.device, noise_labels, self.diffusion_step_embed_dim_in)
        # #print(noise_labels.shape)
        # noise_labels_embed = swish(self.fc_t1(noise_labels_embed))
        # noise_labels_embed = swish(self.fc_t2(noise_labels_embed))

        #print("test 1: ",  x.shape)
        # noise_labels_embed = noise_labels_embed.unsqueeze(1)
        #print("", noise_labels_embed.shape)

        if self.add_noise: # For models where noise is already added ex. CDTD
            x=self.linear_layer(x)# + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
        else:
            x=self.linear_layer(x)

        x=self.layer_norm(x)
        x=self.relu(x)
        x=self.mamba(x)
        x=self.output_layer(x)
        return x