"""
Get the trend for number of atoms : pharmacophores of dataset
"""
from typing import List, Optional, Tuple
from pathlib import Path
from tqdm import tqdm
from copy import deepcopy
import pickle
import argparse

import numpy as np
from rdkit import Chem

from shepherd_score.pharm_utils.pharmacophore import get_pharmacophores


def get_atom_count_and_pharm_count(mol: Chem.Mol) -> Tuple[int, int]:
    """ Gets number of atoms and pharmacophores of Mol object. """
    try:
        atom_count = Chem.AddHs(mol).GetNumAtoms()

        pharm_types, _, _ = get_pharmacophores(mol, multi_vector=False, check_access=False)

        pharm_count = len(pharm_types)
    except Exception:
        return None, None

    return atom_count, pharm_count

if __name__=='__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data-file', type=str, help='Path to training data file pickle file with molbock_charges')
    parser.add_argument('--save-dir', type=str, help='Directory for where to save data')
    args = parser.parse_args()

    data_path = Path(args.data_file)
    save_dir = Path(args.save_dir)
    for data_name in ('gdb', 'moses_aq'):
        assert (save_dir / data_name).is_dir()
        print(f'Loading {data_path}')
        with open(data_path, 'rb') as f:
            molblocks_charges = pickle.load(f)

        num_occurences = np.zeros((200,200))

        for molblock_charges in tqdm(molblocks_charges,
                                     total=len(molblocks_charges),
                                     miniters=int(len(molblocks_charges)/10),
                                     maxinterval=20000,
                                     desc=data_name):
            mol = Chem.MolFromMolBlock(molblock_charges[0])

            atom_count, pharm_count = get_atom_count_and_pharm_count(mol)
            if atom_count is None or pharm_count is None:
                continue
            num_occurences[atom_count, pharm_count] += 1
        
        save_path = save_dir / data_name / 'atom_pharm_counts.npz'
        np.save(save_path, num_occurences)
        print(f'Saved to {save_path}')

