"""
Benchmark of UnconditionalEvalPipeline / ConsistencyEvalPipeline for training set structures.
"""

import os
from typing import List, Optional
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
import pickle
import argparse

import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import rdkit.Chem.rdDetermineBonds

from shepherd_score.score.constants import ALPHA

from shepherd_score.conformer_generation import embed_conformer_from_smiles, single_point_xtb_from_xyz
from shepherd_score.container import Molecule

from evaluate import ConsistencyEvalPipeline, resample_surf_scores, get_mol_from_atom_pos


if 'TMPDIR' in os.environ:
    TMPDIR = Path(os.environ['TMPDIR'])
else:
    TMPDIR = Path('./')


def generate_benchmark_set(training_set_dir: str,
                           relevant_sets: List[int],
                           save_dir: str,
                           data_name: str,
                           num_molecs: int = 1000,
                           num_surf_points: int = 75,
                           probe_radius: float = 0.6):
    """
    Pretend that MMFF structures are "generated" and run ConsistencyEvalPipeline on it.

    training_set_dir : str Path to training data
    """
    dir_path = Path(training_set_dir)
    save_dir_path = Path(save_dir)
    if not dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {dir_path}')
    if not save_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {save_dir_path}')

    file_paths = {p.stem.split('.')[0][-1] : p for p in dir_path.glob('*.pkl')}

    molblocks = []
    for set_ind in tqdm(relevant_sets, total=len(relevant_sets), desc='Loading data'):
        file_path = file_paths[str(set_ind)]
        with open(file_path, 'rb') as f:
            molblocks_and_charges_single = pickle.load(f)
        molblocks.extend([molblocks_and_charges[0] for molblocks_and_charges in molblocks_and_charges_single])
        del molblocks_and_charges_single

    num_total_molecs = len(molblocks)

    # Choose random indices
    rng = np.random.default_rng()
    rand_inds = rng.choice(a=num_total_molecs, size=num_molecs, replace=False)
    molblocks = [molblocks[i] for i in rand_inds]

    ls_mmff_molblocks = []
    ls_atoms_pos = []
    ls_surf_points = []
    ls_surf_esp = []
    ls_pharm_feats = []
    num_failed = 0
    for molblock in tqdm(molblocks, total=len(molblocks), desc='Generating Molecule Objects',
                         miniters=50, maxinterval=10000):
        try:
            smiles = Chem.MolToSmiles(Chem.MolFromMolBlock(molblock, removeHs=False))

            mol = embed_conformer_from_smiles(smiles, MMFF_optimize=True)
            molec = Molecule(mol=mol,
                             num_surf_points=num_surf_points,
                             probe_radius=probe_radius,
                             pharm_multi_vector=False)
        except:
            num_failed += 1
            continue

        if molec.mol is not None:
            atom_id = np.array([a.GetAtomicNum() for a in molec.mol.GetAtoms()])
            atom_pos = np.array(molec.mol.GetConformer().GetPositions())
            if not (isinstance(molec.surf_pos, np.ndarray) and
                isinstance(molec.surf_esp, np.ndarray) and
                isinstance(molec.pharm_ancs, np.ndarray) and
                isinstance(molec.pharm_types, np.ndarray) and
                isinstance(molec.pharm_vecs, np.ndarray)):
                continue
        else:
            continue

        ls_mmff_molblocks.append(Chem.MolToMolBlock(molec.mol))

        ls_atoms_pos.append(
            (atom_id, atom_pos)
        )
        ls_surf_points.append(
            molec.surf_pos
        )
        ls_surf_esp.append(
            molec.surf_esp
        )
        ls_pharm_feats.append(
            (molec.pharm_types, molec.pharm_ancs, molec.pharm_vecs)
        )

    print(f'{num_failed} failed.')

    # Save representations
    save_dir_path = save_dir_path / data_name
    save_dir_path.mkdir(parents=True, exist_ok=True)

    molblock_save_path = save_dir_path / 'mmff_molblocks.pkl'
    with open(molblock_save_path, 'wb') as f:
        pickle.dump(ls_mmff_molblocks, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'Saved MMFF molblocks to {molblock_save_path}')
    
    atom_pos_path = save_dir_path / 'atom_pos.pkl'
    with open(atom_pos_path, 'wb') as f:
        pickle.dump(ls_atoms_pos, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'Saved MMFF atom types and positions to {atom_pos_path}')

    surf_points_path = save_dir_path / 'surfpos.pkl'
    with open(surf_points_path, 'wb') as f:
        pickle.dump(ls_surf_points, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'Saved MMFF surface points to {surf_points_path}')

    surf_esp_path = save_dir_path / 'surfesp.pkl'
    with open(surf_esp_path, 'wb') as f:
        pickle.dump(ls_surf_esp, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'Saved MMFF surface esp to {surf_esp_path}')

    pharm_feats_path = save_dir_path / 'pharmfeats.pkl'
    with open(pharm_feats_path, 'wb') as f:
        pickle.dump(ls_pharm_feats, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f'Saved MMFF surface esp to {pharm_feats_path}')

    return ls_mmff_molblocks, ls_atoms_pos, ls_surf_points, ls_surf_esp, ls_pharm_feats


def run_consistency_benchmark(save_dir: str,
                              data_name: str,
                              solvent: Optional[str] = None,
                              num_processes: int = 1,
                              probe_radius: float = 0.6,
                              ) -> None:
    """
    Generate a benchmarking set of molecules -- both MMFF and its corresponding xtb relaxed
    structures.

    Arguments
    ---------
    training_set_dir : str
    relevant_sets : List[int]
    save_file : str -- must end in .pkl in an existing directory
    num_molecs : int (default = 1000)
    """
    save_dir_path = Path(save_dir)
    if not save_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {save_dir_path}')

    # Load files
    atom_pos_path = save_dir_path / data_name / 'atom_pos.pkl'
    with open(str(atom_pos_path), 'rb') as f:
        ls_atoms_pos = pickle.load(f)

    surf_points_path = save_dir_path / data_name / 'surfpos.pkl'
    with open(str(surf_points_path), 'rb') as f:
        ls_surf_points = pickle.load(f)

    surf_esp_path = save_dir_path / data_name / 'surfesp.pkl'
    with open(str(surf_esp_path), 'rb') as f:
        ls_surf_esp = pickle.load(f)

    pharm_feats_path = save_dir_path / data_name / 'pharmfeats.pkl'
    with open(str(pharm_feats_path), 'rb') as f:
        ls_pharm_feats = pickle.load(f)

    # Initialize evaluation
    consis_pipeline = ConsistencyEvalPipeline(
        generated_mols=ls_atoms_pos,
        generated_surf_points=ls_surf_points,
        generated_surf_esp=ls_surf_esp,
        generated_pharm_feats=ls_pharm_feats,
        probe_radius=probe_radius,
        pharm_multi_vector=False,
        solvent=solvent
    )

    # Run evaluation
    consis_pipeline.evaluate(num_processes=num_processes, verbose=True)

    # Save values to numpy object
    benchmark_output_path = save_dir_path / data_name / f'benchmark_output.npz'
    np.savez(
        file=str(benchmark_output_path),
        num_valid = consis_pipeline.num_valid,
        num_valid_post_opt = consis_pipeline.num_valid_post_opt,
        num_consistent_graph = consis_pipeline.num_consistent_graph,
        strain_energies = consis_pipeline.strain_energies,
        rmsds = consis_pipeline.rmsds,
        SA_scores = consis_pipeline.SA_scores,
        logPs = consis_pipeline.logPs,
        QEDs = consis_pipeline.QEDs,
        fsp3s = consis_pipeline.fsp3s,
        frac_valid = consis_pipeline.frac_valid,
        frac_valid_post_opt = consis_pipeline.frac_valid_post_opt,
        frac_consistent = consis_pipeline.frac_consistent,
        frac_unique = consis_pipeline.frac_unique,
        frac_unique_post_opt = consis_pipeline.frac_unique_post_opt,
        avg_graph_diversity = consis_pipeline.avg_graph_diversity,
        sims_surf_consistent = consis_pipeline.sims_surf_consistent,
        sims_esp_consistent = consis_pipeline.sims_esp_consistent,
        sims_pharm_consistent = consis_pipeline.sims_pharm_consistent,
        sims_surf_upper_bound_75 = consis_pipeline.sims_surf_upper_bound_75,
        sims_esp_upper_bound_75 = consis_pipeline.sims_esp_upper_bound_75,
        sims_surf_upper_bound_400 = consis_pipeline.sims_surf_upper_bound_400,
        sims_esp_upper_bound_400 = consis_pipeline.sims_esp_upper_bound_400,
        sims_surf_consistent_relax = consis_pipeline.sims_surf_consistent_relax,
        sims_esp_consistent_relax = consis_pipeline.sims_esp_consistent_relax,
        sims_pharm_consistent_relax = consis_pipeline.sims_pharm_consistent_relax,
        graph_similarity_matrix = consis_pipeline.graph_similarity_matrix
    )
    print(f'Finished {data_name} benchmark!\nSaved to {benchmark_output_path}')


def run_resample(save_dir: str,
                 data_name: str,
                 solvent: Optional[str] = None,
                 num_processes: int = 1,
                 probe_radius: float = 0.6,
                 ) -> None:
    """
    Generate a benchmarking set of molecules -- both MMFF and its corresponding xtb relaxed
    structures.

    Arguments
    ---------
    training_set_dir : str
    relevant_sets : List[int]
    save_file : str -- must end in .pkl in an existing directory
    num_molecs : int (default = 1000)
    """
    save_dir_path = Path(save_dir)
    if not save_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {save_dir_path}')

    # Load files
    atom_pos_path = save_dir_path / data_name / 'atom_pos.pkl'
    with open(str(atom_pos_path), 'rb') as f:
        ls_atoms_pos = pickle.load(f)

    surf_points_path = save_dir_path / data_name / 'surfpos.pkl'
    with open(str(surf_points_path), 'rb') as f:
        ls_surf_points = pickle.load(f)

    surf_esp_path = save_dir_path / data_name / 'surfesp.pkl'
    with open(str(surf_esp_path), 'rb') as f:
        ls_surf_esp = pickle.load(f)

    pharm_feats_path = save_dir_path / data_name / 'pharmfeats.pkl'
    with open(str(pharm_feats_path), 'rb') as f:
        ls_pharm_feats = pickle.load(f)


    sims_surf_upper_bound = np.empty(len(ls_atoms_pos))
    sims_esp_upper_bound = np.empty(len(ls_atoms_pos))
    
    pbar = tqdm(enumerate(ls_atoms_pos), desc='Resampling',
                total=len(ls_atoms_pos))
    for i, gen_mol in pbar:
        atoms, positions = gen_mol
        surf_points = ls_surf_points[i]
        surf_esp = ls_surf_esp[i]
        pharm_feats = ls_pharm_feats[i]

        mol = None
        partial_charges = None
        charge = 0
        xyz_block = None

        try:
            # 1. Converts coords + atom_ids -> xyz block
            # 2. Get mol from xyz block
            mol, charge, xyz_block = get_mol_from_atom_pos(atoms=atoms, positions=positions)

            is_valid = mol is not None

            # 3. Get xtb energy and charges of initial conformation
            _, partial_charges = single_point_xtb_from_xyz(
                xyz_block=xyz_block,
                solvent=solvent,
                charge=charge,
                num_cores=num_processes,
                temp_dir=TMPDIR
            )
            partial_charges = np.array(partial_charges)

            if is_valid:
                ref_molec = Molecule(
                    mol=mol,
                    partial_charges=partial_charges,
                    num_surf_points=len(surf_points),
                    probe_radius=probe_radius,
                    surface_points=surf_points,
                    electrostatics=surf_esp,
                    pharm_multi_vector=False,
                    pharm_types=pharm_feats[0],
                    pharm_ancs=pharm_feats[1],
                    pharm_vecs=pharm_feats[2]
                )
                surf_scores, esp_scores = resample_surf_scores(
                    ref_molec=ref_molec,
                    num_samples=5,
                    eval_surf=True,
                    eval_esp=True,
                    lam=0.3
                )
                if surf_scores is not None:
                    sims_surf_upper_bound[i] = max(surf_scores)
                else:
                    sims_surf_upper_bound[i] = np.nan
                if esp_scores is not None:
                    sims_esp_upper_bound[i] = max(esp_scores)
                else:
                    sims_esp_upper_bound[i] = np.nan
        except Exception as e:
            print(f'Failed on #{i}')
            sims_surf_upper_bound[i] = np.nan
            sims_esp_upper_bound[i] = np.nan
    
    upperbound_path = save_dir_path / data_name / f'upper_bounds.npz'
    np.savez(file=upperbound_path,
             sims_surf_upper_bound = sims_surf_upper_bound,
             sims_esp_upper_bound = sims_esp_upper_bound)

    print(f'Finished {data_name} benchmark!\nSaved to {upperbound_path}')


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--solvent', type=str, default='',
                        help='Solvent for xtb optimization.')
    parser.add_argument('--num-processes', type=int, default=1,
                        help='Number of processes to use for xtb optimization.')
    parser.add_argument('--training-set-dir', type=str, default=None,
                        help='Path to directory containing pickled training data.')
    parser.add_argument('--save_dir', type=str, required=True,
                        help='Path to directory to save files to.')
    parser.add_argument('--relevant-sets', type=str, required=True,
                        help='Which training sets to sample from. E.g., "1,2,3"')
    parser.add_argument('--data-name', required=True, type=str, help='Name of dataset (i.e., gdb, moses_aq)')
    parser.add_argument('--resample', type=int, default=0, help='Whether to only run resampling script.')

    args = parser.parse_args()
    print(args)

    solvent = args.solvent
    if solvent == '':
        solvent = None

    num_processes = args.num_processes
    training_set_dir = args.training_set_dir
    save_dir = str(args.save_dir)
    data_name = args.data_name
    relevant_sets = str(args.relevant_sets)
    relevant_sets = [int(s) for s in relevant_sets.split(',')]
    num_molecs = 1000
    # settings for ShEPhERD generated molecules
    num_surf_points = 75
    probe_radius = 0.6

    if int(args.resample) != 0:
        for data_name in ('gdb', 'moses_aq'):
            if data_name == 'gdb':
                solvent = None
            else:
                solvent = 'water'
            print(f'Running resample on {data_name}...')
            run_resample(save_dir=save_dir,
                        data_name=data_name,
                        solvent=solvent,
                        num_processes=num_processes,
                        probe_radius=probe_radius)
    else:
        if training_set_dir is not None:
            generate_benchmark_set(
                training_set_dir=training_set_dir,
                save_dir=save_dir,
                data_name=data_name,
                relevant_sets=relevant_sets,
                num_molecs=num_molecs,
                num_surf_points=num_surf_points, # settings for ShEPhERD generated molecules
                probe_radius=probe_radius  # settings for ShEPhERD generated molecules
            )
        else:
            if data_name in ('gdb', 'moses_aq'):
                if data_name == 'gdb':
                    solvent = None
                else:
                    solvent = 'water'
                print(f'Running {data_name}...')
                run_consistency_benchmark(
                    save_dir=save_dir,
                    data_name=data_name,
                    solvent=solvent,
                    num_processes=num_processes,
                    probe_radius=probe_radius
                )
