

import torch.nn as nn
import torch
import numpy as np
import model_utils.Denoiser
import rtdl_num_embeddings as rtdl_embed


def swish(x):
    return x * torch.sigmoid(x)

def calc_diffusion_step_embedding(device, diffusion_steps, diffusion_step_embed_dim_in):
    
    diffusion_steps = diffusion_steps.to("cuda:0")

    half_dim = diffusion_step_embed_dim_in // 2
    _embed = np.log(10000) / (half_dim - 1)
    _embed = torch.exp(torch.arange(half_dim) * -_embed).to("cuda:0")
    _embed = diffusion_steps * _embed
    diffusion_step_embed = torch.cat((torch.sin(_embed),
                                      torch.cos(_embed)), 1)

    return diffusion_step_embed

class MLP2048(nn.Module):
    def __init__(self, device, config, bins, d_embedding, d_in = 6, dim_t = 2048, diffusion_step_embed_dim_in=128, diffusion_step_embed_dim_mid=512):
        super().__init__()
        self.dim_t = dim_t
        self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in
        self.device = device
        self.config = config
        if config["model_type"] == "SDE_EDM" and config["sigma_per_feature"]:
            self.edm_schedule = True
            self.noise_schedule = model_utils.SDE_EDM.LernableNoiseLoss(device=device, sigma_data=config["sigma_data"], num_features=d_in)
        else:
            self.edm_schedule = False
        if config["model_type"] == "CDTD":
            self.add_noise = False
        else:
            self.add_noise = True

        self.mlp = nn.Sequential(
            nn.Linear(dim_t, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t * 2),
            nn.SiLU(),
            nn.Linear(dim_t * 2, dim_t),
            nn.SiLU(),
            nn.Linear(dim_t, d_in),
        )
        self.proj = nn.Linear(d_in, dim_t)

        self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid)
        self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, dim_t)



    
    def forward(self, x):
        # print("x: ", x.shape)
        if self.edm_schedule:
            x, noise_labels, c_skip, c_out, weight, x_orig, sigma = self.noise_schedule(x)
        else:
            weight = None
            x_orig = None
        



        if self.add_noise: # For models where noise is already added ex. 
            # print("test 3", (self.proj(x) + noise_labels_embed).shape)
            F1_x = self.proj(x)# + noise_labels_embed #Aus Paper wo tabellendaten erzeugt werden wo tabellendaten
            # print("F1_x ", F1_x.shape)
        else:
            F1_x = self.proj(x)
        #print("test 2: ", x.shape)
        F2_x = self.mlp(F1_x)
        # print("F2_x ", F2_x.shape)


        if self.edm_schedule:
            D_x = c_skip * x + c_out * F2_x.to(torch.float32)
            return D_x, weight, x_orig, sigma

        else:
            D_x = F2_x
            # print("Dx: ", D_x.shape)
            # print("D_x ", D_x.shape)
            return D_x