"""
Evaluation using UnconditionalEvalPipeline / ConsistencyEvalPipeline for training set structures.
"""

import os
from typing import List, Optional, Tuple, Union
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
import pickle
import argparse

import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import rdkit.Chem.rdDetermineBonds

from shepherd_score.score.constants import ALPHA

from shepherd_score.conformer_generation import embed_conformer_from_smiles, single_point_xtb_from_xyz
from shepherd_score.container import Molecule

from evaluate import ConsistencyEvalPipeline, resample_surf_scores, get_mol_from_atom_pos


if 'TMPDIR' in os.environ:
    TMPDIR = Path(os.environ['TMPDIR'])
else:
    TMPDIR = Path('./')


def run_consistency_evaluation(load_path: str,
                               condition: str,
                               solvent: Optional[str] = None,
                               training_molblock_charges=Union[List[Tuple], None],
                               num_processes: int = 4,
                               probe_radius: float = 0.6,
                               ) -> None:
    """
    Generate a benchmarking set of molecules -- both MMFF and its corresponding xtb relaxed
    structures.

    Arguments
    ---------
    load_dir : path to dir containing "samples.pickle"
    condition : str (x1x2, x1x3, x1x4)
    """
    load_file_path = Path(load_path)
    if not load_file_path.is_file():
        raise ValueError(f'Provided path is not a file: {load_file_path}')

    save_file_dir = load_file_path.parent / f'consis_evals'
    save_file_dir.mkdir(parents=True, exist_ok=True)

    if condition in ('x1x2', 'x1x3', 'x1x4'):
        save_file_path = save_file_dir / 'consis_eval'
    else:
        save_file_path = save_file_dir / 'consis_eval'
    print(f'Saving to {save_file_path}')
    
    with open(load_file_path, 'rb') as f:
        samples = pickle.load(f)
    print(f'Samples loaded from {load_file_path}')
    

    ls_atoms_pos = []
    if condition == 'x1x2' or condition == 'x1x3' or condition == 'x1x3x4':
        ls_surf_points = []
    if condition == 'x1x3' or condition == 'x1x3x4':
        ls_surf_esp = []
    if condition == 'x1x4' or condition == 'x1x3x4':
        ls_pharm_feats = []

    for i in range(len(samples)):
        ls_atoms_pos.append(
            (samples[i]['x1']['atoms'], samples[i]['x1']['positions'])
        )
        if condition == 'x1x2':
            ls_surf_points.append(
                samples[i]['x2']['positions']
            )
        if condition == 'x1x3' or condition == 'x1x3x4':
            ls_surf_points.append(
                samples[i]['x3']['positions']
            )
            ls_surf_esp.append(
                samples[i]['x3']['charges']
            )
        if condition == 'x1x4' or condition == 'x1x3x4':
            ls_pharm_feats.append(
                (samples[i]['x4']['types'], samples[i]['x4']['positions'], samples[i]['x4']['directions'])
            )
    
    if condition == 'x1x3x4':
        assert isinstance(ls_surf_points, list) and isinstance(ls_surf_esp, list) and isinstance(ls_pharm_feats, list)
    elif condition == 'x1x2':
        assert isinstance(ls_surf_points, list)
        ls_surf_esp = None
        ls_pharm_feats = None
    elif condition == 'x1x3':
        assert isinstance(ls_surf_points, list) and isinstance(ls_surf_esp, list)
        ls_pharm_feats = None
    elif condition == 'x1x4':
        assert isinstance(ls_pharm_feats, list)
        ls_surf_points = None
        ls_surf_esp = None

    print(f'Initializing ConsistencyEvalPipeline for {condition}.')

    # Initialize evaluation
    consis_pipeline = ConsistencyEvalPipeline(
        generated_mols=ls_atoms_pos,
        generated_surf_points=ls_surf_points,
        generated_surf_esp=ls_surf_esp,
        generated_pharm_feats=ls_pharm_feats,
        probe_radius=probe_radius,
        pharm_multi_vector=False,
        solvent=solvent,
        random_molblock_charges=training_molblock_charges
    )

    # Run evaluation
    print('Running evaluation...')
    consis_pipeline.evaluate(num_processes=num_processes, verbose=True)
    print('Finished evaluation.')

    # Save values to numpy object
    np.savez(
        file=str(save_file_path),
        molblocks=consis_pipeline.molblocks,
        molblocks_post_opt = consis_pipeline.molblocks_post_opt,
        num_valid = consis_pipeline.num_valid,
        num_valid_post_opt = consis_pipeline.num_valid_post_opt,
        num_consistent_graph = consis_pipeline.num_consistent_graph,
        strain_energies = consis_pipeline.strain_energies,
        rmsds = consis_pipeline.rmsds,
        SA_scores = consis_pipeline.SA_scores,
        logPs = consis_pipeline.logPs,
        QEDs = consis_pipeline.QEDs,
        fsp3s = consis_pipeline.fsp3s,
        frac_valid = consis_pipeline.frac_valid,
        frac_valid_post_opt = consis_pipeline.frac_valid_post_opt,
        frac_consistent = consis_pipeline.frac_consistent,
        frac_unique = consis_pipeline.frac_unique,
        frac_unique_post_opt = consis_pipeline.frac_unique_post_opt,
        avg_graph_diversity = consis_pipeline.avg_graph_diversity,
        sims_surf_consistent = consis_pipeline.sims_surf_consistent,
        sims_esp_consistent = consis_pipeline.sims_esp_consistent,
        sims_pharm_consistent = consis_pipeline.sims_pharm_consistent,
        sims_surf_upper_bound_75 = consis_pipeline.sims_surf_upper_bound_75,
        sims_esp_upper_bound_75 = consis_pipeline.sims_esp_upper_bound_75,
        sims_surf_upper_bound_400 = consis_pipeline.sims_surf_upper_bound_400,
        sims_esp_upper_bound_400 = consis_pipeline.sims_esp_upper_bound_400,
        sims_surf_lower_bound = consis_pipeline.sims_surf_lower_bound,
        sims_esp_lower_bound = consis_pipeline.sims_esp_lower_bound,
        sims_pharm_lower_bound = consis_pipeline.sims_pharm_lower_bound,
        sims_surf_consistent_relax = consis_pipeline.sims_surf_consistent_relax,
        sims_esp_consistent_relax = consis_pipeline.sims_esp_consistent_relax,
        sims_pharm_consistent_relax = consis_pipeline.sims_pharm_consistent_relax,
        sims_surf_consistent_relax_align = consis_pipeline.sims_surf_consistent_relax_aligned,
        sims_esp_consistent_relax_align = consis_pipeline.sims_esp_consistent_relax_aligned,
        sims_pharm_consistent_relax_align = consis_pipeline.sims_pharm_consistent_relax_aligned,
        graph_similarity_matrix = consis_pipeline.graph_similarity_matrix
    )
    print(f'Finished benchmark!\nSaved to {save_file_path}')


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--load-file', type=str, required=True,
                        help='Path to samples file.')
    parser.add_argument('--training-data', type=str, required=True,
                        help="Path to GDB or MOSES-aq training data file to randomly sample and compare scoring functions")
    parser.add_argument('--task-id', type=str, help='Task ID.')
    args = parser.parse_args()
    print(args)

    load_file = str(args.load_file)
    my_task_id = int(args.task_id) if args.task_id != '' else None
    training_data_file = Path(str(args.training_data))
    assert training_data_file.is_file()

    # settings for ShEPhERD generated molecules
    num_surf_points = 75 # number used to sample
    probe_radius = 0.6
    if 'moses' in load_file:
        solvent = 'water'
        model = 'x1x3x4'
        num_processes = 16
        print('Doing Moses - aq')
    else:
        solvent = None # evaluating on GDB unconditional
        print('Doing GDB-17')

        my_task_id = int(args.task_id)

        model_types = ('x1x2', 'x1x3', 'x1x4')
        model = model_types[my_task_id]
        num_processes = 16

    if 'moses' in load_file:
        training_molblock_charges = None
    else:
        #load training data for random alignments / evaluations
        with open(training_data_file, 'rb') as f:
            training_molblock_charges = pickle.load(f)
    
    run_consistency_evaluation(
        load_path=load_file,
        condition=model,
        solvent=solvent,
        training_molblock_charges=training_molblock_charges,
        num_processes=num_processes,
        probe_radius=probe_radius
    )
