from omegaconf import OmegaConf
import torch
import pytorch_lightning as pl
import time
import os
import wandb
from itertools import compress

####################################################à
#extra imports
from didigress.digress import DiGress
from torchmetrics import MeanAbsoluteError

from didigress import utils
from didigress.datasets import qm9_dataset

from rdkit.Chem.rdDistGeom import ETKDGv3, EmbedMolecule
from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule
from rdkit.Chem import Crippen, AllChem
from rdkit import Chem, DataStructs
from rdkit.Chem.Fingerprints import FingerprintMols

import math
try:
    import psi4
except ModuleNotFoundError:
    print("PSI4 not found")
from didigress.utils import TimeoutException, time_limit

from didigress.metrics.properties import mw, penalized_logp, qed, drd2, mw
from didigress.metrics.sascorer import calculateScore

from didigress.utils import clean_mol

from didigress.analysis.rdkit_functions import mol2smiles
from didigress.metrics.abstract_metrics import MeanConfidence

from didigress.metrics.properties import similarity

import pickle

class FreeGress(DiGress):
    model_dtype = torch.float32
    best_val_nll = 1e8
    val_counter = 0
    start_epoch_time = None
    train_iterations = None
    val_iterations = None

    def __init__(self, cfg, dataset_infos, losses, noise_model, visualizer=None):
        super().__init__(cfg=cfg, dataset_infos=dataset_infos, losses=losses,
                         noise_model=noise_model, visualizer=visualizer)
        
        ##################################################################################

        # specific properties to generate molecules
        if self.experiment_type == "mae":
            self.cond_val = MeanAbsoluteError()
            self.num_valid_molecules = 0
        else:
            self.cond_val           = MeanConfidence()
            self.succ_val           = MeanConfidence()
            self.diff_val           = MeanConfidence()
            self.total_similarity   = 0
        self.num_total = 0

        #wish I had a more elegant way to do this
        self.node_model = None

        #stores the generated smiles on the test step
        self.generated_smiles = []

        self.experiment_type = OmegaConf.select(cfg, "guidance.experiment_type")
        self.improvement_type = OmegaConf.select(cfg, "guidance.improvement_type")
        self.improvement_limits = OmegaConf.select(cfg, "guidance.improvement_limits")
        self.improvement_target = OmegaConf.select(cfg, "guidance.improvement_target")
        self.fixed_sampling_target = OmegaConf.select(cfg, "guidance.fixed_sampling_target")

    #effectively returns p(G_0|G_t, y)
    def forward(self, z_t, extra_data, model_ref, train_step = False):
        assert z_t.node_mask is not None
        model_input = z_t.copy()

        guidance = model_input.guidance
        node_mask = model_input.node_mask
        bs = extra_data.X.size(0)

        #replicates the null guidance token for the whole batch size
        cf_null_token = self.cf_null_token.repeat((bs, 1))

        #if we don't pass a guidance vector, we use the null guidance token:
        if(guidance is None):
            guidance = cf_null_token
        elif(train_step and self.cfg.guidance.p_uncond > 0): #it has sense only if the guidance is not ENTIRELY equal to cf_null_token
            #proceeds to randomly mask the guidance
            #TODO: this likely will have to be fixed for text-based where
            #the null token is inplanted in the query itself (it's the "pad" token)
            #and hence all of this is already done implicitly

            #True = substitute with null token
            guidance_mask = torch.rand((bs, 1), device="cuda").to(guidance.device) < self.cfg.guidance.p_uncond

            #now it has the same size as cf_null_token
            guidance_mask = guidance_mask.repeat((1, cf_null_token.size(-1)))

            guidance = torch.where(guidance_mask, cf_null_token, guidance)

        #this is used during test when we sample n_samples_per_test_molecule
        #at once using the same guidance (=> the guidance has shape: (1, guidance_size)).
        #we need to repeat it n_samples_per_test_molecule (=> shape: (n_samples_per_test_molecule, guidance_size))
        #the "not train_step" is not necessary but better have it
        if(bs > guidance.size(0) and not train_step):
            guidance = guidance.repeat((bs, 1))

        #we put the production of X, E, y in a method as we need to 
        #recycle the whole code later on if we want to calculate
        #the same with the null token in place of the guidance.
        def produce_XEy(z_t, extra_data, guidance = None):
            #X      = transformer output which uses guidance (if any)
            #X_null = transformer output which uses the null token
            X = z_t.X.clone().float()
            E = z_t.E.clone().float()
            y = z_t.y.clone().float()

            if(guidance != None):
                if('y' in self.cfg.guidance.guidance_medium):
                    g_y = guidance

                    #adds the guidance
                    y   = torch.hstack((y, z_t.t.clone(), g_y)).float()
                

                n  = extra_data.X.size(1) #number of nodes

                #spreads the guidance on all nodes
                if('X' in self.cfg.guidance.guidance_medium):
                    #g_X must have size (bs, n, features)
                    g_X = torch.reshape(guidance, shape = (bs,    1, -1)).repeat((1, n,    1))
                    X   = torch.cat((X, g_X), dim=-1)

                #spreads the guidance on all edges
                if('E' in self.cfg.guidance.guidance_medium):
                    #g_E must have size (bs, n, n, features)
                    g_E = torch.reshape(guidance, shape = (bs, 1, 1, -1)).repeat((1, n, n, 1))
                    E   = torch.cat((E, g_E), dim=-1)

            #we finally add the extra data as requested
            X = torch.cat((X, extra_data.X), dim=2).float()
            E = torch.cat((E, extra_data.E), dim=3).float()
            y = torch.hstack((y, extra_data.y)).float()

            return utils.PlaceHolder(X=X, E=E, y=y, charges=z_t.charges, guidance=guidance,
                                     pos=z_t.pos, node_mask=z_t.node_mask)

        z_t_processed = produce_XEy(z_t, extra_data, guidance)

        #p(x_0|x_t, guidance)
        out = model_ref(z_t_processed)

        #if(self.cfg.guidance.loss == 'crossentropy'):
        if not train_step and self.cfg.guidance.s > 0:
            #first of all, we need to calculate p(x_0|x_t, None) as well:
            z_null = produce_XEy(z_t, extra_data, cf_null_token)
            out_null = model_ref(z_null)

            if(model_ref.predict_delt == False):
                out.X = out_null.X + self.cfg.guidance.s*(out.X - out_null.X)
                out.E = out_null.E + self.cfg.guidance.s*(out.E - out_null.E)
            
                if(self.use_charges):
                    out.charges = out_null.charges + \
                        self.cfg.guidance.s*(out.charges - out_null.charges)
                if(self.use_3d):
                    out.pos = out_null.pos + self.cfg.guidance.s*(out.pos - out_null.pos)
                if(self.use_ins_del):
                    out.insert_time = out_null.insert_time + \
                        self.cfg.guidance.s*(out.insert_time - out_null.insert_time)
                # print("=============================================")
            else:
                out.X       = None
                out.E       = None
                out.pos     = None
                out.charges = None

                out.y       = out_null.y + self.cfg.guidance.s*(out.y - out_null.y)

        #if the loss is a crossentropy, we are done here.
        #the raw outputs are the only thing we need.
        return utils.PlaceHolder(X = out.X, E = out.E, y = out.y, guidance=guidance,
                                charges=out.charges, pos=out.pos, insert_time=out.insert_time,
                                node_mask=node_mask)
        #else:
        #    raise NotImplementedError("ERROR: unimplemented loss")

    def validation_step_preops(self, data, i):
        if i == 0:
            self.validation_guidance_vectors = data.guidance[0]

    def perform_validation_sampling(self):
        self.print(f"Sampling start")
        start = time.time()
        gen = self.cfg.general
        samples = self.sample_n_graphs(samples_to_generate=math.ceil(gen.samples_to_generate / max(len(gen.gpus), 1)),
                                    chains_to_save=gen.chains_to_save if self.local_rank == 0 else 0,
                                    samples_to_save=gen.samples_to_save if self.local_rank == 0 else 0,
                                    test=False, guidance = self.validation_guidance_vectors)
        print(f'Done on {self.local_rank}. Sampling took {time.time() - start:.2f} seconds\n')
        print(f"Computing sampling metrics on {self.local_rank}...")
        self.val_sampling_metrics(samples, self.name, self.current_epoch, self.local_rank)

        if self.experiment_type == "mae":
            mae = self.mae_test(samples, self.validation_guidance_vectors)
            log_dict = {'validation_mae': mae}
            self.log_dict(log_dict, on_epoch=True, on_step=False, sync_dist=True)        

    def test_step(self, data, i):
        if self.experiment_type == "mae":
            return self.test_step_mae(data, i)      
        elif self.experiment_type == "optimization":
            return self.test_step_optimization(data, i)
        elif self.experiment_type == "fixed_sampling":
            return None
        
    def test_step_mae(self, data, i):
        if(self.cfg.train.batch_size > 1):
            print("WARNING: batch size > 1. You may not have enough batches to run this test.",
                  "Try relaunching this experiment with train.batch_size=1")
            print("batch size is:", self.cfg.train.batch_size)
            data.guidance = data.guidance[0,:]
        print(f'Select No.{i+1} test molecule')
        # Extract properties
        target_properties = data.guidance.clone()
        print("TARGET PROPERTIES", target_properties)

        start = time.time()
        samples = self.sample_n_graphs(samples_to_generate=self.cfg.guidance.n_samples_per_test_molecule,
                                       chains_to_save=1,
                                       samples_to_save=10,
                                       test=True,
                                       guidance=target_properties)
        print(f'Sampling took {time.time() - start:.2f} seconds\n')
        _ = self.save_cond_samples(samples, target_properties, file_path=os.path.join(os.getcwd(), f'cond_smiles{i}.pkl'))

        # save conditional generated samples
        mae = self.mae_test(samples, target_properties)
        print("==============================================================")
        
        return {'mae': mae}

    def on_test_epoch_end(self) -> None:
        if self.experiment_type == "mae":
            self.on_test_epoch_end_mae()
        elif self.experiment_type == "optimization":
            self.on_test_epoch_end_optimization()
        elif self.experiment_type == "fixed_sampling":
            self.on_test_epoch_end_fixed_sampling()

    def on_test_epoch_end_mae(self):
        final_mae = self.cond_val.compute()
        final_validity = self.num_valid_molecules / self.num_total
        print("Final MAE", final_mae)
        print("Final validity", final_validity * 100)

        #######################################################################
        unique_generated_smiles = set(self.generated_smiles)
        if(len(self.generated_smiles) != 0):
            final_uniqueness = len(unique_generated_smiles)/len(self.generated_smiles)
            print("final_uniqueness = ", final_uniqueness)
        else:
            final_uniqueness = 0
            print("final_uniqueness = 0 due to no uniques")

        if(self.cfg.dataset.name in ['qm9', 'zinc250k']):
            train_dataset_smiles = self.val_sampling_metrics.train_smiles
        else:
            print("TODO: implement get_train_smiles for other datasets")

        train_dataset_smiles_set = set(train_dataset_smiles)
        print("There are ", len(train_dataset_smiles_set), " smiles in the training set")

        final_novelty_smiles   = unique_generated_smiles.difference(train_dataset_smiles_set)

        if(len(unique_generated_smiles) != 0):
            final_novelty          = len(final_novelty_smiles)/len(unique_generated_smiles)
            print("final_novelty", final_novelty)
        else:
            final_novelty = 0
            print("Final Novelty = 0 due to no smiles generated")
        #######################################################################

        wandb.run.summary['final_MAE'] = final_mae
        wandb.run.summary['final_validity'] = final_validity
        wandb.run.summary['final_uniqueness'] = final_uniqueness
        wandb.log({'final mae': final_mae,
                   'final validity': final_validity,
                   'final uniqueness': final_uniqueness})
        wandb.log({'final split_check': self.split_check})

    def test_step_optimization(self, data, i):
        print("==============================================================")
        if(self.cfg.train.batch_size > 1):
            print("WARNING: batch size > 1. You may not have enough batches to run this test.",
                  "Try relaunching this experiment with train.batch_size=1")
            print("batch size is:", self.cfg.train.batch_size)
            data.guidance = data.guidance[0,:]
        print(f'Select No.{i+1} test molecule')
        # Extract properties

        original_smile = data.smiles[0]
        dense_data = utils.to_dense(data, self.dataset_infos)

        device = dense_data.X.device
        
        # Creates n_samples_per_test_molecule copies
        n_per_test = self.cfg.guidance.n_samples_per_test_molecule
        dense_data = dense_data.duplicate(n_per_test)

        z_t, _, _, _ = self.corrupt_data(dense_data, train_step=False, 
                                         corruption_step=self.corruption_step)
        
        if self.cfg.guidance.improvement_type == 'free':
            target_properties = dense_data.guidance + self.cfg.guidance.improvement_threshold
        elif self.cfg.guidance.improvement_type == 'fixed':
            target_properties = torch.full_like(dense_data.guidance, self.cfg.guidance.improvement_target)
        else:
            raise Exception("Improvement type not supported")
        
        print(f"original property: ", dense_data.guidance[0,:].item())
        print(f"target_properties:\n{target_properties.reshape(-1)}")

        start = time.time()
        samples = self.sample_n_graphs(samples_to_generate=n_per_test,
                                       chains_to_save=self.cfg.general.final_model_chains_to_save,
                                       samples_to_save=self.cfg.general.final_model_samples_to_save,
                                       test=True,
                                       guidance=target_properties,
                                       z_t=z_t)
        print(f'Sampling took {time.time() - start:.2f} seconds\n')
        _ = self.save_cond_samples(samples, target_properties, file_path=os.path.join(os.getcwd(), f'cond_smiles{i}.pkl'))

        valid_properties = self.cfg.guidance.guidance_target

        if('mu' in valid_properties or 'homo' in valid_properties):
            self.initialize_psi4()

        # results_dict only contains the results for the VALID molecules. Thus,
        # all the arrays have size [#valid_molecules] with the exception of 
        # valid_mask which is 
        results_dict            = self.compute_properties(samples, valid_properties)
        numerical_results_dict  = results_dict['numerical_results_dict']
        sample_smiles           = results_dict['sample_smiles']
        split_molecules         = results_dict['split_molecules']
        valid_mask              = results_dict['valid_mask']

        print("numerical_results_dict:\n", numerical_results_dict)
        print("sample_smiles:\n", sample_smiles)

        # removes from data.guidance and original_smile the elements
        # that belong to invalid samples (outputs and sample_smiles are
        # already filtered)
        print("valid_mask:", valid_mask.int().reshape(-1))
        original_properties = dense_data.guidance[valid_mask].to(device)

        _, outputs, _ = self.make_log_dict(original_properties, 
                                           numerical_results_dict, 0)
        
        print(f"outputs (size:{outputs.size()})\n", outputs.reshape(-1))
        
        keep_mask = torch.ones((len(sample_smiles), len(valid_properties)), dtype=torch.bool)
        # Next we mark the elements that should be removed if they are out of range
        if (self.cfg.guidance.improvement_type == 'fixed'):
            curr_idx = 0
            for property in ['mu', 'homo', 'penalizedlogp', 'qed', 'mw', 'sas', 'logp', 'drd2',
                             'bertz', 'TPSA', 'isomer', 'fprint']:
                if(property in valid_properties):
                    keep_mask[:, curr_idx] = torch.logical_and(outputs[:,curr_idx] >= self.improvement_limits[property][0],
                                                               outputs[:,curr_idx] <= self.improvement_limits[property][1])
                    curr_idx += 1
        
        similarities = []
        for i in range(len(sample_smiles)):
            sim = similarity(sample_smiles[i], original_smile)
            similarities.append(sim)
            keep_mask[i] &= (sim >= self.cfg.guidance.similarity_threshold)

        print("original smile:", original_smile)
        print("similarities against original mol:", [ round(elem, 2) for elem in similarities])

        print("FINAL keep_mask:", keep_mask.reshape(-1).int())
        keep_mask_list  = keep_mask.reshape(-1).tolist()
        # Removes from similarity, outputs and target_properties
        # the elements where similarities_mask is False
        print("sample_smiles BEFORE:", sample_smiles)
        print("keep_mask_list:", keep_mask_list)
        sample_smiles       = list(set(compress(sample_smiles, keep_mask_list)))
        print("sample_smiles AFTER:", sample_smiles)
        similarities        = list(compress(similarities, keep_mask_list))
        outputs             = outputs[keep_mask].to(device)
        original_properties = original_properties[keep_mask]

        success             = keep_mask.any().to(device)
        if not success:
            print("no molecule satisfies the constraints")
            self.succ_val.update(torch.tensor([0], device=device))
            
            tmp_succ, tmp_succ_conf = self.succ_val.compute()
            print(f"Temporary success: {tmp_succ} +- {tmp_succ_conf}")

            
            log_dict = {'impr': None,
                    'diff': None,
                    'succ': tmp_succ}
            
            if wandb.run:
                wandb.log(log_dict)
            
            return log_dict
        
        # If we cross the above check it means that we had a successful
        # test (at least 1 mol satisfying the constraint) and we can move on
        self.succ_val.update(torch.tensor([1], device=device))

        improvements        = outputs-original_properties
        best_improvement    = torch.max(improvements)

        print(f"improvements (max: {best_improvement.item()}):\n{improvements.reshape(-1)}")

        # for i in range(len(sample_smiles)):
        #     print(f"{original_smile} {sample_smiles[i]} - {original_properties[i]} -> {outputs[i].item()} - Similarity: {similarities[i].item()}")

        print("best_improvement: ", best_improvement)
        self.cond_val.update(best_improvement)

        if len(sample_smiles) >= 2:
            print("FINAL sample_smiles: ", sample_smiles)
            dist = []
            for i in range(len(sample_smiles)):
                for j in range(i+1, len(sample_smiles)):
                    sim = similarity(sample_smiles[i], sample_smiles[j])
                    dist.append(1 - sim)
            average_dist = sum(dist) / len(dist)
            print("average_dist: ", average_dist)
            self.diff_val.update(torch.tensor([average_dist], device=device))
        else:
            self.diff_val.update(torch.tensor([0], device=device))
            print("skipping pairwise dist computation since we had less than 2 valid mols")
        
        tmp_impr, tmp_impr_conf = self.cond_val.compute()
        tmp_diff, tmp_diff_conf = self.diff_val.compute()
        tmp_succ, tmp_succ_conf = self.succ_val.compute()
        print(f"Temporary improvement: {tmp_impr} +- {tmp_impr_conf}")
        print(f"Temporary difference: {tmp_diff} +- {tmp_diff_conf}")
        print(f"Temporary success: {tmp_succ} +- {tmp_succ_conf}")

        log_dict = {'impr': tmp_impr,
                'diff': tmp_diff,
                'succ': tmp_succ}

        if wandb.run:
            wandb.log(log_dict)
        
        return log_dict
    
    def on_test_epoch_end_optimization(self):
        final_impr, final_impr_conf = self.cond_val.compute()
        final_diff, final_diff_conf = self.diff_val.compute()
        final_succ, final_succ_conf = self.succ_val.compute()
        print(f"Final improvement: {final_impr} +- {final_impr_conf}")
        print(f"Final difference: {final_diff} +- {final_diff_conf}")
        print(f"Final success: {final_succ} +- {final_succ_conf}")

        wandb.run.summary['final_improvement'] = final_impr
        wandb.run.summary['final_difference'] = final_diff
        wandb.run.summary['final_success'] = final_succ
        wandb.log({'final_improvement': final_impr,
                   'final_difference': final_diff,
                   'final_success': final_succ})

    def on_test_epoch_end_fixed_sampling(self):
        target              = self.cfg.guidance.fixed_sampling_target

        n_graphs_to_gen     = self.cfg.guidance.n_samples_per_test_molecule
        target_properties   = torch.full(size=(n_graphs_to_gen,1), fill_value=target, device=self.device)

        samples = self.sample_n_graphs(samples_to_generate=n_graphs_to_gen,
                                       chains_to_save=0,
                                       samples_to_save=10,
                                       test=True,
                                       guidance=target_properties)
        to_estimate  = ['penalizedlogp', 'qed', 'mw','logp']
        results_dict = self.compute_properties(samples, to_estimate).get('numerical_results_dict')

        print("results_dict:\n", results_dict)
        for p in to_estimate:
            curr_results = torch.tensor([results_dict.get(p)])
            top_3 = torch.topk(curr_results, k=3)
            print(f"{p} top 3: ", top_3)


    # This just returns 
    # the log_dict (isolated MAEs), 
    # outputs (the estimations of the target values) 
    # and num_valid_molecules updated
    def make_log_dict(self, input_properties, numerical_results_dict, num_valid_molecules):
        log_dict = {}
        i = 0
        outputs = None
        valid_properties = self.cfg.guidance.guidance_target
        for tgt in ['mu', 'homo', 'penalizedlogp', 'qed', 'mw', 'sas', 'logp', 'drd2',
                    'bertz', 'TPSA', 'isomer', 'fprint']:
            if tgt in valid_properties:
                num_valid_molecules = max(num_valid_molecules, len(numerical_results_dict[tgt]))

                curr_numerical = torch.FloatTensor(numerical_results_dict[tgt]).unsqueeze(1)

                if(outputs == None):
                    outputs        = curr_numerical
                else:
                    print(tgt, " outputs", outputs)
                    #print("binary_outputs", binary_outputs)
                    outputs        = torch.hstack((outputs, curr_numerical))

                if(len(valid_properties) > 1):
                    curr_tgt_key = "val_epoch/" + tgt + "_mae"
                    curr_tgt_vec = input_properties[..., i].repeat(curr_numerical.size(0), 1).cpu()
                    print(tgt, " tgt vec: ", curr_tgt_vec)

                    curr_tgt_improvement = curr_numerical - curr_tgt_vec
                    log_dict.update({curr_tgt_key: curr_tgt_improvement})
                
                i = i + 1
        return log_dict, outputs, num_valid_molecules

    def mae_test(self, samples, input_properties):
        valid_properties = self.cfg.guidance.guidance_target
        #test_thresholds  = self.cfg.guidance.test_thresholds

        if('mu' in valid_properties or 'homo' in valid_properties):
            self.initialize_psi4()

        print("valid_properties", valid_properties)

        split_molecules = 0

        results_dict = self.compute_properties(samples, valid_properties)
        numerical_results_dict = results_dict['numerical_results_dict']
        sample_smiles = results_dict['sample_smiles']
        split_molecules = results_dict['split_molecules']

        num_valid_molecules = 0
        outputs = None

        log_dict, outputs, num_valid_molecules = self.make_log_dict(input_properties, 
                                                        numerical_results_dict, 
                                                        num_valid_molecules)
        
        print("Number of valid samples", num_valid_molecules)
        self.num_valid_molecules += num_valid_molecules
        self.num_total += len(samples)

        #we can recycle tgt for outputs[tgt]

        target_tensor = input_properties.repeat(outputs.size(0), 1).cpu()

        print("outputs", outputs)

        mae = self.cond_val(outputs, target_tensor)
        
        unique_smiles              = set(sample_smiles)
        n_unique_smiles            = len(unique_smiles)
        if(len(sample_smiles) != 0):
            n_unique_smiles_percentage = n_unique_smiles/len(sample_smiles)
        else:
            n_unique_smiles_percentage = 0
        print("percentage of unique_samples: ", n_unique_smiles_percentage)
        
        #print("binary_outputs =", binary_outputs)
        print("target_tensor  =", target_tensor)

        print('Conditional generation metric:')
        print(f'Epoch {self.current_epoch}: MAE: {mae}')

        log_dict.update({"val_epoch/conditional generation mae": mae,
                        'Valid molecules'                     : num_valid_molecules,
                        'Valid molecules splitted'            : split_molecules,
                        "val_epoch/n_unique_smiles"           : n_unique_smiles,
                        "val_epoch/n_unique_smiles_percentage": n_unique_smiles_percentage,
                        })
        
        if wandb.run:
            wandb.log(log_dict)

        return mae

    def initialize_psi4(self):
        try:
            import psi4
            # Hardware side settings (CPU thread number and memory settings used for calculation)
            psi4.set_num_threads(nthread=4)
            psi4.set_memory("5GB")
            psi4.core.set_output_file('psi4_output.dat', False)
        except ModuleNotFoundError:
            print("PSI4 not found")

    def compute_properties(self, samples, valid_properties):
        sample_smiles = []
        split_molecules = 0

        valid_mask = torch.ones((len(samples),), dtype=torch.bool)

        numerical_results_dict = {}
        for tgt in valid_properties:
            numerical_results_dict[tgt] = []

        for i, sample in enumerate(samples):
            raw_mol = sample.rdkit_mol

            mol = Chem.rdchem.RWMol(raw_mol)

            try:
                Chem.SanitizeMol(mol)
            except:
                print('invalid chemistry')
                valid_mask[i] = False
                continue

            # Coarse 3D structure optimization by generating 3D structure from SMILES
            mol = Chem.AddHs(mol)
            params = ETKDGv3()
            params.randomSeed = 1
            try:
                EmbedMolecule(mol, params)
            except Chem.rdchem.AtomValenceException:
                print('invalid chemistry')
                valid_mask[i] = False
                continue

            # Structural optimization with MMFF (Merck Molecular Force Field)
            try:
                s = MMFFOptimizeMolecule(mol)
                print(s)
            except:
                print('Bad conformer ID')
                valid_mask[i] = False
                continue

            try:
                conf = mol.GetConformer()
            except:
                print("GetConformer failed")
                valid_mask[i] = False
                continue
            
            mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
            if(len(mol_frags) > 1):
                split_molecules = split_molecules + 1
                if(self.cfg.guidance.include_split == False):
                    print("Ignoring a split molecule")
                    valid_mask[i] = False
                    continue

            ##########################################################
            ##########################################################

            if('mu' in valid_properties or 'homo' in valid_properties):
                # Convert to a format that can be input to Psi4.
                # Set charge and spin multiplicity (below is charge 0, spin multiplicity 1)

                # Get the formal charge
                fc = 'FormalCharge'
                mol_FormalCharge = int(mol.GetProp(fc)) if mol.HasProp(fc) else Chem.GetFormalCharge(mol)

                sm = 'SpinMultiplicity'
                if mol.HasProp(sm):
                    mol_spin_multiplicity = int(mol.GetProp(sm))
                else:
                    # Calculate spin multiplicity using Hund's rule of maximum multiplicity...
                    NumRadicalElectrons = 0
                    for Atom in mol.GetAtoms():
                        NumRadicalElectrons += Atom.GetNumRadicalElectrons()
                    TotalElectronicSpin = NumRadicalElectrons / 2
                    SpinMultiplicity = 2 * TotalElectronicSpin + 1
                    mol_spin_multiplicity = int(SpinMultiplicity)

                mol_input = "%s %s" % (mol_FormalCharge, mol_spin_multiplicity)
                print(mol_input)
                #mol_input = "0 1"

                # Describe the coordinates of each atom in XYZ format
                for atom in mol.GetAtoms():
                    mol_input += "\n " + atom.GetSymbol() + " " + str(conf.GetAtomPosition(atom.GetIdx()).x) \
                                + " " + str(conf.GetAtomPosition(atom.GetIdx()).y) \
                                + " " + str(conf.GetAtomPosition(atom.GetIdx()).z)

                try:
                    molecule = psi4.geometry(mol_input)
                except:
                    print('Can not calculate psi4 geometry')
                    valid_mask[i] = False
                    continue

                # Convert to a format that can be input to pyscf
                # Set calculation method (functional) and basis set
                level = "b3lyp/6-31G*"

                # Calculation method (functional), example of basis set
                # theory = ['hf', 'b3lyp']
                # basis_set = ['sto-3g', '3-21G', '6-31G(d)', '6-31+G(d,p)', '6-311++G(2d,p)']

                # Perform structural optimization calculations
                print('Psi4 calculation starts!!!')
                #energy, wave_function = psi4.optimize(level, molecule=molecule, return_wfn=True)
                try:
                    with time_limit(3600):
                        energy, wave_function = psi4.energy(level, molecule=molecule, return_wfn=True)
                except:
                    print("Psi4 did not converge")
                    valid_mask[i] = False
                    continue

                print('Chemistry information check!!!')
            
            ##########################################################
            ##########################################################
            try:
                mol = raw_mol
                mol = clean_mol(mol)

                smile = Chem.MolToSmiles(mol)
                print("Generated SMILES ", smile)
            except:
                print("clean_mol failed")
                valid_mask[i] = False
                continue

            sample_smiles.append(smile)
            self.generated_smiles.append(smile)

            if 'mu' in valid_properties:
                dip_x, dip_y, dip_z = wave_function.variable('SCF DIPOLE')[0],\
                                      wave_function.variable('SCF DIPOLE')[1],\
                                      wave_function.variable('SCF DIPOLE')[2]
                dipole_moment = math.sqrt(dip_x**2 + dip_y**2 + dip_z**2) * 2.5417464519
                print("Dipole moment", dipole_moment)
                numerical_results_dict['mu'].append(dipole_moment)

            if 'homo' in valid_properties:
                # Compute HOMO (Unit: au= Hartree）
                LUMO_idx = wave_function.nalpha()
                HOMO_idx = LUMO_idx - 1
                homo = wave_function.epsilon_a_subset("AO", "ALL").np[HOMO_idx]

                # convert unit from a.u. to ev
                homo = homo * 27.211324570273
                numerical_results_dict['homo'].append(homo)

            if 'penalizedlogp' in valid_properties:
                plogp_estimate = penalized_logp(mol)
                numerical_results_dict['penalizedlogp'].append(plogp_estimate)

            if 'qed' in valid_properties:
                qed_estimate = qed(mol)
                numerical_results_dict['qed'].append(qed_estimate)

            if 'mw' in valid_properties:
                mw_estimate = mw(mol) / 100
                numerical_results_dict['mw'].append(mw_estimate)

            if 'sas' in valid_properties:
                sas_estimate = calculateScore(mol)
                numerical_results_dict['sas'].append(sas_estimate)
            
            if 'logp' in valid_properties:
                logp_estimate = Crippen.MolLogP(mol)
                numerical_results_dict['logp'].append(logp_estimate)

            if 'drd2' in valid_properties:
                logp_estimate = drd2(smile)
                numerical_results_dict['drd2'].append(logp_estimate)

        return {"numerical_results_dict": numerical_results_dict, 
                "sample_smiles": sample_smiles, 
                "split_molecules": split_molecules,
                "valid_mask": valid_mask}

    

    def save_cond_samples(self, samples, target, file_path):
        cond_results = {'smiles': [], 'input_targets': target}
        invalid = 0
        disconnected = 0

        print("\tConverting conditionally generated molecules to SMILES ...")
        for sample in samples:
            mol = sample.rdkit_mol
            smile = mol2smiles(mol)
            if smile is not None:
                cond_results['smiles'].append(smile)
                mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False)
                if len(mol_frags) > 1:
                    print("Disconnected molecule", mol, mol_frags)
                    disconnected += 1
            else:
                print("Invalid molecule obtained.")
                invalid += 1

        print("Number of invalid molecules", invalid)
        print("Number of disconnected molecules", disconnected)

        # save samples
        with open(file_path, 'wb') as f:
            pickle.dump(cond_results, f)

        return cond_results