"""
Metrics to measure classification performance
"""

import torch
from torch import nn
from torch.nn import functional as F

from utils.credal_ensemble_utils import credal_ensemble_forward_pass

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix


def get_logits_labels(model, data_loader, device):
    """
    Utility function to get logits and labels.
    """
    model.eval()
    logits = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            label = label.to(device)

            logit = model(data)
            logits.append(logit)
            labels.append(label)
    logits = torch.cat(logits, dim=0)
    labels = torch.cat(labels, dim=0)
    return logits, labels


def test_classification_net_softmax(softmax_prob, labels):
    """
    This function reports classification accuracy and confusion matrix given softmax vectors and
    labels from a model.
    """
    labels_list = []
    predictions_list = []
    confidence_vals_list = []

    confidence_vals, predictions = torch.max(softmax_prob, dim=1)
    labels_list.extend(labels.cpu().numpy())
    predictions_list.extend(predictions.cpu().numpy())
    confidence_vals_list.extend(confidence_vals.cpu().numpy())
    accuracy = accuracy_score(labels_list, predictions_list)
    return (
        confusion_matrix(labels_list, predictions_list),
        accuracy,
        labels_list,
        predictions_list,
        confidence_vals_list,
    )


def test_classification_net_logits(logits, labels):
    """
    This function reports classification accuracy and confusion matrix given logits and labels
    from a model.
    """
    softmax_prob = F.softmax(logits, dim=1)
    return test_classification_net_softmax(softmax_prob, labels)


def test_classification_net(model, data_loader, device):
    """
    This function reports classification accuracy and confusion matrix over a dataset.
    """
    logits, labels = get_logits_labels(model, data_loader, device)
    return test_classification_net_logits(logits, labels)


def test_classification_net_ensemble(model_ensemble, data_loader, device):
    """
    This function reports classification accuracy and confusion matrix over a dataset
    for a deep ensemble.
    """
    for model in model_ensemble:
        model.eval()
    softmax_prob_lower = []
    softmax_prob_upper = []
    labels = []
    with torch.no_grad():
        for data, label in data_loader:
            data = data.to(device)
            label = label.to(device)
            probs, _, _, _ = credal_ensemble_forward_pass(model_ensemble, data)
            lower_probs = probs[:, :probs.shape[-1]//2]
            upper_probs = probs[:, probs.shape[-1]//2:]
            softmax_prob_lower.append(lower_probs)
            softmax_prob_upper.append(upper_probs)
            labels.append(label)
            
    softmax_prob_lower = torch.cat(softmax_prob_lower, dim=0)
    softmax_prob_upper = torch.cat(softmax_prob_upper, dim=0)
    labels = torch.cat(labels, dim=0)
    
    # test_classification_net_softmax(softmax_prob_upper, labels)
    # return test_classification_net_softmax(softmax_prob_lower, labels)
    return test_classification_net_softmax(softmax_prob_upper, labels)