import functools
import jax
from jax import numpy as jnp
import optax

from sklearn.metrics import confusion_matrix
from scipy.stats import spearmanr

# from ProtLig_GPCRclassA.metrics import confusion_matrix_binary

def make_compute_ec50_metrics(min_conc_sample, max_conc_sample, conc_parameter_id_map, use_jit = False, suffix = ''):
    """
    Notes:
    ------
    EC50 below or above thresholds are ignored, because they can not be fit.
    """
    def compute_ec50_metrics(ec50_pred, ec50_label):
        ec50_label_value = ec50_label['value']

        ec50_targets = ec50_label_value
        ec50_predictions = ec50_pred

        ec50_min_label_mask = ec50_targets >= min_conc_sample
        ec50_max_label_mask = ec50_targets <= max_conc_sample
        unfitted_ec50_mask = ec50_predictions > -99.9 # NOTE: -100.0 is manually set if fit_curve does not converge.

        ec50_label_mask = ec50_label['label_mask'].astype(bool)
        ec50_parameter_mask = ec50_label['parameter'] != conc_parameter_id_map['ec50_greater_than']    
        ec50_final_mask = ec50_label_mask * ec50_parameter_mask * ec50_min_label_mask * ec50_max_label_mask * unfitted_ec50_mask

        metrics = {'ec50_predictions' + suffix: ec50_predictions,
                   'ec50_targets' + suffix: ec50_targets,
                   'ec50_final_mask' + suffix : ec50_final_mask}
        return metrics
    
    if use_jit:
        return jax.jit(compute_ec50_metrics)
    else:
        return compute_ec50_metrics


def make_compute_metrics_concentration(num_thresholds = 200, use_jit = False, num_classes = 2, suffix = ''):
    """
    Notes:
    ------
    If you want to jit compute_metrics set use_jit = True. This is necessary because jitted and non jitted versions use different 
    implementations of confusion_matrix. Non-jitted version uses sklearn.confusion_matrix and jitted version uses simplified version
    of sklearn.confusion_matrix with all of checks discareded and with jax.numpy instead of numpy.
    """
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)]

    if num_classes == 2:
        if use_jit:
            _confusion_matrix = functools.partial(confusion_matrix_binary, labels = jnp.array([0,1]), sample_weight = None)
        else:
            _confusion_matrix = functools.partial(confusion_matrix, labels = jnp.array([0,1]), sample_weight = None)
        def _compute_metrics(pred_probs, labels):
            pred_probs = jnp.reshape(pred_probs, newshape = (-1, )) # TODO: Check this
            _labels = jnp.reshape(labels['_main_label'], newshape = (-1, )) # TODO: Check this
            # pred_probs = jax.nn.sigmoid(_logits)
            _C = {}
            conf_matrix = None
            for threshold in thresholds:
                _pred_labels = jnp.asarray(pred_probs > threshold).astype(jnp.int32)
                _C[threshold] = _confusion_matrix(y_true = _labels, y_pred = _pred_labels)
            pred_labels = jnp.round(pred_probs).astype(jnp.int32) # Threshold 0.5
            conf_matrix = _confusion_matrix(y_true = _labels, y_pred = pred_labels)
            return conf_matrix, _C
    else:
        raise NotImplementedError('In concentration there could be only 2 classes.')

    def compute_metrics(pred_probs, labels):
        conf_matrix, confusion_matrix_per_threshold = _compute_metrics(pred_probs, labels)
        metrics = {'confusion_matrix' + suffix : conf_matrix,
                'confusion_matrix_per_threshold' + suffix : confusion_matrix_per_threshold,
                }
        return metrics

    if use_jit:
        return jax.jit(compute_metrics)
    else:
        return compute_metrics
        

import numpy as np
from ProtLig_GPCRclassA.metrics import *

def log_metrics_concentration_from_epoch(name, epoch, hparams, metrics_np, logger, summary = None, activity_suffixes = [''], ec50_suffixes = [''], log_curves = True):
    """
    Log summaries generated by train/eval epoch.
    """
    for suffix in activity_suffixes:
        CM = sum([metrics.pop('confusion_matrix' + suffix) for metrics in metrics_np])
        # CM_per_threshold = jax.tree_multimap(lambda *x: sum(x), *[metrics.pop('confusion_matrix_per_threshold') for metrics in valid_metrics_np])
        logger.info(name + ' accuracy: {}'.format(list(np.diag(CM)/(np.sum(CM, axis = 1) + 10e-9))))

        if 'auxiliary_loss' in metrics_np[0].keys():
            aux_loss = np.mean([metrics.pop('auxiliary_loss') for metrics in metrics_np])
            logger.info(name + ' auxiliary loss: {}'.format(aux_loss))
        else:
            aux_loss = None

        if summary is not None:
            if log_curves:
                CM_per_threshold = jax.tree_map(lambda *x: sum(x), *[metrics.pop('confusion_matrix_per_threshold' + suffix) for metrics in metrics_np])
                roc_curve_values = roc_curve(CM_per_threshold, drop_intermediate=True)
                pr_curve_values = precision_recall_curve(CM_per_threshold)
                AveP = average_precision_score(pr_curve_values)
                summary['epoch_ROC_curve' + suffix] = log_roc_curve(roc_curve_values)
                summary['epoch_PR_curve' + suffix] = log_pr_curve(pr_curve_values, average_precision = AveP)
                summary.scalar('AUC_ROC' + suffix, roc_auc(roc_curve_values))
                summary.scalar('AveP' + suffix, AveP)

            summary['epoch_confusion_matrix' + suffix] = log_confusion_matrix(CM, class_names=['0', '1'])
            precision, recall, f_score = precision_recall_fscore(CM)
            summary.scalar('true_negative_rate' + suffix, true_negative_rate(CM))
            summary.scalar('precision' + suffix, precision)
            summary.scalar('recall' + suffix, recall)
            summary.scalar('f_score' + suffix, f_score)
            summary.scalar('MCC' + suffix, MCC(CM))

    for suffix in ec50_suffixes:
        if summary is not None:
            ec50_predictions = np.concatenate([metrics.get('ec50_predictions' + suffix)[metrics.get('ec50_final_mask' + suffix).astype(bool)] for metrics in metrics_np])
            ec50_targets = np.concatenate([metrics.get('ec50_targets' + suffix)[metrics.get('ec50_final_mask' + suffix).astype(bool)] for metrics in metrics_np])

            # squared_error = optax.squared_error(ec50_predictions, targets = ec50_targets)
            squared_error = np.power(ec50_predictions - ec50_targets, 2)
            mean_squared_error = np.mean(squared_error)
            root_mean_squared_error = np.sqrt(mean_squared_error)
            summary.scalar('mean_squared_error' + suffix, mean_squared_error)
            summary.scalar('root_mean_squared_error' + suffix, root_mean_squared_error)
            summary['ec50_error_distribution' + suffix] = log_ec50_error_histogram(squared_error)
            summary['ec50_error_distribution_zoom_4' + suffix] = log_ec50_error_histogram(squared_error[squared_error <= 4.0])

            if ec50_predictions.shape[0] > 0: # If there are any examples for a given ec50_final_mask...
                spearman_corr = spearmanr(ec50_predictions, ec50_targets).statistic # Returns tuple (statistic, p-value)
            else:
                spearman_corr = np.nan
            summary.scalar('spearman_corr' + suffix, spearman_corr)

    if summary is not None:
        if aux_loss is not None:
            summary.scalar('epoch_auxiliary_loss', aux_loss)
    
    if name == 'valid':
        logger.info('Confusion matrix:\n{}'.format(CM))
        logger.info('--------')
    
    return summary