import math

from tqdm import tqdm
import torch

from qm9.analyze_joint_training import BasicSmilesMetrics, build_2D_mols, smiles_from_2d_mols_list
from qm9.sampling import sample


def sample_from_ldm(model_sample, nodes_dist, args, device, dataset_info, prop_dist=None,
                     n_samples=1000, batch_size=100, enforce_unconditional_generation=False,
                     regressor_guidance=False):
    """
    Args:
        enforce_unconditional_generation (bool): only relevant when using a conditional model trained 
                        with condition dropout. If True, set conditioning vector to (0,1) to generate 
                        unconditional molecules.
    """
    model_sample.eval()
    batch_size = min(batch_size, n_samples)

    molecules = {'atom_types': [], 'formal_charges': [], 'positions': [], 'adjacency_matrices': [], 'node_mask': [], 'z_h': []}
    if args.context_node_nf > 0 or regressor_guidance:
        molecules['context_global'] = []

    n_batches = math.ceil(n_samples/batch_size) # account for remainder
    for i in tqdm(range(n_batches)):
        if i == n_batches - 1 and n_samples % batch_size != 0:
            n_mols = n_samples % batch_size
        else:
            n_mols = batch_size

        nodesxsample = nodes_dist.sample(n_mols)

        # some batches might throw an error
        batch_done = False
        while not batch_done:
            try:
                x, atom_types, formal_charges, adjacency_matrices, node_mask, edge_mask, z_h, context_global = sample(args, device, model_sample, dataset_info, prop_dist,
                                                        nodesxsample=nodesxsample, enforce_unconditional_generation=enforce_unconditional_generation, regressor_guidance=regressor_guidance)
            except Exception as e:
                if 'nan' in str(e):
                    # sample again
                    continue
                else:
                    raise e
            batch_done = True

        molecules['atom_types'].append(atom_types.detach().cpu())
        molecules['formal_charges'].append(formal_charges.detach().cpu())
        molecules['positions'].append(x.detach().cpu())
        molecules['adjacency_matrices'].append(adjacency_matrices.detach().cpu())
        molecules['node_mask'].append(node_mask.detach().cpu())
        molecules['z_h'].append(z_h.detach().cpu())
        if args.context_node_nf > 0 or regressor_guidance:
            molecules['context_global'].append(context_global.detach().cpu())

    molecules = {key: torch.cat(molecules[key], dim=0) for key in molecules}

    rdkit_mols = build_2D_mols(molecules, dataset_info)
    smiles = smiles_from_2d_mols_list(rdkit_mols)
    return molecules, smiles
