import torch
from torch.distributions import Categorical, Dirichlet


def fgsm_attack(data, data_grad, eps=0.05):
    sign_data_grad = data_grad.sign()
    perturbed_data = data + eps * sign_data_grad
    return perturbed_data


def mutual_information(alpha):
    alpha_0 = alpha.sum(dim=-1, keepdims=True)
    H_y = -((alpha / alpha_0) * (alpha / alpha_0).log()).sum(dim=-1)
    H_y_z = -((alpha / alpha_0) * ((alpha + 1).digamma() - (alpha_0 + 1).digamma())).sum(dim=-1)
    return H_y - H_y_z


def acc(alphas, labels):
    return 100. * (Dirichlet(alphas).mean.argmax(dim=-1) == labels).sum().float() / len(labels)


def llh(alphas, labels):
    return Categorical(probs=Dirichlet(alphas).mean).log_prob(labels).mean()


def ece(alphas, labels, n_bins=10):
    pred = Dirichlet(alphas).mean
    pred_probs, pred_labels = pred.max(dim=-1)
    bin_acc = torch.zeros(n_bins)
    bin_conf = torch.zeros(n_bins)
    bin_counter = torch.zeros(n_bins)
    for i in range(len(labels)):
        bin = int(pred_probs[i] * n_bins)
        bin_counter[bin] += 1
        bin_acc[bin] += (pred_labels[i] == labels[i])
        bin_conf[bin] += pred_probs[i]
    for bin in range(n_bins):
        if bin_counter[bin] > 0:
            bin_acc[bin] /= bin_counter[bin]
            bin_conf[bin] /= bin_counter[bin]
    return ((bin_counter / len(labels)) * (bin_acc - bin_conf).abs()).sum()
