import os
import pandas
import json
import numpy


class ScreeningConfidence:
    def __init__(self, base_working_dir, data_path):
        """
        Parameters:
        -----------
        base_working_dir : str
            working directory from which data directories are created.

        data_path : str
            path to raw data.
        """
        self.base_working_dir = base_working_dir
        self.data_path = data_path

    def read_data(self, data_path):
        experiments = pandas.read_csv(data_path['experiments'], sep=';', index_col = 0)
        references = pandas.read_csv(data_path['references'], sep=';', index_col = 0)
        main_receptors = pandas.read_csv(data_path['main_receptors'], sep=';', index_col = 0)
        
        raw_data = {'experiments' : experiments,
                    'references' : references,
                    'main_receptors' : main_receptors[['mutation', 'uniprot_id', 'sequence', 'mutated_sequence']]}

        return raw_data


    @staticmethod
    def _create_seq_id(data):
        seqs = data[['Mutation', 'Uniprot ID', '_Sequence', 'mutated_Sequence']].copy()
        seqs = seqs[~seqs['mutated_Sequence'].duplicated()]
        seqs.reset_index(drop = True, inplace = True)
        seqs.index.name = 'seq_id'
        seqs.index = 's_' + seqs.index.astype(str)
        # seqs = seqs.to_frame()
        seqs.reset_index(drop = False, inplace = True)

        data = pandas.merge(data, seqs[['seq_id', 'mutated_Sequence']], on = 'mutated_Sequence', how = 'left')
        return data, seqs

    @staticmethod
    def _create_mol_id(data):
        # mols = data[['_MolID', '_SMILES']].copy()
        mols = data[['_MolID']].copy()
        mols = mols[~mols['_MolID'].duplicated()]
        mols.reset_index(drop = True, inplace = True)
        mols.index.name = 'mol_id'
        mols.index = 'm_' + mols.index.astype(str)
        # mols = mols.to_frame()
        mols.reset_index(drop = False, inplace = True)

        data = pandas.merge(data, mols[['mol_id', '_MolID']], on = '_MolID', how = 'left')
        return data, mols
    
    @staticmethod
    def _change_units(x, _from, _to):
        try:
            x = float(x)
        except:
            return x
        if _from == 'uM' and _to == 'Log(M)':
            return numpy.log10(x) - 6.0 # uM = 10^-6 M
        elif _from == 'Log(M)' and _to == 'uM':
            return 10.0**(x + 6.0)
        elif _from == 'mM' and _to == 'uM':
            return x*(10.0**3)


    @staticmethod
    def num_unique_value_screen(_df):
        return len(_df['value_screen'].unique())


    def main(self):
        raw_data = self.read_data(self.data_path)

        data = raw_data['experiments'].copy()
        references = raw_data['references'].copy()
        main_receptors = raw_data['main_receptors'].copy()

        data = pandas.merge(data, references, left_on = 'references_id', right_on = 'id', how = 'left')
        data = pandas.merge(data, main_receptors, left_on = 'main_receptors_id', right_on = 'id', how = 'left')

        print('''WARNING: Taking only Mainland 2015 data because there we know how to 
                distinguish between primary and secondary sreeening''')
        data = data[data['doi'] == 'https://doi.org/10.1038/sdata.2015.2']

        # Change units:
        data.loc[data['unit'] == 'uM', 'value'] = data[data['unit'] == 'uM']['value'].apply(lambda x: self._change_units(x, _from = 'uM', _to = 'Log(M)'))
        data.loc[data['unit'] == 'uM', 'unit'] = 'Log(M)'
        data.loc[data['unit_screen'] == 'mM', 'value_screen'] = data[data['unit_screen'] == 'mM']['value_screen'].apply(lambda x: self._change_units(x, _from = 'mM', _to = 'uM'))
        data.loc[data['unit_screen'] == 'mM', 'unit_screen'] = 'uM'

        # process EC50:
        data_ec50 = data[data['parameter'] == 'ec50']
        pairs_ec50 = data_ec50.groupby(['main_compounds_id', 'main_receptors_id']).apply(lambda x: x['responsive'].mean()) # TODO: apply more complex function returning {'Responsive', 'sample_weight'}

        # Check for EC50 both responsive and non-responsive:
        if ((pairs_ec50 > 0.0)&(pairs_ec50 < 1.0)).any():
            _ec50_inconsitent_idx = pairs_ec50[((pairs_ec50 > 0.0)&(pairs_ec50 < 1.0))].index
            print('INFO: To discard EC50 inconsistent: {}'.format(len(_ec50_inconsitent_idx)))
            pairs_ec50 = pairs_ec50.loc[pairs_ec50.index.difference(_ec50_inconsitent_idx)]

        pairs_ec50 = pairs_ec50.astype(int)
        pairs_ec50.name = 'responsive'
        pairs_ec50 = pairs_ec50.to_frame()
        pairs_ec50['_Parameter'] = 'ec50'

        # process Screening:
        data_screening = data[data['parameter'] != 'ec50']      

        _count_unique_concentrations = data_screening.groupby(['main_compounds_id', 'main_receptors_id']).apply(self.num_unique_value_screen)
        pairs_primary_idx = _count_unique_concentrations[_count_unique_concentrations == 1].index
        pairs_secondary_idx = _count_unique_concentrations[_count_unique_concentrations > 1].index

        if data_screening.empty:
            raise ValueError('No screening data.')
        
        def _is_sorted(s):
            return all(s.iloc[i] <= s.iloc[i+1] for i in range(len(s) - 1))
        _screening_consistency_ordering = data_screening.groupby(['main_compounds_id', 'main_receptors_id']).apply(lambda x: _is_sorted(x.sort_values(['value_screen', 'responsive'])['responsive']))
        _screening_consistency_ordering.name = 'Check'
        _screening_inconsistent_ordering_idx = _screening_consistency_ordering[~_screening_consistency_ordering].index
        _screening_consistent_ordering_idx = _screening_consistency_ordering[_screening_consistency_ordering].index
        print('INFO: To discard because of screening ordering (Inconsistent through Value_Screen): {}'.format(len(_screening_inconsistent_ordering_idx)))

        _screening_consistency_per_value = data_screening.groupby(['main_compounds_id', 'main_receptors_id', 'value_screen']).apply(lambda x: ((x['responsive']==1).all() or (x['responsive']==0).all()))
        _screening_consistency_per_value.name = 'Check'
        _screening_consistency_per_value = _screening_consistency_per_value.reset_index('value_screen')
        _screening_inconsitent_per_value_idx = _screening_consistency_per_value[~_screening_consistency_per_value['Check']].index
        _screening_inconsitent_per_value_idx = _screening_inconsitent_per_value_idx.drop_duplicates()
        _screening_consitent_per_value_idx = _screening_consistency_per_value[_screening_consistency_per_value['Check']].index
        _screening_consitent_per_value_idx = _screening_consitent_per_value_idx.drop_duplicates()
        print('INFO: To discard screening inconsistent per Value_Screen: {}'.format(len(_screening_inconsitent_per_value_idx)))

        _screening_inconsitent_idx = _screening_inconsitent_per_value_idx.union(_screening_inconsistent_ordering_idx)
        print('INFO: To discard screening inconsistent: {}'.format(len(_screening_inconsitent_idx)))

        pairs_screening = data_screening.groupby(['main_compounds_id', 'main_receptors_id']).apply(lambda x: (x['responsive']==1).any()).astype(int) # TODO: apply more complex function returning {'Responsive', 'sample_weight'}
        pairs_screening.name = 'responsive'

        pairs_primary_idx = pairs_primary_idx.difference(_screening_inconsitent_idx) # TODO: Discarding inconsistent here. Do we want to do it?
        pairs_secondary_idx = pairs_secondary_idx.difference(_screening_inconsitent_idx) # TODO: Discarding inconsistent here. Do we want to do it?
        
        pairs_primary = pairs_screening.loc[pairs_primary_idx]
        pairs_primary = pairs_primary.to_frame()
        pairs_primary['_Parameter'] = 'primaryScreening'

        pairs_secondary = pairs_screening.loc[pairs_secondary_idx]
        pairs_secondary = pairs_secondary.to_frame()
        pairs_secondary['_Parameter'] = 'secondaryScreening'

        probs = {}

        primary_vs_ec50 = pairs_ec50.join(pairs_primary, how = 'inner', lsuffix = '_ec50', rsuffix = '_primary')
        posPrimary_vs_ec50 = primary_vs_ec50[primary_vs_ec50['responsive_primary'] == 1]
        probs['posEC50_if_posPrimary'] = sum(posPrimary_vs_ec50['responsive_ec50'] == 1) / len(posPrimary_vs_ec50)
        probs['negEC50_if_posPrimary'] = sum(posPrimary_vs_ec50['responsive_ec50'] == 0) / len(posPrimary_vs_ec50)
        negPrimary_vs_ec50 = primary_vs_ec50[primary_vs_ec50['responsive_primary'] == 0]
        probs['posEC50_if_negPrimary'] = sum(negPrimary_vs_ec50['responsive_ec50'] == 1) / len(negPrimary_vs_ec50)
        probs['negEC50_if_negPrimary'] = sum(negPrimary_vs_ec50['responsive_ec50'] == 0) / len(negPrimary_vs_ec50)

        secondary_vs_ec50 = pairs_ec50.join(pairs_secondary, how = 'inner', lsuffix = '_ec50', rsuffix = '_secondary')
        posSecondary_vs_ec50 = secondary_vs_ec50[secondary_vs_ec50['responsive_secondary'] == 1]
        probs['posEC50_if_posSecondary'] = sum(posSecondary_vs_ec50['responsive_ec50'] == 1) / len(posSecondary_vs_ec50)
        probs['negEC50_if_posSecondary'] = sum(posSecondary_vs_ec50['responsive_ec50'] == 0) / len(posSecondary_vs_ec50)
        negSecondary_vs_ec50 = secondary_vs_ec50[secondary_vs_ec50['responsive_secondary'] == 0]
        probs['posEC50_if_negSecondary'] = sum(negSecondary_vs_ec50['responsive_ec50'] == 1) / len(negSecondary_vs_ec50)
        probs['negEC50_if_negSecondary'] = sum(negSecondary_vs_ec50['responsive_ec50'] == 0) / len(negSecondary_vs_ec50)

        with open(os.path.join(self.base_working_dir, 'Screening_confidence_probabilities.json'), 'w') as outfile:
            json.dump(probs, outfile)

        return probs
    
