import argparse
import os
from os.path import join
import json
import logging
import random
import math
import time
import itertools
from rdkit import Chem
import copy
from tqdm import tqdm
import pickle
import torch.nn.functional as F

from guacamol.assess_distribution_learning import assess_distribution_learning, _assess_distribution_learning
from guacamol.utils.helpers import setup_default_logger

from guacamol_evaluation.ldm_generator import LDMSmilesSampler
from guacamol_evaluation.generator import EDMSmilesSampler
from guacamol_evaluation.load_model import load_ldm_model, load_regressor
from qm9.data.molecule_class import Molecule, MoleculeBatch
from qm9.models import DistributionAddNodes
from qm9.analyze_joint_training import is_valid, get_largest_connected_component, canon_smiles, compute_diversity
from optimization.props.properties import get_morgan_fingerprint, similarity, penalized_logp, drd2, qed, tpsa
from synthetic_coordinates.rdkit_helpers import smiles_to_mol

import numpy as np
import torch

def select_indices(z_x, z_h, node_mask, indices):
    z_x_new = z_x[indices]
    z_h_new = {key: val[indices] for key, val in z_h.items()}
    node_mask_new = node_mask[indices]
    return z_x_new, z_h_new, node_mask_new

def repeat(z_x, z_h, node_mask, n_repeat):
    z_x_new = z_x.repeat(n_repeat, 1, 1)
    z_h_new = {key: val.repeat(n_repeat, 1, 1) for key, val in z_h.items()}
    node_mask_new = node_mask.repeat(n_repeat, 1, 1)
    return z_x_new, z_h_new, node_mask_new

def concatenate(z_x1, z_h1, node_mask1, z_x2, z_h2, node_mask2):
    z_x_new = torch.cat([z_x1, z_x2], dim=0)
    z_h_new = {key: torch.cat([z_h1[key], z_h2[key]], dim=0) for key in z_h1}
    node_mask_new = torch.cat([node_mask1, node_mask2], dim=0)
    return z_x_new, z_h_new, node_mask_new

# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/distribution_learning.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Molecule Optimization and scaffolding evaluation',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--exp_name", default="ldm_training_explicitH_simple_dec_early_enc")
    parser.add_argument('--scaffolding', action='store_true', default=False, help='whether to evaluate scaffolding instead of optimization')
    parser.add_argument('--smiles_folder', default='data/zinc250k_explicitH_extra_features/smiles')
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--t_add_nodes", type=int, default=400)
    parser.add_argument("--t_optimize", type=int, default=200)
    parser.add_argument("--step_size", type=float, default=0.5)
    parser.add_argument('--guidance_exp_name', type=str, default=None)
    parser.add_argument("--guidance_scale", type=float, default=1.0)
    parser.add_argument("--schedule_id", type=int, default=-1)

    parser.add_argument('--ckpt_prefix', type=str, default="", help=" '' or 'best_fcd_' or 'last_' ")
    parser.add_argument('--optimization_prop', type=str, default="penalized_logP", help="one of 'penalized_logP', 'qed', 'drd2', 'tpsa' ")
    parser.add_argument("--similarity_threshold", type=float, default=0.4)
    parser.add_argument('--guidance_linear_schedule', action='store_true', default=False, help='whether to scale the guidance with the time_step')
    parser.add_argument("--guidance_schedule_bias", type=float, default=0.0)
    parser.add_argument("--guidance_scale_t_add_nodes", type=float, default=1.0)
    parser.add_argument('--hp_search', action='store_true', default=False, help='are we doing hyperparameter search')

    args = parser.parse_args()

    # set up optimization prop-related stuff
    all_props = ['penalized_logP', 'qed', 'drd2', 'tpsa']
    prop_idx = all_props.index(args.optimization_prop)
    all_prop_fns = [penalized_logp, qed, drd2, tpsa]
    compute_prop_fn = all_prop_fns[prop_idx]

    # load generative model
    exp_folder = 'outputs/' + args.exp_name
    model, utilities_dict = load_ldm_model(exp_folder, mol_optimizer=True, ckpt_prefix=args.ckpt_prefix)
    device = utilities_dict['device']
    dtype = torch.float32
    model.set_device_and_dtype(device, dtype)

    # load regressor model
    regressor_prop, prop_dist = load_regressor(exp_folder='outputs/'+args.guidance_exp_name, regression_target=all_props)
    #regressor_structure = load_regressor(exp_folder='outputs/train_regressor_morgan_fingerprint_1024_with_cls_weight', regression_target='morgan_fingerprint')


    ##########
    # BEGIN HP optimization for editing with regressor guidance
    # currently optimizing:
    # t_optim and scale factor :green_mark:
    # currently testing evolutionary algorithm :green_mark:
    # currently testing evolutionary algorithm with different schedules
    ##########

    # optimization
    # t_optimize_schedules = [[350,300,250], [350,250,200], [300,250,200]]
    # t_to_scale = {200: [3,4], 250: [2,3], 300: [1,2], 350: [0.5,1]}
    # schedules = []
    # for t in t_to_scale:
    #     for s in t_to_scale[t]:
    #         schedules.append([[t,t,t],[s,s,s]])
    # for t1, t2, t3 in t_optimize_schedules:
    #     for s1, s2, s3 in itertools.product(t_to_scale[t1], t_to_scale[t2], t_to_scale[t3]):
    #         schedules.append([[t1, t2, t3], [s1, s2, s3]])
    # schedule = schedules[args.schedule_id]

    #schedule = [[350, 250, 200], [0.5, 2.0, 3.0]]

    #schedule = [[250, 250, 250, 250], [3.0, 3.0, 3.0, 3.0]]


    schedule = [ [args.t_optimize]*4, [args.guidance_scale]*4 ]
    # create folder
    #timestr = time.strftime("%Y%m%d-%H%M%S")
    #folder_name = f'optimization/dynamic_evolutionary_algorithm_schedule_{schedule}_{timestr}'
    if args.hp_search:
        os.makedirs('optimization/results_optimization_hp_search', exist_ok=True)
        folder_name = f'optimization/results_optimization_hp_search/results_{args.optimization_prop}_t={args.t_optimize}_s={args.guidance_scale}_b={args.guidance_schedule_bias}'
    else:
        os.makedirs('optimization/results_optimization_final', exist_ok=True)
        folder_name = f'optimization/results_optimization_final/results_{args.optimization_prop}_t={args.t_optimize}_s={args.guidance_scale}_b={args.guidance_schedule_bias}'
    if args.guidance_linear_schedule:
        folder_name += '_guidance_linear_schedule'
    os.makedirs(folder_name, exist_ok=True)

    # load data to optimize
    data_folder = f'data/zinc250k_constrained_optimization/{args.optimization_prop}'
    if args.optimization_prop == 'penalized_logP':
        if args.similarity_threshold == 0.4:
            data_folder += '_4'
        elif args.similarity_threshold == 0.6:
            data_folder += '_6'
    with open(join(data_folder, 'test.txt'), 'r') as smiles_file:
        input_smiles_list = [line.strip() for line in smiles_file.readlines()]

    ### --> add nodes
    # compute conditional distribution of number of nodes
    print('Computing distribution of number of nodes')
    dist_file_path = join(data_folder, 'n_atoms_dist.p')
    if os.path.exists(dist_file_path):
        with open(join(data_folder, 'n_atoms_dist.p'), 'rb') as fp:
            n_atoms_dist = pickle.load(fp)
    else:
        with open(join(data_folder, 'train_pairs.txt'), 'r') as f:
            training_pairs = [line.strip() for line in f.readlines()]
        def get_n_atoms_from_smiles(smiles):
            mol = smiles_to_mol(smiles, only_explicit_H=True)
            return mol.GetNumAtoms()
        
        n_atoms_dist = {}
        for n in range(len(training_pairs)):
            s1, s2 = training_pairs[n].split(' ')
        
            n_atoms1, n_atoms2 = get_n_atoms_from_smiles(s1), get_n_atoms_from_smiles(s2)
            if n_atoms2 < n_atoms1:
                # if target mol is smaller, we treat it as if we keep the same number of atoms and then the model can decide to remove atoms
                n_atoms2 = n_atoms1
            
            if n_atoms1 not in n_atoms_dist:
                n_atoms_dist[n_atoms1] = {}
            if n_atoms2 not in n_atoms_dist[n_atoms1]:
                n_atoms_dist[n_atoms1][n_atoms2] = 0
                
            n_atoms_dist[n_atoms1][n_atoms2] += 1
        # save for next time
        with open(join(data_folder, 'n_atoms_dist.p'), 'wb') as fp:
            pickle.dump(n_atoms_dist, fp)
    dist_add_nodes = DistributionAddNodes(n_atoms_dist)
    ### <-- add nodes
    
    if args.hp_search:
        input_smiles_list = input_smiles_list[:20]
    else:
        if args.schedule_id == 0:
            input_smiles_list = input_smiles_list[:200]
        elif args.schedule_id == 1:
            input_smiles_list = input_smiles_list[200:400]
        elif args.schedule_id == 2:
            input_smiles_list = input_smiles_list[400:600]
        elif args.schedule_id == 3:
            input_smiles_list = input_smiles_list[600:]

    dataset_improvements = []
    dataset_diversities = []
    for input_smiles in input_smiles_list:
        input_smiles_canonlized = (Chem.MolToSmiles(Chem.MolFromSmiles(input_smiles)))
        if os.path.exists(f'{folder_name}/{input_smiles_canonlized}.json'):
            print(f'SMILES {input_smiles_canonlized} already processed. Skipping it.')
            continue
        print(f'Processing SMILES: {input_smiles_canonlized}')

        # read scaffold smiles, compute synthetic coordinates and prepare for further computation
        input_mol = Molecule(device, dtype, smiles=input_smiles)
        input_graph = input_mol.get_graph()

        # construct a batch out of the molecule to sample different completions
        input_batch = MoleculeBatch(batch_size=args.batch_size, molecule=input_mol).graph

        ### --> add nodes
        input_size = input_graph['num_atoms'].int().item()
        if input_size in n_atoms_dist:
            print('Sampling new n_atoms')
            add_nodes = dist_add_nodes.sample(n_initial_mol=input_size, n_samples=args.batch_size) - input_size
        else:
        # naive way to get n_atoms
            print('Just enumerating new n_atoms')
            max_add = 40 - input_size
            step_size = 1
            add_nodes = list(range(0, max_add+1, step_size))
            add_nodes = torch.Tensor(add_nodes).repeat(math.ceil(args.batch_size/len(add_nodes)))[:args.batch_size]
        assert len(add_nodes) == args.batch_size
        ### <-- add nodes

        x, h, node_mask, edge_mask, context = model.prepare_batch(input_batch, property_norms=None, conditioning=None)
        z_x, z_h = model.encode(x, h, node_mask, edge_mask, context)

        # if mol is not correctly reconstructed with our autoencoder, slightly optimize the representation
        z_x, z_h = model.optimize_z_xh(input_batch, z_x, z_h, h, node_mask, edge_mask, context)

        # pad zeros to make it compatible
        n_nodes = z_x.size(1)
        n_pad = add_nodes.max().int().item()
        z_x = F.pad(z_x, (0,0,0,n_pad), "constant", 0)
        z_h = F.pad(z_h, (0,0,0,n_pad), "constant", 0)
        node_mask = F.pad(node_mask, (0,0,0,n_pad), "constant", 0)
        edge_mask = F.pad(edge_mask.reshape(args.batch_size, n_nodes, n_nodes), (0,n_pad,0,n_pad), "constant", 0).reshape(-1,1)

        n_nodes = z_x.size(1)

        # z_x, z_h = model.edit(z_x, z_h, args.t_optimize, args.batch_size, n_nodes, node_mask, edge_mask, context=None)

        # config = f't_optimize_{args.t_optimize}_guidance_scale_{args.guidance_scale}'
        # smiles_dict = {config: model.get_smiles_from_x_h(z_x, z_h, node_mask, utilities_dict['dataset_info'])}

        smiles_dict = {}
        z_x_top10 = z_x[:10]
        z_h_top10 = z_h[:10]
        node_mask_top10 = node_mask[:10]
        edge_mask_top10 = edge_mask.reshape(args.batch_size,-1)[:10].reshape(-1,1)
        z_x_top10, z_h_top10 = model.decode_from_z0(torch.cat([z_x_top10, z_h_top10], dim=2), node_mask_top10, edge_mask_top10)

        try:
            for optimization_round in range(4):
                print(f'Optimization Round: {optimization_round}. Time_step: {schedule[0][optimization_round]}. Guidance scale: {schedule[1][optimization_round]}')
                ### Guidance fn
                def guidance_fn(zt, t, node_mask, edge_mask, context):
                    zt.requires_grad = True
                    pred = regressor_prop._forward(t, zt, node_mask, edge_mask)[:, prop_idx]
                    d_zt = torch.autograd.grad(outputs=pred, inputs=zt, grad_outputs=torch.ones_like(pred))[0]
                    if args.guidance_linear_schedule:
                        assert t.min() == t.max(), "t contains different values ?!?"
                        d_zt = (schedule[1][optimization_round] * t.min() + args.guidance_schedule_bias) * d_zt
                    else:
                        d_zt = schedule[1][optimization_round] * d_zt
                    return d_zt.detach()
                model.guidance_fn = guidance_fn
                ### Guidance fn

                # only add nodes in the first round
                if optimization_round > 0 and add_nodes is not None:
                    add_nodes = None
                # recompute edge mask
                edge_mask = node_mask.squeeze().unsqueeze(1) * node_mask.squeeze().unsqueeze(2)
                diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(device)
                edge_mask *= diag_mask
                edge_mask = edge_mask.view(-1, 1).to(device, dtype)

                z_x, z_h, node_mask, edge_mask = model.edit(z_x, z_h, schedule[0][optimization_round], args.batch_size, n_nodes, node_mask, edge_mask, None, add_nodes=add_nodes)
                z_x, z_h, node_mask = concatenate(z_x, z_h, node_mask, z_x_top10, z_h_top10, node_mask_top10)
                gen_smiles = model.get_smiles_from_x_h(z_x, z_h, node_mask, utilities_dict['dataset_info'])
                smiles_dict[optimization_round] = gen_smiles

                # TODO: think of how to only get largest connected componnents for next generation
                smiles_props = [(i, compute_prop_fn(s)-compute_prop_fn(input_smiles_canonlized), similarity(s, input_smiles_canonlized)) for i, s in enumerate(gen_smiles)]
                smiles_props_filtered = [prop for prop in smiles_props if prop[2] >= args.similarity_threshold]
                smiles_props_ranked = sorted(smiles_props_filtered, key=lambda x: x[1], reverse=True)
                top_10 = smiles_props_ranked[:10]
                if len(top_10) < 10:
                    # when the input molecule is not correctly reconstructed by our autoencoder
                    # we use it anyways for further optimization
                    print('found less than 10 top mols. Filling with the input molecule even though it might be not correctly reconstructed by the autoencoder')
                    n_missing_smiles = 10 - len(top_10)
                    top_10.extend(smiles_props[-n_missing_smiles:])
                print(f'Optimization Round: {optimization_round}. Best score: {top_10[0][1]}. Similarity: {top_10[0][2]}')
                indices_top10 = [prop[0] for prop in top_10] 
                z_x_top10, z_h_top10, node_mask_top10 = select_indices(z_x, z_h, node_mask, indices_top10)
                z_x, z_h, node_mask = repeat(z_x_top10, z_h_top10, node_mask_top10, args.batch_size//10)
                z_h = z_h['z_h']
        except RuntimeError as e:
            if 'CUDA out of memory' in str(e):
                print('Found an Out of Memory error. Skipping molecule.')
                torch.cuda.empty_cache()
                continue
            else:
                raise e

        # compute improvement value / success flag for the current molecule
        # get largest fragments from all rounds
        smiles_list = [get_largest_connected_component(smiles) for state_smiles in smiles_dict.values() for smiles in state_smiles]
        # get valid smiles
        smiles_list = [s for s in smiles_list if is_valid(s)]
        # canonlize
        smiles_list = [canon_smiles(smiles) for smiles in smiles_list]
        # remove duplicates
        smiles_list = list(set(smiles_list))
        # filter molecules that satisfy the similarity threshold
        smiles_list = [smiles for smiles in smiles_list \
            if similarity(smiles, input_smiles_canonlized) >= args.similarity_threshold]
        # compute max improvement
        #TODO: for drd2 and qed compute sccess rate instead of improvement
        #TODO: compute diversity
        successful_smiles = []
        max_improvement = 0.
        max_smiles = input_smiles_canonlized
        if args.optimization_prop == 'penalized_logP':
            for smiles in smiles_list:
                current_improvement = compute_prop_fn(smiles) - compute_prop_fn(input_smiles_canonlized)
                if current_improvement > max_improvement:
                    max_improvement = current_improvement
                    max_smiles = smiles
                if current_improvement > 1e-6:
                    successful_smiles.append(smiles)
        elif args.optimization_prop == 'qed':
            for smiles in smiles_list:
                assert compute_prop_fn(smiles) <= 1.0, "found qed value >= 1.0. Not expected"
                if compute_prop_fn(smiles) >= 0.9:
                    max_improvement = 1.
                    max_smiles = smiles
                    successful_smiles.append(smiles)
        elif args.optimization_prop == 'drd2':
            for smiles in smiles_list:
                if compute_prop_fn(smiles) >= 0.5:
                    max_improvement = 1.
                    max_smiles = smiles
                    successful_smiles.append(smiles)

        # compute diversity
        diversity = compute_diversity(successful_smiles)
        # save current results
        dataset_improvements.append(max_improvement)
        dataset_diversities.append(diversity)
        smiles_dict['best_generated_smiles'] = max_smiles
        smiles_dict['max_improvement'] = max_improvement
        smiles_dict['diversity'] = diversity

        with open(f'{folder_name}/{input_smiles_canonlized}.json', 'w') as fp:
            json.dump(smiles_dict, fp)

    average_improvement = np.mean(dataset_improvements)
    average_diversity = np.mean(dataset_diversities)
    with open(f'{folder_name}/average_improvement_{args.schedule_id}.txt', 'w') as f:
        f.writelines(f'average_improvement = {average_improvement} \n average_diversity = {average_diversity} \n')


    ##########
    # END HP optimization
    ##########

    if False and args.guidance_exp_name is not None:

        # optimization
        # create folder
        timestr = time.strftime("%Y%m%d-%H%M%S")
        folder_name = f'optimization/optimization_guidance_t_optimize_{args.t_optimize}_{timestr}'
        os.makedirs(folder_name, exist_ok=True)
        with open('data/logp04/valid.txt', 'r') as smiles_file:
            input_smiles_list = [line.strip() for line in smiles_file.readlines()]

        for input_smiles in input_smiles_list:
            input_smiles_canonlized = (Chem.MolToSmiles(Chem.MolFromSmiles(input_smiles)))

            ### Guidance fn
            fp = get_morgan_fingerprint(input_smiles_canonlized, n_bits=1024)
            fp = torch.Tensor(fp).to(device, dtype)
            def binary_cross_entropy_multihead(pred, y):
                pos_weight = torch.Tensor([25.]).to(pred.device)
                loss_f = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
                loss = loss_f(pred, y).mean(dim=1)                
                return loss
            def guidance_fn_structure(zt, t, node_mask, edge_mask, context):
                zt.requires_grad = True
                pred = regressor_structure._forward(t, zt, node_mask, edge_mask)
                energy = -1. * binary_cross_entropy_multihead(pred, fp.unsqueeze(0).repeat(pred.size(0), 1))                
                d_zt = torch.autograd.grad(outputs=energy, inputs=zt, grad_outputs=torch.ones_like(energy))[0]
                d_zt = 40.0 * d_zt
                return d_zt.detach()
            def guidance_fn(zt, t, node_mask, edge_mask, context):
                zt.requires_grad = True
                pred = regressor_prop._forward(t, zt, node_mask, edge_mask)
                d_zt = torch.autograd.grad(outputs=pred, inputs=zt, grad_outputs=torch.ones_like(pred))[0]
                return d_zt.detach()
            def guidance_fn_both(zt, t, node_mask, edge_mask, context):                
                t_int = (t[0,0] * model.T).int().item()
                if t_int % 5 == 0:
                    return guidance_fn_structure(zt, t, node_mask, edge_mask, context)
                else:
                    return guidance_fn(zt, t, node_mask, edge_mask, context)

            model.guidance_fn = guidance_fn_both
            ### Guidance fn

            # intermediate_states_0 = torch.load(f'optimization/optimization_t_add_nodes_200_20231110-175756/{input_smiles_canonlized}.pt')
            # z_x, z_h, node_mask = intermediate_states_0
            # z_h = z_h['z_h']

            # intermediate_states = {}
            # intermediate_states[0] = intermediate_states_0

            # # iteratively optimize
            # n_samples, n_nodes = z_x.size(0), z_x.size(1)
            # # Compute edge_mask

            # edge_mask = node_mask.squeeze().unsqueeze(1) * node_mask.squeeze().unsqueeze(2)
            # diag_mask = ~torch.eye(edge_mask.size(1), dtype=torch.bool).unsqueeze(0).to(device)
            # edge_mask *= diag_mask
            # edge_mask = edge_mask.view(n_samples * n_nodes * n_nodes, 1).to(device, dtype)

            # context = None
            # n_steps = 2
            # for step in tqdm(range(n_steps)):
            #     z_x, z_h = model.edit(z_x, z_h, args.t_optimize, n_samples, n_nodes, node_mask, edge_mask, context)
            #     intermediate_states[step+1] = (copy.deepcopy(z_x), copy.deepcopy(z_h), copy.deepcopy(node_mask))
            #     z_h = z_h['z_h']

            smiles_dict = {state: model.get_smiles_from_x_h(intermediate_states[state][0], 
                                                        intermediate_states[state][1], 
                                                        intermediate_states[state][2], 
                                                        utilities_dict['dataset_info']) for state in intermediate_states}

            with open(f'{folder_name}/{input_smiles_canonlized}.json', 'w') as fp:
                json.dump(smiles_dict, fp)

    elif False and args.scaffolding:
        # create folder
        timestr = time.strftime("%Y%m%d-%H%M%S")
        folder_name = f'optimization/scaffolding_{timestr}'
        os.makedirs(folder_name, exist_ok=True)

        with open(join(args.smiles_folder, 'vocab.txt'), 'r') as smiles_file:
            scaffold_smiles_list = [line.strip() for line in smiles_file.readlines()]

        for scaffold_smiles in scaffold_smiles_list:
            # read scaffold smiles, compute synthetic coordinates and prepare for further computation
            scaffold_mol = Molecule(device, dtype, smiles=scaffold_smiles)
            scaffold_graph = scaffold_mol.get_graph()

            # construct a batch out of the molecule to sample different completions
            scaffold_batch = MoleculeBatch(batch_size=args.batch_size, molecule=scaffold_mol).graph

            # Check if our decoder is able to reconsturct the scaffold. If not it's useless to try to complete it
            x, h, node_mask = model.reconstruct(scaffold_batch)
            reconstructed_scaffold_smiles = model.get_smiles_from_x_h(x, h, node_mask, utilities_dict['dataset_info'])[0]

            reconstructed_scaffold_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(reconstructed_scaffold_smiles))
            scaffold_smiles_canonlized = (Chem.MolToSmiles(Chem.MolFromSmiles(scaffold_smiles)))
            if reconstructed_scaffold_smiles != scaffold_smiles_canonlized:
                # we are not able to reconstruct canonlized smiles
                print('Skipping scaffold because model cannot reconstruct it.')
                continue

            scaffold_size = scaffold_graph['num_atoms'].int().item()
            max_add = 40 - scaffold_size
            step_size = 1
            add_nodes = list(range(1, max_add+1, step_size))
            add_nodes = torch.Tensor(add_nodes).repeat(math.ceil(args.batch_size/len(add_nodes)))[:args.batch_size]
            assert len(add_nodes) == args.batch_size

            x_new, h_new, node_mask_new = model.complete_scaffold(scaffold_batch, add_nodes=add_nodes, use_jumps=True, 
                                                                resampling_times=10, jump_len=40, jump_n_sample=10)

            new_smiles = model.get_smiles_from_x_h(x_new, h_new, node_mask_new, utilities_dict['dataset_info'])
            with open(f'{folder_name}/{scaffold_smiles_canonlized}.txt', 'w') as f:
                f.writelines('\n'.join(new_smiles))

    elif False:
        # optimization
        # create folder
        timestr = time.strftime("%Y%m%d-%H%M%S")
        folder_name = f'optimization/optimization_with_guidance_only_property_t_add_nodes_{args.t_add_nodes}_{timestr}'
        os.makedirs(folder_name, exist_ok=True)

        # compute conditional distribution of number of nodes
        print('Computing distribution of number of nodes')
        with open('data/logp04/train_pairs.txt', 'r') as f:
            training_pairs = [line.strip() for line in f.readlines()]
        def get_n_atoms_from_smiles(smiles):
            mol = Chem.MolFromSmiles(smiles)
            Chem.RemoveStereochemistry(mol)
            mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol))
            mol = Chem.AddHs(mol, explicitOnly=True)
            return mol.GetNumAtoms()

        n_atoms_dist = {}
        for n in range(len(training_pairs)):
            s1, s2 = training_pairs[n].split(' ')

            n_atoms1, n_atoms2 = get_n_atoms_from_smiles(s1), get_n_atoms_from_smiles(s2)
            if n_atoms2 < n_atoms1:
                continue
            
            if n_atoms1 not in n_atoms_dist:
                n_atoms_dist[n_atoms1] = {}
            if n_atoms2 not in n_atoms_dist[n_atoms1]:
                n_atoms_dist[n_atoms1][n_atoms2] = 0
                
            n_atoms_dist[n_atoms1][n_atoms2] += 1
        dist_add_nodes = DistributionAddNodes(n_atoms_dist)

        with open('data/logp04/valid.txt', 'r') as smiles_file:
            input_smiles_list = [line.strip() for line in smiles_file.readlines()]
        
        input_smiles_list = input_smiles_list[:50]

        for input_smiles in input_smiles_list:
            input_smiles_canonlized = (Chem.MolToSmiles(Chem.MolFromSmiles(input_smiles)))
            ### Guidance fn
            fp = get_morgan_fingerprint(input_smiles_canonlized, n_bits=1024)
            fp = torch.Tensor(fp).to(device, dtype)
            def binary_cross_entropy_multihead(pred, y):
                pos_weight = torch.Tensor([25.]).to(pred.device)
                loss_f = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight)
                loss = loss_f(pred, y).mean(dim=1)                
                return loss
            def guidance_fn_structure(zt, t, node_mask, edge_mask, context):
                zt.requires_grad = True
                pred = regressor_structure._forward(t, zt, node_mask, edge_mask)
                energy = -1. * binary_cross_entropy_multihead(pred, fp.unsqueeze(0).repeat(pred.size(0), 1))                
                d_zt = torch.autograd.grad(outputs=energy, inputs=zt, grad_outputs=torch.ones_like(energy))[0]
                d_zt = 40.0 * d_zt
                return d_zt.detach()
            def guidance_fn(zt, t, node_mask, edge_mask, context):
                zt.requires_grad = True
                pred = regressor_prop._forward(t, zt, node_mask, edge_mask)
                d_zt = torch.autograd.grad(outputs=pred, inputs=zt, grad_outputs=torch.ones_like(pred))[0]
                return d_zt.detach()
            def guidance_fn_both(zt, t, node_mask, edge_mask, context):                
                t_int = (t[0,0] * model.T).int().item()
                if t_int % 5 == 0:
                    return guidance_fn_structure(zt, t, node_mask, edge_mask, context)
                else:
                    return guidance_fn(zt, t, node_mask, edge_mask, context)

            model.guidance_fn = guidance_fn
            ### Guidance fn

            # read scaffold smiles, compute synthetic coordinates and prepare for further computation
            input_mol = Molecule(device, dtype, smiles=input_smiles)
            input_graph = input_mol.get_graph()

            # construct a batch out of the molecule to sample different completions
            input_batch = MoleculeBatch(batch_size=args.batch_size, molecule=input_mol).graph

            # Check if our decoder is able to reconsturct the scaffold. If not it's (useless) to try to complete it
            x, h, node_mask = model.reconstruct(input_batch, property_norms=utilities_dict['property_norms'], conditioning=utilities_dict['conditioning'])
            reconstructed_input_smiles = model.get_smiles_from_x_h(x, h, node_mask, utilities_dict['dataset_info'])[0]

            reconstructed_input_smiles = Chem.MolToSmiles(Chem.MolFromSmiles(reconstructed_input_smiles))
            input_smiles_canonlized = (Chem.MolToSmiles(Chem.MolFromSmiles(input_smiles)))
            if reconstructed_input_smiles != input_smiles_canonlized:
                # we are not able to reconstruct canonlized smiles
                print('Warning: current molecule cannot be reconstructed by the VAE model.')
                #continue

            input_size = input_graph['num_atoms'].int().item()
            if input_size in n_atoms_dist:
                print('Sampling new n_atoms')
                add_nodes = dist_add_nodes.sample(n_initial_mol=input_size, n_samples=args.batch_size) - input_size
            else:
            # naive way to get n_atoms
                max_add = 40 - input_size
                step_size = 1
                add_nodes = list(range(0, max_add+1, step_size))
                add_nodes = torch.Tensor(add_nodes).repeat(math.ceil(args.batch_size/len(add_nodes)))[:args.batch_size]
            assert len(add_nodes) == args.batch_size

            intermediate_states = model.optimize(input_batch, t_add_nodes=args.t_add_nodes, t_optimize=args.t_optimize, 
                    n_steps=20, step_size=args.step_size, add_nodes=add_nodes, 
                    property_norms=utilities_dict['property_norms'], conditioning=utilities_dict['conditioning'],
                    use_jumps=True, resampling_times=10, jump_len=40, jump_n_sample=10)

            if args.t_optimize == 0:
                torch.save(intermediate_states[0], f'{folder_name}/{input_smiles_canonlized}.pt')

            smiles_dict = {state: model.get_smiles_from_x_h(intermediate_states[state][0], 
                                                        intermediate_states[state][1], 
                                                        intermediate_states[state][2], 
                                                        utilities_dict['dataset_info']) for state in intermediate_states}

            with open(f'{folder_name}/{input_smiles_canonlized}.json', 'w') as fp:
                json.dump(smiles_dict, fp)
