import os
import pandas
import json
from sklearn.model_selection import train_test_split
import numpy

from ProtLig_GPCRclassA.base_cross_validation import BaseCrossValidation, BaseCVPreProcess, BaseCVSplit
from ProtLig_GPCRclassA.datasets.M2OR.utils import *
from ProtLig_GPCRclassA.datasets.M2OR.preprocess import CV_ORpairs


class CV_ORconc_DiscardMix(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_ORconc_DiscardMix, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        self.read_data_name = 'conc_all'
        self.func_data_name = 'm2or_conc_mixDiscard'
        # self.func_split_data_name = 'EC50_random_data'

    def read_data(self, data_path):
        if isinstance(data_path['experiments'], pandas.DataFrame):
            raise NotImplementedError('experiments being pandas.DataFrame is not implemented.')
            pairs = data_path['pairs']
            data_path['pairs'] = pairs.__class__.__name__
            print(pairs)
        experiments = pandas.read_csv(data_path['experiments'], sep=';', index_col = 0, dtype = {'parameter' : str, 
                                                'value' : str, 'unit' : str, 'value_screen' : str, 
                                                'unit_screen' : str, 'responsive' : int, 'nbr_measurements' : float, 
                                                'species_id' : int, 'main_receptors_id' : int, 'receptors_id' : int, 
                                                'main_compounds_id' : int, 'compounds_id' : int, 'assays_id' : int,
                                                'references_id' : int, 'pairs_id' : object, 'blast_id' : int})
        main_receptors = pandas.read_csv(data_path['main_receptors'], sep=';', index_col = 0)
        main_compounds = pandas.read_csv(data_path['main_compounds'], sep=';', index_col = 0)

        # Load auxillary:
        if 'map_inchikey_to_isomericSMILES' in data_path.keys() and data_path['map_inchikey_to_isomericSMILES'] is not None:
            with open(data_path['map_inchikey_to_isomericSMILES'], 'r') as jsonfile:
                map_inchikey_to_isomericSMILES = json.load(jsonfile)
        else:
            map_inchikey_to_isomericSMILES = {}

        if 'map_inchikey_to_canonicalSMILES' in data_path.keys() and data_path['map_inchikey_to_canonicalSMILES'] is not None:
            with open(data_path['map_inchikey_to_canonicalSMILES'], 'r') as jsonfile:
                map_inchikey_to_canonicalSMILES = json.load(jsonfile)
        else:
            map_inchikey_to_canonicalSMILES = {}

        raw_data = {'experiments' : experiments[['parameter', 'value', 'unit', 'value_screen', 'unit_screen',
                                            'main_receptors_id', 'main_compounds_id',
                                            'nbr_measurements', 'responsive']],
                    'main_receptors' : main_receptors,
                    'main_compounds' : main_compounds,
                    'map_inchikey_to_isomericSMILES' : map_inchikey_to_isomericSMILES,
                    'map_inchikey_to_canonicalSMILES' : map_inchikey_to_canonicalSMILES}

        return raw_data

    @staticmethod
    def _cast_if_possible(x, dtype = float):
        try:
            return dtype(x)
        except:
            print('Following value can not be cast to {} : {} '.format(dtype, x))
            return x 

    @staticmethod
    def _perform_change_units(x, _from, _to):
        if x <= 0.0 and 'Log' in _to and 'Log' not in _from:
            raise ValueError('Value lower than 0.0 can not be changed to log scale.')
        
        if _from == 'uM' and _to == 'log(M)':
            return numpy.log10(x) - 6.0 # uM = 10^-6 M
        elif _from == 'uM' and _to == 'log(mM)':
            return numpy.log10(x) - 3.0 # uM = 10^-3 mM
        elif _from == 'log(M)' and _to == 'log(mM)':
            return x + 3.0 # M = 10^3 mM
        elif _from == 'log(M)' and _to == 'uM':
            return 10.0**(x + 6.0)
        elif _from == 'mM' and _to == 'uM':
            return x*(10.0**3)
        else:
            raise NotImplementedError('Change from {} to {} is not available.'.format(_from, _to))
    
    def _change_units(self, row, value_col, unit_col, _from, _to):
        x = row[value_col]
        if row[unit_col] == _from and x == x:
            try:
                x = float(x)
            except:
                if isinstance(x, str):
                    if (x[0] == '>' or x[0] == '<'):
                        sign = x[0]
                        y = float(x[1:])
                        y = self._perform_change_units(y, _from, _to)
                        return sign + str(y), _to
                    raise NotImplementedError('No treatment of string {} implemented'.format(x))
                raise NotImplementedError('No treatment of {} implemented'.format(x))
                # return row[value_col], row[unit_col] # Don't change anything
            x = self._perform_change_units(x, _from, _to)
            return x, _to
        else:
            return row[value_col], row[unit_col] # Don't change anything    

    @staticmethod
    def _check_ec50_consistency(data_ec50):
        pairs_ec50 = data_ec50.groupby(['mol_id', 'seq_id']).apply(lambda x: x['responsive'].mean()) # TODO: apply more complex function returning {'Responsive', 'sample_weight'}
        _ec50_inconsistent_idx = pairs_ec50[((pairs_ec50 > 0.0)&(pairs_ec50 < 1.0))].index
        print('INFO: To discard EC50 inconsistent: {}'.format(len(_ec50_inconsistent_idx)))
        return _ec50_inconsistent_idx

    @staticmethod
    def _ec50_get_pairs(x):
        def _ec50_value_func(df):
            vals = df['value']
            nds = (vals == 'n.d')
            # Treat n.d:
            if nds.all():
                return 'n.d'
            elif nds.any():
                print(df[['parameter', 'value', 'unit', 'value_screen', 'unit_screen', 'responsive', 'nbr_measurements']])
                raise ValueError('Both n.d and values in Value.')
            # Treat < signed values:
            if vals.astype(str).str.contains('<').any():
                raise ValueError('Sign < encountered for EC50.')
            # Treat > signed values:
            signed = vals.astype(str).str.contains('>')
            if signed.all():
                _val = vals.apply(lambda x: float(x[1:])).median() # Take median of "greater then" values.
                if sum(signed) > 1:
                    print('WARNING: More than one sign > in Value (see below):')
                    print(df[['parameter', 'value', 'unit', 'value_screen', 'unit_screen', 'responsive', 'nbr_measurements']])
                return '>' + str(_val)
            elif signed.any():
                vals = vals[~signed]
                print('Both value and sign > in Value. Discarding signed values (see below):')
                print(df[['parameter', 'value', 'unit', 'value_screen', 'unit_screen', 'responsive', 'nbr_measurements']])
                return vals.median()
            return vals.median()

        def _ec50_unit_func(df):
            units = df['unit']
            first_unit = units.iloc[0]
            if (units == first_unit).all():
                return units.iloc[0]
            else:
                raise ValueError('Inconsistent units: {}'.format(units))

        row = {}
        row['parameter'] = x['parameter'].iloc[0]
        row['value'] = _ec50_value_func(x)
        row['unit'] = _ec50_unit_func(x)
        row['value_screen'] = x['value_screen'].iloc[0] if x['value_screen'].isna().all() else x['value_screen'].median()
        row['unit_screen'] = x['unit_screen'].iloc[0]
        row['nbr_measurements'] = x['nbr_measurements'].mean()
        row['responsive'] = x['responsive'].mean()

        return pandas.Series(row)

    def _func_data_ec50(self, data_ec50):
        print('Preparing EC50 pairs...')
        _ec50_inconsistent_idx = self._check_ec50_consistency(data_ec50)

        # Drop inconsistent:
        _data = data_ec50.set_index(['mol_id', 'seq_id'], drop = True)
        _data = _data.loc[_data.index.difference(_ec50_inconsistent_idx)]
        _data = _data.reset_index(drop = False)

        # Fix EC50 value and unit for non-responsive pairs:
        _data['value'] = _data.apply(lambda x: 'n.d' if x['responsive'] == 0 and x['parameter'] == 'ec50' else x['value'], axis = 1)
        _data['unit'] = _data.apply(lambda x: 'log(mM)' if x['responsive'] == 0 and x['parameter'] == 'ec50' else x['unit'], axis = 1)

        # Drop responsive EC50 pairs with value n.d:
        _condition = (_data['responsive'] == 1)&(_data['parameter'] == 'ec50')&(_data['value'] == 'n.d')
        _data = _data[~_condition]

        # Aggregate to pairs:
        pairs_ec50 = _data.groupby(['mol_id', 'seq_id']).apply(self._ec50_get_pairs)
        pairs_ec50 = pairs_ec50[['parameter','value','unit','value_screen','unit_screen','responsive','nbr_measurements']]
        return pairs_ec50
    
    @staticmethod
    def _check_consistency_ordering(data_screening, exclude_idx = None):
        def _is_sorted(s):
            return all(s.iloc[i] <= s.iloc[i+1] for i in range(len(s) - 1))

        # Responsive in low concentration, non-responsive in high:    
        _screening_consistency_ordering = data_screening.groupby(['mol_id', 'seq_id']).apply(lambda x: _is_sorted(x.sort_values(['value_screen', 'responsive'])['responsive']))
        if exclude_idx is not None:
            _screening_consistency_ordering = _screening_consistency_ordering.loc[_screening_consistency_ordering.index.difference(exclude_idx)] # NOTE: Exclude indices after each groupby
        _screening_consistency_ordering.name = 'Check'
        _screening_inconsistent_ordering_idx = _screening_consistency_ordering[~_screening_consistency_ordering].index
        _screening_consistent_ordering_idx = _screening_consistency_ordering[_screening_consistency_ordering].index
        print('INFO: To discard because of screening ordering (Inconsistent through Value_Screen): {}'.format(len(_screening_inconsistent_ordering_idx)))
        return _screening_inconsistent_ordering_idx

    @staticmethod
    def _check_consistency_per_value(data_screening, exclude_idx = None):
        # Same value screen, different responsivness:
        _screening_consistency_per_value = data_screening.groupby(['mol_id', 'seq_id', 'value_screen']).apply(lambda x: ((x['responsive']==1).all() or (x['responsive']==0).all()))
        _screening_consistency_per_value.name = 'Check'
        _screening_consistency_per_value = _screening_consistency_per_value.reset_index('value_screen')
        if exclude_idx is not None:
            _screening_consistency_per_value = _screening_consistency_per_value.loc[_screening_consistency_per_value.index.difference(exclude_idx)] # NOTE: Exclude indices after each groupby
        _screening_inconsitent_per_value_idx = _screening_consistency_per_value[~_screening_consistency_per_value['Check']].index
        _screening_inconsitent_per_value_idx = _screening_inconsitent_per_value_idx.drop_duplicates()
        _screening_consitent_per_value_idx = _screening_consistency_per_value[_screening_consistency_per_value['Check']].index
        _screening_consitent_per_value_idx = _screening_consitent_per_value_idx.drop_duplicates()
        print('INFO: To discard screening inconsistent per Value_Screen: {}'.format(len(_screening_inconsitent_per_value_idx)))
        return _screening_inconsitent_per_value_idx

    @staticmethod
    def _screening_get_pairs(x):
        def _screening_unit_func(df):
            units = df['unit_screen']
            first_unit = units.iloc[0]
            if (units == first_unit).all():
                return units.iloc[0]
            else:
                raise ValueError('Inconsistent unit_screen: {}'.format(units))

        row = {}
        row['parameter'] = x['parameter'].iloc[0]
        row['value'] = x['value'].iloc[0]
        row['unit'] = x['unit'].iloc[0]
        row['unit_screen'] = _screening_unit_func(x)
        row['nbr_measurements'] = x['nbr_measurements'].mean()
        row['responsive'] = x['responsive'].mean()
        if row['responsive'] > 0.0 and row['responsive'] < 1.0:
            print(x[['parameter', 'value', 'unit', 'value_screen', 'unit_screen', 'responsive', 'nbr_measurements']])
            raise Exception('Resposive is incosistent for above example')
        return pandas.Series(row)

    def _func_data_screening(self, data_screening):
        print('Preparing Screening pairs...')
        _screening_inconsistent_ordering_idx = self._check_consistency_ordering(data_screening)
        _screening_inconsitent_per_value_idx = self._check_consistency_per_value(data_screening)
        _screening_inconsitent_idx = _screening_inconsitent_per_value_idx.union(_screening_inconsistent_ordering_idx)
        print('INFO: To discard screening inconsistent: {}'.format(len(_screening_inconsitent_idx)))

        # Drop inconsistent:
        _data = data_screening.set_index(['mol_id', 'seq_id'], drop = True)
        _data = _data.loc[_data.index.difference(_screening_inconsitent_idx)]
        _data = _data.reset_index(drop = False)

        # Discard Value and Unit information:
        _data['Value'] = float('nan')
        _data['Unit'] = float('nan')

        pairs_screening = _data.groupby(['mol_id', 'seq_id', 'value_screen']).apply(self._screening_get_pairs)
        pairs_screening = pairs_screening.reset_index(level = 'value_screen')
        pairs_screening = pairs_screening[['parameter','value','unit','value_screen','unit_screen','responsive','nbr_measurements']]

        return pairs_screening


    def _func_data(self, raw_data):
        experiments = raw_data['experiments'].copy()
        mols = raw_data['main_compounds'].copy()
        seqs = raw_data['main_receptors'].copy()
        
        # Create sequence ID:
        seqs['seq_id'] = 's_' + seqs.index.astype(str)

        # Create molecules ID:
        mols['mol_id'] = 'm_' + mols.index.astype(str)

        # Add mol_id and seq_id to experiments:
        experiments['seq_id'] = 's_' + experiments['main_receptors_id'].astype(str)
        experiments['mol_id'] = 'm_' + experiments['main_compounds_id'].astype(str)
        
        # NOTE: Discard mixtures:
        _n_w_mix = len(experiments)
        mixture_mol_ids = mols[mols['mixture'] == 'mixture']['mol_id']
        mols = mols[~mols['mol_id'].isin(mixture_mol_ids)]
        experiments = experiments[~experiments['mol_id'].isin(mixture_mol_ids)]
        print('INFO: To discard mixtures (non-unique): {}'.format(_n_w_mix - len(experiments)))

        # Cast to correct dtype if possible:
        experiments['value_screen'] = experiments['value_screen'].apply(self._cast_if_possible)
        experiments['value'] = experiments['value'].apply(self._cast_if_possible)

        # Delete basal activity rows:
        if (experiments['value_screen'] == 0.0).any():
            _n = len(experiments)
            experiments = experiments[~(experiments['value_screen'] == 0.0)]
            print('INFO: Basal activity in the data. Deleting rows with concentration 0.0 with count: {}'.format(_n - len(experiments)))

        # Delete rows with units in "Fold":
        if (experiments['unit'] == 'Fold').any():
            _n = len(experiments)
            experiments = experiments[~(experiments['unit'] == 'Fold')]
            print('INFO: Unit "Fold" in the data. Deleting corresponding rows with count: {}'.format(_n - len(experiments)))

        # Delete rows with units in "uA":
        if (experiments['unit'] == 'uA').any():
            _n = len(experiments)
            experiments = experiments[~(experiments['unit'] == 'uA')]
            print('INFO: Unit "uA" in the data. Deleting corresponding rows with count: {}'.format(_n - len(experiments)))

        # Change units:
        experiments[['value', 'unit']] = experiments.apply(lambda row: self._change_units(row, 'value', 'unit', 'mM', 'uM'), axis = 1, result_type = 'expand')
        experiments[['value', 'unit']] = experiments.apply(lambda row: self._change_units(row, 'value', 'unit', 'uM', 'log(mM)'), axis = 1, result_type = 'expand')
        experiments[['value', 'unit']] = experiments.apply(lambda row: self._change_units(row, 'value', 'unit', 'log(M)', 'log(mM)'), axis = 1, result_type = 'expand')
        experiments[['value_screen', 'unit_screen']] = experiments.apply(lambda row: self._change_units(row, 'value_screen', 'unit_screen', 'mM', 'uM'), axis = 1, result_type = 'expand')
        experiments[['value_screen', 'unit_screen']] = experiments.apply(lambda row: self._change_units(row, 'value_screen', 'unit_screen', 'uM', 'log(mM)'), axis = 1, result_type = 'expand')
        experiments[['value_screen', 'unit_screen']] = experiments.apply(lambda row: self._change_units(row, 'value_screen', 'unit_screen', 'log(M)', 'log(mM)'), axis = 1, result_type = 'expand')

        # Check acceptable units and dash in value/value_screen:
        acceptable_units = set(['log(mM)', 'n.d'])
        n_acceptable_units = len(acceptable_units)
        units = set(experiments.dropna(subset = ['value','unit'])['unit'].unique())
        units_screen = set(experiments.dropna(subset = ['value_screen', 'unit_screen'])['unit_screen'].unique())
        if len(units.union(acceptable_units)) > n_acceptable_units or len(units_screen.union(acceptable_units)) > n_acceptable_units:
            raise ValueError('Unacceptable units found. Acceptable units: {}; Elements in the unit column: {}; Elements in the unit_screen column: {}'.format(acceptable_units, units, units_screen))

        _tmp = experiments.dropna(subset = 'value').copy()
        if ((_tmp['parameter'] == 'ec50')&(_tmp['value'].astype(str).str.contains('[0-9]-'))).any():
            raise ValueError('value contains dash ("-") which is not treated properly.')

        _tmp = experiments.dropna(subset = 'value_screen').copy()
        if ((_tmp['parameter'] != 'ec50')&(_tmp['value_screen'].astype(str).str.contains('[0-9]-'))).any():
            raise ValueError('value_screen contains dash ("-") which is not treated properly.')

        print('\n\n------------\nWARNING: Were these records fixed??')
        print(experiments[experiments.apply(lambda x: x['value'] != 'n.d' and x['responsive'] == 0 and x['parameter'] == 'ec50', axis = 1)])
        print('-------------------------------------------------\n')

        # Save sequences and molecules:
        # seqs.to_csv(os.path.join(self.working_dir, 'seqs.csv'), sep = ';', index = False)
        # mols.to_csv(os.path.join(self.working_dir, 'mols.csv'), sep = ';', index = False)

        # process EC50:
        experiments_ec50 = experiments[experiments['parameter'] == 'ec50']
        pairs_ec50 = self._func_data_ec50(experiments_ec50)

        # process Screening:
        experiments_screening = experiments[experiments['parameter'] != 'ec50']

        if not experiments_screening.empty:
            pairs_screeing = self._func_data_screening(experiments_screening)
            # NOTE: Delete EC50 pairs from screening:
            pairs_screeing = pairs_screeing.loc[pairs_screeing.index.difference(pairs_ec50.index)] 
            pairs = pandas.concat([pairs_ec50, pairs_screeing])
        else:
            pairs = pairs_ec50
        print('Label distribution: Positive: {}  Negative: {}'.format(len(pairs[pairs['responsive'] == 1]), len(pairs[pairs['responsive'] == 0])))
        
        pairs['responsive'] = pairs['responsive'].astype(int)
    
        return pairs, seqs, mols

    def func_data(self, raw_data):
        """
        """
        pairs, seqs, mols = self._func_data(raw_data)
    
        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep = ';', index = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep = ';', index = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep = ';', index = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep = ';', index = True)

        return pairs