import copy

import torch
#from torchdyn.core import NeuralODE
import numpy as np
from tqdm import tqdm
from rdkit import Chem

from equivariant_diffusion.en_diffusion import EnVariationalDiffusion
from equivariant_diffusion import utils as diffusion_utils
from geo_ldm.vae import EnHierarchicalVAE
from optimization.ode_solver import torchdyn_wrapper
from qm9.utils import prepare_context
from visualization.kabsch import align, align_com
from geo_ldm.latent_diffuser import EnLatentDiffusion
from qm9.analyze_joint_training import build_2D_mols, smiles_from_2d_mols_list


class MoleculeOptimizer(EnLatentDiffusion):
    """
    The Molecule Optimizer Module.
    does useful tasks like scaffolding, optimization on molecules
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def set_device_and_dtype(self, device, dtype):
        self.device = device
        self.dtype = dtype

    def prepare_batch(self, batch, property_norms=None, conditioning=None):
        """
        takes in a batch of molecules and transforms it to the format expected by the encoder
        """
        x = batch['positions'].to(self.device, self.dtype)
        h = {'atomic_numbers_one_hot': batch['atomic_numbers_one_hot'].to(self.device, self.dtype), 
             'formal_charges_one_hot': batch['formal_charges_one_hot'].to(self.device, self.dtype)}
        node_mask = batch['atom_mask'].to(self.device, self.dtype).unsqueeze(2)
        edge_mask = batch['edge_mask'].to(self.device, self.dtype)
        if conditioning is not None:
            context = prepare_context(conditioning, batch, property_norms).to(self.device, self.dtype)
            diffusion_utils.assert_correctly_masked(context, node_mask)
        else:
            context = None

        x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        return x, h, node_mask, edge_mask, context

    #@torch.no_grad()
    def encode(self, x, h, node_mask, edge_mask, context=None, sample_encoding_with_gamma=True):
        # Encode data to latent space.
        # TODO: remove sigmas, not used
        z_x_mu, z_x_sigma, z_h_mu, z_h_sigma = self.vae.encode(x, h, node_mask, edge_mask, context)
        # Infer latent z.
        z_xh_mean = torch.cat([z_x_mu, z_h_mu], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh_mean, node_mask)

        if sample_encoding_with_gamma:
            # Compute fixed sigma values.
            t_zeros = torch.zeros(size=(x.size(0), 1), device=x.device)
            gamma_0 = self.inflate_batch_array(self.gamma(t_zeros), x)
            sigma_0 = self.sigma(gamma_0, x)

            z_xh_sigma = sigma_0
            # z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
            z_xh = self.vae.sample_normal(z_xh_mean, z_xh_sigma, node_mask)
        else:
            # Infer latent z.
            z_xh_sigma = torch.cat([z_x_sigma.expand(-1, -1, 3), z_h_sigma], dim=2)
            z_xh = self.sample_normal(z_xh_mean, z_xh_sigma, node_mask)

        z_xh = z_xh.detach()  # Always keep the encoder fixed.
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        z_x = z_xh[:, :, :self.n_dims]
        z_h = z_xh[:, :, self.n_dims:]
        diffusion_utils.assert_mean_zero_with_mask(z_x, node_mask)

        return z_x, z_h

    #@torch.no_grad()
    def sample_noisy_batch(self, nodesxsample, fix_noise=False):
        max_n_nodes = nodesxsample.max().int()
        batch_size = len(nodesxsample)

        node_mask = torch.zeros(batch_size, max_n_nodes)
        for i in range(batch_size):
            node_mask[i, 0:nodesxsample[i]] = 1

        # Compute edge_mask

        edge_mask = node_mask.unsqueeze(1) * node_mask.unsqueeze(2)
        diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0)
        edge_mask *= diag_mask
        edge_mask = edge_mask.view(batch_size * max_n_nodes * max_n_nodes, 1).to(self.device, self.dtype)
        node_mask = node_mask.unsqueeze(2).to(self.device, self.dtype)

        # TODO: figure out how to deal with context

        # sample random noise for the new nodes
        if fix_noise:
            # Noise is broadcasted over the batch axis, useful for visualizations.
            z = self.sample_combined_position_feature_noise(1, max_n_nodes, node_mask)
        else:
            z = self.sample_combined_position_feature_noise(batch_size, max_n_nodes, node_mask)

        diffusion_utils.assert_correctly_masked(z, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        return z, node_mask, edge_mask

    #@torch.no_grad()
    def overwrite_with_mask(self, z, node_mask, z_scaff, node_mask_scaff):
        n_samples = z.size(0)
        n_atoms_scaff = node_mask_scaff.sum(1).squeeze().int()

        # TODO: can parallelize across samples (?)
        for i in range(n_samples):
            n = n_atoms_scaff[i]

            # TODO: implement on with torch tensors on GPU
            # _, z_scaff_aligned = align(z[i, :n, :self.n_dims].cpu(), z_scaff[i, :n, :self.n_dims].cpu())
            # z_scaff_aligned = torch.Tensor(z_scaff_aligned).to(z.device)

            z_scaff_aligned = align_com(z[i, :n, :self.n_dims], z_scaff[i, :n, :self.n_dims])
            z_scaff_aligned = torch.cat([
                z_scaff_aligned,
                z_scaff[i, :n, self.n_dims:]
            ], dim=1)

            z[i, :n, :] = z_scaff_aligned

        z = torch.cat(
            [diffusion_utils.remove_mean_with_mask(z[:, :, :self.n_dims], node_mask),
            z[:, :, self.n_dims:]], dim=2
        )
        diffusion_utils.assert_correctly_masked(z, node_mask)
        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)
        return z

    #@torch.no_grad()
    def scaffold(self, x_scaff, h_scaff, node_mask_scaff, edge_mask_scaff, context_scaff, 
                add_nodes, t_intermediate=0, resampling_times=10, use_jumps=True,
                jump_len=40, jump_n_sample=10):
        n_samples = x_scaff.size(0)

        # the new molecules will have more atoms
        nodesxsample = node_mask_scaff.sum(1).squeeze().int() + add_nodes
        # sample initial noise
        z, node_mask, edge_mask = self.sample_noisy_batch(nodesxsample)

        # overwrite first part with old nodes
        # we start with t=T
        z_T_scaff = self.sample_z_t_given_z_0(x_scaff, h_scaff, self.T, node_mask_scaff, n_samples)
        z = self.overwrite_with_mask(z, node_mask, z_T_scaff, node_mask_scaff)

        if context_scaff is not None:
            max_extra_nodes = add_nodes.max().item()
            one_context = context_scaff[:, 0:1, :]
            context = torch.cat((context_scaff, one_context.repeat(1, max_extra_nodes, 1)), dim=1)
            assert context.size(0) == z.size(0)
            assert context.size(1) == z.size(1)
        else:
            context = None

        z_dict = {'z': z, 'node_mask': node_mask, 'edge_mask': edge_mask, 'context': context}
        scaff_dict = {'x_scaff': x_scaff, 'h_scaff': h_scaff, 'node_mask_scaff': node_mask_scaff}
        if use_jumps:
            z = self.repaint_with_jumps(z_dict, scaff_dict, t_intermediate, jump_len=jump_len, jump_n_sample=jump_n_sample)
        else:
            z = self.repaint_without_jumps(z_dict, scaff_dict, t_intermediate, resampling_times=resampling_times)
        return z, node_mask, edge_mask, context

    #@torch.no_grad()
    def repaint_without_jumps(self, z_dict, scaff_dict, t_intermediate, resampling_times=10, fix_noise=False):
        z = z_dict['z']
        node_mask = z_dict['node_mask']
        edge_mask = z_dict['edge_mask']
        context = z_dict['context']
        n_samples = z.size(0)

        x_scaff = scaff_dict['x_scaff']
        h_scaff = scaff_dict['h_scaff']
        node_mask_scaff = scaff_dict['node_mask_scaff']

        # s = T-1 ... t
        # t = T   ... t+1
        for s in tqdm(reversed(range(t_intermediate, self.T))):
            for u in range(resampling_times):
                s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
                t_array = s_array + 1
                s_array = s_array / self.T
                t_array = t_array / self.T

                z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

                # now z is = zs
                # we overwrite the first part with z_s
                z_s_scaff = self.sample_z_t_given_z_0(x_scaff, h_scaff, s, node_mask_scaff, n_samples)
                z = self.overwrite_with_mask(z, node_mask, z_s_scaff, node_mask_scaff)

                if u < resampling_times - 1:
                    z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                    diffusion_utils.assert_correctly_masked(z, node_mask)
                    diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        if t_intermediate == 0:
            # one last call to the model on t=0
            z = self.sample_xh_from_z0(z, node_mask, edge_mask, context=context, fix_noise=fix_noise)

            # finally overwrite with the scaffold to make sure it is there.
            print('Finally, overwriting with the scaffold')
            z_s_scaff = torch.cat([x_scaff, h_scaff], dim=2)
            z = self.overwrite_with_mask(z, node_mask, z_s_scaff, node_mask_scaff)
        
        return z

    #@torch.no_grad()
    def repaint_with_jumps(self, z_dict, scaff_dict, t_intermediate, jump_len=40, jump_n_sample=10, fix_noise=False):
        z = z_dict['z']
        node_mask = z_dict['node_mask']
        edge_mask = z_dict['edge_mask']
        context = z_dict['context']
        n_samples = z.size(0)

        x_scaff = scaff_dict['x_scaff']
        h_scaff = scaff_dict['h_scaff']
        node_mask_scaff = scaff_dict['node_mask_scaff']

        # adapted for t_intermediate, if it is 0, nothing changes
        times = self.get_scaffolding_schedule(self.T - t_intermediate, jump_len=jump_len, jump_n_sample=jump_n_sample)
        times = [t + t_intermediate for t in times]
        print(f'number of steps in repaint with jumps: {len(times)}')
        print(f'number of backward diffusion calls in repaint with jumps: {len([time_pair for time_pair in list(zip(times[:-1], times[1:])) if time_pair[0] > time_pair[1]])}')
        for t_last, t_cur in tqdm(zip(times[:-1], times[1:])):
            if t_last > t_cur:
                # reverse diffusion
                s_array = torch.full((n_samples, 1), fill_value=t_cur, device=z.device)
                t_array = s_array + 1
                s_array = s_array / self.T
                t_array = t_array / self.T

                z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

                # now z is = zs
                # we overwrite the first part with z_s
                z_s_scaff = self.sample_z_t_given_z_0(x_scaff, h_scaff, t_cur, node_mask_scaff, n_samples)
                z = self.overwrite_with_mask(z, node_mask, z_s_scaff, node_mask_scaff)

            else: # t_last < t_cur
                # forward diffusion
                s_array = torch.full((n_samples, 1), fill_value=t_last, device=z.device)
                t_array = s_array + 1
                s_array = s_array / self.T
                t_array = t_array / self.T

                z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                diffusion_utils.assert_correctly_masked(z, node_mask)
                diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        if t_intermediate == 0:
            # one last call to the model on t=0
            z = self.sample_xh_from_z0(z, node_mask, edge_mask, context=context, fix_noise=fix_noise)

            # finally overwrite with the scaffold to make sure it is there.
            print('Finally, overwriting with the scaffold')
            z_s_scaff = torch.cat([x_scaff, h_scaff], dim=2)
            z = self.overwrite_with_mask(z, node_mask, z_s_scaff, node_mask_scaff)

        return z

    def optimize_z_xh(self, scaffold_graph, z_x, z_h, h, node_mask, edge_mask, context, early_stopping=True):
        adj_gt = scaffold_graph['adj_matrix']
        adj_gt[adj_gt==1.5] = 4
        adj_gt = adj_gt.to(self.device, int)

        z0 = torch.cat([z_x, z_h], dim=2)
        optimizer = torch.optim.Adam([z0], lr=0.1)
        z0.requires_grad_()
        for it in range(500): 
            optimizer.zero_grad()
            adj_recon, h_recon = self.vae.decoder._forward(z0, node_mask, edge_mask, context)
            n_atom_types = h['atomic_numbers_one_hot'].size(-1)
            if early_stopping and torch.all(adj_recon.argmax(-1) == adj_gt) and \
              torch.all(h_recon[:,:,:n_atom_types].argmax(-1) == h['atomic_numbers_one_hot'].argmax(-1)) and \
              torch.all(h_recon[:,:,n_atom_types:].argmax(-1) == h['formal_charges_one_hot'].argmax(-1)):
                break

            loss = self.vae.compute_2d_reconstruction_error(adj_recon, h_recon, adj_gt, h, node_mask, edge_mask)
            loss.backward()
            optimizer.step()
        print(f'Stopping z_xh optimization after {it} iterations')
        z_x = diffusion_utils.remove_mean_with_mask(z0[:, :, :3], node_mask)
        z_h = z0[:, :, 3:]
        return z_x.detach(), z_h.detach()

    # this should be the interface to the user
    #@torch.no_grad()
    def complete_scaffold(self, scaffold_graph, fix_noise=False, add_nodes=None, resampling_times=10, use_jumps=True,
                          jump_len=40, jump_n_sample=10, dataset_info=None):
        if add_nodes is not None:
            add_nodes = add_nodes.to(self.device, torch.int)
        x, h, node_mask, edge_mask, context = self.prepare_batch(scaffold_graph, property_norms=None)
        z_x, z_h = self.encode(x, h, node_mask, edge_mask, context, sample_encoding_with_gamma=False)

        z_x, z_h = self.optimize_z_xh(scaffold_graph, z_x, z_h, h, node_mask, edge_mask, context, early_stopping=True)
        x_rec, h_rec = self.decode_from_z_xh(torch.cat([z_x, z_h], dim=2), node_mask, edge_mask, context=None, fix_noise=False)
        smiles_rec = self.get_smiles_from_x_h(x_rec, h_rec, node_mask, dataset_info)
        print(f'smiles_rec: {smiles_rec}')
        print(f'Scaffolding on {Chem.MolToSmiles(Chem.MolFromSmiles(smiles_rec[0]))}')

        z, node_mask, edge_mask, context = self.scaffold(z_x, z_h, node_mask, edge_mask, context, add_nodes, t_intermediate=0, 
                                                resampling_times=resampling_times, use_jumps=use_jumps, 
                                                jump_len=jump_len, jump_n_sample=jump_n_sample)

        # TODO: deal with context
        #x, h = self.decode_from_z0(z, node_mask, edge_mask, context)
        x, h = self.decode_from_z_xh(z, node_mask, edge_mask, context)

        return x, h, node_mask

    # this should be the interface to the user
    #@torch.no_grad()
    def optimize(self, scaffold_graph, t_add_nodes, t_optimize, n_steps, step_size, fix_noise=False, add_nodes=None, 
                    resampling_times=10, use_jumps=True,
                    jump_len=40, jump_n_sample=10, property_norms=None, conditioning=None):
        if add_nodes is not None:
            add_nodes = add_nodes.to(self.device, torch.int)
        x, h, node_mask, edge_mask, context = self.prepare_batch(scaffold_graph, property_norms=property_norms, conditioning=conditioning)
        z_x, z_h = self.encode(x, h, node_mask, edge_mask, context)
        n_samples, n_nodes = z_x.size(0), z_x.size(1)
        max_extra_nodes = add_nodes.max().item()

        intermediate_states = {}

        # first add atoms and harmonize them
        if add_nodes is not None:
            z_x, z_h, node_mask, edge_mask, context = self.add_nodes(z_x, z_h, t_add_nodes, node_mask, edge_mask, context, 
                                add_nodes=add_nodes, resampling_times=resampling_times, use_jumps=use_jumps, 
                                jump_len=jump_len, jump_n_sample=jump_n_sample)

            intermediate_states[0] = (copy.deepcopy(z_x), copy.deepcopy(z_h), copy.deepcopy(node_mask))
            z_h = z_h['z_h']

        if t_optimize == 0:
            # we're just adding nodes
            return intermediate_states

        # iteratively optimize
        for step in tqdm(range(n_steps)):
            noise = torch.randn(n_samples, 1).unsqueeze(1).repeat(1, n_nodes+max_extra_nodes, 1).to(self.device)
            context = context + step_size + (step_size/10) * noise
            context = context * node_mask
            z_x, z_h = self.edit(z_x, z_h, t_optimize, n_samples, n_nodes, node_mask, edge_mask, context)
            intermediate_states[step+1] = (copy.deepcopy(z_x), copy.deepcopy(z_h), copy.deepcopy(node_mask))
            z_h = z_h['z_h']

        return intermediate_states

    #@torch.no_grad()
    def add_nodes(self, z_x, z_h, t_intermediate, node_mask, edge_mask, context, add_nodes,
                  resampling_times=10, use_jumps=True, jump_len=40, jump_n_sample=10, fix_noise=False):

        # scaffold on z_t and denoise new nodes from T ... t_intermediate
        z, node_mask, edge_mask, context = self.scaffold(z_x, z_h, node_mask, edge_mask, context, add_nodes, t_intermediate=t_intermediate, 
                                                resampling_times=resampling_times, use_jumps=use_jumps, 
                                                jump_len=jump_len, jump_n_sample=jump_n_sample)

        # denoise whole molecule to harmonize from t_intermediate ... 0
        n_samples = z.size(0)
        for s in tqdm(reversed(range(0, t_intermediate))):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T

            z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

        x, h = self.decode_from_z0(z, node_mask, edge_mask, context)

        return x, h, node_mask, edge_mask, context

    #@torch.no_grad()
    def edit(self, x, h, t, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, add_nodes=None, visualize=False):
        if visualize:
            states = {}
            states[0] = torch.cat([x, h], dim=2)
        if add_nodes is not None:
            # if we're adding nodes, we fix the input molecule at time t and denoise the additional nodes from T ... t
            add_nodes = add_nodes.to(self.device, torch.int)
            # we do this first part without guidance
            guidance_fn = self.guidance_fn
            self.guidance_fn = None
            # we also reduce the number of steps to allow for resampling steps without additional costs
            T = self.T
            self.T = T // 10

            # scaffold on z_t and denoise new nodes from T ... t_intermediate
            z, node_mask, edge_mask, context = self.scaffold(x, h, node_mask, edge_mask, context, add_nodes, t_intermediate=t//10, 
                                                    resampling_times=10, use_jumps=True, 
                                                    jump_len=5, jump_n_sample=10)
            if visualize:
                states[t] = z
            # at this point, z contains the noised version of the input mol + new atoms
            
            # we put back the guidance_fn and T
            self.guidance_fn = guidance_fn
            self.T = T
        else:
            # if we're not adding nodes, we jump to the time t
            # instead of sampling random noise, we add noise to the input mols for editing
            t_batch = torch.full((n_samples, 1), fill_value=t, device=x.device).float() / self.T
            gamma_t = self.inflate_batch_array(self.gamma(t_batch), x)
            alpha_t = self.alpha(gamma_t, x)
            sigma_t = self.sigma(gamma_t, x)

            # Sample zt ~ Normal(alpha_t x, sigma_t)
            eps = self.sample_combined_position_feature_noise(
                n_samples=x.size(0), n_nodes=x.size(1), node_mask=node_mask)

            xh = torch.cat([x, h], dim=2)

            # Sample z_t given x, h for timestep t, from q(z_t | x, h)
            z = alpha_t * xh + sigma_t * eps

            diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        for s in tqdm(reversed(range(0, t))):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T

            z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

        if visualize:
            states[-1] = z
        x, h = self.decode_from_z0(z, node_mask, edge_mask, context)

        if visualize:
            return x, h, node_mask, edge_mask, states
        return x, h, node_mask, edge_mask

    def edit_with_guidance(self, x, h, t, n_samples, n_nodes, node_mask, edge_mask, context, fix_noise=False, prop_guidance=None, structure_guidance=None):
        # instead of sampling random noise, we add noise to the input mols for editing
        t_batch = torch.full((n_samples, 1), fill_value=t, device=x.device).float() / self.T
        gamma_t = self.inflate_batch_array(self.gamma(t_batch), x)
        alpha_t = self.alpha(gamma_t, x)
        sigma_t = self.sigma(gamma_t, x)

        # Sample zt ~ Normal(alpha_t x, sigma_t)
        eps = self.sample_combined_position_feature_noise(
            n_samples=x.size(0), n_nodes=x.size(1), node_mask=node_mask)

        xh = torch.cat([x, h], dim=2)

        # Sample z_t given x, h for timestep t, from q(z_t | x, h)
        z = alpha_t * xh + sigma_t * eps

        diffusion_utils.assert_mean_zero_with_mask(z[:, :, :self.n_dims], node_mask)

        for s in tqdm(reversed(range(0, t))):
            s_array = torch.full((n_samples, 1), fill_value=s, device=z.device)
            t_array = s_array + 1
            s_array = s_array / self.T
            t_array = t_array / self.T

            # if s % 5 == 0:
            #     self.guidance_fn = structure_guidance
            # else:
            #     self.guidance_fn = prop_guidance

            self.guidance_fn = structure_guidance
            z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)

            self.guidance_fn = prop_guidance
            for _ in range(5):
                z = self.sample_p_zs_given_zt(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)
                z = self.sample_p_zt_given_zs(s_array, t_array, z, node_mask, edge_mask, context, fix_noise=fix_noise)


        x, h = self.decode_from_z0(z, node_mask, edge_mask, context)

        return x, h


    # this should be the interface to the user
    #@torch.no_grad()
    def reconstruct(self, scaffold_graph, property_norms=None, conditioning=None, fix_noise=False, add_nodes=None, resampling_times=1, use_jumps=False):
        x, h, node_mask, edge_mask, context = self.prepare_batch(scaffold_graph, property_norms=property_norms, conditioning=conditioning)
        z_x, z_h = self.encode(x, h, node_mask, edge_mask, context)
        z = torch.cat([z_x, z_h], dim=2)

        # TODO: deal with context
        x, h = self.decode_from_z0(z, node_mask, edge_mask, context)

        return x, h, node_mask

    def get_smiles_from_x_h(self, x, h, node_mask, dataset_info):
        molecules = {'atom_types': h['atom_types'] , 
                    'formal_charges': h['formal_charges'], 
                    'positions': x, 
                    'adjacency_matrices': h['adjacency_matrices'], 
                    'node_mask': node_mask, 
                    'z_h': h['z_h']}
        
        rdkit_mols = build_2D_mols(molecules, dataset_info, use_ghost_nodes=False)
        smiles = smiles_from_2d_mols_list(rdkit_mols)
        return smiles


    def decode_from_z0(self, z, node_mask, edge_mask, context=None, fix_noise=False):
        # Finally sample p(x, h | z_0).
        # TODO: check if necessary!!
        x, h = self.sample_p_xh_given_z0(z, node_mask, edge_mask, context, fix_noise=fix_noise)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)

        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        z_x, z_h = x, h

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context, valency_check=False)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_h}
        x = z_x

        return x, h

    def sample_xh_from_z0(self, z, node_mask, edge_mask, context=None, fix_noise=False):
        # Finally sample p(x, h | z_0).
        # TODO: check if necessary!!
        x, h = self.sample_p_xh_given_z0(z, node_mask, edge_mask, context, fix_noise=fix_noise)

        diffusion_utils.assert_mean_zero_with_mask(x, node_mask)

        max_cog = torch.sum(x, dim=1, keepdim=True).abs().max().item()
        if max_cog > 5e-2:
            print(f'Warning cog drift with error {max_cog:.3f}. Projecting '
                  f'the positions down.')
            x = diffusion_utils.remove_mean_with_mask(x, node_mask)

        z_x, z_h = x, h

        z_xh = torch.cat([z_x, z_h], dim=2)
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)
        return z_xh

    def decode_from_z_xh(self, z_xh, node_mask, edge_mask, context=None, fix_noise=False):
        diffusion_utils.assert_correctly_masked(z_xh, node_mask)

        if self.norm_values is not None:
            z_xh = self.unnormalize_z(z_xh, node_mask)
        adj_recon, atom_types_recon, formal_charges_recon = self.vae.decode(z_xh, node_mask, edge_mask, context, valency_check=False)

        h = {'adjacency_matrices': adj_recon, 'atom_types': atom_types_recon, 'formal_charges': formal_charges_recon, 'z_h': z_xh[:, :, self.n_dims:]}
        x = z_xh[:, :, :self.n_dims]

        return x, h
