import argparse
import os
from os.path import join
import logging
import random
import traceback
from rdkit import Chem

from guacamol.assess_distribution_learning import assess_distribution_learning, _assess_distribution_learning
from guacamol.utils.helpers import setup_default_logger

from guacamol_evaluation.ldm_generator import LDMSmilesSampler
from guacamol_evaluation.generator import EDMSmilesSampler
from guacamol_evaluation.ldm_sampling import sample_from_ldm
from qm9.analyze_joint_training import is_valid
from train_test import compute_prop_mae_on_generated_mols, compute_prop_values_on_generated_mols
from guacamol_evaluation.load_model import load_regressor

import numpy as np
import torch

# To run from Terminal go to main directory and run
# PYTHONPATH="${PYTHONPATH}:." python guacamol_evaluation/distribution_learning.py 
# --output_dir guacamol_evaluation/ --model_path outputs/edm_qm9_sc_rdkit_no_charges_resume/ --batch_size 100
if __name__ == '__main__':
    setup_default_logger()

    parser = argparse.ArgumentParser(description='Molecule distribution learning benchmark for EDM smiles sampler',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--exp_folder", default="outputs/edm_zinc250k_without_h")
    parser.add_argument('--dist_file', default='data/zinc250k/smiles/train.txt')
    parser.add_argument('--suite', default='v2')
    parser.add_argument("--batch_size", type=int, default=100)
    parser.add_argument("--number_samples", type=int, default=10000)
    parser.add_argument('--use_cached_3d_mols', action='store_true', default=False, help='whether to use pre-computed 3D mols')
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument('--disable_tf32', action='store_true', default=False, help='whether to disable TF32 for better precision for inference')
    parser.add_argument('--ckpt_prefix', type=str, default="", help=" '' or 'best_fcd_' or 'last_' ")
    parser.add_argument('--regressor_free_guidance', action='store_true', default=False, help='whether to enable regressor free guidance for conditional generation')
    
    parser.add_argument('--regressor_guidance', action='store_true', default=False, help='whether to enable regressor guidance for conditional generation')
    parser.add_argument("--regressor_exp_folder", default="outputs/regressor_model_zinc250k_logp")
    parser.add_argument('--conditional_prop', type=str, default="penalized_logP", help="one of 'penalized_logP', 'qed', 'drd2', 'tpsa' ")
    parser.add_argument("--guidance_scale", type=float, default=1.0)
    parser.add_argument('--guidance_loss', type=str, default="l1", help="one of 'l1', 'l2' ")

    args = parser.parse_args()

    if args.disable_tf32:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False

    # seed everything
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

    # if args.output_dir is None:
    #     args.output_dir = os.path.dirname(os.path.realpath(__file__))

    with open(args.dist_file, 'r') as smiles_file:
        smiles_list = [line.strip() for line in smiles_file.readlines()]

    generator = LDMSmilesSampler(exp_folder=args.exp_folder,
                                 batch_size=args.batch_size,
                                 total_number_samples=args.number_samples,
                                 use_cached_3d_mols=args.use_cached_3d_mols,
                                 ckpt_prefix=args.ckpt_prefix)
    
    if args.regressor_free_guidance:
        generator.model.regressor_free_guidance = True

    if args.regressor_guidance:
        all_props = ['penalized_logP', 'qed', 'drd2', 'tpsa']
        prop_idx = all_props.index(args.conditional_prop)

        regressor, prop_dist = load_regressor(args.regressor_exp_folder, all_props)
        ### Guidance fn
        def guidance_fn(zt, t, node_mask, edge_mask, context):
            zt.requires_grad = True
            pred = regressor._forward(t, zt, node_mask, edge_mask)[:, prop_idx]
            if args.guidance_loss == 'l2':
                loss = (pred - generator.model.target_prop) ** 2
            elif args.guidance_loss == 'l1':
                loss = torch.abs(pred - generator.model.target_prop)
            elif args.guidance_loss == 'max':
                loss = -1. * pred
            d_zt = torch.autograd.grad(outputs=loss, inputs=zt, grad_outputs=torch.ones_like(loss))[0]
            assert t.min() == t.max(), "t contains different values ?!?"
            d_zt = args.guidance_scale * t.min() * d_zt
            #d_zt = args.guidance_scale * d_zt
            return -1. * d_zt.detach() # -1 to minimize loss
        generator.model.guidance_fn = guidance_fn
        ### Guidance fn

    smiles = []
    conditional_values = []
    n_valid_samples = 0
    # account for generating an extra 100 molecules for when the final call for visualizing.
    while n_valid_samples < args.number_samples:
        n_samples = max(2, args.number_samples - n_valid_samples)
        if n_samples % args.batch_size == 1:
            n_samples += 1

        sampled_molecules, sampled_smiles = sample_from_ldm(generator.model, generator.nodes_dist, generator.args, generator.device, 
                                        generator.dataset_info, prop_dist=prop_dist[args.conditional_prop],
                                        n_samples=n_samples, batch_size=generator.batch_size, enforce_unconditional_generation=False, 
                                        regressor_guidance=args.regressor_guidance)
        smiles.extend(sampled_smiles)
        conditional_values.append(sampled_molecules['context_global'])

        # count valid and unique smiles so far, over canonicalized smiles
        n_valid_samples = len(set([Chem.MolToSmiles(Chem.MolFromSmiles(s)) for s in smiles if is_valid(s)]))

    conditional_values = torch.cat(conditional_values, dim=0)

    mae = compute_prop_mae_on_generated_mols(smiles, conditional_values.squeeze(), prop=args.conditional_prop)
    generated_scores_sorted = compute_prop_values_on_generated_mols(smiles, prop=args.conditional_prop)

    exp_name = args.exp_folder.split('/')[-1]
    if args.disable_tf32:
        exp_name += '_disable_tf32'
    os.makedirs('guacamol_evaluation/results_conditional_generation', exist_ok=True)
    os.makedirs(f'guacamol_evaluation/results_conditional_generation/{exp_name}', exist_ok=True)
    if args.regressor_guidance:
        flag = 'DEBUG_' if args.number_samples == 1000 else ''
        results_folder = join(f'guacamol_evaluation/results_conditional_generation/{exp_name}', f'{flag}prop={args.conditional_prop}_s={args.guidance_scale}_loss={args.guidance_loss}')
    else:
        results_folder = join(f'guacamol_evaluation/results_conditional_generation/{exp_name}', f'regressor_free_guidance_{args.regressor_free_guidance}')
    os.makedirs(results_folder, exist_ok=True)

    results_file_path = join(results_folder, 'conditional_generation_results.json')

    with open(results_file_path.replace('.json', '_smiles.txt'), 'w') as f:
        f.writelines('\n'.join(smiles))

    torch.save(conditional_values, results_file_path.replace('.json', '_conditional_values.npy'))
    with open(results_file_path.replace('.json', '_final_mae.txt'), 'w') as f:
        f.writelines(f'mae = {mae} \n')
    with open(results_file_path.replace('.json', '_generated_scores_sorted.txt'), 'w') as f:
        f.writelines('\n'.join([str(x) for x in generated_scores_sorted]))

    # Now we also evaluate distribution learning metrics on the conditional smiles
    generator.generated_smiles = smiles
    _assess_distribution_learning(model=generator,
                                  chembl_training_file=args.dist_file,
                                  json_output_file=results_file_path.replace('conditional_generation_results', 'distribution_learning_results'),
                                  benchmark_version=args.suite,
                                  number_samples=args.number_samples)
