import copy

import torch
#from torchdyn.core import NeuralODE
import numpy as np
from tqdm import tqdm

from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
from equivariant_diffusion import utils as diffusion_utils
from geo_ldm.vae import EnHierarchicalVAE
from optimization.ode_solver import torchdyn_wrapper
from qm9.utils import prepare_context
from visualization.kabsch import align


class EnLatentDiffusion(EnVariationalDiffusion):
    """
    The E(n) Latent Diffusion Module.
    """
    def __init__(self, **kwargs):
        vae = kwargs.pop('vae')
        trainable_ae = kwargs.pop('trainable_ae', False)
        super().__init__(**kwargs)

        # Create self.vae as the first stage model.
        self.trainable_ae = trainable_ae
        self.instantiate_first_stage(vae)

    def normalize(self, x, h, node_mask):
        x = x / self.norm_values[0]
        delta_log_px = -self.subspace_dimensionality(node_mask) * np.log(self.norm_values[0])

        h = (h - self.norm_biases[1]) / self.norm_values[1]
        h = h * node_mask

        return x, h, delta_log_px
    
    def unnormalize_z(self, z, node_mask):
        # Overwrite the unnormalize_z function to do nothing (for sample_chain). 

        # Parse from z
        x, h = z[:, :, 0:self.n_dims], z[:, :, self.n_dims:]

        x = x * self.norm_values[0]
        h = h * self.norm_values[1] + self.norm_biases[1]
        h = h * node_mask

        # Unnormalize
        # x, h_cat, h_int = self.unnormalize(x, h_cat, h_int, node_mask)
        output = torch.cat([x, h], dim=2)
        return output
    
    def log_constants_p_h_given_z0(self, h, node_mask):
        """Computes p(h|z0)."""
        batch_size = h.size(0)

        n_nodes = node_mask.squeeze(2).sum(1)  # N has shape [B]
        assert n_nodes.size() == (batch_size,)
        degrees_of_freedom_h = n_nodes * self.n_dims

        zeros = torch.zeros((h.size(0), 1), device=h.device)
        gamma_0 = self.gamma(zeros)

        # Recall that sigma_x = sqrt(sigma_0^2 / alpha_0^2) = SNR(-0.5 gamma_0).
        log_sigma_x = 0.5 * gamma_0.view(batch_size)

        return degrees_of_freedom_h * (- log_sigma_x - 0.5 * np.log(2 * np.pi))

    def sample_p_xh_given_z0(self, z0, node_mask, edge_mask, context, fix_noise=False):
        """Samples x ~ p(x|z0)."""
        zeros = torch.zeros(size=(z0.size(0), 1), device=z0.device)
        gamma_0 = self.gamma(zeros)
        # Computes sqrt(sigma_0^2 / alpha_0^2)
        sigma_x = self.SNR(-0.5 * gamma_0).unsqueeze(1)
        with torch.no_grad():
            net_out = self.phi(z0, zeros, node_mask, edge_mask, context)

        # Compute mu for p(zs | zt).
        mu_x = self.compute_x_pred(net_out, z0, gamma_0)
        xh = self.sample_normal(mu=mu_x, sigma=sigma_x, node_mask=node_mask, fix_noise=fix_noise)

        x = xh[:, :, :self.n_dims]
        h = xh[:, :, self.n_dims:]

        # h_int = z0[:, :, -1:] if self.include_charges else torch.zeros(0).to(z0.device)
        # x, h_cat, h_int = self.unnormalize(x, z0[:, :, self.n_dims:-1], h_int, node_mask)

        # h_cat = F.one_hot(torch.argmax(h_cat, dim=2), self.num_classes) * node_mask
        # h_int = torch.round(h_int).long() * node_mask

        # Make the data structure compatible with the EnVariationalDiffusion sample() and sample_chain().
        # h = {'integer': xh[:, :, self.n_dims:], 'categorical': torch.zeros(0).to(xh)}
        
        return x, h
    
    def log_pxh_given_z0_without_constants(
            self, x, h, z_t, gamma_0, eps, net_out, node_mask, epsilon=1e-10):

        # Computes the error for the distribution N(latent | 1 / alpha_0 z_0 + sigma_0/alpha_0 eps_0, sigma_0 / alpha_0),
        # the weighting in the epsilon parametrization is exactly '1'.
        log_pxh_given_z_without_constants = -0.5 * self.compute_error(net_out, gamma_0, eps)

        # Combine log probabilities for x and h.
        log_p_xh_given_z = log_pxh_given_z_without_constants

        return log_p_xh_given_z
    
    def forward(self, x, h, adj_gt, node_mask=None, edge_mask=None, context=None):
        """
        Computes the loss (type l2 or NLL) if training. And if eval then always computes NLL.
        """

        # Encode data to latent space.
        # TODO: remove sigmas, not used
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(x, h, node_mask, edge_mask, context)
        # Compute fixed sigma values.
        t_zeros = torch.zeros(size=(x.size(0), 1), device=x.device)
        gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), x)
        sigma_0 = self.sigma(gamma_0, x)

        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
        z_xh_sigma = sigma_0
        # z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
        z_xh = self.vae.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
        # z_xh = z_xh_mean
        z_xh = z_xh.detach()  # Always keep the encoder fixed.
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        z_x = z_xh[:, :, :self.n_dims]
        z_h = z_xh[:, :, self.n_dims:]
        diffusion_utils.assert_mean_zero_with_mask(z_x, node_mask)
        # : I adapted the EnVariationalDiffusion compute_loss()
        # Make the data structure compatible with the EnVariationalDiffusion compute_loss().
        # z_h = {'categorical': torch.zeros(0).to(z_h), 'integer': z_h}

        # Normalize data, take into account volume change in x.
        if self.norm_values is not None:
            z_x, z_h, delta_log_px = self.normalize(z_x, z_h, node_mask)

            # Reset delta_log_px if not vlb objective.
            if self.training and self.loss_type == 'l2':
                delta_log_px = torch.zeros_like(delta_log_px)

        if self.training:
            # Only 1 forward pass when t0_always is False.
            loss_ld, loss_dict = self.compute_loss(z_x, z_h, node_mask, edge_mask, context, t0_always=False)
        else:
            # Less variance in the estimator, costs two forward passes.
            loss_ld, loss_dict = self.compute_loss(z_x, z_h, node_mask, edge_mask, context, t0_always=True)
        
        if torch.any(torch.isnan(loss_ld)):
            print(f"Detected NaN in latent_diffuser.forward. loss_ld={loss_ld}")

        # The _constants_ depending on sigma_0 from the
        # cross entropy term E_q(z0 | x) [log p(x | z0)].
        neg_log_constants = -self.log_constants_p_h_given_z0(
            torch.cat([h['atomic_numbers_one_hot'], h['formal_charges_one_hot']], dim=2), node_mask)
        if torch.any(torch.isnan(neg_log_constants)):
            print(f"Detected NaN in latent_diffuser.forward. neg_log_constants={neg_log_constants}")

        # Reset constants during training with l2 loss.
        if self.training and self.loss_type == 'l2':
            neg_log_constants = torch.zeros_like(neg_log_constants)

        # Compute reconstruction loss.
        if self.trainable_ae:
            if self.joint_training:
                xh_pred = loss_dict['xh_pred']#.detach() # detach for training only decoder
                if self.norm_values is not None:
                    xh_pred = self.unnormalize_z(xh_pred, node_mask)
            else:
                xh_pred = z_xh

            # Decoder output (reconstruction).
            adj_recon, h_recon = self.vae.decoder._forward(xh_pred, node_mask, edge_mask, context)
            loss_recon = self.vae.compute_2d_reconstruction_error(adj_recon, h_recon, adj_gt, h, node_mask, edge_mask)

            if self.joint_training:
                loss_recon = self.lambda_joint_loss * loss_recon
        else:
            loss_recon = 0

        # loss_recon is already scalar
        loss_ld = loss_ld.mean(0)
        neg_log_constants = neg_log_constants.mean(0)

        neg_log_pxh = loss_ld + loss_recon + neg_log_constants

        if torch.any(torch.isnan(neg_log_pxh)):
            print(f"Detected NaN in latent_diffuser.forward. neg_log_pxh={neg_log_pxh}")

        return neg_log_pxh
    
    #@torch.no_grad()
    def sample(self, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False):
        """
        Draw samples from the generative model.
        """
        z_x, z_h = super().sample(n_samples, n_nodes, node_mask, edge_mask, context, fix_noise)

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        with torch.no_grad():
            adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        x = z_x

        return x, h

    def sample_z_t_given_z_0(self, x, h, t, node_mask, n_samples):
        # instead of sampling random noise, we add noise to the input mols for editing
        t_batch = torch.full((n_samples, 1), fill_value=t, device=x.device).float() / self.T
        gamma_t = self.inflate_batch_array(self.gamma(t_batch), x)
        alpha_t = self.alpha(gamma_t, x)
        sigma_t = self.sigma(gamma_t, x)

        # Sample zt ~ Normal(alpha_t x, sigma_t)
        eps = self.sample_combined_position_feature_noise(
            n_samples=x.size(0), n_nodes=x.size(1), node_mask=node_mask)

        xh = torch.cat([x, h], dim=2)

        # Sample z_t given x, h for timestep t, from q(z_t | x, h)
        z_t = alpha_t * xh + sigma_t * eps

        diffusion_utils.assert_mean_zero_with_mask(z_t[:, :, :self.n_dims], node_mask)
        return z_t

    def sample_p_zt_given_zs(self, s, t, zs, node_mask, edge_mask, context, fix_noise=False):
        """Samples from zs ~ p(zs | zt). Only used during sampling."""
        gamma_s = self.gamma(s)
        gamma_t = self.gamma(t)

        sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
            self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs)

        diffusion_utils.assert_mean_zero_with_mask(zs[:, :, :self.n_dims], node_mask)

        mu = alpha_t_given_s * zs
        sigma = sigma_t_given_s

        zt = self.sample_normal(mu, sigma, node_mask, fix_noise)

        # Project down to avoid numerical runaway of the center of gravity.
        zt = torch.cat(
            [diffusion_utils.remove_mean_with_mask(zt[:, :, :self.n_dims],
                                                   node_mask),
             zt[:, :, self.n_dims:]], dim=2
        )
        return zt

    def get_scaffolding_schedule(self, t_T, jump_len, jump_n_sample):
        jumps = {}
        for j in range(0, t_T, jump_len):
            jumps[j] = jump_n_sample - 1
        t = t_T
        ts = []
        while t >= 1:
            t = t-1
            ts.append(t+1)
            if jumps.get(t, 0) > 0:
                jumps[t] = jumps[t] - 1
                for _ in range(jump_len):
                    ts.append(t)
                    t = t + 1
        ts.append(0)
        return ts

    @torch.no_grad()
    def complete_scaffold(self, scaffold_graph, context=None, fix_noise=False, add_nodes=None, resampling_times=1, use_jumps=False):
        x = scaffold_graph['positions']
        h = {'atomic_numbers_one_hot': scaffold_graph['atomic_numbers_one_hot'], 'formal_charges_one_hot': scaffold_graph['formal_charges_one_hot']}
        node_mask = scaffold_graph['atom_mask'].unsqueeze(2)
        edge_mask = scaffold_graph['edge_mask']
        #context = prepare_context(['penalized_logP'], batch, property_norms).to(device, dtype)
        #diffusion_utils.assert_correctly_masked(context, node_mask)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)
        n_samples = x.size(0)
        n_nodes = x.size(1)

        # Encode data to latent space.
        # TODO: remove sigmas, not used
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(x, h, node_mask, edge_mask, context)
        # Compute fixed sigma values.
        t_zeros = torch.zeros(size=(x.size(0), 1), device=x.device)
        gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), x)
        sigma_0 = self.sigma(gamma_0, x)

        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
        z_xh_sigma = sigma_0
        # z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
        z_xh = self.vae.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
        #z_xh = z_xh_mean
        #z_xh = z_xh.detach()  # Always keep the encoder fixed.
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        x = z_xh[:, :, :self.n_dims]
        h = z_xh[:, :, self.n_dims:]
        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)


        # TODO: fix when working on a batch of different scaffolds, this will fail
        max_extra_nodes = add_nodes.max().item()
        x = torch.cat((x, torch.zeros(n_samples, max_extra_nodes, self.n_dims, device=x.device)), dim=1)
        h = torch.cat((h, torch.zeros(n_samples, max_extra_nodes, self.in_node_nf, device=h.device)), dim=1)
        node_mask = torch.cat((node_mask, torch.zeros(n_samples, max_extra_nodes, 1, device=node_mask.device)), dim=1)
        n_nodes += max_extra_nodes
        if context is not None:
            one_context = context[:, 0:1, :]
            context = torch.cat((context, one_context.repeat(1, max_extra_nodes, 1)), dim=1)

        # get previous number of atoms per moelcule
        old_node_mask = copy.deepcopy(node_mask)
        old_n_atoms = node_mask.sum(1).squeeze().long()

        # update node_mask accordingly
        for i in range(n_samples):
            n = old_n_atoms[i]
            add_n = add_nodes[i]
            node_mask[i][n:n+add_n] = 1.

        edge_mask = node_mask.squeeze().unsqueeze(1) * node_mask
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(x.device)
        edge_mask *= diag_mask
        edge_mask = edge_mask.view(n_samples * n_nodes * n_nodes, 1)
        
        # sample random noise for the new nodes
        if fix_noise:
            # Noise is broadcasted over the batch axis, useful for visualizations.
            z = self.sample_combined_position_feature_noise(1, n_nodes, node_mask)
        else:
            z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)

        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        # overwrite first part with old nodes
        # we start with t=T
        z_T = self.sample_z_t_given_z_0(x, h, self.T, old_node_mask, n_samples)
        for i in range(n_samples):
            n = old_n_atoms[i]
            _, z_T_aligned = align(z[i, :n, :self.n_dims].cpu(), z_T[i, :n, :self.n_dims].cpu())
            z_T_aligned = torch.cat([
                torch.Tensor(z_T_aligned).to(z.device),
                z_T[i, :n, self.n_dims:]
            ], dim=1)
            z[i, :n, :] = z_T_aligned

        z = torch.cat(
            [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
            z[:, :, self.n_dims:]], dim=2
        )
        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        if use_jumps:
            times = self.get_scaffolding_schedule(self.T, jump_len=200, jump_n_sample=10)
            for t_last, t_cur in tqdm(zip(times[:-1], times[1:])):
                if t_last > t_cur:
                    # reverse diffusion
                    s_array = torch.full((n_samples, 1), fill_value=t_cur, device=z.device)
                    t_array = s_array + 1
                    s_array = s_array / self.T
                    t_array = t_array / self.T

                    z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

                    # now z is = zs
                    # we overwrite the first part with z_s
                    z_s = self.sample_z_t_given_z_0(x, h, t_cur, old_node_mask, n_samples)
                    for i in range(n_samples):
                        n = old_n_atoms[i]
                        _, z_s_aligned = align(z[i, :n, :self.n_dims].cpu(), z_s[i, :n, :self.n_dims].cpu())
                        z_s_aligned = torch.cat([
                            torch.Tensor(z_s_aligned).to(z.device),
                            z_s[i, :n, self.n_dims:]
                        ], dim=1)
                        z[i, :n, :] = z_s_aligned

                    z = torch.cat(
                        [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
                        z[:, :, self.n_dims:]], dim=2
                    )
                    diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)
                else: # t_last < t_cur
                    # forward diffusion
                    s_array = torch.full((n_samples, 1), fill_value=t_last, device=z.device)
                    t_array = s_array + 1
                    s_array = s_array / self.T
                    t_array = t_array / self.T

                    z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                    diffusion_utils.assert_correctly_masked(z, node_mask)
                    diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        else:
            # s = T-1 ... t
            # t = T   ... t+1
            for s in tqdm(reversed(range(0, self.T))):
                for u in range(resampling_times):
                    s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
                    t_array = s_array + 1
                    s_array = s_array / self.T
                    t_array = t_array / self.T

                    z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

                    # now z is = zs
                    # we overwrite the first part with z_s
                    z_s = self.sample_z_t_given_z_0(x, h, s, old_node_mask, n_samples)
                    for i in range(n_samples):
                        n = old_n_atoms[i]
                        _, z_s_aligned = align(z[i, :n, :self.n_dims].cpu(), z_s[i, :n, :self.n_dims].cpu())
                        z_s_aligned = torch.cat([
                            torch.Tensor(z_s_aligned).to(z.device),
                            z_s[i, :n, self.n_dims:]
                        ], dim=1)
                        z[i, :n, :] = z_s_aligned

                    z = torch.cat(
                        [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
                        z[:, :, self.n_dims:]], dim=2
                    )
                    diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

                    if u < resampling_times - 1:
                        z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                        diffusion_utils.assert_correctly_masked(z, node_mask)
                        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        # Finally sample p(x, h | z_0).
        x, h = self.sample_p_xh_given_z0(z, node_mask, edge_mask, context, fix_noise=fix_noise)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)

        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        z_x, z_h = x, h

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        x = z_x

        if add_nodes is not None:
            return x, h, node_mask
        return x, h

    @torch.no_grad()
    def scaffold_intermediate_state(self, x, h, t_intermediate, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, add_nodes=None):
        # TODO: fix when working on a batch of different scaffolds, this will fail
        max_extra_nodes = add_nodes.max().item()
        x = torch.cat((x, torch.zeros(n_samples, max_extra_nodes, self.n_dims, device=x.device)), dim=1)
        h = torch.cat((h, torch.zeros(n_samples, max_extra_nodes, self.in_node_nf, device=h.device)), dim=1)
        node_mask = torch.cat((node_mask, torch.zeros(n_samples, max_extra_nodes, 1, device=node_mask.device)), dim=1)
        n_nodes += max_extra_nodes
        if context is not None:
            one_context = context[:, 0:1, :]
            context = torch.cat((context, one_context.repeat(1, max_extra_nodes, 1)), dim=1)

        # get previous number of atoms per moelcule
        old_node_mask = copy.deepcopy(node_mask)
        old_n_atoms = node_mask.sum(1).squeeze().long()

        # update node_mask accordingly
        for i in range(n_samples):
            n = old_n_atoms[i]
            add_n = add_nodes[i]
            node_mask[i][n:n+add_n] = 1.

        edge_mask = node_mask.squeeze().unsqueeze(1) * node_mask
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(x.device)
        edge_mask *= diag_mask
        edge_mask = edge_mask.view(n_samples * n_nodes * n_nodes, 1)
        
        # sample random noise for the new nodes
        if fix_noise:
            # Noise is broadcasted over the batch axis, useful for visualizations.
            z = self.sample_combined_position_feature_noise(1, n_nodes, node_mask)
        else:
            z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)

        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        # overwrite first part with old nodes
        # we start with t=T
        z_T = self.sample_z_t_given_z_0(x, h, self.T, old_node_mask, n_samples)
        for i in range(n_samples):
            n = old_n_atoms[i]
            _, z_T_aligned = align(z[i, :n, :self.n_dims].cpu(), z_T[i, :n, :self.n_dims].cpu())
            z_T_aligned = torch.cat([
                torch.Tensor(z_T_aligned).to(z.device),
                z_T[i, :n, self.n_dims:]
            ], dim=1)
            z[i, :n, :] = z_T_aligned

        z = torch.cat(
            [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
            z[:, :, self.n_dims:]], dim=2
        )
        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        times = self.get_scaffolding_schedule(self.T - t_intermediate, jump_len=40, jump_n_sample=10)
        times = [t + t_intermediate for t in times]

        for t_last, t_cur in tqdm(zip(times[:-1], times[1:])):
            if t_last > t_cur:
                # reverse diffusion
                s_array = torch.full((n_samples, 1), fill_value=t_cur, device=z.device)
                t_array = s_array + 1
                s_array = s_array / self.T
                t_array = t_array / self.T

                z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

                # now z is = zs
                # we overwrite the first part with z_s
                z_s = self.sample_z_t_given_z_0(x, h, t_cur, old_node_mask, n_samples)
                for i in range(n_samples):
                    n = old_n_atoms[i]
                    _, z_s_aligned = align(z[i, :n, :self.n_dims].cpu(), z_s[i, :n, :self.n_dims].cpu())
                    z_s_aligned = torch.cat([
                        torch.Tensor(z_s_aligned).to(z.device),
                        z_s[i, :n, self.n_dims:]
                    ], dim=1)
                    z[i, :n, :] = z_s_aligned

                z = torch.cat(
                    [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
                    z[:, :, self.n_dims:]], dim=2
                )
                diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)
            else: # t_last < t_cur
                # forward diffusion
                s_array = torch.full((n_samples, 1), fill_value=t_last, device=z.device)
                t_array = s_array + 1
                s_array = s_array / self.T
                t_array = t_array / self.T

                z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                diffusion_utils.assert_correctly_masked(z, node_mask)
                diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        return z, node_mask, edge_mask, context


        # # s = T-1 ... t
        # # t = T   ... t+1
        # for s in tqdm(reversed(range(t_intermediate, self.T))):
        #     s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
        #     t_array = s_array + 1
        #     s_array = s_array / self.T
        #     t_array = t_array / self.T

        #     z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

        #     # now z is = zs
        #     # we overwrite the first part with z_s
        #     z_s = self.sample_z_t_given_z_0(x, h, s, old_node_mask, n_samples)
        #     for i in range(n_samples):
        #         n = old_n_atoms[i]
        #         z[i, :n, :] = z_s[i, :n, :]

        #     z = torch.cat(
        #         [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
        #         z[:, :, self.n_dims:]], dim=2
        #     )
        #     diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        # z_t = z



    @torch.no_grad()
    def edit(self, x, h, t, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, add_nodes=None):
        """
        Draw samples from the generative model.
        """
        if add_nodes is None:
            # instead of sampling random noise, we add noise to the input mols for editing
            t_batch = torch.full((n_samples, 1), fill_value=t, device=x.device).float() / self.T
            gamma_t = self.inflate_batch_array(self.gamma(t_batch), x)
            alpha_t = self.alpha(gamma_t, x)
            sigma_t = self.sigma(gamma_t, x)

            # Sample zt ~ Normal(alpha_t x, sigma_t)
            eps = self.sample_combined_position_feature_noise(
                n_samples=x.size(0), n_nodes=x.size(1), node_mask=node_mask)

            xh = torch.cat([x, h], dim=2)

            # Sample z_t given x, h for timestep t, from q(z_t | x, h)
            z_t = alpha_t * xh + sigma_t * eps

            diffusion_utils.assert_mean_zero_with_mask(z_t[:, :, :self.n_dims], node_mask)

        else:
            z_t, node_mask, edge_mask, context = self.scaffold_intermediate_state(x, h, t, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise, add_nodes)
            # # TODO: fix when working on a batch of different scaffolds, this will fail
            # max_extra_nodes = add_nodes.max().item()
            # x = torch.cat((x, torch.zeros(n_samples, max_extra_nodes, self.n_dims, device=x.device)), dim=1)
            # h = torch.cat((h, torch.zeros(n_samples, max_extra_nodes, self.in_node_nf, device=h.device)), dim=1)
            # node_mask = torch.cat((node_mask, torch.zeros(n_samples, max_extra_nodes, 1, device=node_mask.device)), dim=1)
            # n_nodes += max_extra_nodes
            # if context is not None:
            #     one_context = context[:, 0:1, :]
            #     context = torch.cat((context, one_context.repeat(1, max_extra_nodes, 1)), dim=1)

            # # get previous number of atoms per moelcule
            # old_node_mask = copy.deepcopy(node_mask)
            # old_n_atoms = node_mask.sum(1).squeeze().long()

            # # update node_mask accordingly
            # for i in range(n_samples):
            #     n = old_n_atoms[i]
            #     add_n = add_nodes[i]
            #     node_mask[i][n:n+add_n] = 1.

            # edge_mask = node_mask.squeeze().unsqueeze(1) * node_mask
            # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(x.device)
            # edge_mask *= diag_mask
            # edge_mask = edge_mask.view(n_samples * n_nodes * n_nodes, 1)
            
            # # sample random noise for the new nodes
            # if fix_noise:
            #     # Noise is broadcasted over the batch axis, useful for visualizations.
            #     z = self.sample_combined_position_feature_noise(1, n_nodes, node_mask)
            # else:
            #     z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)

            # diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

            # # overwrite first part with old nodes
            # # we start with t=T
            # z_T = self.sample_z_t_given_z_0(x, h, self.T, old_node_mask, n_samples)
            # for i in range(n_samples):
            #     n = old_n_atoms[i]
            #     z[i, :n, :] = z_T[i, :n, :]

            # z = torch.cat(
            #     [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
            #     z[:, :, self.n_dims:]], dim=2
            # )
            # diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

            # # s = T-1 ... t
            # # t = T   ... t+1
            # for s in tqdm(reversed(range(t, self.T))):
            #     s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            #     t_array = s_array + 1
            #     s_array = s_array / self.T
            #     t_array = t_array / self.T

            #     z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

            #     # now z is = zs
            #     # we overwrite the first part with z_s
            #     z_s = self.sample_z_t_given_z_0(x, h, s, old_node_mask, n_samples)
            #     for i in range(n_samples):
            #         n = old_n_atoms[i]
            #         z[i, :n, :] = z_s[i, :n, :]

            #     z = torch.cat(
            #         [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
            #         z[:, :, self.n_dims:]], dim=2
            #     )
            #     diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

            # z_t = z

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        # s = t-1 ... 0
        # t = t   ... 1
        z = z_t
        for s in tqdm(reversed(range(0, t))):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T

            z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

        # Finally sample p(x, h | z_0).
        x, h = self.sample_p_xh_given_z0(z, node_mask, edge_mask, context, fix_noise=fix_noise)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)

        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        z_x, z_h = x, h

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        x = z_x

        if add_nodes is not None:
            return x, h, node_mask, edge_mask, context
        return x, h

    @torch.no_grad()
    def optimize(self, batch, property_norms, device, dtype, step_size=0.2, n_steps=10, t=350, add_nodes=None):
        """
        Computes the loss (type l2 or NLL) if training. And if eval then always computes NLL.
        """
        x = batch['positions'].to(device, dtype)
        h = {'atomic_numbers_one_hot': batch['atomic_numbers_one_hot'].to(device, dtype), 'formal_charges_one_hot': batch['formal_charges_one_hot'].to(device, dtype)}
        node_mask = batch['atom_mask'].to(device, dtype).unsqueeze(2)
        edge_mask = batch['edge_mask'].to(device, dtype)
        context = prepare_context(['penalized_logP'], batch, property_norms).to(device, dtype)
        diffusion_utils.assert_correctly_masked(context, node_mask)

        x = diffusion_utils.remove_mean_with_mask(x, node_mask)
        
        n_samples = x.size(0)
        n_nodes = x.size(1)

        # Encode data to latent space.
        # TODO: remove sigmas, not used
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(x, h, node_mask, edge_mask, context)
        # Compute fixed sigma values.
        t_zeros = torch.zeros(size=(x.size(0), 1), device=x.device)
        gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), x)
        sigma_0 = self.sigma(gamma_0, x)

        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)
        z_xh_sigma = sigma_0
        # z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
        z_xh = self.vae.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
        # z_xh = z_xh_mean
        z_xh = z_xh.detach()  # Always keep the encoder fixed.
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        z_x = z_xh[:, :, :self.n_dims]
        z_h = z_xh[:, :, self.n_dims:]
        diffusion_utils.assert_mean_zero_with_mask(z_x, node_mask)

        # adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context)

        # h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        # x = z_x
        # return x,h

        # add nodes
        if add_nodes is not None:
            z_x, z_h, node_mask, edge_mask, context = self.edit(z_x, z_h, t, n_samples, n_nodes, node_mask, edge_mask, context, add_nodes=add_nodes)
            max_extra_nodes = add_nodes.max().item()
            states = {0: (copy.deepcopy(z_x), copy.deepcopy(z_h), copy.deepcopy(node_mask))}
            z_h = z_h['z_h']
        else:
            states = {}
            max_extra_nodes = 0

        for step in tqdm(range(n_steps)):
            noise = torch.randn(n_samples, 1).unsqueeze(1).repeat(1, n_nodes+max_extra_nodes, 1).to(device)
            context = context + step_size + (step_size/10) * noise
            context = context * node_mask
            print(f'step: {step}, context: {context}')
            z_x, z_h = self.edit(z_x, z_h, t, n_samples, n_nodes, node_mask, edge_mask, context)
            states[step+1] = (copy.deepcopy(z_x), copy.deepcopy(z_h), copy.deepcopy(node_mask))
            z_h = z_h['z_h']

        return states


    # Inversing diffusion model...
    @torch.no_grad()
    def sample_as_ode(self, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False):
        if fix_noise:
            # Noise is broadcasted over the batch axis, useful for visualizations.
            z_T = self.sample_combined_position_feature_noise(1, n_nodes, node_mask)
        else:
            z_T = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)

        diffusion_utils.assert_mean_zero_with_mask(z_T[:, :, :self.n_dims], node_mask)

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        # s = T-1 ... 0
        # t = T   ... 1
        z = z_T
        for s in tqdm(reversed(range(0, self.T))):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T

            z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise, deterministic=False)

        return z_T, z

    @torch.no_grad()
    def compute_latent_noise(self, z0, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False):
        diffusion_utils.assert_mean_zero_with_mask(z0[:, :, :self.n_dims], node_mask)

        z = z0
        for s in tqdm(range(0, self.T)):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T
            z = self.sample_p_zt_given_zs_reverse(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise, deterministic=False)

        return z

    def sample_p_zt_given_zs_reverse(self, s, t, zs, node_mask, edge_mask, context, fix_noise=False, deterministic=False):
        """Samples from zs ~ p(zs | zt). Only used during sampling."""
        gamma_s = self.gamma(s)
        gamma_t = self.gamma(t)

        sigma2_t_given_s, sigma_t_given_s, alpha_t_given_s = \
            self.sigma_and_alpha_t_given_s(gamma_t, gamma_s, zs)

        sigma_s = self.sigma(gamma_s, target_tensor=zs)
        sigma_t = self.sigma(gamma_t, target_tensor=zs)

        # Neural net prediction.
        # Here we're approximating
        eps_s = self.phi(zs, t, node_mask, edge_mask, context)

        # Compute mu for p(zs | zt).
        diffusion_utils.assert_mean_zero_with_mask(zs[:, :, :self.n_dims], node_mask)
        diffusion_utils.assert_mean_zero_with_mask(eps_s[:, :, :self.n_dims], node_mask)
        mu = alpha_t_given_s * zs + (sigma2_t_given_s / sigma_t) * eps_s
        #mu = zs / alpha_t_given_s - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_t

        if deterministic:
            zt = mu
            diffusion_utils.assert_correctly_masked(zt, node_mask)
            diffusion_utils.assert_mean_zero_with_mask(zt[:, :, :self.n_dims], node_mask)
        else:
            # Compute sigma for p(zs | zt).
            sigma = sigma_t_given_s * sigma_s / sigma_t

            # Sample zs given the paramters derived from zt.
            zt = self.sample_normal(mu, sigma, node_mask, fix_noise)

        # Project down to avoid numerical runaway of the center of gravity.
        zt = torch.cat(
            [diffusion_utils.remove_mean_with_mask(zt[:, :, :self.n_dims],
                                                   node_mask),
             zt[:, :, self.n_dims:]], dim=2
        )
        return zt


    ### ODE stuff
    def evaluate_vector_field(self, t, zt, node_mask, edge_mask, context, fix_noise=False):
        """Computes the vector field value at time t. Used for solving ODE for optimization."""
        gamma_t = self.gamma(t)

        sigma_t = self.sigma(gamma_t, target_tensor=zt)
        alpha_t = self.alpha(gamma_t, target_tensor=zt)

        s = 1e-5
        alpha_prime = - (1 - 2*s) * 2 * t / self.T
        alpha_prime = self.inflate_batch_array(alpha_prime, target=zt)

        # Neural net prediction.
        eps_t = self.phi(zt, t, node_mask, edge_mask, context)

        # Compute mu for p(zs | zt).
        diffusion_utils.assert_mean_zero_with_mask(zt[:, :, :self.n_dims], node_mask)
        diffusion_utils.assert_mean_zero_with_mask(eps_t[:, :, :self.n_dims], node_mask)
        zt_dot = alpha_prime * (zt / alpha_t - 2 * eps_t / (alpha_t * sigma_t))
        return -zt_dot

        #zt_dot = zt / alpha_t_given_s - zt - (sigma2_t_given_s / alpha_t_given_s / sigma_t) * eps_t

        # Compute sigma for p(zs | zt).
        #sigma = sigma_t_given_s * sigma_s / sigma_t

        # Sample zs given the paramters derived from zt.
        #zs = self.sample_normal(mu, sigma, node_mask, fix_noise)

        # Project down to avoid numerical runaway of the center of gravity.
        zt_dot = torch.cat(
            [diffusion_utils.remove_mean_with_mask(zt_dot[:, :, :self.n_dims],
                                                   node_mask),
             zt_dot[:, :, self.n_dims:]], dim=2
        )
        return zt_dot

    @torch.no_grad()
    def sample_from_ode(self, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, forward=False):
        """
        Draw samples from the generative model.
        """
        # This is what super().sample() does
        """
        Draw samples from the generative model.
        """
        if fix_noise:
            # Noise is broadcasted over the batch axis, useful for visualizations.
            z = self.sample_combined_position_feature_noise(1, n_nodes, node_mask)
        else:
            z = self.sample_combined_position_feature_noise(n_samples, n_nodes, node_mask)

        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        if forward:
            # from data to noise space
            # TODO: figure out correct range
            t_span = torch.linspace(0, 1, self.T+1).to(z.device)
        else:
            # from noise to data sapce
            t_span = torch.linspace(1, 0, self.T+1).to(z.device)

        node = NeuralODE(
            vector_field=torchdyn_wrapper(self, n_samples, node_mask, edge_mask, context),
            solver='euler'
        ).to(z.device)
        trajectory = node.trajectory(x=z, t_span=t_span)
        z_final = trajectory[-1]

        # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1.
        # s = T-1 ... 0
        # t = T   ... 1
        # for s in reversed(range(0, self.T)):
        #     s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
        #     t_array = s_array + 1
        #     s_array = s_array / self.T
        #     t_array = t_array / self.T

        #     z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

        # Finally sample p(x, h | z_0).
        x, h = self.sample_p_xh_given_z0(z_final, node_mask, edge_mask, context, fix_noise=fix_noise)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)

        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        z_x, z_h = x, h

        # See above
        # z_x, z_h = super().sample(n_samples, n_nodes, node_mask, edge_mask, context, fix_noise)

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        x = z_x

        return x, h
    
    @torch.no_grad()
    def sample_chain(self, n_samples, n_nodes, node_mask, edge_mask, context, keep_frames=None):
        """
        Draw samples from the generative model, keep the intermediate states for visualization purposes.
        """
        chain_flat = super().sample_chain(n_samples, n_nodes, node_mask, edge_mask, context, keep_frames)

        # xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
        # chain[0] = xh  # Overwrite last frame with the resulting x and h.

        # chain_flat = chain.view(n_samples * keep_frames, *z.size()[1:])

        chain = chain_flat.view(keep_frames, n_samples, *chain_flat.size()[1:])
        chain_decoded = torch.zeros(
            size=(*chain.size()[:-1], self.vae.in_node_nf + self.vae.n_dims), device=chain.device)

        for i in range(keep_frames):
            z_xh = chain[i]
            diffusion_utils.assert_mean_zero_with_mask(z_xh[:, :, :self.n_dims], node_mask)

            x, h = self.vae.decode(z_xh, node_mask, edge_mask, context)
            xh = torch.cat([x, h['categorical'], h['integer']], dim=2)
            chain_decoded[i] = xh
        
        chain_decoded_flat = chain_decoded.view(n_samples * keep_frames, *chain_decoded.size()[2:])

        return chain_decoded_flat

    def instantiate_first_stage(self, vae: EnHierarchicalVAE):
        if not self.trainable_ae:
            self.vae = vae.eval()
            self.vae.train = disabled_train
            for param in self.vae.parameters():
                param.requires_grad = False
        else:
            self.vae = vae.train()
            for param in self.vae.parameters():
                param.requires_grad = True

def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self
