import torch
from uce import nentr


def torch_kernel(matrix):  # Laplacian kernel
    return torch.exp(-1.0*torch.abs(matrix[:, :, 0] - matrix[:, :, 1])/(2*0.2))


def get_out_tensor_torch(tensor1, tensor2):
    return torch.mean(tensor1*tensor2)


def get_pairs(tensor1, tensor2):
    correct_prob_tiled = tensor1.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)
    incorrect_prob_tiled = tensor2.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)
    correct_prob_pairs = torch.cat([correct_prob_tiled, correct_prob_tiled.permute(1, 0, 2)], dim=2)
    incorrect_prob_pairs = torch.cat([incorrect_prob_tiled, incorrect_prob_tiled.permute(1, 0, 2)], dim=2)
    correct_prob_tiled_1 = tensor1.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)
    incorrect_prob_tiled_1 = tensor2.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)
    correct_incorrect_pairs = torch.cat([correct_prob_tiled_1, incorrect_prob_tiled_1.permute(1, 0, 2)],
                                        dim=2)
    return correct_prob_pairs, incorrect_prob_pairs, correct_incorrect_pairs


def mmce_loss(softmaxes, correct_labels):
    """Function to compute MMCE_m loss."""
    d = softmaxes.device
    predicted_probs = softmaxes
    pred_labels = torch.argmax(predicted_probs, 1)
    predicted_probs, _ = torch.max(predicted_probs, 1)

    correct_mask = torch.where(pred_labels == correct_labels,
                               torch.ones(pred_labels.size(), device=d),
                               torch.zeros(pred_labels.size(), device=d))
    c_minus_r = correct_mask.float() - predicted_probs
    dot_product = torch.matmul(c_minus_r.view(-1, 1),
                               torch.transpose(c_minus_r.view(-1, 1), 1, 0))
    tensor1 = predicted_probs
    prob_tiled = tensor1.view(-1, 1).repeat(1, tensor1.size(0)).unsqueeze(2)
    prob_pairs = torch.cat([prob_tiled, prob_tiled.permute(1, 0, 2)], axis=2)

    kernel_prob_pairs = torch_kernel(prob_pairs)
    numerator = dot_product*kernel_prob_pairs
    return torch.sum(numerator)/torch.square(torch.tensor(correct_mask.size(0)))


def calibration_mmce_w_loss(softmaxes, correct_labels):
    """Function to compute the MMCE_w loss."""
    d = softmaxes.device
    predicted_probs = softmaxes
    range_index = torch.arange(0, predicted_probs.size(0)).view(-1, 1).long()
    predicted_labels = torch.argmax(predicted_probs, 1)

    # gather_index = torch.cat([range_index, predicted_labels.view(-1, 1)], axis=1)
    predicted_probs, _ = torch.max(predicted_probs, 1)

    correct_mask = torch.where(predicted_labels == correct_labels,
                               torch.ones_like(correct_labels),
                               torch.zeros_like(correct_labels))

    k = torch.sum(correct_mask)
    k_p = torch.sum(1.0 - correct_mask)

    cond_k = torch.tensor(0, device=d) if k == 0 else torch.tensor(1, device=d)
    cond_k_p = torch.tensor(0, device=d) if k_p == 0 else torch.tensor(1, device=d)

    k = torch.max(torch.tensor([k, 1], device=d)) * cond_k * cond_k_p + (1 - cond_k * cond_k_p) * 2
    k_p = torch.max(torch.tensor([k_p, 1], device=d)) * cond_k_p * cond_k + (
                (1 - cond_k_p * cond_k) * (correct_mask.size(0) - 2))
    correct_prob, _ = torch.topk(predicted_probs * correct_mask, k.long().item())
    incorrect_prob, _ = torch.topk(predicted_probs * (1 - correct_mask), k_p.long().item())

    correct_prob_pairs, incorrect_prob_pairs, \
    correct_incorrect_pairs = get_pairs(correct_prob, incorrect_prob)
    correct_kernel = torch_kernel(correct_prob_pairs)
    incorrect_kernel = torch_kernel(incorrect_prob_pairs)
    correct_incorrect_kernel = torch_kernel(correct_incorrect_pairs)
    sampling_weights_correct = torch.matmul((1.0 - correct_prob).view(-1, 1),
                                            (1.0 - correct_prob).view(-1, 1).permute(1, 0))
    correct_correct_vals = get_out_tensor_torch(correct_kernel,
                                                sampling_weights_correct)
    sampling_weights_incorrect = torch.matmul(incorrect_prob.view(-1, 1),
                                              incorrect_prob.view(-1, 1).permute(1, 0))
    incorrect_incorrect_vals = get_out_tensor_torch(incorrect_kernel,
                                                    sampling_weights_incorrect)
    sampling_correct_incorrect = torch.matmul((1.0 - correct_prob).view(-1, 1),
                                              incorrect_prob.view(-1, 1).permute(1, 0))
    correct_incorrect_vals = get_out_tensor_torch(correct_incorrect_kernel,
                                                  sampling_correct_incorrect)
    # correct_denom = torch.sum(1.0 - correct_prob)
    # incorrect_denom = torch.sum(incorrect_prob)
    m = torch.sum(correct_mask)
    n = torch.sum(1.0 - correct_mask)
    mmd_error = 1.0 / (m * m + 1e-5) * torch.sum(correct_correct_vals)
    mmd_error += 1.0 / (n * n + 1e-5) * torch.sum(incorrect_incorrect_vals)
    mmd_error -= 2.0 / (m * n + 1e-5) * torch.sum(correct_incorrect_vals)
    return torch.max(
        torch.tensor([(cond_k * cond_k_p).float().detach() * torch.sqrt(mmd_error + 1e-10), 0.0], device=d))


def mmce_loss_entr(softmaxes, correct_labels):
    """Function to compute MMCE_m loss with normalized entropy."""
    d = softmaxes.device
    predicted_probs = softmaxes
    pred_labels = torch.argmax(predicted_probs, dim=1)
    predicted_probs = nentr(predicted_probs, base=softmaxes.size(1))
    predicted_probs = torch.ones_like(predicted_probs) - predicted_probs

    correct_mask = torch.where(pred_labels == correct_labels,
                               torch.ones(pred_labels.size(), device=d),
                               torch.zeros(pred_labels.size(), device=d))
    c_minus_r = correct_mask.float() - predicted_probs
    dot_product = torch.matmul(c_minus_r.view(-1, 1),
                               torch.transpose(c_minus_r.view(-1, 1), 1, 0))
    tensor1 = predicted_probs
    prob_tiled = tensor1.view(-1, 1).repeat(1, tensor1.size(0)).unsqueeze(2)
    prob_pairs = torch.cat([prob_tiled, prob_tiled.permute(1, 0, 2)], axis=2)

    def tf_kernel(matrix):  # Laplacian kernel
        return torch.exp(-1.0*torch.abs(matrix[:, :, 0] - matrix[:, :, 1])/(2*0.2))

    kernel_prob_pairs = tf_kernel(prob_pairs)
    numerator = dot_product*kernel_prob_pairs
    return torch.sum(numerator)/torch.square(torch.tensor(correct_mask.size(0)))


def calibration_mmce_w_loss_entr(softmaxes, correct_labels):
    """Function to compute the MMCE_w loss with entropy."""
    d = softmaxes.device
    predicted_probs = softmaxes
    range_index = torch.arange(0, predicted_probs.size(0)).view(-1, 1).long()
    predicted_labels = torch.argmax(predicted_probs, 1)

    # gather_index = torch.cat([range_index, predicted_labels.view(-1, 1)], axis=1)
    predicted_probs = nentr(predicted_probs, base=softmaxes.size(1))
    predicted_probs = torch.ones_like(predicted_probs) - predicted_probs

    correct_mask = torch.where(predicted_labels == correct_labels,
                               torch.ones_like(correct_labels),
                               torch.zeros_like(correct_labels))

    k = torch.sum(correct_mask)
    k_p = torch.sum(1.0 - correct_mask)

    cond_k = torch.tensor(0, device=d) if k == 0 else torch.tensor(1, device=d)
    cond_k_p = torch.tensor(0, device=d) if k_p == 0 else torch.tensor(1, device=d)

    k = torch.max(torch.tensor([k, 1], device=d)) * cond_k * cond_k_p + (1 - cond_k * cond_k_p) * 2
    k_p = torch.max(torch.tensor([k_p, 1], device=d)) * cond_k_p * cond_k + (
                (1 - cond_k_p * cond_k) * (correct_mask.size(0) - 2))
    correct_prob, _ = torch.topk(predicted_probs * correct_mask, k.long().item())
    incorrect_prob, _ = torch.topk(predicted_probs * (1 - correct_mask), k_p.long().item())

    correct_prob_pairs, incorrect_prob_pairs, \
    correct_incorrect_pairs = get_pairs(correct_prob, incorrect_prob)
    correct_kernel = torch_kernel(correct_prob_pairs)
    incorrect_kernel = torch_kernel(incorrect_prob_pairs)
    correct_incorrect_kernel = torch_kernel(correct_incorrect_pairs)
    sampling_weights_correct = torch.matmul((1.0 - correct_prob).view(-1, 1),
                                            (1.0 - correct_prob).view(-1, 1).permute(1, 0))
    correct_correct_vals = get_out_tensor_torch(correct_kernel,
                                                sampling_weights_correct)
    sampling_weights_incorrect = torch.matmul(incorrect_prob.view(-1, 1),
                                              incorrect_prob.view(-1, 1).permute(1, 0))
    incorrect_incorrect_vals = get_out_tensor_torch(incorrect_kernel,
                                                    sampling_weights_incorrect)
    sampling_correct_incorrect = torch.matmul((1.0 - correct_prob).view(-1, 1),
                                              incorrect_prob.view(-1, 1).permute(1, 0))
    correct_incorrect_vals = get_out_tensor_torch(correct_incorrect_kernel,
                                                  sampling_correct_incorrect)
    # correct_denom = torch.sum(1.0 - correct_prob)
    # incorrect_denom = torch.sum(incorrect_prob)
    m = torch.sum(correct_mask)
    n = torch.sum(1.0 - correct_mask)
    mmd_error = 1.0 / (m * m + 1e-5) * torch.sum(correct_correct_vals)
    mmd_error += 1.0 / (n * n + 1e-5) * torch.sum(incorrect_incorrect_vals)
    mmd_error -= 2.0 / (m * n + 1e-5) * torch.sum(correct_incorrect_vals)
    return torch.max(
        torch.tensor([(cond_k * cond_k_p).float().detach() * torch.sqrt(mmd_error + 1e-10), 0.0], device=d))
