import os
import pandas
import json
from sklearn.model_selection import train_test_split
import numpy

from ProtLig_GPCRclassA.base_cross_validation import BaseCVPreProcess
from ProtLig_GPCRclassA.datasets.M2OR.utils import *

class CV_ORpairs(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_ORpairs, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        self.read_data_name = 'pairs_all'
        self.func_data_name = 'm2or_pairs'
        # self.func_split_data_name = 'EC50_random_data'

    def read_data(self, data_path):
        if isinstance(data_path['pairs'], pandas.DataFrame):
            raise NotImplementedError('pairs being pandas.DataFrame is not implemented.')
            pairs = data_path['pairs']
            data_path['pairs'] = pairs.__class__.__name__
            print(pairs)
        pairs = pandas.read_csv(data_path['pairs'], sep=';', index_col = 0)
        main_receptors = pandas.read_csv(data_path['main_receptors'], sep=';', index_col = 0)
        main_compounds = pandas.read_csv(data_path['main_compounds'], sep=';', index_col = 0)

        # Load auxillary:
        if 'map_inchikey_to_isomericSMILES' in data_path.keys() and data_path['map_inchikey_to_isomericSMILES'] is not None:
            with open(data_path['map_inchikey_to_isomericSMILES'], 'r') as jsonfile:
                map_inchikey_to_isomericSMILES = json.load(jsonfile)
        else:
            map_inchikey_to_isomericSMILES = {}

        if 'map_inchikey_to_canonicalSMILES' in data_path.keys() and data_path['map_inchikey_to_canonicalSMILES'] is not None:
            with open(data_path['map_inchikey_to_canonicalSMILES'], 'r') as jsonfile:
                map_inchikey_to_canonicalSMILES = json.load(jsonfile)
        else:
            map_inchikey_to_canonicalSMILES = {}
        
        raw_data = {'pairs' : pairs[['main_receptors_id', 'mutated_sequence',    # _Sequence
                                    'main_compounds_id', 'inchi_key', 'smiles', 'mol_id', # _mol
                                    'data_quality', 'num_unique_value_screen',    # Confidence score
                                    'responsive']],
                    'main_receptors' : main_receptors,
                    'main_compounds' : main_compounds,
                    'map_inchikey_to_isomericSMILES' : map_inchikey_to_isomericSMILES,
                    'map_inchikey_to_canonicalSMILES' : map_inchikey_to_canonicalSMILES}

        return raw_data

    def _func_data(self, raw_data):
        pairs = raw_data['pairs'].copy()
        mols = raw_data['main_compounds'].copy()
        seqs = raw_data['main_receptors'].copy()
        
        # Create sequence ID:
        seqs['seq_id'] = 's_' + seqs.index.astype(str)

        # Create molecules ID:
        mols['mol_id'] = 'm_' + mols.index.astype(str)

        # Rename pair's mol_id column to _MolID:
        pairs.rename(columns = {'mol_id' : '_MolID'}, inplace = True)
        
        # Add mol_id and seq_id to pairs:
        pairs['seq_id'] = 's_' + pairs['main_receptors_id'].astype(str)
        pairs['mol_id'] = 'm_' + pairs['main_compounds_id'].astype(str)

        pairs['responsive'] = pairs['responsive'].astype(int)

        print('Label distribution: Positive: {}  Negative: {}'.format(len(pairs[pairs['responsive'] == 1]), len(pairs[pairs['responsive'] == 0])))

        return pairs, seqs, mols

    def func_data(self, raw_data):
        """
        """
        pairs, seqs, mols = self._func_data(raw_data)

        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep = ';', index = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep = ';', index = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep = ';', index = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep = ';', index = True)

        return pairs