import argparse
import os
from os.path import join
import logging

from guacamol.assess_distribution_learning import assess_distribution_learning, _assess_distribution_learning
from guacamol.utils.helpers import setup_default_logger

from guacamol_evaluation.ldm_generator import LDMSmilesSampler
from guacamol_evaluation.generator import EDMSmilesSampler

import numpy as np

# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/distribution_learning.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Molecule distribution learning benchmark for EDM smiles sampler',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--exp_folder", default="outputs/edm_zinc250k_without_h")
    parser.add_argument('--dist_file', default='data/zinc250k/smiles/train.txt')
    parser.add_argument('--suite', default='v2')
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--number_samples", type=int, default=10000)
    parser.add_argument('--use_cached_3d_mols', action='store_true', default=False, help='whether to use pre-computed 3D mols')

    args = parser.parse_args()

    # if args.output_dir is None:
    #     args.output_dir = os.path.dirname(os.path.realpath(__file__))

    with open(args.dist_file, 'r') as smiles_file:
        smiles_list = [line.strip() for line in smiles_file.readlines()]

    generator = LDMSmilesSampler(exp_folder=args.exp_folder,
                                 batch_size=args.batch_size,
                                 total_number_samples=args.number_samples,
                                 use_cached_3d_mols=args.use_cached_3d_mols)

    exp_name = args.exp_folder.split('/')[-1]
    os.makedirs('guacamol_evaluation/results', exist_ok=True)
    results_folder = join('guacamol_evaluation/results', exp_name)
    os.makedirs(results_folder, exist_ok=True)

    json_file_path = join(results_folder, 'distribution_learning_results.json')

    #generator.generated_smiles = [s.strip() for s in open('guacamol_evaluation/distribution_learning_results_eval_zinc_no_h_10k_save_smiles_smiles.txt', 'r').readlines()]

    # assess_distribution_learning(generator,
    #                              chembl_training_file=args.dist_file,
    #                              json_output_file=json_file_path,
    #                              benchmark_version=args.suite)

    _assess_distribution_learning(model=generator,
                                  chembl_training_file=args.dist_file,
                                  json_output_file=json_file_path,
                                  benchmark_version=args.suite,
                                  number_samples=args.number_samples)

    logging.info('Saving generated 3D data:')
    output_npz = join(results_folder, 'generated_3d_mols.npz')
    np.savez_compressed(output_npz, **generator.generated_3d_mols)
    logging.info('Processing/saving complete!')

    with open(json_file_path.replace('.json', '_smiles.txt'), 'w') as f:
        f.writelines('\n'.join(generator.generated_smiles))
