import jax
from jax import numpy as jnp
import numpy
from scipy.optimize import curve_fit

from ProtLig_GPCRclassA.utils import tf_to_jax, tf_to_jraph_graph_reshape

from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func, make_aux_loss_func
from ProtLig_GPCRclassA.amino_GNN.concentration.make_compute_metrics import make_compute_metrics_concentration, make_compute_ec50_metrics

# from Receptor_odorant.JAX.BERT_GNN.CLS_GTransformer.profiling_configs import hparams, _datacase, _h5file, dataparams, datadir, logdir, model, get_profiler_logdir

def make_eval_step():
    """
    NOTE: state is returned because of state.rng update.
    """
    def eval_step(state, batch):
        state = state.replace(rngs = jax.tree_map(lambda x: jax.random.split(x)[0], state.rngs)) # update PRNGKeys
        logits = state.apply_fn(state.params, batch[:-1], deterministic = True)
        return state, logits
    return eval_step
    # return jax.jit(eval_step)


def make_valid_conc_masked_epoch(loss_option, logger, 
                        aux_loss_option = None, 
                        loader_output_type = 'jax', num_classes = 3, mask_token_id = None, 
                        cls_token_id = None, pad_token_id = None, sep_token_id = None, 
                        unk_token_id = None, eos_token_id = None, bos_token_id = None,
                        conc_sampler = None,
                        min_conc_sample = -5.0, max_conc_sample = 1.0, step_conc_sample = 0.2, conc_parameter_id_map = None,
                        fit_normalized_curve = False):
    """
    Helper function to create valid_epoch function.
    """
    if fit_normalized_curve:
        def fit_sigmoid_curve(conc, prob):
            def sigmoid(x, x0, k):
                # x0 = ec50; k  = 1/hill_slope
                y = 1 / (1 + numpy.exp(-k*(x-x0)))
                return y
            lower_bound = numpy.array([min_conc_sample, 0.0])
            upper_bound = numpy.array([max_conc_sample, numpy.inf]) # max value in curve should be 1 (considering this as probabilities of activations)
            try:
                p0 = [numpy.median(conc),1] # this is an mandatory initial guess
                popt, _ = curve_fit(sigmoid, conc, prob, p0, method='trf', bounds = (lower_bound, upper_bound))
                popt = {'top' : 1.0, 
                        'ec50' : popt[0], 
                        'slope' : popt[1], 
                        'bottom' : 0.0} # Top, ec50, slope, bottom
                return popt
            except RuntimeError:
                return {'top' : -1.0, 
                        'ec50' : -100.0, 
                        'slope' : 0.0, 
                        'bottom' : -1.0} # Top, ec50, slope, bottom
    else:
        def fit_sigmoid_curve(conc, prob):
            def sigmoid(x, t ,x0, k, b):
                # x0 = ec50; k  = 1/hill_slope; t  = top; b  = bottom
                y = (t - b) / (1 + jnp.exp(-k*(x-x0))) + b
                return y
            lower_bound = numpy.array([0.0, min_conc_sample, 0.0, 0.0])
            upper_bound = numpy.array([1.0, max_conc_sample, numpy.inf, 1.0]) # max value in curve should be 1 (considering this as probabilities of activations)
            try:
                # t, ec50, k , b
                p0 = [max(prob), numpy.median(conc),1, min(prob)] # this is an mandatory initial guess
                popt, _ = curve_fit(sigmoid, conc, prob, p0, method='trf', bounds = (lower_bound, upper_bound))
                return {'top' : popt[0],
                        'ec50' : popt[1],
                        'slope' : popt[2],
                        'bottom' : popt[3]} # top, ec50, slope, bottom
            except RuntimeError:
                return {'top' : -1.0, 
                        'ec50' : -100.0, 
                        'slope' : 0.0, 
                        'bottom' : -1.0} # top, ec50, slope, bottom
            except ValueError as e:
                logger.warning(e)
                logger.info('Concentrations: ', conc)
                logger.info('Probabilities: ', prob)
                return {'top' : -1.0, 
                        'ec50' : -100.0, 
                        'slope' : 0.0, 
                        'bottom' : -1.0} # top, ec50, slope, bottom
        
    def get_masked_tokens_and_labels(input_ids, attn_mask, seq_mask_rng):
        seq_label = input_ids.copy()
        seq_mask = jax.random.uniform(seq_mask_rng, shape = input_ids.shape) < 0.15
        cls_mask = input_ids != cls_token_id
        pad_mask = attn_mask.astype(bool)
        sep_mask = input_ids != sep_token_id
        unk_mask = input_ids != unk_token_id
        eos_mask = input_ids != eos_token_id
        bos_mask = input_ids != bos_token_id
        input_ids_mask = seq_mask * cls_mask * pad_mask * sep_mask * unk_mask * eos_mask * bos_mask
        return seq_label, input_ids_mask

    loss_func = make_loss_func(is_weighted = False, option = loss_option, num_classes = num_classes)
    if aux_loss_option is not None:
        aux_loss_func = make_aux_loss_func(option = aux_loss_option)
    else:
        aux_loss_func = None
    compute_metrics = make_compute_metrics_concentration(use_jit = False)
    compute_metrics_maxprob = make_compute_metrics_concentration(use_jit = False, suffix = '_maxprob')
    compute_ec50_metrics = make_compute_ec50_metrics(min_conc_sample, max_conc_sample, conc_parameter_id_map = conc_parameter_id_map)
    compute_ec50_metrics_true_positive_top_pred_05 = make_compute_ec50_metrics(min_conc_sample, max_conc_sample, conc_parameter_id_map = conc_parameter_id_map, suffix = '_true_positive_top_pred_05')
    eval_step = make_eval_step()
    # jit:
    eval_step = jax.jit(eval_step)

    concentrations = jnp.arange(min_conc_sample, max_conc_sample + step_conc_sample, step_conc_sample)

    # Case loader outputs jnp.DeviceArray:
    if loader_output_type == 'jax':
        def valid_epoch(state, valid_loader):
            raise NotImplementedError('Logic chaged. See tf version of epoch...')
            batch_metrics = []
            for i, batch in enumerate(valid_loader):
                input_ids_label, input_ids_mask = get_masked_tokens_and_labels(input_ids = batch[0][2], 
                                                                        attn_mask = batch[0][1], 
                                                                        seq_mask_rng = state.rngs['seq_mask'])
                S = batch[0][:-1]
                G = batch[1]
                labels = batch[2]
                labels.update({'_input_ids_label' : input_ids_label})
                batch = (S, G, labels)
                state, logits = eval_step(state, batch)
                metrics = compute_metrics(logits, labels = batch[-1])
                logger.debug('eval_step: {}:  eval_loss:  {}'.format(i, metrics['loss']))
                batch_metrics.append(metrics)
            valid_loader.reset()
            return state, batch_metrics
    # Case loader outputs tf.Tensor:
    elif loader_output_type == 'tf':
        def valid_epoch(state, valid_loader):
            batch_metrics = []
            for i, batch in valid_loader.enumerate():
                batch = jax.tree_map(lambda x: jax.device_put(tf_to_jax(x), device = jax.devices()[0]), batch)
                S, G, labels = batch
                G =  tf_to_jraph_graph_reshape(G)
                ec50_label = G.globals['_conc']
                ec50_label['label_mask'] = labels['_main_label']
                # Get masked tokens:
                input_ids_label, input_ids_mask = get_masked_tokens_and_labels(input_ids = S[2], 
                                                                        attn_mask = S[1], 
                                                                        seq_mask_rng = state.rngs['seq_mask'])
                S = S[:-1] # Discard input_ids
                labels.update({'_input_ids_label' : input_ids_label})
                
                # Create batches with a given concentrations:
                results = []
                for conc in concentrations:
                    _globals = G.globals.copy()
                    new_conc = conc * jnp.ones_like(_globals['_conc']['value'])
                    _globals.update({'_conc' : new_conc})
                    new_G = G._replace(globals = _globals)
                    new_batch = (S, new_G, labels)
                    state, logits = eval_step(state, new_batch)
                    result = {'_conc' : jnp.expand_dims(new_conc, axis = -1), '_main_label' : logits['_main_label']}
                    results.append(result)

                results = jax.tree_map(lambda *x: jnp.concatenate(x, axis = -1), *results)
                prob = jax.nn.sigmoid(results['_main_label'])
                conc = results['_conc']

                # NOTE: This can be probably vmapped:
                curve_fits = []
                for i in range(conc.shape[0]):
                    curve_fit = fit_sigmoid_curve(conc[i, ...], prob[i, ...])
                    curve_fits.append(curve_fit)

                curve_params = jax.tree_map(lambda *x: jnp.stack(x, axis = 0), *curve_fits)
                ec50_pred = curve_params['ec50']

                ec50_metrics = compute_ec50_metrics(ec50_pred, ec50_label)

                # NOTE: Change the ec50_metrics['squared_error_mask'] according to the above WARNING

                top_pred = curve_params['top']
                metrics = compute_metrics(top_pred, labels = batch[-1])

                max_prob = jnp.max(prob, axis = -1)
                metrics_maxprob = compute_metrics_maxprob(max_prob, labels = batch[-1])               
                
                ec50_labels_true_positive = ec50_label.copy()
                ec50_labels_true_positive['label_mask'] = ec50_labels_true_positive['label_mask'] * jnp.squeeze(top_pred > 0.5) # Only for correctly predicted.
                ec50_metrics_true_positive_top_pred_05 = compute_ec50_metrics_true_positive_top_pred_05(ec50_pred, ec50_labels_true_positive)

                metrics.update(ec50_metrics)
                metrics.update(metrics_maxprob)
                metrics.update(ec50_metrics_true_positive_top_pred_05)
                batch_metrics.append(metrics)
            return state, batch_metrics
    return valid_epoch