import argparse
from typing import List
import logging

import numpy as np

from guacamol.utils.helpers import setup_default_logger
from guacamol_evaluation.load_model import load_model
from guacamol_evaluation.sample_smiles import sample_3d_molecules


class EDM3DMolGenerator:
    """
    Generator that samples SMILES strings using an EDM model.
    """

    def __init__(self, generator_model_path: str, batch_size: int) -> None:
        """
        Args:
            generator_model_path (str): path to folder containing generator model checkpoints. e.g. outputs/edm_zinc/
            batch_size (int): batch size used to sample from the model
            edge_predictor_path (str): path to folder containing edge predictor model checkpoints. e.g. outputs/edge_model_zinc/
        """
        self.model, utilities_dict = load_model(generator_model_path)
        self.nodes_dist, self.args, self.device, self.dataset_info = utilities_dict.values()
        self.batch_size = batch_size

    def generate(self, number_samples: int) -> List[str]:
        # dict containing all the generated features
        molecules = sample_3d_molecules(self.model, self.nodes_dist, self.args, self.device, self.dataset_info, 
                                        n_samples=number_samples, batch_size=self.batch_size)

        return molecules


# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/generate_3d_molecules.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Generating and serializing 3D molecules for future processing',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--output_file', default='guacamol_evaluation/generated_3d_mols_edm_zinc250k_without_h.npz', help='Output directory')
    parser.add_argument("--model_path", default="outputs/edm_zinc250k_without_h")
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--number_samples", type=int, default=20000)

    args = parser.parse_args()

    generator = EDM3DMolGenerator(generator_model_path=args.model_path, batch_size=args.batch_size)
    generated_molecules = generator.generate(args.number_samples)

    logging.info('Saving generated data:')
    np.savez_compressed(args.output_file, **generated_molecules)
    logging.info('Processing/saving complete!')
