import logging

import torch
from torch import nn

eps = 1e-10


def get_sharpness_reg(f, y, bandwidth, device):
    check_input(f, bandwidth)
    # Binary case
    if f.shape[1] == 1:
        log_kern = get_kernel(f, bandwidth, device)
        return get_sharpness_ratio(y, log_kern, len(f))

    raise NotImplementedError


def get_ece_reg(f, y, bandwidth, p, mc_type, device):
    check_input(f, bandwidth)
    if f.shape[1] == 1:
        return 2 * get_ratio_binary(f, y, bandwidth, p, device)
    else:
        if mc_type == 'marginal':
            return get_ratio_marginal_vect(f, y, bandwidth, p, device)
        elif mc_type == 'canonical':
            return get_ratio_canonical(f, y, bandwidth, p, device)
        elif mc_type == 'top_label':
            return get_ratio_toplabel(f, y, bandwidth, p, device)


def get_kernel(f, bandwidth, device):
    # if num_classes == 1
    if f.shape[1] == 1:
        log_kern = beta_kernel(f, f, bandwidth).squeeze()
    else:
        log_kern = dirichlet_kernel(f, bandwidth).squeeze()
    # Trick: -inf on the diagonal
    return log_kern + torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)


def get_sharpness_ratio(y, log_kern, N):
    # Select the entries where y = 1
    idx = torch.where(y == 1)[0]
    if not idx.numel():
        return 0

    if idx.numel() == 1:
        # because of -inf in the vector
        log_kern = torch.cat((log_kern[:idx], log_kern[idx+1:]))
        N -= 1

    log_kern_y = torch.index_select(log_kern, 1, idx)

    log_num_s1 = torch.logsumexp(log_kern_y, dim=1)
    log_num_s2 = torch.logsumexp(2 * log_kern_y, dim=1)
    log_den_s1 = torch.logsumexp(log_kern, dim=1)
    log_den_s2 = torch.logsumexp(2 * log_kern, dim=1)

    num = torch.exp(2 * log_num_s1) - torch.exp(log_num_s2)
    den = torch.exp(2 * log_den_s1) - torch.exp(log_den_s2)
    den = torch.clamp(den, min=eps)

    if isnan(log_num_s1) + isnan(log_num_s2) + isnan(log_den_s1) + isnan(log_den_s2) + isnan(num) + isnan(den):
        logging.warning("NAN!")

    ratio = num / den
    if isnan(ratio):
        logging.warning("NAN in ratio!")
        ratio = torch.nan_to_num(ratio)

    if isinf(ratio):
        logging.warning("INF in ratio!")
        ratio[torch.isinf(ratio)] = 0

    return torch.mean(ratio)


def get_ratio_binary(f, y, bandwidth, p, device):
    assert f.shape[1] == 1

    log_kern = get_kernel(f, bandwidth, device)

    return get_kde_for_ece(f, y, log_kern, p)


def get_ratio_marginal(f, y, bandwidth, p, device):
    ratio = 0
    for i in range(f.shape[1]):
        binary_y = (y == i).long()
        ratio += get_ratio_binary(f[:, i].unsqueeze(-1), binary_y, bandwidth, p, device).squeeze()

    return ratio


def get_ratio_marginal_vect(f, y, bandwidth, p, device):
    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    log_kern_vect = beta_kernel(f, f, bandwidth).squeeze()
    log_kern_diag = torch.diag(torch.finfo(torch.float).min * torch.ones(len(f))).to(device)
    # Multiclass case
    log_kern_diag_repeated = f.shape[1] * [log_kern_diag]
    log_kern_diag_repeated = torch.stack(log_kern_diag_repeated, dim=2)
    log_kern_vect = log_kern_vect + log_kern_diag_repeated

    return get_kde_for_ece_vect(f, y_onehot, log_kern_vect, p)


def get_kde_for_ece_vect(f, y, log_kern, p):
    log_kern_y = log_kern * y
    # Trick: -inf instead of 0 in log space
    log_kern_y[log_kern_y == 0] = torch.finfo(torch.float).min

    log_num = torch.logsumexp(log_kern_y, dim=1)
    log_den = torch.logsumexp(log_kern, dim=1)

    num = torch.exp(log_num)
    den = torch.exp(log_den)
    den = torch.clamp(den, min=eps)

    ratio = num / den
    ratio = torch.abs(ratio - f)**p

    return torch.sum(torch.mean(ratio, dim=0))


def get_ratio_toplabel(f, y, bandwidth, p, device):
    f_max, indices = torch.max(f, 1)
    f_max = f_max.unsqueeze(-1)
    y_max = (y == indices).to(torch.int)

    return get_ratio_binary(f_max, y_max, bandwidth, p, device)


def get_ratio_canonical(f, y, bandwidth, p, device):
    log_kern = get_kernel(f, bandwidth, device)
    kern = torch.exp(log_kern)

    y_onehot = nn.functional.one_hot(y, num_classes=f.shape[1]).to(torch.float32)
    kern_y = torch.matmul(kern, y_onehot)
    den = torch.sum(kern, dim=1)
    # to avoid division by 0
    den = torch.clamp(den, min=eps)

    ratio = kern_y / den.unsqueeze(-1)
    ratio = torch.sum(torch.abs(ratio - f)**p, dim=1)

    return torch.mean(ratio)


def get_kde_for_ece(f, y, log_kern, p):
    f = f.squeeze()
    N = len(f)
    # Select the entries where y = 1
    idx = torch.where(y == 1)[0]
    if not idx.numel():
        return torch.sum((torch.abs(-f))**p) / N

    if idx.numel() == 1:
        # because of -inf in the vector
        log_kern = torch.cat((log_kern[:idx], log_kern[idx+1:]))
        f_one = f[idx]
        f = torch.cat((f[:idx], f[idx+1:]))

    log_kern_y = torch.index_select(log_kern, 1, idx)

    log_num = torch.logsumexp(log_kern_y, dim=1)
    log_den = torch.logsumexp(log_kern, dim=1)

    num = torch.exp(log_num)
    den = torch.exp(log_den)
    den = torch.clamp(den, min=eps)

    ratio = num / den
    ratio = torch.abs(ratio - f)**p

    if idx.numel() == 1:
        return (ratio.sum() + f_one ** p)/N

    return torch.mean(ratio)


def beta_kernel(z, zi, bandwidth=0.1):
    p = zi / bandwidth + 1
    q = (1-zi) / bandwidth + 1
    z = z.unsqueeze(-2)

    log_beta = torch.lgamma(p) + torch.lgamma(q) - torch.lgamma(p + q)
    log_num = (p-1) * torch.log(z) + (q-1) * torch.log(1-z)
    log_beta_pdf = log_num - log_beta

    return log_beta_pdf


def dirichlet_kernel(z, bandwidth=0.1):
    alphas = z / bandwidth + 1

    log_beta = (torch.sum((torch.lgamma(alphas)), dim=1) - torch.lgamma(torch.sum(alphas, dim=1)))
    log_num = torch.matmul(torch.log(z), (alphas-1).T)
    log_dir_pdf = log_num - log_beta

    return log_dir_pdf


def get_bandwidth(b, f, device):
    if b == 'auto':
        return select_bandwidth(f, device)
    return float(b)


def select_bandwidth(f, device):
    bandwidths = torch.cat((torch.logspace(start=-5, end=-1, steps=30), torch.linspace(0.2, 1, steps=5)))
    max_b = -1
    max_l = 0
    n = len(f)
    for b in bandwidths:
        log_kern = get_kernel(f, b, device)
        log_fhat = torch.logsumexp(log_kern, 1) - torch.log((n-1)*b)
        l = torch.sum(log_fhat)
        if l > max_l:
            max_l = l
            max_b = b

    return max_b


def check_input(f, bandwidth):
    assert not isnan(f)
    assert len(f.shape) == 2
    assert bandwidth > 0
    assert torch.min(f) >= 0
    assert torch.max(f) <= 1


def isnan(a):
    return torch.any(torch.isnan(a))


def isinf(a):
    return torch.any(torch.isinf(a))
