"""
Script to run conditional evaluation of generated molecules.
"""

import os
from typing import Optional
from pathlib import Path
import pickle
import argparse
import open3d

import numpy as np
from rdkit import Chem

from shepherd_score.score.constants import COULOMB_SCALING
from shepherd_score.container import Molecule, MoleculePair
from shepherd_score.generate_point_cloud import get_atomic_vdw_radii, get_molecular_surface
from shepherd_score.conformer_generation import charges_from_single_point_conformer_with_xtb

from evaluate import ConditionalEvalPipeline

if 'TMPDIR' in os.environ:
    TMPDIR = os.environ['TMPDIR']
else:
    TMPDIR = './'


def run_conditional_eval(sample_id,
                         job_id,
                         num_tasks,
                         load_dir,
                         solvent: Optional[str] = None,
                         num_processes: int = 1,
                         probe_radius: float = 0.6
                         ) -> None:
    """
    Run conditional evaluation and save. Split it jobs by the samples file and the job id.
    """
    load_dir_path = Path(load_dir)
    if not load_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {load_dir_path}')
    save_file_dir = load_dir_path / f'cond_eval_{sample_id}'
    save_file_dir.mkdir(parents=True, exist_ok=True)
    
    load_file_path = load_dir_path / f'samples_{sample_id}.pickle'
    with open(load_file_path, 'rb') as f:
        samples = pickle.load(f)
        
    ref_mol = Chem.MolFromMolBlock(samples[0], removeHs=False)
    ref_partial_charges = samples[1]
    surface_points = samples[2]
    electrostatics = samples[3]
    pharm_types = samples[4]
    pharm_ancs = samples[5]
    pharm_vecs = samples[6]
    ref_molec = Molecule(ref_mol,
                         probe_radius=probe_radius,
                         partial_charges=np.array(ref_partial_charges),
                         num_surf_points=400,
                         pharm_multi_vector=False,
                         pharm_types=pharm_types,
                         pharm_ancs=pharm_ancs,
                         pharm_vecs=pharm_vecs)

    generated_mols = [(samples[-1][i]['x1']['atoms'], samples[-1][i]['x1']['positions']) for i in range(len(samples[-1]))]

    subselected_gen_mols = generated_mols[job_id:len(generated_mols):num_tasks]
    
    print(f'Starting Conditional Eval Pipeline on sample {sample_id} and job {job_id}')
    cond_pipe = ConditionalEvalPipeline(
        ref_molec=ref_molec,
        generated_mols=subselected_gen_mols,
        condition='all',
        num_surf_points=400,
        pharm_multi_vector=False,
        solvent=solvent
    )

    # Run evaluation
    cond_pipe.evaluate(num_processes=num_processes, verbose=True)

    save_file = save_file_dir / f'cond_eval_{job_id}'
    np.savez(
        save_file,
        ref_molblock = cond_pipe.ref_molblock,
        ref_mol_SA_score = cond_pipe.ref_mol_SA_score,
        ref_mol_QED = cond_pipe.ref_mol_QED,
        ref_mol_logP = cond_pipe.ref_mol_logP,
        ref_mol_fsp3 = cond_pipe.ref_mol_fsp3,
        ref_mol_morgan_fp = cond_pipe.ref_mol_morgan_fp,
        ref_surf_resampling_scores = cond_pipe.ref_surf_resampling_scores,
        ref_surf_esp_resampling_scores = cond_pipe.ref_surf_esp_resampling_scores,
        sims_surf_upper_bound = cond_pipe.sims_surf_upper_bound,
        sims_esp_upper_bound = cond_pipe.sims_esp_upper_bound,
        molblocks = cond_pipe.molblocks,
        molblocks_post_opt = cond_pipe.molblocks_post_opt,
        num_valid = cond_pipe.num_valid,
        num_valid_post_opt = cond_pipe.num_valid_post_opt,
        num_consistent_graph = cond_pipe.num_consistent_graph,
        strain_energies = cond_pipe.strain_energies,
        rmsds = cond_pipe.rmsds,
        SA_scores = cond_pipe.SA_scores,
        logPs = cond_pipe.logPs,
        QEDs = cond_pipe.QEDs,
        fsp3s = cond_pipe.fsp3s,
        frac_valid = cond_pipe.frac_valid,
        frac_valid_post_opt = cond_pipe.frac_valid_post_opt,
        frac_consistent = cond_pipe.frac_consistent,
        frac_unique = cond_pipe.frac_unique,
        frac_unique_post_opt = cond_pipe.frac_unique_post_opt,
        avg_graph_diversity = cond_pipe.avg_graph_diversity,
        sims_surf_target = cond_pipe.sims_surf_target,
        sims_esp_target = cond_pipe.sims_esp_target,
        sims_pharm_target = cond_pipe.sims_pharm_target,
        sims_surf_target_relax = cond_pipe.sims_surf_target_relax,
        sims_esp_target_relax = cond_pipe.sims_esp_target_relax,
        sims_pharm_target_relax = cond_pipe.sims_pharm_target_relax,
        graph_similarities = cond_pipe.graph_similarities,
        sims_surf_target_relax_esp_aligned = cond_pipe.sims_surf_target_relax_esp_aligned,
        sims_esp_target_relax_esp_aligned = cond_pipe.sims_esp_target_relax_esp_aligned,
        sims_pharm_target_relax_esp_aligned = cond_pipe.sims_pharm_target_relax_esp_aligned,
        molblocks_post_opt_esp_aligned = cond_pipe.molblocks_post_opt_esp_aligned
    )
    print(f'Finished {job_id} evaluation!\nSaved to {save_file}')


def run_conditional_eval_by_sample_only(
        sample_id,
        load_dir,
        conditioning_type: str,
        solvent: Optional[str] = None,
        num_processes: int = 1,
        probe_radius: float = 0.6
        ) -> None:
    """
    Run conditional evaluation and save. Do not split up each samples file by job id.
    """
    load_dir_path = Path(load_dir) / conditioning_type
    if not load_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {load_dir_path}')

    if conditioning_type == 'x2':
        condition = 'surf'
    elif conditioning_type == 'x3':
        condition = 'esp'
    elif conditioning_type == 'x4':
        condition = 'pharm'
    
    save_file_dir = load_dir_path / f'cond_evals'
    save_file_dir.mkdir(parents=True, exist_ok=True)
    
    load_file_path = load_dir_path / f'samples_{sample_id}.pickle'
    with open(load_file_path, 'rb') as f:
        samples = pickle.load(f)
        
    ref_mol = Chem.MolFromMolBlock(samples[0], removeHs=False)
    ref_partial_charges = samples[1]
    surface_points = samples[2]
    electrostatics = samples[3]
    pharm_types = samples[4]
    pharm_ancs = samples[5]
    pharm_vecs = samples[6]
    ref_molec = Molecule(ref_mol,
                         probe_radius=probe_radius,
                         partial_charges=np.array(ref_partial_charges),
                         num_surf_points=400,
                         pharm_multi_vector=False,
                         pharm_types=pharm_types,
                         pharm_ancs=pharm_ancs,
                         pharm_vecs=pharm_vecs)

    generated_mols = [(samples[-1][i]['x1']['atoms'], samples[-1][i]['x1']['positions']) for i in range(len(samples[-1]))]
    
    print(f'Starting Conditional Eval Pipeline on sample {sample_id}.')
    cond_pipe = ConditionalEvalPipeline(
        ref_molec=ref_molec,
        generated_mols=generated_mols,
        condition=condition,
        num_surf_points=400,
        pharm_multi_vector=False,
        solvent=solvent
    )

    # Run evaluation
    cond_pipe.evaluate(num_processes=num_processes, verbose=True)

    save_file = save_file_dir / f'cond_eval_{sample_id}'
    np.savez(
        save_file,
        ref_molblock = cond_pipe.ref_molblock,
        ref_mol_SA_score = cond_pipe.ref_mol_SA_score,
        ref_mol_QED = cond_pipe.ref_mol_QED,
        ref_mol_logP = cond_pipe.ref_mol_logP,
        ref_mol_fsp3 = cond_pipe.ref_mol_fsp3,
        ref_mol_morgan_fp = cond_pipe.ref_mol_morgan_fp,
        ref_surf_resampling_scores = cond_pipe.ref_surf_resampling_scores,
        ref_surf_esp_resampling_scores = cond_pipe.ref_surf_esp_resampling_scores,
        sims_surf_upper_bound = cond_pipe.sims_surf_upper_bound,
        sims_esp_upper_bound = cond_pipe.sims_esp_upper_bound,
        molblocks = cond_pipe.molblocks,
        molblocks_post_opt = cond_pipe.molblocks_post_opt,
        num_valid = cond_pipe.num_valid,
        num_valid_post_opt = cond_pipe.num_valid_post_opt,
        num_consistent_graph = cond_pipe.num_consistent_graph,
        strain_energies = cond_pipe.strain_energies,
        rmsds = cond_pipe.rmsds,
        SA_scores = cond_pipe.SA_scores,
        logPs = cond_pipe.logPs,
        QEDs = cond_pipe.QEDs,
        fsp3s = cond_pipe.fsp3s,
        frac_valid = cond_pipe.frac_valid,
        frac_valid_post_opt = cond_pipe.frac_valid_post_opt,
        frac_consistent = cond_pipe.frac_consistent,
        frac_unique = cond_pipe.frac_unique,
        frac_unique_post_opt = cond_pipe.frac_unique_post_opt,
        avg_graph_diversity = cond_pipe.avg_graph_diversity,
        sims_surf_target = cond_pipe.sims_surf_target,
        sims_esp_target = cond_pipe.sims_esp_target,
        sims_pharm_target = cond_pipe.sims_pharm_target,
        sims_surf_target_relax = cond_pipe.sims_surf_target_relax,
        sims_esp_target_relax = cond_pipe.sims_esp_target_relax,
        sims_pharm_target_relax = cond_pipe.sims_pharm_target_relax,
        graph_similarities = cond_pipe.graph_similarities,
        sims_surf_target_relax_esp_aligned = cond_pipe.sims_surf_target_relax_esp_aligned,
        sims_esp_target_relax_esp_aligned = cond_pipe.sims_esp_target_relax_esp_aligned,
        sims_pharm_target_relax_esp_aligned = cond_pipe.sims_pharm_target_relax_esp_aligned,
        molblocks_post_opt_esp_aligned = cond_pipe.molblocks_post_opt_esp_aligned

    )
    print(f'Finished sample {sample_id} evaluation!\nSaved to {save_file}')


def get_electrostatic_potential(partial_charges, centers, surf_pos) -> np.ndarray:
    """
    Get the electrostatic potential at each surface point.
    """
    distances = np.linalg.norm(surf_pos[:, np.newaxis] - centers, axis=2)
    # Calculate the potentials
    E_pot = np.dot(partial_charges, 1 / distances.T) * COULOMB_SCALING
    # Ensure that invalid distances (where distance is 0) are handled
    E_pot[np.isinf(E_pot)] = 0    
    return E_pot.astype(np.float32)


def run_conditional_eval_frag(job_id,
                              num_tasks,
                              load_dir,
                              solvent: Optional[str] = None,
                              num_processes: int = 1,
                              probe_radius: float = 0.6
                              ) -> None:
    """
    Run conditional evaluation and save. Split it jobs by the samples file and the job id.
    """
    print('Starting Frag')
    load_dir_path = Path(load_dir)
    if not load_dir_path.is_dir():
        raise ValueError(f'Provided path is not a directory: {load_dir_path}')
    save_file_dir = load_dir_path / 'cond_evals'
    save_file_dir.mkdir(parents=True, exist_ok=True)
    
    load_file_path = load_dir_path / f'samples.pickle'
    with open(load_file_path, 'rb') as f:
        samples = pickle.load(f)

    # Create reference mol from all fragments
    mols = [Chem.MolFromMolBlock(s, removeHs=False) for s in samples[0]]
    centers = []
    radii = []
    for mol in mols:
        centers.append(mol.GetConformer().GetPositions())
        radii.append(get_atomic_vdw_radii(mol))
    centers_comb = np.concatenate(centers)
    radii_comb = np.concatenate(radii)
    surface_points = get_molecular_surface(centers_comb, radii_comb, num_points=400, probe_radius=probe_radius)

    partial_charges = []
    esps = []
    for i, mol in enumerate(mols):
        charges = charges_from_single_point_conformer_with_xtb(mol, solvent='water', num_cores=1, temp_dir=TMPDIR)
        partial_charges.append(charges)
        esps.append(get_electrostatic_potential(charges, centers[i], surface_points))
    avg_esp = np.stack(esps).mean(axis=0)

    print('Finished generating the merged reference Molecule object.')
    
    # just choose the first molblock, it doesn't affect anything once you make the Molecule object
    ref_mol = mols[0]
    ref_partial_charges = partial_charges[0]
    # surface_points = samples[2] ignore since len=75
    # electrostatics = samples[3] ignore since len=75
    pharm_types = samples[3]
    pharm_ancs = samples[4]
    pharm_vecs = samples[5]
    ref_molec = Molecule(ref_mol,
                         probe_radius=probe_radius,
                         partial_charges=np.array(ref_partial_charges),
                        #  num_surf_points=400,
                         surface_points=surface_points, # We generate the surface
                         electrostatics=avg_esp, # We generate the esp
                         pharm_multi_vector=False,
                         pharm_types=pharm_types,
                         pharm_ancs=pharm_ancs,
                         pharm_vecs=pharm_vecs)

    generated_mols = [(samples[-1][i]['x1']['atoms'], samples[-1][i]['x1']['positions']) for i in range(len(samples[-1]))]

    subselected_gen_mols = generated_mols[job_id:len(generated_mols):num_tasks]
    
    print(f'Starting Conditional Eval Pipeline on Fragment mergining and job {job_id}')
    cond_pipe = ConditionalEvalPipeline(
        ref_molec=ref_molec,
        generated_mols=subselected_gen_mols,
        condition='all',
        num_surf_points=400,
        pharm_multi_vector=False,
        solvent=solvent
    )

    # Run evaluation
    cond_pipe.evaluate(num_processes=num_processes, verbose=True)

    save_file = save_file_dir / f'cond_eval_{job_id}'
    np.savez(
        save_file,
        ref_molblock = cond_pipe.ref_molblock,
        ref_mol_SA_score = cond_pipe.ref_mol_SA_score,
        ref_mol_QED = cond_pipe.ref_mol_QED,
        ref_mol_logP = cond_pipe.ref_mol_logP,
        ref_mol_fsp3 = cond_pipe.ref_mol_fsp3,
        ref_mol_morgan_fp = cond_pipe.ref_mol_morgan_fp,
        ref_surf_resampling_scores = cond_pipe.ref_surf_resampling_scores,
        ref_surf_esp_resampling_scores = cond_pipe.ref_surf_esp_resampling_scores,
        sims_surf_upper_bound = cond_pipe.sims_surf_upper_bound,
        sims_esp_upper_bound = cond_pipe.sims_esp_upper_bound,
        molblocks = cond_pipe.molblocks,
        molblocks_post_opt = cond_pipe.molblocks_post_opt,
        num_valid = cond_pipe.num_valid,
        num_valid_post_opt = cond_pipe.num_valid_post_opt,
        num_consistent_graph = cond_pipe.num_consistent_graph,
        strain_energies = cond_pipe.strain_energies,
        rmsds = cond_pipe.rmsds,
        SA_scores = cond_pipe.SA_scores,
        logPs = cond_pipe.logPs,
        QEDs = cond_pipe.QEDs,
        fsp3s = cond_pipe.fsp3s,
        frac_valid = cond_pipe.frac_valid,
        frac_valid_post_opt = cond_pipe.frac_valid_post_opt,
        frac_consistent = cond_pipe.frac_consistent,
        frac_unique = cond_pipe.frac_unique,
        frac_unique_post_opt = cond_pipe.frac_unique_post_opt,
        avg_graph_diversity = cond_pipe.avg_graph_diversity,
        sims_surf_target = cond_pipe.sims_surf_target,
        sims_esp_target = cond_pipe.sims_esp_target,
        sims_pharm_target = cond_pipe.sims_pharm_target,
        sims_surf_target_relax = cond_pipe.sims_surf_target_relax,
        sims_esp_target_relax = cond_pipe.sims_esp_target_relax,
        sims_pharm_target_relax = cond_pipe.sims_pharm_target_relax,
        graph_similarities = cond_pipe.graph_similarities,
        sims_surf_target_relax_esp_aligned = cond_pipe.sims_surf_target_relax_esp_aligned,
        sims_esp_target_relax_esp_aligned = cond_pipe.sims_esp_target_relax_esp_aligned,
        sims_pharm_target_relax_esp_aligned = cond_pipe.sims_pharm_target_relax_esp_aligned,
        molblocks_post_opt_esp_aligned = cond_pipe.molblocks_post_opt_esp_aligned

    )
    print(f'Finished {job_id} evaluation!\nSaved to {save_file}')


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--load-dir-path', type=str, help='Path to directory to load files that are structured as `samples_\{i\}.pickle`.')
    parser.add_argument('--task-id', type=int, help='Task ID.')
    parser.add_argument('--num-tasks', type=int, help='Number of tasks.')
    parser.add_argument('--sample-id', type=int, required=True, help='Index used to load the file. [0,4]')
    args = parser.parse_args()
    print(args)

    load_dir = Path(args.load_dir_path)
    if not load_dir.is_dir():
        raise ValueError('Provided --load-dir-path is not a directory.')
    my_task_id = int(args.task_id)
    num_tasks = int(args.num_tasks)
    sample_id = int(args.sample_id)

    print(f'Loading from {load_dir}')
    if 'NP_analogues' in str(load_dir):
        run_conditional_eval(
            sample_id=sample_id,
            job_id=my_task_id,
            num_tasks=num_tasks,
            load_dir=load_dir,
            solvent='water',
            num_processes=4,
            probe_radius=0.6
        )
    
    if 'fragment_merging' in str(load_dir):
        run_conditional_eval_frag(
            job_id=my_task_id,
            num_tasks=num_tasks,
            load_dir=load_dir,
            solvent='water',
            num_processes=4,
            probe_radius=0.6
        )

    # total evals: 60-120
    if 'GDB_conditional' in str(load_dir):
        samples_numbers = np.arange(100)
        samples_to_eval = samples_numbers[my_task_id:len(samples_numbers):num_tasks]
        print(f'Running GDB_conditional for these samples:\n{samples_to_eval}')
        # Go through every assigned sample id (1-2)
        for sample_idx in samples_to_eval:
            # Go through every representation (20*3)
            for conditioning in ('x2', 'x3', 'x4'):
                # Use task ID as sample id
                print(f'Running sample {sample_idx} for {conditioning} condition.')
                run_conditional_eval_by_sample_only(sample_id=sample_idx,
                                                    load_dir=load_dir,
                                                    conditioning_type=conditioning,
                                                    solvent=None,
                                                    num_processes=4,
                                                    probe_radius=0.6)
