from rdkit import Chem

import wandb
from equivariant_diffusion.utils import assert_mean_zero_with_mask, remove_mean_with_mask,\
    assert_correctly_masked, sample_center_gravity_zero_gaussian_with_mask
from bond_type_prediction.utils import get_adj_matrix_from_batch
import numpy as np
import qm9.visualizer as vis
from qm9.analyze import analyze_stability_for_molecules
from qm9.analyze_joint_training import BasicSmilesMetrics, build_2D_mols, smiles_from_2d_mols_list, is_valid
from qm9.sampling import sample_chain, sample, sample_sweep_conditional
import utils
import qm9.utils as qm9utils
from qm9 import losses
import time
import torch
from tqdm import tqdm

from bond_type_prediction.eval import reconstruction_error
from conditional_generation.penalized_logP import compute_penalized_logP
from optimization.props.properties import get_morgan_fingerprint, similarity_tensors, penalized_logp, drd2, qed, tpsa 


def train_epoch(args, loader, epoch, model, model_dp, model_ema, ema, device, dtype, property_norms, property_norms_regression,
                optim, nodes_dist, gradnorm_queue, dataset_info, prop_dist, prop_encoder):
    if args.train_diffusion:
        model_name = 'diffusion_model'
    elif args.train_regressor:
        model_name = 'regression_model'
        for prop in args.regression_target:
            model_name += '_' + prop
    else:
        model_name = 'vae'

    model_dp.train()
    model.train()

    nll_epoch = []
    n_iterations = len(loader)
    for i, data in enumerate(loader):
        x = data['positions'].to(device, dtype)
        node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
        edge_mask = data['edge_mask'].to(device, dtype)
        atomic_numbers_one_hot = data['atomic_numbers_one_hot'].to(device, dtype)
        formal_charges_one_hot = data['formal_charges_one_hot'].to(device, dtype)
        extra_atom_features = data['atomic_features'].to(device, dtype) if args.use_extra_atomic_features else None
        if model_name == 'vae' or args.trainable_ae:
            # converting between adj_list and adj_matrix is slow so we cached them ;)!
            if 'adj_matrix' in data:
                adj_gt = data['adj_matrix'].to(device, int)
            else:
                adj_gt = get_adj_matrix_from_batch(data).to(device)
        else:
            # unless we're training vae, we do not need the adjacency matrix, so we skip loading it
            adj_gt = None

        x = remove_mean_with_mask(x, node_mask)

        if args.augment_noise > 0:
            # Add noise eps ~ N(0, augment_noise) around points.
            eps = sample_center_gravity_zero_gaussian_with_mask(x.size(), x.device, node_mask)
            x = x + eps * args.augment_noise
            x = remove_mean_with_mask(x, node_mask)

        if args.data_augmentation:
            x = utils.random_rotation(x).detach()

        check_mask_correct([x, atomic_numbers_one_hot, formal_charges_one_hot], node_mask)
        assert_mean_zero_with_mask(x, node_mask)

        h = {'atomic_numbers_one_hot': atomic_numbers_one_hot, 'formal_charges_one_hot': formal_charges_one_hot, 'extra_atom_features': extra_atom_features}

        if len(args.conditioning) > 0:
            dropout_p = 0.1
            context = qm9utils.prepare_context(args.conditioning, data, property_norms, prop_encoder=prop_encoder, condition_dropout=args.condition_dropout, dropout_p=dropout_p).to(device, dtype)
            assert_correctly_masked(context, node_mask)
        else:
            context = None

        if len(args.regression_target) > 0:
            if args.regression_target in ['morgan_fingerprint']:
                regressor_target = qm9utils.prepare_classification_target(args.regression_target, data, property_norms_regression).to(device, dtype)
            else:
                regressor_target = qm9utils.prepare_regression_target(args.regression_target, data, property_norms_regression).to(device, dtype)
        else:
            regressor_target = None

        optim.zero_grad()

        # transform batch through flow
        nll, reg_term, mean_abs_z = losses.compute_loss_and_nll(args, model, nodes_dist,
                                                                x, h, node_mask, edge_mask, context, regressor_target, adj_gt)
        # standard nll from forward KL
        loss = nll + args.ode_regularization * reg_term
        if args.train_diffusion and ( (epoch > 0 and loss >= 10) or torch.isnan(loss)):
            print(f'Encountered high or nan loss value: {loss}. Skipping current batch')
            continue

        try:
            loss.backward()
        except Exception as e:
            print('Cannot take gradients! Saving model and batch')
            print(f'loss: {loss}')
            print(e)
            utils.save_model(model, f'outputs/{args.exp_name}/divergent_{model_name}.npy')
            torch.save(data, f'outputs/{args.exp_name}/divergent_batch.npy')
            print('Cannot take gradients! Skipping current batch')
            continue

        if args.clip_grad:
            try:
                grad_norm = utils.gradient_clipping(model, gradnorm_queue, train_diffusion=args.train_diffusion)
            except Exception as e:
                print('Gradients are bad! Saving model and batch')
                print(f'loss: {loss}')
                print(e)
                utils.save_model(model, f'outputs/{args.exp_name}/bad_gradient_{model_name}.npy')
                torch.save(data, f'outputs/{args.exp_name}/bad_gradient_batch.npy')
                print(f'Encountered infinite gradients. Skipping current batch')
                continue

        else:
            grad_norm = 0.

        optim.step()
        # if after taking an optimizer step, any one of the model's parameters becomes nan, we stop the training.
        if any([torch.any(torch.isnan(p.data)) for p in model.parameters()]):
            raise Exception("Some of the model's weights are NaN. Stopping the training.")

        # Update EMA if enabled.
        if args.ema_decay > 0:
            ema.update_model_average(model_ema, model)

        # TODO: fix sampling
        if args.verbose and i % args.n_report_steps == 0:
            print(f"\rEpoch: {epoch}, iter: {i}/{n_iterations}, "
                  f"Loss {loss.item():.2f}, NLL: {nll.item():.2f}, "
                  f"RegTerm: {reg_term.item():.1f}, "
                  f"GradNorm: {grad_norm:.1f}, ")
        nll_epoch.append(nll.item())
        # TODO: fix this, this runs at the beginning of the 10th epoch
        if False and (epoch % args.test_epochs == 0) and (i % args.visualize_every_batch == 0) and not (epoch == 0 and i == 0) and args.train_diffusion:
            start = time.time()
            if len(args.conditioning) > 0:
                save_and_sample_conditional(args, device, model_ema, prop_dist, dataset_info, epoch=epoch)
            save_and_sample_chain(model_ema, args, device, dataset_info, prop_dist, epoch=epoch,
                                  batch_id=str(i))
            sample_different_sizes_and_save(model_ema, nodes_dist, args, device, dataset_info,
                                            prop_dist, epoch=epoch)
            print(f'Sampling took {time.time() - start:.2f} seconds')

            vis.visualize(f"outputs/{args.exp_name}/epoch_{epoch}_{i}", dataset_info=dataset_info, wandb=wandb)
            vis.visualize_chain(f"outputs/{args.exp_name}/epoch_{epoch}_{i}/chain/", dataset_info, wandb=wandb)
            if len(args.conditioning) > 0:
                vis.visualize_chain("outputs/%s/epoch_%d/conditional/" % (args.exp_name, epoch), dataset_info,
                                    wandb=wandb, mode='conditional')
        wandb.log({f"{model_name}/Batch NLL": nll.item()}, commit=True)
    wandb.log({f"{model_name}/Epoch": epoch}, commit=False)
    wandb.log({f"{model_name}/Train Epoch NLL": np.mean(nll_epoch)}, commit=False)


def check_mask_correct(variables, node_mask):
    for i, variable in enumerate(variables):
        if len(variable) > 0:
            assert_correctly_masked(variable, node_mask)


def test(args, loader, epoch, eval_model, device, dtype, property_norms, property_norms_regression, nodes_dist, partition='Test', prop_encoder=None):
    eval_model.eval()
    with torch.no_grad():
        nll_epoch = 0
        n_samples = 0

        n_iterations = len(loader)

        for i, data in enumerate(loader):
            x = data['positions'].to(device, dtype)
            batch_size = x.size(0)
            node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
            edge_mask = data['edge_mask'].to(device, dtype)
            atomic_numbers_one_hot = data['atomic_numbers_one_hot'].to(device, dtype)
            formal_charges_one_hot = data['formal_charges_one_hot'].to(device, dtype)
            extra_atom_features = data['atomic_features'].to(device, dtype) if args.use_extra_atomic_features else None
            if (not args.train_diffusion and not args.train_regressor) or args.trainable_ae:
                if 'adj_matrix' in data:
                    adj_gt = data['adj_matrix'].to(device, int)
                else:
                    adj_gt = get_adj_matrix_from_batch(data).to(device)
            else:
                adj_gt = None

            if args.augment_noise > 0:
                # Add noise eps ~ N(0, augment_noise) around points.
                eps = sample_center_gravity_zero_gaussian_with_mask(x.size(),
                                                                    x.device,
                                                                    node_mask)
                x = x + eps * args.augment_noise

            x = remove_mean_with_mask(x, node_mask)
            check_mask_correct([x, atomic_numbers_one_hot, formal_charges_one_hot], node_mask)
            assert_mean_zero_with_mask(x, node_mask)

            h = {'atomic_numbers_one_hot': atomic_numbers_one_hot, 'formal_charges_one_hot': formal_charges_one_hot, 'extra_atom_features': extra_atom_features}

            if len(args.conditioning) > 0:
                dropout_p = 0.1
                context = qm9utils.prepare_context(args.conditioning, data, property_norms, prop_encoder=prop_encoder, condition_dropout=args.condition_dropout, dropout_p=dropout_p).to(device, dtype)
                assert_correctly_masked(context, node_mask)
            else:
                context = None

            if len(args.regression_target) > 0:
                if args.regression_target in ['morgan_fingerprint']:
                    regressor_target = qm9utils.prepare_classification_target(args.regression_target, data, property_norms_regression).to(device, dtype)
                else:
                    regressor_target = qm9utils.prepare_regression_target(args.regression_target, data, property_norms_regression).to(device, dtype)
            else:
                regressor_target = None

            # transform batch through flow
            # TODO: update for joint_training
            nll, _, _ = losses.compute_loss_and_nll(args, eval_model, nodes_dist, x, h,
                                                    node_mask, edge_mask, context, regressor_target, adj_gt)
            # standard nll from forward KL

            nll_epoch += nll.item() * batch_size
            n_samples += batch_size
            if args.verbose and i % args.n_report_steps == 0:
                print(f"\r {partition} NLL \t epoch: {epoch}, iter: {i}/{n_iterations}, "
                      f"NLL: {nll_epoch/n_samples:.2f}")

    return nll_epoch/n_samples


def eval_vae_reconstruction(args, property_norms, vae_model, loader, device, dtype, mode='Val', log_to_wandb=True, prop_encoder=None, inject_noise=False):
    vae_model.eval()
    results = {}
    with torch.no_grad():
        for i, data in enumerate(loader):
            x = data['positions'].to(device, dtype)
            node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
            edge_mask = data['edge_mask'].to(device, dtype)
            atomic_numbers_one_hot = data['atomic_numbers_one_hot'].to(device, dtype)
            formal_charges_one_hot = data['formal_charges_one_hot'].to(device, dtype)
            extra_atom_features = data['atomic_features'].to(device, dtype) if args.use_extra_atomic_features else None
            if 'adj_matrix' in data:
                adj_gt = data['adj_matrix'].to(device, int)
            else:
                adj_gt = get_adj_matrix_from_batch(data).to(device)

            x = remove_mean_with_mask(x, node_mask)

            check_mask_correct([x, atomic_numbers_one_hot, formal_charges_one_hot], node_mask)
            assert_mean_zero_with_mask(x, node_mask)

            h = {'atomic_numbers_one_hot': atomic_numbers_one_hot, 'formal_charges_one_hot': formal_charges_one_hot, 'extra_atom_features': extra_atom_features}

            if len(args.conditioning) > 0:
                context = qm9utils.prepare_context(args.conditioning, data, property_norms, prop_encoder=prop_encoder).to(device, dtype)
                assert_correctly_masked(context, node_mask)
            else:
                context = None

            adj_recon, atom_types_recon, formal_charges_recon = vae_model.reconstruct(x, h, node_mask, edge_mask, context=context, inject_noise=inject_noise)

            results_batch = reconstruction_error(adj_recon, atom_types_recon, formal_charges_recon, adj_gt, data)
            for key in results_batch:
                if key in results:
                    results[key] += results_batch[key]
                else:
                    results[key] = results_batch[key]

    results['edge_accuracy'] = (results['correct_edges'] / results['total_edges']) * 100
    results['atom_types_accuracy'] = (results['correct_atom_types'] / results['total_atoms']) * 100
    results['formal_charges_accuracy'] = (results['correct_formal_charges'] / results['total_atoms']) * 100
    results['molecule_accuracy'] = (results['correct_molecules'] / results['total_molecules']) * 100

    if log_to_wandb:
        for metric in ['edge_accuracy', 'atom_types_accuracy', 'formal_charges_accuracy', 'molecule_accuracy']:
            wandb.log({f'vae/{mode}_{metric}': results[metric]})

    return results

def eval_regression_model(args, property_norms, property_norms_regression, regression_model, loader, device, dtype, mode='Val', log_to_wandb=True, prop_encoder=None, t_lower=None, t_upper=None,):
    loss_l1 = torch.nn.L1Loss(reduction='none')
    regression_model.eval()
    maes = []
    if args.regression_target in ['morgan_fingerprint']:
        acc_negs = []
        acc_poss = []

    with torch.no_grad():
        for i, data in enumerate(loader):
            x = data['positions'].to(device, dtype)
            node_mask = data['atom_mask'].to(device, dtype).unsqueeze(2)
            edge_mask = data['edge_mask'].to(device, dtype)
            atomic_numbers_one_hot = data['atomic_numbers_one_hot'].to(device, dtype)
            formal_charges_one_hot = data['formal_charges_one_hot'].to(device, dtype)
            extra_atom_features = data['atomic_features'].to(device, dtype) if args.use_extra_atomic_features else None

            x = remove_mean_with_mask(x, node_mask)

            check_mask_correct([x, atomic_numbers_one_hot, formal_charges_one_hot], node_mask)
            assert_mean_zero_with_mask(x, node_mask)

            h = {'atomic_numbers_one_hot': atomic_numbers_one_hot, 'formal_charges_one_hot': formal_charges_one_hot, 'extra_atom_features': extra_atom_features}

            if len(args.conditioning) > 0:
                context = qm9utils.prepare_context(args.conditioning, data, property_norms, prop_encoder=prop_encoder).to(device, dtype)
                assert_correctly_masked(context, node_mask)
            else:
                context = None

            pred = regression_model.compute_pred(x, h, node_mask, edge_mask, context=context, t_lower=t_lower, t_upper=t_upper)
            target = qm9utils.prepare_regression_target(args.regression_target, data, property_norms=None, normalize=False).to(device, dtype)

            if args.regression_target in ['morgan_fingerprint']:
                pred_labels = (pred>=0.).to(dtype)
                # it's actually accuracy but for consistency, we save it in variable mae
                mae = (pred_labels == target).sum() / (target.size(0) * target.size(1)) * 100
                metric_name = 'accuracy'

                acc_neg = (pred_labels[target==0] == 0).sum() / (target == 0).sum() * 100
                acc_pos = (pred_labels[target==1] == 1).sum() / (target == 1).sum() * 100
                acc_negs.append(acc_neg)
                acc_poss.append(acc_pos)
            else:
                pred_unnormalized = qm9utils.unnormalize_regression_prediction(args.regression_target, pred, property_norms_regression)
                mae = loss_l1(pred_unnormalized, target) # (bs, n_props)
                metric_name = 'mae'
    
            maes.append(mae)

    maes = torch.cat(maes, dim=0)
    if len(maes.size()) == 1:
        maes = maes.unsqueeze(1)
    avg_mae = maes.mean(dim=0)

    if log_to_wandb:
        for i, prop in enumerate(args.regression_target):
            if t_lower is not None and t_lower == t_upper:
                wandb.log({f'regression_model_{prop}/{mode}_{metric_name}_t={t_lower}': avg_mae[i].item()})
            else:
                wandb.log({f'regression_model_{prop}/{mode}_{prop}_{metric_name}_all_t': avg_mae[i].item()})

    if args.regression_target in ['morgan_fingerprint']:
        acc_negs = torch.stack(acc_negs)
        avg_acc_neg = acc_negs.mean().item()
        acc_poss = torch.stack(acc_poss)
        avg_acc_pos = acc_poss.mean().item()
        if log_to_wandb:
            if t_lower is not None and t_lower == t_upper:
                wandb.log({f'regression_model_{args.regression_target}/{mode}_accuracy_neg_t={t_lower}': avg_acc_neg})
                wandb.log({f'regression_model_{args.regression_target}/{mode}_accuracy_pos_t={t_lower}': avg_acc_pos})
            else:
                wandb.log({f'regression_model_{args.regression_target}/{mode}_accuracy_neg_all_t': avg_acc_neg})
                wandb.log({f'regression_model_{args.regression_target}/{mode}_accuracy_pos_all_t': avg_acc_pos})

    return avg_mae.mean().item()


def save_and_sample_chain(model, args, device, dataset_info, prop_dist,
                          epoch=0, id_from=0, batch_id=''):
    one_hot, charges, x = sample_chain(args=args, device=device, flow=model,
                                       n_tries=1, dataset_info=dataset_info, prop_dist=prop_dist)

    vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/chain/',
                      one_hot, charges, x, dataset_info, id_from, name='chain')

    return one_hot, charges, x


def sample_different_sizes_and_save(model, nodes_dist, args, device, dataset_info, prop_dist,
                                    n_samples=5, epoch=0, batch_size=100, batch_id=''):
    batch_size = min(batch_size, n_samples)
    for counter in range(int(n_samples/batch_size)):
        nodesxsample = nodes_dist.sample(batch_size)
        one_hot, charges, x, node_mask, edge_mask = sample(args, device, model, prop_dist=prop_dist,
                                                nodesxsample=nodesxsample,
                                                dataset_info=dataset_info)
        print(f"Generated molecule: Positions {x[:-1, :, :]}")
        vis.save_xyz_file(f'outputs/{args.exp_name}/epoch_{epoch}_{batch_id}/', one_hot, charges, x, dataset_info,
                          batch_size * counter, name='molecule')


def analyze_and_save(epoch, model_sample, nodes_dist, args, device, dataset_info, prop_dist,
                     n_samples=1000, batch_size=100, log_to_wandb=True, prop_encoder=None):
    model_sample.eval()
    print(f'Analyzing molecule stability at epoch {epoch}...')
    batch_size = min(batch_size, n_samples)
    assert n_samples % batch_size == 0

    molecules = {'atom_types': [], 'formal_charges': [], 'positions': [], 'adjacency_matrices': [], 'node_mask': [], 'z_h': []}
    if args.context_node_nf > 0:
        molecules['context_global'] = []
    for i in tqdm(range(int(n_samples/batch_size))):
        nodesxsample = nodes_dist.sample(batch_size)
        if args.use_ghost_nodes:
            # When using ghost nodes, we sample molecules to have the same size, which is the maximum size
            nodesxsample[:] = dataset_info['max_n_nodes']
        x, atom_types, formal_charges, adjacency_matrices, node_mask, edge_mask, z_h, context_global = sample(args, device, 
                                                model_sample, dataset_info, prop_dist, 
                                                nodesxsample=nodesxsample, prop_encoder=prop_encoder)

        molecules['atom_types'].append(atom_types.detach().cpu())
        molecules['formal_charges'].append(formal_charges.detach().cpu())
        molecules['positions'].append(x.detach().cpu())
        molecules['adjacency_matrices'].append(adjacency_matrices.detach().cpu())
        molecules['node_mask'].append(node_mask.detach().cpu())
        molecules['z_h'].append(z_h.detach().cpu())
        if args.context_node_nf > 0:
            molecules['context_global'].append(context_global.detach().cpu())

    molecules = {key: torch.cat(molecules[key], dim=0) for key in molecules}

    rdkit_mols = build_2D_mols(molecules, dataset_info, use_ghost_nodes=args.use_ghost_nodes)
    smiles = smiles_from_2d_mols_list(rdkit_mols)
    if args.context_node_nf > 0:
        if 'morgan_fingerprint' in args.conditioning:
            accuracy, accuracy_neg, accuracy_pos, avg_similarity = compute_accuracy_on_generated_mols(smiles, molecules['context_global'].squeeze())
            if log_to_wandb:
                wandb.log({'diffusion_model/morgan_fingerprint accuracy': accuracy})
                wandb.log({'diffusion_model/morgan_fingerprint accuracy_neg': accuracy_neg})
                wandb.log({'diffusion_model/morgan_fingerprint accuracy_pos': accuracy_pos})
                wandb.log({'diffusion_model/morgan_fingerprint avg_similarity': avg_similarity})
        else:
            mae = compute_prop_mae_on_generated_mols(smiles, molecules['context_global'].squeeze())
            if log_to_wandb:
                wandb.log({'diffusion_model/penalized_logP MAE': mae})
    metrics = BasicSmilesMetrics(dataset_info, n_generated=n_samples)
    rdkit_tuple, unique_valid_smiles = metrics.evaluate(smiles)

    if log_to_wandb and rdkit_tuple is not None:
        wandb.log({'diffusion_model/Validity': rdkit_tuple[0], 'diffusion_model/Uniqueness': rdkit_tuple[1], 'diffusion_model/Novelty': rdkit_tuple[2]})
    if not log_to_wandb:
        # probably debugging, then return molecules as well
        return rdkit_tuple, unique_valid_smiles, molecules
    return rdkit_tuple, unique_valid_smiles


def compute_accuracy_on_generated_mols(smiles_list, conditioned_values):
    """
    smiles_list: list of generated SMILES
    conditioned_values: contains the conditioned morgan fingerprints

    Returns: Acccuracies and tanimoto similarity
    """
    print('Computing Accuracy of fingerprints on generated mols')
    generated_fingerprints = []
    conditioned_fingerprints = []
    for i, smiles in enumerate(smiles_list):
        if is_valid(smiles):
            generated_fingerprint = get_morgan_fingerprint(smiles, n_bits=1024)
            conditioned_fingerprint = conditioned_values[i]

            generated_fingerprints.append(generated_fingerprint)
            conditioned_fingerprints.append(conditioned_fingerprint)

    assert len(generated_fingerprints) == len(conditioned_fingerprints)
    if len(generated_fingerprints) > 0:
        print(f'Found {len(generated_fingerprints)} valid molecules on which we compute the fingerprints metrics.')
        generated_fingerprints = torch.Tensor(generated_fingerprints)
        conditioned_fingerprints = torch.stack(conditioned_fingerprints)

        total_fingerprints = conditioned_fingerprints.size(0) * conditioned_fingerprints.size(1)
        accuracy = (generated_fingerprints == conditioned_fingerprints).sum() / total_fingerprints * 100
        accuracy_neg = (generated_fingerprints[conditioned_fingerprints==0] == 0).sum() / (conditioned_fingerprints==0).sum() * 100
        accuracy_pos = (generated_fingerprints[conditioned_fingerprints==1] == 1).sum() / (conditioned_fingerprints==1).sum() * 100

        avg_similarity = np.mean([similarity_tensors(generated_fingerprints[i], conditioned_fingerprints[i], n_bits=1024) for i in range(len(conditioned_fingerprints))])

        return accuracy, accuracy_neg, accuracy_pos, avg_similarity
    else:
        print('Found no valid mols. Returning -1')
        return -1., -1., -1., -1.


# TODO: generalize this to multiple props
def compute_prop_mae_on_generated_mols(smiles_list, conditioned_values, prop='penalized_logP'):
    print(f'Computing MAE of {prop} on generated mols')
    generated_scores = []
    conditioned_scores = []
    processed_smiles = set()
    for i, smiles in enumerate(smiles_list):
        if is_valid(smiles) and Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) not in processed_smiles:
            if prop == 'penalized_logP':
                generated_score = penalized_logp(smiles)
            if prop == 'qed':
                generated_score = qed(smiles)
            if prop == 'drd2':
                generated_score = drd2(smiles)
            if prop == 'tpsa':
                generated_score = tpsa(smiles)

            conditioned_score = conditioned_values[i]

            generated_scores.append(generated_score)
            conditioned_scores.append(conditioned_score)

            processed_smiles.add(Chem.MolToSmiles(Chem.MolFromSmiles(smiles)))

    assert len(generated_scores) == len(conditioned_scores)
    if len(generated_scores) > 0:
        print(f'Found {len(generated_scores)} valid and unique molecules on which we compute the MAE.')
        generated_scores = torch.Tensor(generated_scores)
        conditioned_scores = torch.Tensor(conditioned_scores)
        l1_loss = torch.nn.L1Loss()
        mae = l1_loss(generated_scores, conditioned_scores).item()
        return mae
    else:
        print('Found no valid mols. Returning -1')
        return -1.

def compute_prop_values_on_generated_mols(smiles_list, prop='penalized_logP'):
    print(f'Computing values of {prop} on generated mols')
    generated_scores = []
    processed_smiles = set()
    for i, smiles in enumerate(smiles_list):
        if is_valid(smiles) and Chem.MolToSmiles(Chem.MolFromSmiles(smiles)) not in processed_smiles:
            if prop == 'penalized_logP':
                generated_score = penalized_logp(smiles)
            if prop == 'qed':
                generated_score = qed(smiles)
            if prop == 'drd2':
                generated_score = drd2(smiles)
            if prop == 'tpsa':
                generated_score = tpsa(smiles)

            generated_scores.append(generated_score)
            processed_smiles.add(Chem.MolToSmiles(Chem.MolFromSmiles(smiles)))

    return sorted(generated_scores, reverse=True)

def save_and_sample_conditional(args, device, model, prop_dist, dataset_info, epoch=0, id_from=0):
    one_hot, charges, x, node_mask = sample_sweep_conditional(args, device, model, dataset_info, prop_dist)

    vis.save_xyz_file(
        'outputs/%s/epoch_%d/conditional/' % (args.exp_name, epoch), one_hot, charges, x, dataset_info,
        id_from, name='conditional', node_mask=node_mask)

    return one_hot, charges, x
