"""
Run benchmark for randomly selecting 20 random molecules from the training set (xtb), aligning,
and scoring them with surf, esp, and pharm for each target molecule used for conditioning during
inference.
"""

from typing import Union, List, Tuple, Optional
from pathlib import Path
from tqdm import tqdm
import argparse
import pickle

import numpy as np
from rdkit import Chem
import pandas as pd

from shepherd_score.score.constants import ALPHA

from shepherd_score.alignment_utils.se3_np import apply_SE3_transform_np, apply_SO3_transform_np
from shepherd_score.container import Molecule, MoleculePair, update_mol_coordinates
from shepherd_score.score.gaussian_volume_overlap_np import get_ROCS_np
from shepherd_score.score.electrostatic_scoring_np import get_ROCS_esp_np
from shepherd_score.score.pharmacophore_scoring_np import get_pharm_score_np

RNG = np.random.default_rng()

def run_sim_scoring_benchmark(ls_rand_molblock_charges: List[Tuple],
                              target_molec: Molecule,
                              ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """  """
    sim_surf_scores = np.empty(len(ls_rand_molblock_charges))
    sim_esp_scores = np.empty(len(ls_rand_molblock_charges))
    sim_pharm_scores = np.empty(len(ls_rand_molblock_charges))
    for i, rand_molblock_charges in enumerate(ls_rand_molblock_charges):
        rand_mol = Chem.MolFromMolBlock(rand_molblock_charges[0], removeHs=False)
        rand_mol_partial_charges = np.array(rand_molblock_charges[1])

        rand_molec = Molecule(mol=rand_mol,
                              num_surf_points=target_molec.num_surf_points,
                              probe_radius=target_molec.probe_radius,
                              partial_charges=rand_mol_partial_charges,
                              pharm_multi_vector=False)
        
        mp = MoleculePair(ref_mol=target_molec,
                          fit_mol=rand_molec,
                          num_surf_points=target_molec.num_surf_points)
        
        # Surface
        mp.align_with_surf(alpha=ALPHA(mp.num_surf_points),
                           num_repeats=50,
                           trans_init=False,
                           use_jax=False,
                           verbose=False)
        sim_surf_scores[i] = mp.sim_aligned_surf if mp.sim_aligned_surf is not None else np.nan

        # ESP
        mp.align_with_esp(alpha=ALPHA(mp.num_surf_points),
                          lam=0.3,
                          num_repeats=50,
                          trans_init=False,
                          use_jax=False)
        sim_esp_scores[i] = mp.sim_aligned_esp if mp.sim_aligned_esp is not None else np.nan

        # Pharm
        mp.align_with_pharm(num_repeats=50,
                            trans_init=False,
                            use_jax=False)
        sim_pharm_scores[i] = mp.sim_aligned_pharm if mp.sim_aligned_pharm is not None else np.nan

    return sim_surf_scores, sim_esp_scores, sim_pharm_scores


def run_sim_scoring_benchmark_np(ls_rand_molblock_charges: List[Tuple],
                                 target_molec: Molecule,
                                 ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """ Only align with esp, then score the others with that alignment """
    print('Only align with esp, then score the others with that alignment...')
    sim_surf_scores = np.empty(len(ls_rand_molblock_charges))
    sim_esp_scores = np.empty(len(ls_rand_molblock_charges))
    sim_pharm_scores = np.empty(len(ls_rand_molblock_charges))
    alpha = ALPHA(target_molec.num_surf_points)
    for i, rand_molblock_charges in tqdm(enumerate(ls_rand_molblock_charges),
                                         total=len(ls_rand_molblock_charges)):
        rand_mol = Chem.MolFromMolBlock(rand_molblock_charges[0], removeHs=False)
        rand_mol_partial_charges = np.array(rand_molblock_charges[1])

        rand_molec = Molecule(mol=rand_mol,
                              num_surf_points=target_molec.num_surf_points,
                              probe_radius=target_molec.probe_radius,
                              partial_charges=rand_mol_partial_charges,
                              pharm_multi_vector=target_molec.pharm_multi_vector)
        
        mp = MoleculePair(ref_mol=target_molec,
                          fit_mol=rand_molec,
                          num_surf_points=target_molec.num_surf_points)

        # ESP
        mp.align_with_esp(alpha=alpha,
                          lam=0.3,
                          num_repeats=50,
                          trans_init=False,
                          use_jax=False)
        sim_esp_scores[i] = mp.sim_aligned_esp if mp.sim_aligned_esp is not None else np.nan

        if mp.transform_esp is not None:
            esp_se3_alignment = mp.transform_esp
        else:
            sim_surf_scores[i] = np.nan
            sim_pharm_scores[i] = np.nan
            continue

        surf_score = get_ROCS_np(target_molec.surf_pos,
                                         apply_SE3_transform_np(rand_molec.surf_pos, esp_se3_alignment),
                                         alpha=alpha)
        pharm_score = get_pharm_score_np(
            target_molec.pharm_types, rand_molec.pharm_types,
            target_molec.pharm_ancs, apply_SE3_transform_np(rand_molec.pharm_ancs, esp_se3_alignment),
            target_molec.pharm_vecs, apply_SO3_transform_np(rand_molec.pharm_vecs, esp_se3_alignment),
            )

        sim_surf_scores[i] = surf_score if surf_score is not None else np.nan
        sim_pharm_scores[i] = pharm_score if pharm_score is not None else np.nan

    return sim_surf_scores, sim_esp_scores, sim_pharm_scores


if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--partition', type=str, required=True,
                        help="gdb or np")
    parser.add_argument('--load-dir', type=str, help='Path to directory holding generated samples as `samples_\{i\}.pickle`')
    parser.add_argument('--training-data', type=str, required=True, help='Path to training data to sample random molecules from.')
    parser.add_argument('--folder', type=str, help='if gdb must include either x2, x3, or x4')
    args = parser.parse_args()
    print(args)

    partition = str(args.partition)
    load_dir_path = Path(str(args.load_dir))
    assert load_dir_path.is_dir()

    if partition == 'gdb':
        print(f'Loading from GDB_conditional')
    elif partition == 'np':
        print('Loading from NP analogues')
    else:
        raise ValueError(f'--partition must be `gdb` or `np`')

    training_data_file = Path(str(args.training_data))
    assert training_data_file.is_file()
    
    #load training data for random alignments / evaluations
    print(f'Loading training molblock_charges from {training_data_file}')
    with open(training_data_file, 'rb') as f:
        training_molblock_charges = pickle.load(f)
    
    # settings for ShEPhERD generated molecules
    num_surf_points = 400 # number used to sample
    probe_radius = 0.6

    if partition == 'gdb':
        all_surf_scores = []
        all_esp_scores = []
        all_pharm_scores = []
        path = load_dir_path
        assert path.is_dir()
        num_files = len(list(path.glob('*.pickle')))
        for i in tqdm(range(num_files), desc=f'{partition}'):
            with open(path / f'samples_{i}.pickle', 'rb') as f:
                samples = pickle.load(f)

            target_molec = Molecule(mol=Chem.MolFromMolBlock(samples[0], removeHs=False),
                                    num_surf_points=num_surf_points,
                                    probe_radius=probe_radius,
                                    partial_charges=samples[1],
                                    pharm_multi_vector=False
                                )
            # Get random molblocks
            rand_inds = RNG.choice(len(training_molblock_charges), 20, replace=False)
            ls_rand_molblock_charges = [training_molblock_charges[i] for i in rand_inds]

            sim_surf_scores, sim_esp_scores, sim_pharm_scores = run_sim_scoring_benchmark(
                ls_rand_molblock_charges,
                target_molec=target_molec
                )
            all_surf_scores.append(sim_surf_scores)
            all_esp_scores.append(sim_esp_scores)
            all_pharm_scores.append(sim_pharm_scores)

        data = []
        for i, (surf, esp, pharm) in enumerate(zip(all_surf_scores, all_esp_scores, all_pharm_scores)):
            for j in range(len(surf)):  # Assuming length of 20
                data.append({
                    'ref_mol_num': i,
                    'Surf_Score': surf[j],
                    'ESP_Score': esp[j],
                    'Pharm_Score': pharm[j]
                })

        # Convert the list of dictionaries to a pandas DataFrame
        df = pd.DataFrame(data)
        save_path = load_dir_path / 'rand_molec_sim_scores.pkl'
        df.to_pickle(save_path)
        print(f'Saved to {save_path}!')

    elif partition == 'np':
        all_surf_scores = []
        all_esp_scores = []
        all_pharm_scores = []
        num_files = len(list(load_dir_path.glob('*.pickle')))
        for i in tqdm(range(num_files), desc=f'{partition}'):
            with open(load_dir_path / f'samples_{i}.pickle', 'rb') as f:
                samples = pickle.load(f)

            target_molec = Molecule(mol=Chem.MolFromMolBlock(samples[0], removeHs=False),
                                    num_surf_points=num_surf_points,
                                    probe_radius=probe_radius,
                                    partial_charges=samples[1],
                                    pharm_multi_vector=False
                                )

            # Get random molblocks
            rand_inds = RNG.choice(len(training_molblock_charges), 2500, replace=False)
            ls_rand_molblock_charges = [training_molblock_charges[i] for i in rand_inds]

            sim_surf_scores, sim_esp_scores, sim_pharm_scores = run_sim_scoring_benchmark_np(
                ls_rand_molblock_charges,
                target_molec=target_molec
                )
            all_surf_scores.append(sim_surf_scores)
            all_esp_scores.append(sim_esp_scores)
            all_pharm_scores.append(sim_pharm_scores)

        data = []
        for i, (surf, esp, pharm) in enumerate(zip(all_surf_scores, all_esp_scores, all_pharm_scores)):
            for j in range(len(surf)):
                data.append({
                    'ref_mol_num': i,
                    'Surf_Score': surf[j],
                    'ESP_Score': esp[j],
                    'Pharm_Score': pharm[j]
                })

        # Convert the list of dictionaries to a pandas DataFrame
        df = pd.DataFrame(data)
        save_path = load_dir_path / 'rand_molec_sim_scores.pkl'
        df.to_pickle(save_path)
        print(f'Saved to {save_path}!')
