from jax import numpy as jnp
import jax
import copy
import time
import functools

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, pad_graph
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute, AminoElementPrecomputeMasked

class ConcentrationSampler_legacy:
    def __init__(self,
                ec50_seed = None,
                ec50_std_multiplier = 2.0,
                ec50_lower_margin = 0.0,
                ec50_upper_margin = 0.0,
                ec50_lower_extreme = 3.0,
                ec50_upper_extreme = 3.0,
                ec50_greater_than_lower_margin = 0.0,
                ec50_greater_than_lower_extreme = 3.0,
                mean_ec50 = None,
                std_ec50 = None,
                conc_parameter_id_map = None):

        if ec50_seed is not None:
            self.ec50_seed = ec50_seed
        else:
            self.ec50_seed = int(time.time())
        self.prng_key = jax.random.PRNGKey(self.ec50_seed)
        
        if ec50_lower_margin < 0.0 or ec50_upper_margin < 0.0:
            raise ValueError('EC50 margins must be positive.') 
        if ec50_lower_margin > ec50_lower_extreme or ec50_upper_margin > ec50_upper_extreme:
            raise ValueError('Margins chosen above extremes.') 

        self.ec50_std_multiplier = ec50_std_multiplier
        self.ec50_lower_margin = ec50_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_upper_margin = ec50_upper_margin # NOTE: Implemented as an addition to EC50
        self.ec50_lower_extreme = ec50_lower_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_upper_extreme = ec50_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_margin = ec50_greater_than_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_extreme = ec50_greater_than_lower_extreme

        self.mean_ec50 = mean_ec50
        self.ec50_sigma = std_ec50 * self.ec50_std_multiplier

        if conc_parameter_id_map is None:
            self.conc_parameter_id_map = {'ec50_nd' : 0,
                                'ec50_greater_than' : 1,
                                'ec50' : 2,
                                'screening' : 3}
        else:
            self.conc_parameter_id_map = conc_parameter_id_map
        self.conc_parameter_id_inverse_map = {val : key for key,val in self.conc_parameter_id_map.items()}

        self.sample_func = functools.partial(self.sample_concentration_and_label,
                                            ec50_lower_margin = self.ec50_lower_margin,
                                            ec50_upper_margin = self.ec50_upper_margin,
                                            ec50_lower_extreme = self.ec50_lower_extreme,
                                            ec50_upper_extreme = self.ec50_upper_extreme,
                                            ec50_greater_than_lower_margin = self.ec50_greater_than_lower_margin,
                                            ec50_greater_than_lower_extreme = self.ec50_greater_than_lower_extreme,
                                            mean_ec50 = self.mean_ec50,
                                            ec50_sigma = self.ec50_sigma,
                                            conc_parameter_id_map = self.conc_parameter_id_map)

    @staticmethod
    def sample_concentration_and_label(val, param, label, ec50_prng_key, _label_key, 
                                       ec50_lower_margin, ec50_upper_margin, 
                                       ec50_lower_extreme, ec50_upper_extreme,
                                       ec50_greater_than_lower_margin, ec50_greater_than_lower_extreme,
                                       mean_ec50, ec50_sigma,
                                       conc_parameter_id_map):
        _label = copy.deepcopy(label)
        if param == conc_parameter_id_map['screening']:
            _value = val
        elif param == conc_parameter_id_map['ec50_nd']:
            _value = mean_ec50 + ec50_sigma * jax.random.normal(ec50_prng_key, shape=()) # NOTE: ec50_sigma is std after multiping by ec50_std_multiplier
        elif param == conc_parameter_id_map['ec50']:
            if label == 0: # NOTE: Label: 0, Param: EC50 can not exist
                raise Exception('There is an example with EC50 response 0 and value non-NaN.')
            # Sample non-responsive
            if jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) < 0.5: 
                minval = val - ec50_lower_extreme
                maxval = val - ec50_lower_margin
                _label = 0
            # Sample responsive
            else:
                minval = val + ec50_upper_margin
                maxval = val + ec50_upper_extreme
                _label = 1
            _value = jax.random.uniform(ec50_prng_key, shape=(), minval=minval, maxval=maxval)
        elif param == conc_parameter_id_map['ec50_greater_than']:
            if label == 0: # NOTE: Label: 0, Param: EC50_greater_than can not exist
                raise Exception('There is an example with EC50 response 0 and value non-NaN.')
            # Sample non-responsive
            minval = val - ec50_greater_than_lower_extreme
            maxval = val - ec50_greater_than_lower_margin
            _label = 0
            _value = jax.random.uniform(ec50_prng_key, shape=(), minval=minval, maxval=maxval)
        return jnp.array(_value, ndmin = 1), jnp.array(_label, ndmin = 1)


    def __call__(self, conc_dict, label_dict):
        new_vals = []
        new_labels = []
        for i in range(len(label_dict['_main_label'])):
            new_prng_key, ec50_prng_key, _label_key = jax.random.split(self.prng_key, 3)
            self.prng_key = new_prng_key

            val = conc_dict['value'][i]
            param = conc_dict['parameter'][i]
            label = label_dict['_main_label'][i]
        
            new_val, new_label = self.sample_func(val, param, label, ec50_prng_key, _label_key)
            new_vals.append(new_val)
            new_labels.append(new_label)

        label_dict['_main_label'] = jnp.concatenate(new_labels)
        return jnp.concatenate(new_vals), label_dict
    
# -----------------
# JITtable version:
# -----------------
class ConcentrationSampler:
    def __init__(self,
                ec50_seed = None,
                screening_lower_margin = None,
                screening_upper_margin = None,
                screening_lower_extreme = None,
                screening_upper_extreme = None,
                ec50_nd_lower_extreme = None,
                ec50_nd_upper_extreme = None,
                ec50_std_multiplier = 2.0,
                ec50_lower_margin = 0.0,
                ec50_upper_margin = 0.0,
                ec50_lower_extreme = 3.0,
                ec50_upper_extreme = 3.0,
                ec50_greater_than_lower_margin = 0.0,
                ec50_greater_than_lower_extreme = 3.0,
                mean_ec50 = None,
                std_ec50 = None,
                screening_lower_perturbation_prob = 0.0,
                screening_lower_perturbation_shift = 0.0,
                conc_parameter_id_map = None,
                **kwargs):

        if ec50_seed is not None:
            self.ec50_seed = ec50_seed
        else:
            self.ec50_seed = int(time.time())
        self.prng_key = jax.random.PRNGKey(self.ec50_seed)
        
        if screening_lower_margin is None:
            screening_lower_margin = ec50_lower_margin
        if screening_upper_margin is None:
            screening_upper_margin = ec50_upper_margin
        if screening_lower_extreme is None:
            screening_lower_extreme = ec50_lower_extreme
        if screening_upper_extreme is None:
            screening_upper_extreme = ec50_upper_extreme
        
        if ec50_nd_lower_extreme is None:
            ec50_nd_lower_extreme = ec50_lower_extreme
        if ec50_nd_upper_extreme is None:
            ec50_nd_upper_extreme = ec50_upper_extreme

        self.screening_lower_margin = screening_lower_margin # NOTE: Implemented as an addition to EC50
        self.screening_upper_margin = screening_upper_margin # NOTE: Implemented as an addition to EC50
        self.screening_lower_extreme = screening_lower_extreme # NOTE: Implemented as an addition to EC50
        self.screening_upper_extreme = screening_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_nd_lower_extreme = ec50_nd_lower_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_nd_upper_extreme = ec50_nd_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_std_multiplier = ec50_std_multiplier
        self.ec50_lower_margin = ec50_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_upper_margin = ec50_upper_margin # NOTE: Implemented as an addition to EC50
        self.ec50_lower_extreme = ec50_lower_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_upper_extreme = ec50_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_margin = ec50_greater_than_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_extreme = ec50_greater_than_lower_extreme

        self.mean_ec50 = mean_ec50
        self.ec50_sigma = std_ec50 * self.ec50_std_multiplier

        self.screening_lower_perturbation_prob = screening_lower_perturbation_prob
        self.screening_lower_perturbation_shift = screening_lower_perturbation_shift

        self._check_input_integrity()

        if conc_parameter_id_map is None:
            self.conc_parameter_id_map = {'ec50_nd' : 0,
                                'ec50_greater_than' : 1,
                                'ec50' : 2,
                                'screening' : 3}
        else:
            self.conc_parameter_id_map = conc_parameter_id_map
        self.conc_parameter_id_inverse_map = {val : key for key,val in self.conc_parameter_id_map.items()}

        self.sample_func = functools.partial(self.sample_concentration_and_label,
                                            screening_lower_margin = self.screening_lower_margin,
                                            screening_upper_margin = self.screening_upper_margin,
                                            screening_lower_extreme = self.screening_lower_extreme,
                                            screening_upper_extreme = self.screening_upper_extreme,
                                            ec50_nd_lower_extreme = self.ec50_nd_lower_extreme,
                                            ec50_nd_upper_extreme = self.ec50_nd_upper_extreme,
                                            ec50_lower_margin = self.ec50_lower_margin,
                                            ec50_upper_margin = self.ec50_upper_margin,
                                            ec50_lower_extreme = self.ec50_lower_extreme,
                                            ec50_upper_extreme = self.ec50_upper_extreme,
                                            ec50_greater_than_lower_margin = self.ec50_greater_than_lower_margin,
                                            ec50_greater_than_lower_extreme = self.ec50_greater_than_lower_extreme,
                                            mean_ec50 = self.mean_ec50,
                                            ec50_sigma = self.ec50_sigma,
                                            conc_parameter_id_map = self.conc_parameter_id_map)
        self.call = self.make_call()

    def _check_input_integrity(self):
        if self.ec50_lower_margin < 0.0 or self.ec50_upper_margin < 0.0 or self.ec50_lower_extreme < 0.0 or self.ec50_upper_extreme < 0.0:
            raise ValueError('EC50 margins and extremes must be positive.') 
        if self.ec50_lower_margin > self.ec50_lower_extreme or self.ec50_upper_margin > self.ec50_upper_extreme:
            raise ValueError('Margins chosen above extremes.')
        
        if self.screening_lower_margin < 0.0 or self.screening_upper_margin < 0.0 or self.screening_lower_extreme < 0.0 or self.screening_upper_extreme < 0.0:
            raise ValueError('EC50 margins and extremes must be positive.') 
        if self.screening_lower_margin > self.screening_lower_extreme or self.screening_upper_margin > self.screening_upper_extreme:
            raise ValueError('Margins chosen above extremes.')

        if self.ec50_nd_lower_extreme < 0.0 or self.ec50_nd_upper_extreme < 0.0:
            raise ValueError('n.d. EC50 extremes must be positive.') 
        
        if self.screening_lower_perturbation_prob < 0.0:
            raise ValueError('Negative screening perturbation probability.')
        if self.screening_lower_perturbation_shift < 0.0:
            raise ValueError('Negative screening perturbation shift.')
        return

    def _make_core_call(self):
        if self.screening_lower_perturbation_prob > 0.0 and self.screening_lower_perturbation_shift > 0.0:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key, screening_perturbation_key = jax.random.split(prng_key, 3)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                # Screening perturbation
                param_mask = (param == self.conc_parameter_id_map['screening']).astype(jnp.int32)
                label_mask = (label == 0).astype(jnp.int32)
                screening_perturbation_mask = jax.random.uniform(screening_perturbation_key, shape=(), minval=0.0, maxval=1.0) < self.screening_lower_perturbation_prob
                screening_perturbation_mask = screening_perturbation_mask.astype(jnp.int32)
                val = val + param_mask * label_mask * screening_perturbation_mask * self.screening_lower_perturbation_shift # Shift value from time to time...

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        else:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key = jax.random.split(prng_key, 2)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        return _call
    
    def make_call(self):
        _vmapped_call = jax.vmap(self._make_core_call())
        def call(inputs):
            _batch_inputs, conc_sampler_key = inputs
            conc_sampler_rngs = jax.random.split(conc_sampler_key, _batch_inputs[-1]['_main_label'].shape[0])
            return _vmapped_call(inputs = (_batch_inputs, conc_sampler_rngs))
        return call

    def __call__(self, inputs):
        return self.call(inputs)

    @staticmethod
    def sample_concentration_and_label(val, param, label, prng_key, _label_key,
                                       screening_lower_margin, screening_upper_margin,
                                       screening_lower_extreme, screening_upper_extreme,
                                       ec50_nd_lower_extreme, ec50_nd_upper_extreme,
                                       ec50_lower_margin, ec50_upper_margin, 
                                       ec50_lower_extreme, ec50_upper_extreme,
                                       ec50_greater_than_lower_margin, ec50_greater_than_lower_extreme,
                                       mean_ec50, ec50_sigma,
                                       conc_parameter_id_map):
        """
        Notes:
        ------
        Difference to previous versions:
        1. EC50 n.d sampled uniformly in a range instead of normal distribution around mean EC50. 
        """


        def false_fun(x):
            return 0, 0.0

        def screening_value_fun(x):
            def non_responsive_fun(val):
                minval = val - screening_lower_extreme
                maxval = val - screening_lower_margin
                return minval, maxval
            def responsive_fun(val):
                minval = val + screening_upper_margin
                maxval = val + screening_upper_extreme
                return minval, maxval
            
            minval, maxval = jax.lax.cond(label, responsive_fun, non_responsive_fun, x)

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_nd_fun(x):
            # _value = mean_ec50 + ec50_sigma * jax.random.normal(prng_key, shape=()) # NOTE: ec50_sigma is std after multiping by ec50_std_multiplier
            minval = mean_ec50 - ec50_nd_lower_extreme
            maxval = mean_ec50 + ec50_nd_upper_extreme

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_fun(x):
            def non_responsive_fun(val):
                minval = val - ec50_lower_extreme
                maxval = val - ec50_lower_margin
                return minval, maxval
            def responsive_fun(val):
                minval = val + ec50_upper_margin
                maxval = val + ec50_upper_extreme
                return minval, maxval
            
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) >= 0.5
            minval, maxval = jax.lax.cond(response, responsive_fun, non_responsive_fun, x)

            _label = response.astype(jnp.int32)
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_greater_than_fun(val):
            # Sample non-responsive
            minval = val - ec50_greater_than_lower_extreme
            maxval = val - ec50_greater_than_lower_margin
            _label = 0
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value

        _label_screening, _value_screening = jax.lax.cond(param == conc_parameter_id_map['screening'], screening_value_fun, false_fun, val)
        _label_ec50_nd, _value_ec50_nd = jax.lax.cond(param == conc_parameter_id_map['ec50_nd'], ec50_nd_fun, false_fun, val)
        _label_ec50, _value_ec50 = jax.lax.cond(param == conc_parameter_id_map['ec50'], ec50_fun, false_fun, val)
        _label_ec50_greater_than, _value_ec50_greater_than = jax.lax.cond(param == conc_parameter_id_map['ec50_greater_than'], ec50_greater_than_fun, false_fun, val)

        _label = _label_screening + _label_ec50_nd + _label_ec50 + _label_ec50_greater_than
        _value = _value_screening + _value_ec50_nd + _value_ec50 + _value_ec50_greater_than

        return jnp.array(_value, ndmin = 0), jnp.array(_label, ndmin = 0)

    # ------------------------
    # LEGACY VERSION:
    # ------------------------
    @staticmethod
    def sample_concentration_and_label_LEGACY(val, param, label, prng_key, _label_key,
                                       screening_lower_margin, screening_upper_margin,
                                       screening_lower_extreme, screening_upper_extreme,
                                       ec50_nd_lower_extreme, ec50_nd_upper_extreme,
                                       ec50_lower_margin, ec50_upper_margin, 
                                       ec50_lower_extreme, ec50_upper_extreme,
                                       ec50_greater_than_lower_margin, ec50_greater_than_lower_extreme,
                                       mean_ec50, ec50_sigma,
                                       conc_parameter_id_map):

        def false_fun(x):
            return 0, 0.0

        def screening_value_fun(x):
            _value = x
            _label = label
            return _label, _value
        
        def ec50_nd_fun(x):
            _value = mean_ec50 + ec50_sigma * jax.random.normal(prng_key, shape=()) # NOTE: ec50_sigma is std after multiping by ec50_std_multiplier
            _label = label
            return _label, _value
        
        def ec50_fun(x):
            def non_responsive_fun(val):
                minval = val - ec50_lower_extreme
                maxval = val - ec50_lower_margin
                return minval, maxval
            def responsive_fun(val):
                minval = val + ec50_upper_margin
                maxval = val + ec50_upper_extreme
                return minval, maxval
            
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) >= 0.5
            minval, maxval = jax.lax.cond(response, responsive_fun, non_responsive_fun, x)

            _label = response.astype(jnp.int32)
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_greater_than_fun(val):
            # Sample non-responsive
            minval = val - ec50_greater_than_lower_extreme
            maxval = val - ec50_greater_than_lower_margin
            _label = 0
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value

        _label_screening, _value_screening = jax.lax.cond(param == conc_parameter_id_map['screening'], screening_value_fun, false_fun, val)
        _label_ec50_nd, _value_ec50_nd = jax.lax.cond(param == conc_parameter_id_map['ec50_nd'], ec50_nd_fun, false_fun, val)
        _label_ec50, _value_ec50 = jax.lax.cond(param == conc_parameter_id_map['ec50'], ec50_fun, false_fun, val)
        _label_ec50_greater_than, _value_ec50_greater_than = jax.lax.cond(param == conc_parameter_id_map['ec50_greater_than'], ec50_greater_than_fun, false_fun, val)

        _label = _label_screening + _label_ec50_nd + _label_ec50 + _label_ec50_greater_than
        _value = _value_screening + _value_ec50_nd + _value_ec50 + _value_ec50_greater_than

        return jnp.array(_value, ndmin = 0), jnp.array(_label, ndmin = 0)