import os
import pandas
import json
from shutil import copy
import numpy

from matplotlib import pyplot as plt

from mol2graph.utils import get_num_atoms_and_bonds

from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess
from ProtLig_GPCRclassA.datasets.ORligand.utils import *

# ---------------
# Mixtures graph:
# ---------------
class CVPP_Mixture_ConcatGraph(BaseCVPostProcess):
    def __init__(self, data_dir, mols_csv = None, mol_id_col = 'mol_id', auxiliary_data_path = None):
        name = None
        super(CVPP_Mixture_ConcatGraph, self).__init__(name, data_dir)

        self.prefix = 'concatGraph_'
        self.mol_id_col = mol_id_col

        if mols_csv is None:
            mols_csv = os.path.join(self.data_dir, 'mols.csv')

        _, self.mols_csv_name = os.path.split(mols_csv)
        self.mols_csv = mols_csv

        self.auxiliary_data_path = auxiliary_data_path

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'data_dir' : self.data_dir,
                'auxiliary_data_path' : self.auxiliary_data_path}

    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols

    def load_auxiliary(self):
        auxiliary = {}
        for key in self.auxiliary_data_path.keys():
            if self.auxiliary_data_path[key] is not None:
                with open(self.auxiliary_data_path[key], 'r') as jsonfile:
                    auxiliary[key] = json.load(jsonfile)
            else:
                auxiliary[key] = {}
        return auxiliary

    @staticmethod
    def _update_map_inchikey_to_isomericSMILES(inchikeys_series, map_inchikey_to_isomericSMILES):
        """
        """
        current_idx = pandas.Index(map_inchikey_to_isomericSMILES.keys(), name = 'inchi_key')
        candidate_idx = inchikeys_series.dropna()
        candidate_idx = candidate_idx.str.split(' ').explode() # TODO: pandas FutureWarning for this row.
        candidate_idx = pandas.Index(candidate_idx.unique())
        new_idx = candidate_idx.difference(current_idx)
        if len(new_idx) > 0:
            print('Updating map_inchikey_to_isomericSMILES...')
            NEW = get_map_inchikey_to_isomericSMILES(new_idx.tolist())
            map_inchikey_to_isomericSMILES.update(NEW)
        return map_inchikey_to_isomericSMILES

    @staticmethod
    def _update_map_inchikey_to_canonicalSMILES(inchikeys_series, map_inchikey_to_canonicalSMILES):
        """
        """
        current_idx = pandas.Index(map_inchikey_to_canonicalSMILES.keys(), name = 'inchi_key')
        candidate_idx = inchikeys_series.dropna()
        candidate_idx = candidate_idx.str.split(' ').explode() # TODO: pandas FutureWarning for this row.
        candidate_idx = pandas.Index(candidate_idx.unique())
        new_idx = candidate_idx.difference(current_idx)
        if len(new_idx) > 0:
            print('Updating map_inchikey_to_canonicalSMILES...')
            NEW = get_map_inchikey_to_canonicalSMILES(new_idx.tolist())
            map_inchikey_to_canonicalSMILES.update(NEW)
        return map_inchikey_to_canonicalSMILES

    def update_auxiliary(self, auxiliary, data):
        """
        """
        map_inchikey_to_isomericSMILES = self._update_map_inchikey_to_isomericSMILES(data['inchi_key'], auxiliary['map_inchikey_to_isomericSMILES'])
        with open(os.path.join(self.working_dir, 'map_inchikey_to_isomericSMILES.json'), 'w') as jsonfile:
            json.dump(map_inchikey_to_isomericSMILES, jsonfile)
        auxiliary['map_inchikey_to_isomericSMILES'] = map_inchikey_to_isomericSMILES

        map_inchikey_to_canonicalSMILES = self._update_map_inchikey_to_canonicalSMILES(data['inchi_key'], auxiliary['map_inchikey_to_canonicalSMILES'])
        with open(os.path.join(self.working_dir, 'map_inchikey_to_canonicalSMILES.json'), 'w') as jsonfile:
            json.dump(map_inchikey_to_canonicalSMILES, jsonfile)
        auxiliary['map_inchikey_to_canonicalSMILES'] = map_inchikey_to_canonicalSMILES

        return auxiliary

    @staticmethod
    def _get_SMILES_mono(row, _map_to_isomeric):
        """
        Notes:
        ------
        Use enumerate_isomers to go from canonicalSMILES to list of isomericSMILES.
        """
        if row['inchi_key'] == row['inchi_key']:
            isoSMILES = _map_to_isomeric[row['inchi_key']]
        else:
            isomers = enumerate_isomers(row['smiles'])
            if len(isomers) > 1:
                raise ValueError('{} is not mono'.format(row['smiles']))
            isoSMILES = isomers[0]
        return isoSMILES

    @staticmethod
    def _get_SMILES_sum_of_isomers(row, _map_to_canonical):
        """
        Unroll InChI keys to set of isomers separated by \'.\'.
        Using enumerate_isomers to go from canonicalSMILES to list of isomericSMILES.

        Notes:
        ------
        I. Ensamble approach
        """
        # TODO: How to process unrolled isomers?
        if row['inchi_key'] == row['inchi_key']:
            canonicalSMILES = _map_to_canonical[row['inchi_key']]
            isomers = enumerate_isomers(canonicalSMILES)
        else:
            isomers = enumerate_isomers(row['smiles'])
        if len(isomers) < 1:
            raise ValueError('Only one isomer found for sum of isomers {}'.format(row['inchi_key']))
        isoSMILES = '.'.join(isomers)
        return isoSMILES

    @staticmethod
    def _get_SMILES_mixture(row, _map_to_canonical):
        """
        Unroll InChI keys to set of isomers separated by \'.\'.
        Using enumerate_isomers to go from canonicalSMILES to list of isomericSMILES.

        Notes:
        ------
        I. Ensamble approach

        WARNING: This is neglecting concentration differences if they mix racemic + others in the same concentration.
        """
        # TODO: How to process unrolled isomers?
        if row['inchi_key'] == row['inchi_key']:
            isomers = []
            for x in row['inchi_key'].split(' '):
                _canonicalSMILES = _map_to_canonical[x]
                _isomers = enumerate_isomers(_canonicalSMILES)
                isomers += _isomers
        else:
            isomers = []
            for x in row['smiles'].split(' '):
                _isomers = enumerate_isomers(x)
                isomers += _isomers
        isoSMILES = '.'.join(isomers)
        return isoSMILES

    def get_SMILES(self, mols, auxiliary):
        mols_mono = mols[mols['mixture'] == 'mono'].copy()
        if not mols_mono.empty:
            mols_mono['_SMILES'] = mols_mono.apply(lambda x: self._get_SMILES_mono(x, _map_to_isomeric = auxiliary['map_inchikey_to_isomericSMILES']), axis = 1)
    
        mols_sum_of_isomers = mols[mols['mixture'] == 'sum of isomers'].copy()
        if not mols_sum_of_isomers.empty:
            mols_sum_of_isomers['_SMILES'] = mols_sum_of_isomers.apply(lambda x: self._get_SMILES_sum_of_isomers(x, _map_to_canonical = auxiliary['map_inchikey_to_canonicalSMILES']), axis = 1)
    
        mols_mixture = mols[mols['mixture'] == 'mixture'].copy()
        if not mols_mixture.empty:
            mols_mixture['_SMILES'] = mols_mixture.apply(lambda x: self._get_SMILES_mixture(x, _map_to_canonical = auxiliary['map_inchikey_to_canonicalSMILES']), axis = 1)
            
        mols = pandas.concat([mols_mono, mols_sum_of_isomers, mols_mixture])

        smiles_col = mols['_SMILES']
        smiles_col.name = 'SMILES_concatGraph'
        
        return smiles_col

    def postprocess(self):
        mols = self.load_mols()
        auxiliary = self.load_auxiliary()
        auxiliary = self.update_auxiliary(auxiliary, mols)

        smiles_col = self.get_SMILES(mols, auxiliary)

        mols = mols.join(smiles_col, how = 'left')
    
        mols.to_csv(os.path.join(self.working_dir, self.mols_csv_name), sep=';')

        self.save_hparams(prefix = self.prefix)
        return None




class CVPP_Mixture_Racemic(CVPP_Mixture_ConcatGraph):
    """
    ('Discard chirality')
    Discard mixture and treat sum of isomers as racemic (using canonical smiles).
    """
    def __init__(self, data_dir, mols_csv = None, mol_id_col = 'mol_id', auxiliary_data_path = None):
        name = None
        super(CVPP_Mixture_ConcatGraph, self).__init__(name, data_dir)

        self.prefix = 'racemicGraph_'
        self.mol_id_col = mol_id_col

        if mols_csv is None:
            mols_csv = os.path.join(self.data_dir, 'mols.csv')

        _, self.mols_csv_name = os.path.split(mols_csv)
        self.mols_csv = mols_csv

        self.auxiliary_data_path = auxiliary_data_path

    @staticmethod
    def _get_SMILES_sum_of_isomers(row, _map_to_canonical):
        """
        set SMILES of sum of isomers to canonical SMILES.
        """
        # TODO: How to process unrolled isomers? TODO: What is this comment!?
        if row['inchi_key'] == row['inchi_key']:
            canonicalSMILES = _map_to_canonical[row['inchi_key']]
            # isomers = enumerate_isomers(canonicalSMILES)
        else:
            canonicalSMILES = row['smiles']
            # isomers = enumerate_isomers(row['canonicalSMILES'])
        # if len(isomers) < 1:
        #     raise ValueError('Only one isomer found for sum of isomers {}'.format(row['inchi_key']))
        # isoSMILES = '.'.join(isomers)
        return canonicalSMILES

    def get_SMILES(self, mols, auxiliary):
        mols = self.load_mols()
        auxiliary = self.load_auxiliary()
        auxiliary = self.update_auxiliary(auxiliary, mols)
    
        mols_mono = mols[mols['mixture'] == 'mono'].copy()
        if not mols_mono.empty:
            mols_mono['_SMILES'] = mols_mono.apply(lambda x: self._get_SMILES_mono(x, _map_to_isomeric = auxiliary['map_inchikey_to_isomericSMILES']), axis = 1)
    
        mols_sum_of_isomers = mols[mols['mixture'] == 'sum of isomers'].copy()
        if not mols_sum_of_isomers.empty:
            mols_sum_of_isomers['_SMILES'] = mols_sum_of_isomers.apply(lambda x: self._get_SMILES_sum_of_isomers(x, _map_to_canonical = auxiliary['map_inchikey_to_canonicalSMILES']), axis = 1)

        mols = pandas.concat([mols_mono, mols_sum_of_isomers])

        smiles_col = mols['_SMILES']
        smiles_col.name = 'SMILES_racemic'

        return smiles_col