import torch

from equivariant_diffusion import utils as diffusion_utils
from geo_ldm.regressor import EGNN_regressor
from geo_ldm.latent_diffuser import EnLatentDiffusion


class DiffusionGuidanceModel(torch.nn.Module):
    """
    The Diffusion Guidance Model
    """
    def __init__(
            self,
            regressor: EGNN_regressor,
            ldm: EnLatentDiffusion,
            n_dims: int,
            include_atomic_numbers=False,
            max_step: int = 1000,
            ):
        super().__init__()

        self.regressor = regressor
        # wrap ldm so that it's not included in the module_list
        self.ldm_wrapper = LDMWrapper(ldm, regressor.device)
        self.n_dims = n_dims
        self.include_atomic_numbers = include_atomic_numbers
        self.max_step = max_step

        self.loss_l1 = torch.nn.L1Loss()

    def compute_pred(self, x, h, node_mask, edge_mask, context, t_lower=None, t_upper=None):
        """
        Computes an l1 loss for target property regression.

        Note: t_upper refers to the right time point included.
        """
        # initialize end time points
        if t_lower is None:
            t_lower = 0
        if t_upper is None:
            t_upper = self.max_step

        # Encoder output.
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.ldm_wrapper.ldm.vae.encode(x, h, node_mask, edge_mask, context=context)
        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
        z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
        z_xh = self.ldm_wrapper.ldm.vae.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
        # z_xh = z_xh_mean
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(z_xh[:, :, :self.n_dims], node_mask)

        # add noise to z_xh, using the ldm
        # Sample a timestep t.
        if t_lower == t_upper:
            t_int = torch.full((x.size(0), 1), fill_value=t_lower, device=x.device).float()
        else:
            t_int = torch.randint(
                t_lower, t_upper + 1, size=(x.size(0), 1), device=x.device).float()
        # Normalize t to [0, 1]
        t = t_int / self.ldm_wrapper.ldm.T
        # Compute gamma_t
        gamma_t = self.ldm_wrapper.ldm.inflate_batch_array(self.ldm_wrapper.ldm.gamma(t), x)
        # Compute alpha_t and sigma_t from gamma.
        alpha_t = self.ldm_wrapper.ldm.alpha(gamma_t, x)
        sigma_t = self.ldm_wrapper.ldm.sigma(gamma_t, x)

        # Sample zt ~ Normal(alpha_t x, sigma_t)
        eps = self.ldm_wrapper.ldm.sample_combined_position_feature_noise(
            n_samples=x.size(0), n_nodes=x.size(1), node_mask=node_mask)
        # Sample z_t given x, h for timestep t, from q(z_t | x, h)
        z_t = alpha_t * z_xh + sigma_t * eps
        diffusion_utils.assert_correctly_masked(z_t, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(z_t[:, :, :self.n_dims], node_mask)

        # use time_steps as extra features
        pred = self.regressor._forward(t, z_t, node_mask, edge_mask, context=None) # (bs,) or (bs, 512) if classifier
        return pred


    def forward(self, x, h, adj_gt, node_mask=None, edge_mask=None, context=None, regressor_target=None):
        """
        Computes the ELBO if training. And if eval then always computes NLL.
        """
        pred = self.compute_pred(x, h, node_mask, edge_mask, context=context, t_lower=0, t_upper=self.max_step)
        if self.regressor.classifier_head:
            loss = binary_cross_entropy_multihead(pred, regressor_target, pos_weight=self.pos_weight)
        else:
            # TODO: how to pass target? how to do normalization??
            loss = self.loss_l1(pred, regressor_target)

        loss, loss_dict = loss, {'loss_t': loss.squeeze()}
        neg_log_pxh = loss

        return neg_log_pxh

    def compute_pos_weight(self, dataloader):
        morgan_fingerprint_all = dataloader.dataset.data['morgan_fingerprint']
        n_pos = morgan_fingerprint_all.sum(0)
        n_neg = morgan_fingerprint_all.size(0) - n_pos

        assert ((n_neg + n_pos) == morgan_fingerprint_all.size(0)).all()

        self.pos_weight = (n_neg / n_pos).to(self.regressor.device)
        

class LDMWrapper:
    """
    A Wrapper class that does NOT inheret from torch.nn.Module to give acces to other modules without including them in the module_list
    """
    def __init__(self, ldm: EnLatentDiffusion, device):
        self.ldm = ldm.to(device)

def binary_cross_entropy_multihead(pred, y, pos_weight):
    #pred:(Bs,512), y:(Bs,512)
    # TODO: remove hard-coded pos_weight value
    #pos_weight = torch.Tensor([10.73365110382887]).to(pred.device) # for 512 fingerprints
    #pos_weight = torch.Tensor([21.989648134536598]).to(pred.device) # TODO: remove?
    loss_f = torch.nn.BCEWithLogitsLoss(reduction='mean', pos_weight=pos_weight)
    loss = loss_f(pred,y)#.sum(dim=1)
    return loss
