from jax import numpy as jnp
import jax
import copy
import time
import functools

from ProtLig_GPCRclassA.utils import create_line_graph, Sequence, Label, pad_graph
from ProtLig_GPCRclassA.amino_GNN.element import AminoElementPrecompute, AminoElementPrecomputeMasked

# -----------------
# JITtable version:
# -----------------
class ConcentrationSampler:
    def __init__(self,
                ec50_seed = None,
                screening_lower_margin = None,
                screening_upper_margin = None,
                screening_lower_extreme = None,
                screening_upper_extreme = None,
                ec50_nd_lower_extreme = None,
                ec50_nd_upper_extreme = None,
                ec50_std_multiplier = 2.0,
                ec50_lower_margin = 0.0,
                ec50_upper_margin = 0.0,
                ec50_lower_extreme = 3.0,
                ec50_upper_extreme = 3.0,
                ec50_greater_than_lower_margin = 0.0,
                ec50_greater_than_lower_extreme = 3.0,
                mean_ec50 = None,
                std_ec50 = None,
                screening_lower_perturbation_prob = 0.0,
                screening_lower_perturbation_shift = 0.0,
                conc_parameter_id_map = None,
                **kwargs):

        if ec50_seed is not None:
            self.ec50_seed = ec50_seed
        else:
            self.ec50_seed = int(time.time())
        self.prng_key = jax.random.PRNGKey(self.ec50_seed)
        
        if screening_lower_margin is None:
            screening_lower_margin = ec50_lower_margin
        if screening_upper_margin is None:
            screening_upper_margin = ec50_upper_margin
        if screening_lower_extreme is None:
            screening_lower_extreme = ec50_lower_extreme
        if screening_upper_extreme is None:
            screening_upper_extreme = ec50_upper_extreme
        
        if ec50_nd_lower_extreme is None:
            ec50_nd_lower_extreme = ec50_lower_extreme
        if ec50_nd_upper_extreme is None:
            ec50_nd_upper_extreme = ec50_upper_extreme

        self.screening_lower_margin = screening_lower_margin # NOTE: Implemented as an addition to EC50
        self.screening_upper_margin = screening_upper_margin # NOTE: Implemented as an addition to EC50
        self.screening_lower_extreme = screening_lower_extreme # NOTE: Implemented as an addition to EC50
        self.screening_upper_extreme = screening_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_nd_lower_extreme = ec50_nd_lower_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_nd_upper_extreme = ec50_nd_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_std_multiplier = ec50_std_multiplier
        self.ec50_lower_margin = ec50_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_upper_margin = ec50_upper_margin # NOTE: Implemented as an addition to EC50
        self.ec50_lower_extreme = ec50_lower_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_upper_extreme = ec50_upper_extreme # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_margin = ec50_greater_than_lower_margin # NOTE: Implemented as an addition to EC50
        self.ec50_greater_than_lower_extreme = ec50_greater_than_lower_extreme

        self.mean_ec50 = mean_ec50
        self.ec50_sigma = std_ec50 * self.ec50_std_multiplier

        self.screening_lower_perturbation_prob = screening_lower_perturbation_prob
        self.screening_lower_perturbation_shift = screening_lower_perturbation_shift

        self._check_input_integrity()

        if conc_parameter_id_map is None:
            self.conc_parameter_id_map = {'ec50_nd' : 0,
                                'ec50_greater_than' : 1,
                                'ec50' : 2,
                                'screening' : 3}
        else:
            self.conc_parameter_id_map = conc_parameter_id_map
        self.conc_parameter_id_inverse_map = {val : key for key,val in self.conc_parameter_id_map.items()}

        self.sample_func = functools.partial(self.sample_concentration_and_label,
                                            screening_lower_margin = self.screening_lower_margin,
                                            screening_upper_margin = self.screening_upper_margin,
                                            screening_lower_extreme = self.screening_lower_extreme,
                                            screening_upper_extreme = self.screening_upper_extreme,
                                            ec50_nd_lower_extreme = self.ec50_nd_lower_extreme,
                                            ec50_nd_upper_extreme = self.ec50_nd_upper_extreme,
                                            ec50_lower_margin = self.ec50_lower_margin,
                                            ec50_upper_margin = self.ec50_upper_margin,
                                            ec50_lower_extreme = self.ec50_lower_extreme,
                                            ec50_upper_extreme = self.ec50_upper_extreme,
                                            ec50_greater_than_lower_margin = self.ec50_greater_than_lower_margin,
                                            ec50_greater_than_lower_extreme = self.ec50_greater_than_lower_extreme,
                                            mean_ec50 = self.mean_ec50,
                                            ec50_sigma = self.ec50_sigma,
                                            conc_parameter_id_map = self.conc_parameter_id_map)
        self.call = self.make_call()

    def _check_input_integrity(self):
        if self.ec50_lower_margin < 0.0 or self.ec50_upper_margin < 0.0 or self.ec50_lower_extreme < 0.0 or self.ec50_upper_extreme < 0.0:
            raise ValueError('EC50 margins and extremes must be positive.') 
        if self.ec50_lower_margin > self.ec50_lower_extreme or self.ec50_upper_margin > self.ec50_upper_extreme:
            raise ValueError('Margins chosen above extremes.')
        
        if self.screening_lower_margin < 0.0 or self.screening_upper_margin < 0.0 or self.screening_lower_extreme < 0.0 or self.screening_upper_extreme < 0.0:
            raise ValueError('EC50 margins and extremes must be positive.') 
        if self.screening_lower_margin > self.screening_lower_extreme or self.screening_upper_margin > self.screening_upper_extreme:
            raise ValueError('Margins chosen above extremes.')

        if self.ec50_nd_lower_extreme < 0.0 or self.ec50_nd_upper_extreme < 0.0:
            raise ValueError('n.d. EC50 extremes must be positive.') 
        
        if self.screening_lower_perturbation_prob < 0.0:
            raise ValueError('Negative screening perturbation probability.')
        if self.screening_lower_perturbation_shift < 0.0:
            raise ValueError('Negative screening perturbation shift.')
        return

    def _make_core_call(self):
        if self.screening_lower_perturbation_prob > 0.0 and self.screening_lower_perturbation_shift > 0.0:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key, screening_perturbation_key = jax.random.split(prng_key, 3)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                # Screening perturbation
                param_mask = (param == self.conc_parameter_id_map['screening']).astype(jnp.int32)
                label_mask = (label == 0).astype(jnp.int32)
                screening_perturbation_mask = jax.random.uniform(screening_perturbation_key, shape=(), minval=0.0, maxval=1.0) < self.screening_lower_perturbation_prob
                screening_perturbation_mask = screening_perturbation_mask.astype(jnp.int32)
                val = val + param_mask * label_mask * screening_perturbation_mask * self.screening_lower_perturbation_shift # Shift value from time to time...

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        else:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key = jax.random.split(prng_key, 2)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        return _call
    
    def make_call(self):
        _vmapped_call = jax.vmap(self._make_core_call())
        def call(inputs):
            _batch_inputs, conc_sampler_key = inputs
            conc_sampler_rngs = jax.random.split(conc_sampler_key, _batch_inputs[-1]['_main_label'].shape[0])
            return _vmapped_call(inputs = (_batch_inputs, conc_sampler_rngs))
        return call

    def __call__(self, inputs):
        return self.call(inputs)

    @staticmethod
    def sample_concentration_and_label(val, param, label, prng_key, _label_key,
                                       screening_lower_margin, screening_upper_margin,
                                       screening_lower_extreme, screening_upper_extreme,
                                       ec50_nd_lower_extreme, ec50_nd_upper_extreme,
                                       ec50_lower_margin, ec50_upper_margin, 
                                       ec50_lower_extreme, ec50_upper_extreme,
                                       ec50_greater_than_lower_margin, ec50_greater_than_lower_extreme,
                                       mean_ec50, ec50_sigma,
                                       conc_parameter_id_map):
        """
        Notes:
        ------
        Difference to previous versions:
        1. EC50 n.d sampled uniformly in a range instead of normal distribution around mean EC50. 
        """


        def false_fun(x):
            return 0, 0.0

        def screening_value_fun(x):
            def non_responsive_fun(val):
                minval = val - screening_lower_extreme
                maxval = val - screening_lower_margin
                return minval, maxval
            def responsive_fun(val):
                minval = val + screening_upper_margin
                maxval = val + screening_upper_extreme
                return minval, maxval
            
            minval, maxval = jax.lax.cond(label, responsive_fun, non_responsive_fun, x)

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_nd_fun(x):
            # _value = mean_ec50 + ec50_sigma * jax.random.normal(prng_key, shape=()) # NOTE: ec50_sigma is std after multiping by ec50_std_multiplier
            minval = mean_ec50 - ec50_nd_lower_extreme
            maxval = mean_ec50 + ec50_nd_upper_extreme

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_fun(x):
            def non_responsive_fun(val):
                minval = val - ec50_lower_extreme
                maxval = val - ec50_lower_margin
                return minval, maxval
            def responsive_fun(val):
                minval = val + ec50_upper_margin
                maxval = val + ec50_upper_extreme
                return minval, maxval
            
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) >= 0.5
            minval, maxval = jax.lax.cond(response, responsive_fun, non_responsive_fun, x)

            _label = response.astype(jnp.int32)
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_greater_than_fun(val):
            # Sample non-responsive
            minval = val - ec50_greater_than_lower_extreme
            maxval = val - ec50_greater_than_lower_margin
            _label = 0
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value

        _label_screening, _value_screening = jax.lax.cond(param == conc_parameter_id_map['screening'], screening_value_fun, false_fun, val)
        _label_ec50_nd, _value_ec50_nd = jax.lax.cond(param == conc_parameter_id_map['ec50_nd'], ec50_nd_fun, false_fun, val)
        _label_ec50, _value_ec50 = jax.lax.cond(param == conc_parameter_id_map['ec50'], ec50_fun, false_fun, val)
        _label_ec50_greater_than, _value_ec50_greater_than = jax.lax.cond(param == conc_parameter_id_map['ec50_greater_than'], ec50_greater_than_fun, false_fun, val)

        _label = _label_screening + _label_ec50_nd + _label_ec50 + _label_ec50_greater_than
        _value = _value_screening + _value_ec50_nd + _value_ec50 + _value_ec50_greater_than

        return jnp.array(_value, ndmin = 0), jnp.array(_label, ndmin = 0)


class ConcentrationSamplerFixedExtremes(ConcentrationSampler):
    """
    Sampler with extremes that are not dependednt on the measured concentration. It ensures that all the samples are within 
    lower and upper extremes.
    """
    def _check_input_integrity(self):
        if self.ec50_lower_margin < 0.0 or self.ec50_upper_margin < 0.0:
            raise ValueError('EC50 margins must be positive.') 
        
        if self.screening_lower_margin < 0.0 or self.screening_upper_margin < 0.0:
            raise ValueError('EC50 margins and extremes must be positive.') 
        
        if self.screening_lower_perturbation_prob < 0.0:
            raise ValueError('Negative screening perturbation probability.')
        if self.screening_lower_perturbation_shift < 0.0:
            raise ValueError('Negative screening perturbation shift.')
        
        print('\n\nWARNING: The sampling is changed here compared to ConcentrationSampler and this can affect label distribution. Check Adjusted class weights!\n\n')
        return
    
    @staticmethod
    def sample_concentration_and_label(value, param, label, prng_key, _label_key,
                                       screening_lower_margin, screening_upper_margin,
                                       screening_lower_extreme, screening_upper_extreme,
                                       ec50_nd_lower_extreme, ec50_nd_upper_extreme,
                                       ec50_lower_margin, ec50_upper_margin, 
                                       ec50_lower_extreme, ec50_upper_extreme,
                                       ec50_greater_than_lower_margin, ec50_greater_than_lower_extreme,
                                       mean_ec50, ec50_sigma,
                                       conc_parameter_id_map):
        """
        Notes:
        ------
        Difference to previous versions:
        1. EC50 n.d sampled uniformly in a range instead of normal distribution around mean EC50. 
        2. _extreme are considered as fixed boundaries, rather than a distance from val.
        """
        ec50_greater_than_upper_extreme = ec50_nd_upper_extreme
        eps = 1e-5
        def false_fun(x):
            return 0, 0.0

        def screening_value_fun(x):
            def non_responsive_fun(val):
                minval = screening_lower_extreme
                maxval = val - screening_lower_margin
                maxval = jnp.clip(maxval, screening_lower_extreme + eps, screening_upper_extreme) # NOTE: This is in case val - screening_lower_margin is not in interval (screening_lower_extreme, screening_upper_extreme).
                return minval, maxval
            def responsive_fun(val):
                minval = val + screening_upper_margin
                minval = jnp.clip(minval, screening_lower_extreme, screening_upper_extreme - eps) # NOTE: This is in case val + screening_upper_margin is not in interval (screening_lower_extreme, screening_upper_extreme).
                maxval = screening_upper_extreme
                return minval, maxval
            
            minval, maxval = jax.lax.cond(label, responsive_fun, non_responsive_fun, x)

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_nd_fun(x):
            # _value = mean_ec50 + ec50_sigma * jax.random.normal(prng_key, shape=()) # NOTE: ec50_sigma is std after multiping by ec50_std_multiplier
            minval = ec50_nd_lower_extreme
            maxval = ec50_nd_upper_extreme

            _label = label
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_fun(x):
            def non_responsive_fun(val):
                minval = ec50_lower_extreme
                maxval = val - ec50_lower_margin
                maxval = jnp.clip(maxval, ec50_lower_extreme + eps, ec50_upper_extreme) # NOTE: This is in case val - ec50_lower_margin is not in interval (ec50_lower_extreme, ec50_upper_extreme).
                return minval, maxval
            def responsive_fun(val):
                minval = val + ec50_upper_margin
                minval = jnp.clip(minval, ec50_lower_extreme, ec50_upper_extreme - eps) # NOTE: This is in case val + ec50_upper_margin is not in interval (ec50_lower_extreme, ec50_upper_extreme).
                maxval = ec50_upper_extreme
                return minval, maxval
            
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) >= 0.5
            response = jax.lax.cond(x <= ec50_lower_extreme, lambda x: jnp.array(True), lambda x: x, response) # NOTE: Modify response if we are out of bounds.
            response = jax.lax.cond(x >= ec50_upper_extreme, lambda x: jnp.array(False), lambda x: x, response) # NOTE: Modify response if we are out of bounds.

            minval, maxval = jax.lax.cond(response, responsive_fun, non_responsive_fun, x)

            _label = response.astype(jnp.int32)
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value
        
        def ec50_greater_than_fun(val):
            # Sample non-responsive
            minval = ec50_greater_than_lower_extreme
            maxval = val - ec50_greater_than_lower_margin
            maxval = jnp.clip(maxval, ec50_greater_than_lower_extreme + eps, ec50_greater_than_upper_extreme) # NOTE: This is in case val - ec50_greater_than_lower_margin is not in interval (ec50_greater_than_lower_extreme, inf).
            _label = 0
            _value = jax.random.uniform(prng_key, shape=(), minval=minval, maxval=maxval)
            return _label, _value

        _label_screening, _value_screening = jax.lax.cond(param == conc_parameter_id_map['screening'], screening_value_fun, false_fun, value)
        _label_ec50_nd, _value_ec50_nd = jax.lax.cond(param == conc_parameter_id_map['ec50_nd'], ec50_nd_fun, false_fun, value)
        _label_ec50, _value_ec50 = jax.lax.cond(param == conc_parameter_id_map['ec50'], ec50_fun, false_fun, value)
        _label_ec50_greater_than, _value_ec50_greater_than = jax.lax.cond(param == conc_parameter_id_map['ec50_greater_than'], ec50_greater_than_fun, false_fun, value)

        _label = _label_screening + _label_ec50_nd + _label_ec50 + _label_ec50_greater_than
        _value = _value_screening + _value_ec50_nd + _value_ec50 + _value_ec50_greater_than

        return jnp.array(_value, ndmin = 0), jnp.array(_label, ndmin = 0)


class ConcentrationSamplerFixedExtremesAdjustWeightsExact(ConcentrationSamplerFixedExtremes):
    """
    Sampler with extremes that are not dependednt on the measured concentration. It ensures that all the samples are within 
    lower and upper extremes.
    """
    def __init__(self,
                ec50_seed = None,
                screening_lower_margin = None,
                screening_upper_margin = None,
                screening_lower_extreme = None,
                screening_upper_extreme = None,
                ec50_nd_lower_extreme = None,
                ec50_nd_upper_extreme = None,
                ec50_std_multiplier = 2.0,
                ec50_lower_margin = 0.0,
                ec50_upper_margin = 0.0,
                ec50_lower_extreme = 3.0,
                ec50_upper_extreme = 3.0,
                ec50_greater_than_lower_margin = 0.0,
                ec50_greater_than_lower_extreme = 3.0,
                mean_ec50 = None,
                std_ec50 = None,
                screening_lower_perturbation_prob = 0.0,
                screening_lower_perturbation_shift = 0.0,
                conc_parameter_id_map = None,
                augmented_data = None):
        
        super(ConcentrationSamplerFixedExtremesAdjustWeightsExact, self).__init__(
                ec50_seed = ec50_seed,
                screening_lower_margin = screening_lower_margin,
                screening_upper_margin = screening_upper_margin,
                screening_lower_extreme = screening_lower_extreme,
                screening_upper_extreme = screening_upper_extreme,
                ec50_nd_lower_extreme = ec50_nd_lower_extreme,
                ec50_nd_upper_extreme = ec50_nd_upper_extreme,
                ec50_std_multiplier = ec50_std_multiplier,
                ec50_lower_margin = ec50_lower_margin,
                ec50_upper_margin = ec50_upper_margin,
                ec50_lower_extreme = ec50_lower_extreme,
                ec50_upper_extreme = ec50_upper_extreme,
                ec50_greater_than_lower_margin = ec50_greater_than_lower_margin,
                ec50_greater_than_lower_extreme = ec50_greater_than_lower_extreme,
                mean_ec50 = mean_ec50,
                std_ec50 = std_ec50,
                screening_lower_perturbation_prob = screening_lower_perturbation_prob,
                screening_lower_perturbation_shift = screening_lower_perturbation_shift,
                conc_parameter_id_map = conc_parameter_id_map,
                )

        df_ec50_non_active = augmented_data[augmented_data.apply(lambda row: (row['_conc']['parameter'][0] == conc_parameter_id_map['ec50_nd'] and row['_label']['_main_label'] == 0), axis = 1)]
        df_ec50_greater_than = augmented_data[augmented_data.apply(lambda row: (row['_conc']['parameter'][0] == conc_parameter_id_map['ec50_greater_than'] and row['_label']['_main_label'] == 1), axis = 1)]
        df_screening_non_active = augmented_data[augmented_data.apply(lambda row: (row['_conc']['parameter'][0] == conc_parameter_id_map['screening'] and row['_label']['_main_label'] == 0), axis = 1)]
        
        df_ec50_active = augmented_data[augmented_data.apply(lambda row: (row['_conc']['parameter'][0] == conc_parameter_id_map['ec50'] and row['_label']['_main_label'] == 1), axis = 1)]
        df_screening_active = augmented_data[augmented_data.apply(lambda row: (row['_conc']['parameter'][0] == conc_parameter_id_map['screening'] and row['_label']['_main_label'] == 1), axis = 1)]

        conc_series_ec50_non_active = df_ec50_non_active.apply(lambda x: x['_conc']['value'][0], axis = 1)
        conc_series_ec50_greater_than = df_ec50_greater_than.apply(lambda x: x['_conc']['value'][0], axis = 1)
        conc_series_screening_non_active = df_screening_non_active.apply(lambda x: x['_conc']['value'][0], axis = 1)
        conc_series_ec50_active = df_ec50_active.apply(lambda x: x['_conc']['value'][0], axis = 1)
        conc_series_screening_active = df_screening_active.apply(lambda x: x['_conc']['value'][0], axis = 1)

        self.shifted_conc_ec50_non_active = jnp.array(conc_series_ec50_non_active.values) - ec50_lower_margin
        self.shifted_conc_ec50_greater_than = jnp.array(conc_series_ec50_greater_than.values) - ec50_greater_than_lower_margin
        self.shifted_conc_screening_non_active = jnp.array(conc_series_screening_non_active.values) - screening_lower_margin
        self.shifted_conc_ec50_active = jnp.array(conc_series_ec50_active.values) + ec50_upper_margin
        self.shifted_conc_screening_active = jnp.array(conc_series_screening_active.values) + screening_upper_margin

        self.get_sample_weight = self.make_get_sample_weight()

    def _make_core_call(self):
        if self.screening_lower_perturbation_prob > 0.0 and self.screening_lower_perturbation_shift > 0.0:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key, screening_perturbation_key = jax.random.split(prng_key, 3)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                # Screening perturbation
                param_mask = (param == self.conc_parameter_id_map['screening']).astype(jnp.int32)
                label_mask = (label == 0).astype(jnp.int32)
                screening_perturbation_mask = jax.random.uniform(screening_perturbation_key, shape=(), minval=0.0, maxval=1.0) < self.screening_lower_perturbation_prob
                screening_perturbation_mask = screening_perturbation_mask.astype(jnp.int32)
                val = val + param_mask * label_mask * screening_perturbation_mask * self.screening_lower_perturbation_shift # Shift value from time to time...

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                sample_weight = self.get_sample_weight(new_val, new_label)
                label_dict['_main_sample_weight'] = sample_weight

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        else:
            def _call(inputs):
                _batch_inputs, prng_key = inputs
                conc_dict, label_dict = _batch_inputs
                prng_key, _label_key = jax.random.split(prng_key, 2)

                val = conc_dict['value']
                param = conc_dict['parameter']
                label = label_dict['_main_label']

                new_val, new_label = self.sample_func(val, param, label, prng_key, _label_key)

                label_dict['_main_label'] = new_label # jnp.concatenate(new_labels)
                # return jnp.concatenate(new_vals), label_dict

                sample_weight = self.get_sample_weight(new_val, new_label)
                label_dict['_main_sample_weight'] = sample_weight

                return new_val, label_dict # jax.tree_map(lambda x: x.shape, conc_dict)
        return _call
    
    def make_get_sample_weight(self):
        ec50_non_active = self.shifted_conc_ec50_non_active.shape[0]
        def get_sample_weight(new_val, new_label):
            n0_greater_than = jnp.sum(new_val <= self.shifted_conc_ec50_greater_than)
            n0_ec50_non_active = ec50_non_active
            n0_ec50_active = jnp.sum(new_val <= self.shifted_conc_ec50_active)
            n0_screening_non_active = jnp.sum(new_val <= self.shifted_conc_screening_non_active)
            
            n1_ec50_active = jnp.sum(new_val >= self.shifted_conc_ec50_active)
            n1_screening_active = jnp.sum(new_val >= self.shifted_conc_screening_active)

            n0 = n0_greater_than + n0_ec50_non_active + n0_ec50_active + n0_screening_non_active
            n1 = n1_ec50_active + n1_screening_active

            def weight_non_active(x):
                return (n0 + n1)/(2*n0)

            def weight_active(x):
                return (n0 + n1)/(2*n1)

            sample_weight = jax.lax.cond(new_label, weight_active, weight_non_active, new_label)
            return (sample_weight, n0, n1)
        return get_sample_weight
    

class LabelSampler:
    def __init__(self, unknown_case_sample_weight_scale = 1.0,
                sampling_region_lower_bound = None,
                sampling_region_upper_bound = None,
                **kwargs):
        """
        """
        self.unknown_case_sample_weight_scale = unknown_case_sample_weight_scale
        self.sampling_region_lower_bound = sampling_region_lower_bound
        self.sampling_region_upper_bound = sampling_region_upper_bound

        self._check_input_integrity()

        self.sample_func = functools.partial(self.sample_concentration_and_label,
                                            sampling_region_lower_bound = self.sampling_region_lower_bound,
                                            sampling_region_upper_bound = self.sampling_region_upper_bound)
        
        self.batched_get_weights = functools.partial(self.batch_get_weights,
                                            alpha = self.unknown_case_sample_weight_scale)
        self.call = self.make_call()

    def _check_input_integrity(self):
        if self.unknown_case_sample_weight_scale < 0.0 or self.unknown_case_sample_weight_scale > 1.0:
            raise ValueError('Sample weight scale for unknwon case must be in interval <0, 1>.')
        return

    def _make_core_call(self):
        def _call(inputs):
            _batch_inputs, prng_key = inputs
            conc_dict, label_dict = _batch_inputs # NOTE: label_dict was used to construct C0 and C1 in Dataset preprocessing.
            prng_key, _label_key = jax.random.split(prng_key, 2)

            C0 = conc_dict['C0']
            C1 = conc_dict['C1']

            new_val, new_label, p1 = self.sample_func(C0, C1, prng_key, _label_key)

            return new_val, new_label, p1 # jax.tree_map(lambda x: x.shape, conc_dict)
        return _call
    
    def make_call(self):
        _vmapped_call = jax.vmap(self._make_core_call())
        def call(inputs):
            _batch_inputs, conc_sampler_key = inputs
            label_dict = _batch_inputs[-1]
            conc_sampler_rngs = jax.random.split(conc_sampler_key, label_dict['_main_label'].shape[0])

            new_vals, new_labels, p1s = _vmapped_call(inputs = (_batch_inputs, conc_sampler_rngs))
            sample_weights = self.batched_get_weights(new_labels, p1s)
            
            label_dict['_main_label'] = new_labels
            label_dict['_main_sample_weight'] = sample_weights
            
            return new_vals, label_dict
        return call

    def __call__(self, inputs):
        return self.call(inputs)
    
    @staticmethod
    def sample_concentration_and_label(C0, C1, prng_key, _label_key,
                                       sampling_region_lower_bound,
                                       sampling_region_upper_bound):
        # Sample concentration uniformly from (L, U)
        val = jax.random.uniform(prng_key, shape=(), minval=sampling_region_lower_bound, maxval=sampling_region_upper_bound)

        N0_for_a_given_conc = jnp.sum(val <= C0)
        N1_for_a_given_conc = jnp.sum(val >= C1)

        # case I. (no data):
        case_I_check = jnp.logical_and(N0_for_a_given_conc == 0, N1_for_a_given_conc == 0)
        # case II. (only non-active data), case III. (only active data) and case IV. (both active and non-active data):
        # NOTE: These three cases can be put together and when only one is available the distribution collapses to diract.

        def case_I_fun(x):
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0)
            return response, -1.0
        
        def case_II_III_IV_fun(x):
            N0 = x[0]
            N1 = x[1]
            p1 = N1/(N0 + N1)
            response = jax.random.uniform(_label_key, shape=(), minval=0.0, maxval=1.0) < p1
            response = response.astype(jnp.float32)
            return response, p1

        new_label, p1 = jax.lax.cond(case_I_check, case_I_fun, case_II_III_IV_fun, (N0_for_a_given_conc, N1_for_a_given_conc))

        return jnp.array(val, ndmin = 0), jnp.array(new_label, ndmin = 0), jnp.array(p1, ndmin = 0)
    
    @staticmethod
    def batch_get_weights(labels, p1s, alpha):
        """
        NOTE: Statistics here are drawn from batch not full data!!
        """
        eps = 1e-5
        p0s = 1 - p1s
        n = p1s.shape[0]

        # case I. (no data):
        case_I_check = p1s < 0.0
        # case II. (only non-active data), case III. (only active data) and case IV. (both active and non-active data):
        # NOTE: These three cases can be put together and when only one is available the distribution collapses to diract.

        weights_case_I = alpha * (n/(jnp.sum(case_I_check) + eps)) # 1/p(label|c)
        weights_case_II_III_IV_label_0 = n/(jnp.sum(p0s * (1 - case_I_check)) + eps)
        weights_case_II_III_IV_label_1 = n/(jnp.sum(p1s * (1 - case_I_check)) + eps)

        _weights_unif = jnp.where(case_I_check, weights_case_I, 0.0)
        _weights_0 = jnp.where(jnp.logical_and(jnp.logical_not(case_I_check), labels == 0.0), weights_case_II_III_IV_label_0, 0.0)
        _weights_1 = jnp.where(jnp.logical_and(jnp.logical_not(case_I_check), labels == 1.0), weights_case_II_III_IV_label_1, 0.0)

        weights = 0.5*(_weights_unif + _weights_0 + _weights_1) # NOTE: 1/2 is there so that weights are 1 in case of labels [0, 0, 1, 1] (even number of class examples).

        return weights