import time

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch_ema import ExponentialMovingAverage
from tqdm import tqdm
from model_utils.CDTD import CatEmbedding, FourierFeatures, WeightNetwork, Timewarp_Logistic, low_discrepancy_sampler


class MixedTypeDiffusion(nn.Module):
    def __init__(
        self,
        config,
        device,
        dim,
        model,
        num_features,
        sigma_data_cont,
        sigma_min_cont,
        sigma_max_cont,
        timewarp_type="bytype",
        timewarp_weight_low_noise=1.0,
    ):
        super(MixedTypeDiffusion, self).__init__()
        self.device = device
        self.config = config
        self.model = model
        self.dim = dim
        self.num_features = num_features
        self.num_cont_features = num_features

        self.register_buffer("sigma_data_cont", torch.tensor(sigma_data_cont))

        
        self.weight_network = WeightNetwork(1024)

        # timewarping
        self.timewarp_type = timewarp_type
        self.sigma_min_cont = torch.tensor(sigma_min_cont)
        self.sigma_max_cont = torch.tensor(sigma_max_cont)

        # combine sigma boundaries for transforming sigmas to [0,1]
        sigma_min = torch.cat(
            (
                torch.tensor(sigma_min_cont).repeat(self.num_cont_features),
            ),
            dim=0,
        )
        sigma_max = torch.cat(
            (
                torch.tensor(sigma_max_cont).repeat(self.num_cont_features),
            ),
            dim=0,
        )
        self.register_buffer("sigma_max", sigma_max)
        self.register_buffer("sigma_min", sigma_min)

        self.timewarp_cdf = Timewarp_Logistic(
            self.timewarp_type,
            self.num_cont_features,
            sigma_min,
            sigma_max,
            weight_low_noise=timewarp_weight_low_noise,
            decay=0.0,
        )

    def diffusion_loss(self, x_cont_0, cont_preds):
        if self.config["backbone_model"] == "TabM": #to make dimensions fit (TabM [batch, number_models, number features])
            x_cont_0 = x_cont_0.unsqueeze(1).expand(-1, 32, self.num_features)
        
        assert cont_preds.shape == x_cont_0.shape

        # MSE loss over numerical features
        mse_losses = (cont_preds - x_cont_0) ** 2

        return mse_losses

    def add_noise(self, x_cont_0, sigma):
        sigma_cont = sigma
        x_cont_t = x_cont_0 + torch.randn_like(x_cont_0) * sigma_cont

        return x_cont_t

    def loss_fn(self, x_cont, u=None):
        batch = x_cont.shape[0]

        # get ground truth data
        x_cont_0 = x_cont

        # draw u and convert to standard deviations for noise
        with torch.no_grad():
            if u is None:
                u = low_discrepancy_sampler(batch, device=self.device)  # (B,)
            sigma = self.timewarp_cdf(u, invert=True).detach().to(torch.float32)
            u = u.to(torch.float32)
            assert sigma.shape == (batch, self.num_features)

        x_cont_t = self.add_noise(x_cont_0, sigma)
        cont_preds = self.precondition(x_cont_t, u, sigma)
        mse_losses = self.diffusion_loss(
            x_cont_0, cont_preds
        )

        # compute EDM weight
        sigma_cont = sigma
        cont_weight = (sigma_cont**2 + self.sigma_data_cont**2) / (
            (sigma_cont * self.sigma_data_cont) ** 2 + 1e-7
        )
        if self.config["backbone_model"] == "TabM" and self.training: #to make dimensions fit (TabM [batch, number_models, number features])
                cont_weight = cont_weight.unsqueeze(1).expand(-1, -1, self.num_features)

        losses = {}
        losses["unweighted"] = mse_losses #torch.cat((mse_losses), dim=1)
        losses["unweighted_calibrated"] = losses["unweighted"] #/ self.normal_const
        weighted_calibrated = cont_weight * mse_losses #(
            #torch.cat((cont_weight * mse_losses), dim=1) #/ self.normal_const
        #)
        c_noise = torch.log(u.to(torch.float32) + 1e-8) * 0.25
        time_reweight = self.weight_network(c_noise).unsqueeze(1)

        if self.config["backbone_model"] == "TabM" and self.training: #to make dimensions fit (TabM [batch, number_models, number features])
                time_reweight = time_reweight.unsqueeze(1)

        losses["timewarping"] = self.timewarp_cdf.loss_fn(
            sigma.detach(), losses["unweighted_calibrated"].detach()
        )
        weightnet_loss = (
            time_reweight.exp() - weighted_calibrated.detach().mean(1)
        ) ** 2
        losses["weighted_calibrated"] = (
            weighted_calibrated / time_reweight.exp().detach()
        )
        train_loss = (
            losses["weighted_calibrated"].mean()
            + losses["timewarping"].mean()
            + weightnet_loss.mean()
        )

        losses["train_loss"] = train_loss

        return train_loss, sigma
    def calc_anom_score(self, x_cont, u=None):
        batch = x_cont.shape[0]
        # print("Batch", batch)
        # get ground truth data
        x_cont_0 = x_cont

        # draw u and convert to standard deviations for noise
        with torch.no_grad():
            if u is None:
                u = low_discrepancy_sampler(batch, device=self.device)  # (B,)
            sigma = self.timewarp_cdf(u, invert=True).detach().to(torch.float32)
            # normal = torch.ones(batch, device=x_cont.device)
            # sigma = normal * 0.03
            # print(sigma)
            u = u.to(torch.float32)
            #print("Sigma", sigma.shape)
            #print((batch, self.num_features))

            assert sigma.shape == (batch, self.num_features)

        x_cont_t = self.add_noise(x_cont_0, sigma)
        cont_preds = self.precondition(x_cont_t, u, sigma)

        if self.config["anom_score"] == "simple":
            loss = torch.abs(((cont_preds - x_cont_0) /sigma**2)) #Anomaliescore - Scorefunction as in paper
        elif self.config["anom_score"] == "precondition":
            # loss = torch.abs(((cont_preds - x_cont_0) /sigma**2))
            loss = (cont_preds - x_cont_0) ** 2
        elif self.config["anom_score"] == "direct":
            loss = cont_preds
        else:
            raise Exception("Anomaliescore not defined for EDM Diffusionmodel")
        return loss
    
    def precondition(self, x_cont_t, u, sigma):
        """
        Improved preconditioning proposed in the paper "Elucidating the Design
        Space of Diffusion-Based Generative Models" (EDM) adjusted for categorical data
        """

        sigma_cont = sigma

        c_in_cont = 1 / (self.sigma_data_cont**2 + sigma_cont**2).sqrt()
        # c_noise = u.log() / 4
        c_noise = torch.log(u + 1e-8) * 0.25 * 1000
        
        #print("Noise", c_noise.shape)
        
        #print("input", x_cont_t.shape)
        cont_preds = self.model(
            torch.unsqueeze(x_cont_t, dim=1),
            c_noise,
        )
        #print("out", cont_preds.shape)

        cont_preds = torch.squeeze(cont_preds, dim=1)
        #print("out", cont_preds.shape)
        #print(x_cont_t.shape)
        
        if self.config["backbone_model"] == "TabM" and self.training: #to make dimensions fit (TabM [batch, number_models, number features])
            x_cont_t = x_cont_t.unsqueeze(1).expand(-1, 32, self.num_features)
        assert cont_preds.shape == x_cont_t.shape
        # apply preconditioning to continuous features
        c_skip = self.sigma_data_cont**2 / (sigma_cont**2 + self.sigma_data_cont**2)
        c_out = (
            sigma_cont
            * self.sigma_data_cont
            / (sigma_cont**2 + self.sigma_data_cont**2).sqrt()
        )
        if self.config["backbone_model"] == "TabM" and self.training: #to make dimensions fit (TabM [batch, number_models, number features])
                c_out = c_out.unsqueeze(1).expand(-1, -1, self.num_features)
                c_skip = c_skip.unsqueeze(1).expand(-1, -1, self.num_features)

        if self.config["anom_score"] == "simple":
            D_x = cont_preds
        elif self.config["anom_score"] == "precondition":
            D_x = c_skip * x_cont_t + c_out * cont_preds
        elif self.config["anom_score"] == "direct":
            D_x = cont_preds
        else:
            raise Exception("Anomaliescore not defined for EDM Diffusionmodel")
        
        return D_x

