import functools
import jax
from jax import numpy as jnp

from sklearn.metrics import confusion_matrix

# from ProtLig_GPCRclassA.metrics import confusion_matrix_binary

# from ProtLig_GPCRclassA.amino_GNN.make_loss_func import make_loss_func


def make_compute_metrics(loss_func, num_thresholds = 200, use_jit = False, num_classes = 3, aux_loss_func = None):
    """
    Notes:
    ------
    If you want to jit compute_metrics set use_jit = True. This is necessary because jitted and non jitted versions use different 
    implementations of confusion_matrix. Non-jitted version uses sklearn.confusion_matrix and jitted version uses simplified version
    of sklearn.confusion_matrix with all of checks discareded and with jax.numpy instead of numpy.
    """
    thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)]

    if num_classes == 2:
        if use_jit:
            _confusion_matrix = functools.partial(confusion_matrix_binary, labels = jnp.array([0,1]), sample_weight = None)
        else:
            _confusion_matrix = functools.partial(confusion_matrix, labels = jnp.array([0,1]), sample_weight = None)
        def _compute_metrics(logits, labels):
            _logits = jnp.reshape(logits['_main_label'], newshape = (-1, )) # TODO: Check this
            _labels = jnp.reshape(labels['_main_label'], newshape = (-1, )) # TODO: Check this
            pred_probs = jax.nn.sigmoid(_logits)
            _C = {}
            conf_matrix = None
            for threshold in thresholds:
                _pred_labels = jnp.asarray(pred_probs > threshold).astype(jnp.int32)
                _C[threshold] = _confusion_matrix(y_true = _labels, y_pred = _pred_labels)
            pred_labels = jnp.round(pred_probs).astype(jnp.int32)
            conf_matrix = _confusion_matrix(y_true = _labels, y_pred = pred_labels)
            return conf_matrix, _C
    else:
        if use_jit:
            raise NotImplementedError('jittable confusion matrix is not implemented for multiple classes yet.')
            _confusion_matrix = functools.partial(confusion_matrix_binary, labels = jnp.array([0,1]), sample_weight = None)
        else:
            _confusion_matrix = functools.partial(confusion_matrix, labels = jnp.arange(num_classes), sample_weight = None)

        def _compute_metrics(logits, labels):
            _C = {}
            conf_matrix = None
            pred_labels = jnp.argmax(logits['_main_label'], -1)
            conf_matrix = _confusion_matrix(y_true = labels['_main_label'], y_pred = pred_labels, labels = list(range(num_classes))) # [0,1,2]
            return conf_matrix, _C

    if aux_loss_func is not None:
        def compute_metrics(logits, labels):
            loss = loss_func(logits, labels)
            aux_loss = aux_loss_func(logits, labels)
            conf_matrix, confusion_matrix_per_threshold = _compute_metrics(logits, labels)
            metrics = {'loss' : loss,
                    'confusion_matrix' : conf_matrix,
                    'confusion_matrix_per_threshold' : confusion_matrix_per_threshold,
                    'auxiliary_loss' : aux_loss,
                    }
            return metrics
    else:
        def compute_metrics(logits, labels):
            loss = loss_func(logits, labels)
            conf_matrix, confusion_matrix_per_threshold = _compute_metrics(logits, labels)
            metrics = {'loss' : loss,
                    'confusion_matrix' : conf_matrix,
                    'confusion_matrix_per_threshold' : confusion_matrix_per_threshold,
                    }
            return metrics

    if use_jit:
        return jax.jit(compute_metrics)
    else:
        return compute_metrics
        

import numpy as np
from ProtLig_GPCRclassA.metrics import *

def log_metrics_from_epoch(name, epoch, hparams, metrics_np, logger, summary = None):
    """
    Log summaries generated by train/eval epoch.
    """
    CM = sum([metrics.pop('confusion_matrix') for metrics in metrics_np])
    loss = np.mean([metrics.pop('loss') for metrics in metrics_np])
    # CM_per_threshold = jax.tree_multimap(lambda *x: sum(x), *[metrics.pop('confusion_matrix_per_threshold') for metrics in valid_metrics_np])
    logger.info(name + ' loss: {}'.format(loss))
    logger.info(name + ' accuracy: {}'.format(list(np.diag(CM)/(np.sum(CM, axis = 1) + 10e-9))))

    if 'auxiliary_loss' in metrics_np[0].keys():
        aux_loss = np.mean([metrics.pop('auxiliary_loss') for metrics in metrics_np])
        logger.info(name + ' auxiliary loss: {}'.format(aux_loss))
    else:
        aux_loss = None

    if summary is not None:
        if hparams['OUT_FEATURES'] == 2:
            CM_per_threshold = jax.tree_map(lambda *x: sum(x), *[metrics.pop('confusion_matrix_per_threshold') for metrics in metrics_np])
            roc_curve_values = roc_curve(CM_per_threshold, drop_intermediate=True)
            pr_curve_values = precision_recall_curve(CM_per_threshold)
            AveP = average_precision_score(pr_curve_values)
            summary['epoch_confusion_matrix'] = log_confusion_matrix(CM, class_names=['0', '1'])
            summary['epoch_ROC_curve'] = log_roc_curve(roc_curve_values)
            summary['epoch_PR_curve'] = log_pr_curve(pr_curve_values, average_precision = AveP)
            summary.scalar('AUC_ROC', roc_auc(roc_curve_values))
            summary.scalar('AveP', AveP)
            precision, recall, f_score = precision_recall_fscore(CM)
            summary.scalar('true_negative_rate', true_negative_rate(CM))
            summary.scalar('precision', precision)
            summary.scalar('recall', recall)
            summary.scalar('f_score', f_score)
            summary.scalar('MCC', MCC(CM))
            summary.scalar('epoch_loss', loss)
        else:
            summary['epoch_confusion_matrix'] = log_confusion_matrix(CM, class_names=[str(i) for i in range(hparams['OUT_FEATURES'])])
            summary.scalar('epoch_loss', loss)
            summary.scalar('MCC', MCC(CM))

        if aux_loss is not None:
            summary.scalar('epoch_auxiliary_loss', aux_loss)
    

    if name == 'valid':
        logger.info('Confusion matrix:\n{}'.format(CM))
        logger.info('--------')

    return summary