import os
import json
import pandas
import pubchempy

def get_map_inchikey_to_canonicalSMILES(inchikey):
    mols = pubchempy.get_compounds(inchikey, 'inchikey')
    return {mol.inchikey: mol.canonical_smiles for mol in mols if mol is not None}




class CreateMoleculeMapping:
    """
    Leave cluster of molecules out based on some similarity measure.

    Here we wnat to investigate the effect of predicting pairs for entirely new molecules which are not similar to 
    the training ones.

    Notes:
    ------
    Mutants are in the test set.
    """
    def __init__(self, mols_csv, mol_id_col, working_dir, inchikey_col = None, sep = ';'):
        if working_dir is None:
            raise ValueError('working_dir must be explicitly specified.')
        
        self.working_dir = working_dir

        self.mols_csv = mols_csv
        self.mol_id_col = mol_id_col

        self.inchikey_col = inchikey_col

        self.sep = sep

    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = None, header = 0)
        return mols
    
    def load_maps_dict(self):
        maps_dict = {}
        if self.inchikey_col is not None:
            if os.path.exists(os.path.join(self.working_dir, 'map_inchikey_to_canonicalSMILES.json')):
                with open(os.path.join(self.working_dir, 'map_inchikey_to_canonicalSMILES.json'), 'r') as jsonfile:
                    maps_dict['map_inchikey_to_canonicalSMILES'] = json.load(jsonfile)
            else:
                maps_dict['map_inchikey_to_canonicalSMILES'] = {}
        return maps_dict
    
    @staticmethod
    def _update_map_inchikey_to_canonicalSMILES(inchikeys_series, map_inchikey_to_canonicalSMILES):
        """
        """
        current_idx = pandas.Index(map_inchikey_to_canonicalSMILES.keys(), name = 'InChI Key')
        candidate_idx = inchikeys_series.dropna()
        candidate_idx = candidate_idx.str.split(' ').explode() # TODO: pandas FutureWarning for this row.
        candidate_idx = pandas.Index(candidate_idx.unique())
        new_idx = candidate_idx.difference(current_idx)
        if len(new_idx) > 0:
            print('Updating map_inchikey_to_canonicalSMILES...')
            NEW = get_map_inchikey_to_canonicalSMILES(new_idx.tolist())
            map_inchikey_to_canonicalSMILES.update(NEW)
        return map_inchikey_to_canonicalSMILES

    def update_maps(self):
        """
        """
        mols = self.load_mols()
        maps_dict = self.load_maps_dict()

        if self.inchikey_col is not None:
            map_inchikey_to_canonicalSMILES = self._update_map_inchikey_to_canonicalSMILES(mols[self.inchikey_col], maps_dict['map_inchikey_to_canonicalSMILES'])
            with open(os.path.join(self.working_dir, 'map_inchikey_to_canonicalSMILES.json'), 'w') as jsonfile:
                json.dump(map_inchikey_to_canonicalSMILES, jsonfile)
            maps_dict['map_inchikey_to_canonicalSMILES'] = map_inchikey_to_canonicalSMILES

        return maps_dict