import argparse
import os
from os.path import join
import logging
import random

from guacamol.assess_distribution_learning import assess_distribution_learning, _assess_distribution_learning
from guacamol.utils.helpers import setup_default_logger

from guacamol_evaluation.ldm_scaffolder import LDMSmilesScaffolder
from guacamol_evaluation.generator import EDMSmilesSampler
from synthetic_coordinates.rdkit_helpers import smiles_to_mol
from qm9.models import DistributionNodes

import numpy as np
import torch

# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/distribution_learning.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Molecule distribution learning benchmark for EDM smiles sampler',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--exp_folder", default="outputs/edm_zinc250k_without_h")
    parser.add_argument('--dist_file', default='data/zinc250k/smiles/train.txt')
    parser.add_argument('--suite', default='v2')
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--number_samples", type=int, default=10000)
    parser.add_argument('--use_cached_3d_mols', action='store_true', default=False, help='whether to use pre-computed 3D mols')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument('--disable_tf32', action='store_true', default=False, help='whether to disable TF32 for better precision for inference')
    parser.add_argument('--ckpt_prefix', type=str, default="", help=" '' or 'best_fcd_' or 'last_' ")
    parser.add_argument('--save_generated_3d_data', action='store_true', default=False, help='whether to cache the generated molecules in their 3D format.')

    # new args
    parser.add_argument('--scaffold_smiles', type=str, default="CC", help="SMILES string of the scaffold")
    parser.add_argument("--diffusion_steps", type=int, default=200)
    parser.add_argument("--jump_len", type=int, default=10)
    parser.add_argument("--jump_n_sample", type=int, default=5)

    args = parser.parse_args()

    if args.disable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

    # seed everything
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # if args.output_dir is None:
    #     args.output_dir = os.path.dirname(os.path.realpath(__file__))
    
    # create results folder
    exp_name = args.exp_folder.split('/')[-1]
    exp_name += '_ckpt_prefix_' + args.ckpt_prefix
    if args.disable_tf32:
        exp_name += '_disable_tf32'

    os.makedirs('guacamol_evaluation/results_scaffolding', exist_ok=True)
    results_folder = join('guacamol_evaluation/results_scaffolding', args.scaffold_smiles+f'_T={args.diffusion_steps}_r={args.jump_n_sample}_j={args.jump_len}')
    os.makedirs(results_folder, exist_ok=True)
    json_file_path = join(results_folder, 'distribution_learning_results.json')

    with open(args.dist_file, 'r') as smiles_file:
        smiles_list = [line.strip() for line in smiles_file.readlines()]

    # extract subset of training data that contains the scaffold
    # compute distribution of number of nodes on this subset at the same time
    scaffold_mol = smiles_to_mol(args.scaffold_smiles, only_explicit_H=True)
    train_subset_with_scaffold = []
    n_nodes_histogram = {}
    for s in smiles_list:
        m = smiles_to_mol(s, only_explicit_H=True)
        if m.HasSubstructMatch(scaffold_mol):
            train_subset_with_scaffold.append(s)
            n_nodes = m.GetNumAtoms()
            if n_nodes in n_nodes_histogram:
                n_nodes_histogram[n_nodes] += 1
            else:
                n_nodes_histogram[n_nodes] = 1
    print(f'Found {len(train_subset_with_scaffold)} molecules in the training dataset that contain the scaffold {args.scaffold_smiles}')

    # save them on disk for guacamol evaluator to load them
    subset_dist_file = join(results_folder, 'subset_with_scaffold.txt')
    with open(subset_dist_file, 'w') as f:
        f.writelines('\n'.join(train_subset_with_scaffold))
    args.dist_file = subset_dist_file

    # compute new nodes_distribution that only considers subset with scaffold
    n_nodes_histogram = dict(sorted(n_nodes_histogram.items())) # sort it
    nodes_dist = DistributionNodes(n_nodes_histogram)

    generator = LDMSmilesScaffolder(exp_folder=args.exp_folder,
                                 batch_size=args.batch_size,
                                 updated_nodes_dist=nodes_dist, 
                                 scaffold_smiles=args.scaffold_smiles, 
                                 diffusion_steps=args.diffusion_steps, 
                                 jump_len=args.jump_len, 
                                 jump_n_sample=args.jump_n_sample,
                                 total_number_samples=args.number_samples,
                                 use_cached_3d_mols=args.use_cached_3d_mols,
                                 ckpt_prefix=args.ckpt_prefix)

    #generator.generated_smiles = [s.strip() for s in open('guacamol_evaluation/distribution_learning_results_eval_zinc_no_h_10k_save_smiles_smiles.txt', 'r').readlines()]

    # assess_distribution_learning(generator,
    #                              chembl_training_file=args.dist_file,
    #                              json_output_file=json_file_path,
    #                              benchmark_version=args.suite)

    _assess_distribution_learning(model=generator,
                                  chembl_training_file=args.dist_file,
                                  json_output_file=json_file_path,
                                  benchmark_version=args.suite,
                                  number_samples=args.number_samples)

    with open(json_file_path.replace('.json', '_smiles.txt'), 'w') as f:
        f.writelines('\n'.join(generator.generated_smiles))
