import argparse
import os
from os.path import join
import logging
import random

from guacamol.assess_distribution_learning import assess_distribution_learning, _assess_distribution_learning
from guacamol.utils.helpers import setup_default_logger

from guacamol_evaluation.ldm_generator import LDMSmilesSampler
from guacamol_evaluation.generator import EDMSmilesSampler

import numpy as np
import torch

# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/distribution_learning.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Molecule distribution learning benchmark for EDM smiles sampler',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--exp_folder", default="outputs/edm_zinc250k_without_h")
    parser.add_argument('--dist_file', default='data/zinc250k/smiles/train.txt')
    parser.add_argument('--suite', default='v2')
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--number_samples", type=int, default=10000)
    parser.add_argument('--use_cached_3d_mols', action='store_true', default=False, help='whether to use pre-computed 3D mols')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument('--disable_tf32', action='store_true', default=False, help='whether to disable TF32 for better precision for inference')
    parser.add_argument('--ckpt_prefix', type=str, default="", help=" '' or 'best_fcd_' or 'last_' ")
    parser.add_argument('--save_generated_3d_data', action='store_true', default=False, help='whether to cache the generated molecules in their 3D format.')
    parser.add_argument("--diffusion_steps", type=int, default=1000)
    parser.add_argument("--size_extrapolation", type=int, default=-1)

    args = parser.parse_args()

    if args.disable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

    # seed everything
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # if args.output_dir is None:
    #     args.output_dir = os.path.dirname(os.path.realpath(__file__))

    with open(args.dist_file, 'r') as smiles_file:
        smiles_list = [line.strip() for line in smiles_file.readlines()]

    generator = LDMSmilesSampler(exp_folder=args.exp_folder,
                                 batch_size=args.batch_size,
                                 total_number_samples=args.number_samples,
                                 use_cached_3d_mols=args.use_cached_3d_mols,
                                 ckpt_prefix=args.ckpt_prefix)

    # for size extrapolation experiments:
    if args.size_extrapolation != -1:
        class SizeSamplerExtrapolator:
            def __init__(self, n) -> None:
                self.n = n
            def sample(self, n_samples=1):
                return torch.ones((n_samples,)).int() * self.n

        # only sample fixed atom sizes of the respective size
        generator.dataset_info['max_n_nodes'] += args.size_extrapolation
        generator.nodes_dist = SizeSamplerExtrapolator(generator.dataset_info['max_n_nodes'])
        # only compute validity
        args.suite = 'v3'

    exp_name = args.exp_folder.split('/')[-1]
    exp_name += '_ckpt_prefix_' + args.ckpt_prefix
    if args.disable_tf32:
        exp_name += '_disable_tf32'
    if args.diffusion_steps < 1000:
        # update diffusion steps
        generator.model.T = args.diffusion_steps
        exp_name += f'diffusion_steps_{args.diffusion_steps}'
    if args.size_extrapolation != -1:
        exp_name += f'size_extrapolation_{args.size_extrapolation}'

    os.makedirs('guacamol_evaluation/results', exist_ok=True)
    os.makedirs(f'guacamol_evaluation/results/{exp_name}', exist_ok=True)
    results_folder = join(f'guacamol_evaluation/results/{exp_name}', f'seed_{args.seed}')
    os.makedirs(results_folder, exist_ok=True)

    json_file_path = join(results_folder, 'distribution_learning_results.json')

    #generator.generated_smiles = [s.strip() for s in open('guacamol_evaluation/distribution_learning_results_eval_zinc_no_h_10k_save_smiles_smiles.txt', 'r').readlines()]

    # assess_distribution_learning(generator,
    #                              chembl_training_file=args.dist_file,
    #                              json_output_file=json_file_path,
    #                              benchmark_version=args.suite)

    _assess_distribution_learning(model=generator,
                                  chembl_training_file=args.dist_file,
                                  json_output_file=json_file_path,
                                  benchmark_version=args.suite,
                                  number_samples=args.number_samples)

    if args.save_generated_3d_data:
        logging.info('Saving generated 3D data:')
        output_npz = join(results_folder, 'generated_3d_mols.npz')
        np.savez_compressed(output_npz, **generator.generated_3d_mols)
        logging.info('Processing/saving complete!')

    with open(json_file_path.replace('.json', '_smiles.txt'), 'w') as f:
        f.writelines('\n'.join(generator.generated_smiles))
