
import numpy as np
import torch


def soft_accuracy(outputs, targets, args=None):
    """
    Compute soft classification accuracy.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target}
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    params["target_key"] = "actions"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    # got to numpy
    p = p.cpu().numpy()
    t = t.cpu().numpy()

    # compute soft accurracy
    acc = 0.0
    classes = p.argmax(axis=1)
    for i, cls in enumerate(classes):
        if cls in np.nonzero(t[i])[0]:
            acc += 1

    acc /= len(classes)
    return 100.0 * acc


def entropy(outputs, targets, args=None):
    """
    Compute entropy from logits output.

    args={"output_key": key of prediction output
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    if args:
        params.update(args)

    logits = outputs[params["output_key"]]

    h = torch.softmax(logits, dim=1) * torch.log_softmax(logits, dim=1)
    h = -h.sum(dim=1).mean()

    return h.cpu().numpy()


def mean_class_rank(outputs, targets, args=None):
    """
    Compute mean classification rank of one or more potential target classes.

    args={"output_key": key of prediction output,
          "target_key": key of prediction target}
    """

    # default parameters
    params = dict()
    params["output_key"] = "logits"
    params["target_key"] = "actions"
    if args:
        params.update(args)

    p, t = outputs[params["output_key"]], targets[params["target_key"]]

    # got to numpy
    p = p.cpu().numpy()
    t = t.cpu().numpy()

    # compute class ranks
    ranks = []
    for i in range(p.shape[0]):
        target_classes = np.nonzero(t[i])[0]
        sorted_classes = np.argsort(p[i])[::-1]
        for tc in target_classes:
            rank = np.where(sorted_classes == tc)[0][0] + 1
            ranks.append(rank)

    return np.mean(ranks)
