{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WtMbwLcp41AG"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.compat.v1 import to_float, to_int64, to_int32\n",
    "import numpy as np\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "x4h02Pr944et"
   },
   "outputs": [],
   "source": [
    "def tf_kernel(matrix):\n",
    "    return tf.exp(-1.0*tf.abs(matrix[:, :, 0] - matrix[:, :, 1])/(2*0.2))\n",
    "\n",
    "def get_out_tensor(tensor1, tensor2):\n",
    "    return tf.reduce_mean(tensor1*tensor2)\n",
    "\n",
    "def calibration_unbiased_loss(logits, correct_labels):\n",
    "    \"\"\"Function to compute MMCE_m loss.\"\"\"  \n",
    "    predicted_probs = logits  # tf.nn.softmax(logits)\n",
    "    pred_labels = tf.argmax(predicted_probs, 1)\n",
    "    predicted_probs = tf.reduce_max(predicted_probs, 1)\n",
    "\n",
    "    correct_mask = tf.where(tf.equal(pred_labels, correct_labels),\n",
    "                          tf.ones(tf.shape(pred_labels)),\n",
    "                          tf.zeros(tf.shape(pred_labels)))\n",
    "    c_minus_r = to_float(correct_mask) - predicted_probs\n",
    "    dot_product = tf.matmul(tf.expand_dims(c_minus_r, 1),\n",
    "                          tf.transpose(tf.expand_dims(c_minus_r, 1)))\n",
    "    tensor1 = predicted_probs\n",
    "    prob_tiled = tf.expand_dims(tf.tile(tf.expand_dims(tensor1, 1),\n",
    "                              [1, tf.shape(tensor1)[0]]), 2)\n",
    "    prob_pairs = tf.concat([prob_tiled, tf.transpose(prob_tiled, [1, 0, 2])],\n",
    "                         axis=2)\n",
    "\n",
    "    kernel_prob_pairs = tf_kernel(prob_pairs)\n",
    "    numerator = dot_product*kernel_prob_pairs\n",
    "    return tf.reduce_sum(numerator)/tf.square(to_float(tf.shape(correct_mask)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calibration_mmce_w_loss(softmaxes, correct_labels):\n",
    "    \"\"\"Function to compute the MMCE_w loss.\"\"\"\n",
    "    predicted_probs = softmaxes\n",
    "    range_index = to_int64(tf.expand_dims(tf.range(0,\n",
    "                                        tf.shape(predicted_probs)[0]), 1))\n",
    "    predicted_labels = tf.argmax(predicted_probs, axis=1)\n",
    "    gather_index = tf.concat([range_index,\n",
    "                            tf.expand_dims(predicted_labels, 1)], axis=1)\n",
    "    predicted_probs = tf.reduce_max(predicted_probs, 1)\n",
    "\n",
    "    correct_mask = tf.where(tf.equal(correct_labels, predicted_labels),\n",
    "                          tf.ones(tf.shape(correct_labels)),\n",
    "                          tf.zeros(tf.shape(correct_labels))) \n",
    "\n",
    "    k = to_int32(tf.reduce_sum(correct_mask))\n",
    "    k_p = to_int32(tf.reduce_sum(1.0 - correct_mask))\n",
    "    cond_k = tf.where(tf.equal(k, 0), 0, 1)\n",
    "    cond_k_p = tf.where(tf.equal(k_p, 0), 0, 1)\n",
    "    k = tf.maximum(k, 1)*cond_k*cond_k_p + (1 - cond_k*cond_k_p)*2 \n",
    "    k_p = tf.maximum(k_p, 1)*cond_k_p*cond_k + ((1 - cond_k_p*cond_k)*\n",
    "                                            (tf.shape(correct_mask)[0] - 2))\n",
    "    correct_prob, _ = tf.nn.top_k(predicted_probs*correct_mask, k)\n",
    "    incorrect_prob, _ = tf.nn.top_k(predicted_probs*(1 - correct_mask), k_p)\n",
    "\n",
    "    def get_pairs(tensor1, tensor2):\n",
    "        correct_prob_tiled = tf.expand_dims(tf.tile(tf.expand_dims(tensor1, 1),\n",
    "                          [1, tf.shape(tensor1)[0]]), 2)\n",
    "        incorrect_prob_tiled = tf.expand_dims(tf.tile(tf.expand_dims(tensor2, 1),\n",
    "                          [1, tf.shape(tensor2)[0]]), 2)\n",
    "        correct_prob_pairs = tf.concat([correct_prob_tiled,\n",
    "                         tf.transpose(correct_prob_tiled, [1, 0, 2])],\n",
    "                         axis=2)\n",
    "        incorrect_prob_pairs = tf.concat([incorrect_prob_tiled,\n",
    "                       tf.transpose(incorrect_prob_tiled, [1, 0, 2])],\n",
    "                       axis=2)\n",
    "        correct_prob_tiled_1 = tf.expand_dims(tf.tile(tf.expand_dims(tensor1, 1),\n",
    "                            [1, tf.shape(tensor2)[0]]), 2)\n",
    "        incorrect_prob_tiled_1 = tf.expand_dims(tf.tile(tf.expand_dims(tensor2, 1),\n",
    "                            [1, tf.shape(tensor1)[0]]), 2)\n",
    "        correct_incorrect_pairs = tf.concat([correct_prob_tiled_1,\n",
    "                      tf.transpose(incorrect_prob_tiled_1, [1, 0, 2])],\n",
    "                      axis=2)\n",
    "        return correct_prob_pairs, incorrect_prob_pairs, correct_incorrect_pairs\n",
    "\n",
    "    correct_prob_pairs, incorrect_prob_pairs,\\\n",
    "               correct_incorrect_pairs = get_pairs(correct_prob, incorrect_prob)\n",
    "    correct_kernel = tf_kernel(correct_prob_pairs)\n",
    "    incorrect_kernel = tf_kernel(incorrect_prob_pairs)\n",
    "    correct_incorrect_kernel = tf_kernel(correct_incorrect_pairs)  \n",
    "    sampling_weights_correct = tf.matmul(tf.expand_dims(1.0 - correct_prob, 1),\n",
    "                           tf.transpose(tf.expand_dims(1.0 - correct_prob, 1)))\n",
    "    correct_correct_vals = get_out_tensor(correct_kernel,\n",
    "                                                    sampling_weights_correct)\n",
    "    sampling_weights_incorrect = tf.matmul(tf.expand_dims(incorrect_prob, 1),\n",
    "                           tf.transpose(tf.expand_dims(incorrect_prob, 1)))\n",
    "    incorrect_incorrect_vals = get_out_tensor(incorrect_kernel,\n",
    "                                                    sampling_weights_incorrect)\n",
    "    sampling_correct_incorrect = tf.matmul(tf.expand_dims(1.0 - correct_prob, 1),\n",
    "                           tf.transpose(tf.expand_dims(incorrect_prob, 1)))\n",
    "    correct_incorrect_vals = get_out_tensor(correct_incorrect_kernel,\n",
    "                                                    sampling_correct_incorrect)\n",
    "    correct_denom = tf.reduce_sum(1.0 - correct_prob)\n",
    "    incorrect_denom = tf.reduce_sum(incorrect_prob)\n",
    "    m = tf.reduce_sum(correct_mask)\n",
    "    n = tf.reduce_sum(1.0 - correct_mask)\n",
    "    mmd_error = 1.0/(m*m + 1e-5) * tf.reduce_sum(correct_correct_vals) \n",
    "    mmd_error += 1.0/(n*n + 1e-5) * tf.reduce_sum(incorrect_incorrect_vals)\n",
    "    mmd_error -= 2.0/(m*n + 1e-5) * tf.reduce_sum(correct_incorrect_vals)\n",
    "    return tf.maximum(tf.stop_gradient(to_float(cond_k*cond_k_p))*\\\n",
    "                                     tf.sqrt(mmd_error + 1e-10), 0.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def nentr(p, base=None):\n",
    "    \"\"\"\n",
    "    Calculates entropy of p to the base b. If base is None, the natural logarithm is used.\n",
    "    :param p: batches of class label probability distributions (softmax output)\n",
    "    :param base: base b\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    eps = torch.tensor([1e-16], device=p.device)\n",
    "    if base:\n",
    "        base = torch.tensor([base], device=p.device, dtype=torch.float32)\n",
    "        return (p.mul(p.add(eps).log().div(base.log()))).sum(dim=1).abs()\n",
    "    else:\n",
    "        return (p.mul(p.add(eps).log())).sum(dim=1).abs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gg8GlZwf-nwd"
   },
   "outputs": [],
   "source": [
    "def torch_kernel(matrix):  # Laplacian kernel\n",
    "    return torch.exp(-1.0*torch.abs(matrix[:, :, 0] - matrix[:, :, 1])/(2*0.2))\n",
    "\n",
    "def get_out_tensor_torch(tensor1, tensor2):\n",
    "    return torch.mean(tensor1*tensor2)\n",
    "\n",
    "def calibration_unbiased_loss_torch(softmaxes, correct_labels):\n",
    "    \"\"\"Function to compute MMCE_m loss.\"\"\"  \n",
    "    predicted_probs = softmaxes\n",
    "    pred_labels = torch.argmax(predicted_probs, 1).detach()\n",
    "    predicted_probs, _ = torch.max(predicted_probs, 1)\n",
    "\n",
    "    correct_mask = torch.where(pred_labels == correct_labels,\n",
    "                             torch.ones_like(pred_labels),\n",
    "                             torch.zeros_like(pred_labels))\n",
    "    c_minus_r = correct_mask.float() - predicted_probs\n",
    "    dot_product = torch.matmul(c_minus_r.view(-1, 1),\n",
    "                             torch.transpose(c_minus_r.view(-1, 1), 1, 0))\n",
    "    tensor1 = predicted_probs\n",
    "    prob_tiled = tensor1.view(-1, 1).repeat(1, tensor1.size(0)).unsqueeze(2)\n",
    "    prob_pairs = torch.cat([prob_tiled, prob_tiled.permute(1, 0, 2)], axis=2)\n",
    "\n",
    "    kernel_prob_pairs = torch_kernel(prob_pairs)\n",
    "    numerator = dot_product*kernel_prob_pairs\n",
    "    return torch.sum(numerator)/torch.square(torch.tensor(correct_mask.size(0)))\n",
    "\n",
    "def calibration_unbiased_loss_torch2(softmaxes, correct_labels):\n",
    "    \"\"\"Function to compute MMCE_m loss with normalized entropy.\"\"\"\n",
    "    d = softmaxes.device\n",
    "    predicted_probs = softmaxes\n",
    "    pred_labels = torch.argmax(predicted_probs, dim=1).detach()\n",
    "    predicted_probs = nentr(predicted_probs, base=softmaxes.size(1))\n",
    "    predicted_probs = torch.ones_like(predicted_probs) - predicted_probs\n",
    "\n",
    "    correct_mask = torch.where(pred_labels == correct_labels,\n",
    "                               torch.ones(pred_labels.size(), device=d),\n",
    "                               torch.zeros(pred_labels.size(), device=d))\n",
    "    c_minus_r = correct_mask.float() - predicted_probs\n",
    "    dot_product = torch.matmul(c_minus_r.view(-1, 1),\n",
    "                               torch.transpose(c_minus_r.view(-1, 1), 1, 0))\n",
    "    tensor1 = predicted_probs\n",
    "    prob_tiled = tensor1.view(-1, 1).repeat(1, tensor1.size(0)).unsqueeze(2)\n",
    "    prob_pairs = torch.cat([prob_tiled, prob_tiled.permute(1, 0, 2)], axis=2)\n",
    "\n",
    "    kernel_prob_pairs = torch_kernel(prob_pairs)\n",
    "    numerator = dot_product*kernel_prob_pairs\n",
    "    return torch.sum(numerator)/torch.square(torch.tensor(correct_mask.size(0)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calibration_mmce_w_loss_torch(softmaxes, correct_labels):\n",
    "    \"\"\"Function to compute the MMCE_w loss.\"\"\"\n",
    "    d = softmaxes.device\n",
    "    predicted_probs = softmaxes\n",
    "    range_index = torch.arange(0, predicted_probs.size(0)).view(-1, 1).long()\n",
    "    predicted_labels = torch.argmax(predicted_probs, 1).detach()\n",
    "    \n",
    "    gather_index = torch.cat([range_index, predicted_labels.view(-1, 1)], axis=1)\n",
    "    predicted_probs, _ = torch.max(predicted_probs, 1)\n",
    "\n",
    "    correct_mask = torch.where(predicted_labels == correct_labels,\n",
    "                               torch.ones(correct_labels.size()),\n",
    "                               torch.zeros(correct_labels.size()))\n",
    "\n",
    "    k = torch.sum(correct_mask)\n",
    "    k_p = torch.sum(1.0 - correct_mask)\n",
    "    \n",
    "    cond_k = torch.tensor(0, device=d) if k == 0 else torch.tensor(1, device=d)\n",
    "    cond_k_p = torch.tensor(0, device=d) if k_p == 0 else torch.tensor(1, device=d)\n",
    "    \n",
    "    k = torch.max(torch.tensor([k, 1], device=d))*cond_k*cond_k_p + (1 - cond_k*cond_k_p)*2\n",
    "    k_p = torch.max(torch.tensor([k_p, 1], device=d))*cond_k_p*cond_k + ((1 - cond_k_p*cond_k)*(correct_mask.size(0) - 2))\n",
    "    correct_prob, _ = torch.topk(predicted_probs*correct_mask, k.long().item())\n",
    "    incorrect_prob, _ = torch.topk(predicted_probs*(1 - correct_mask), k_p.long().item())\n",
    "\n",
    "    def get_pairs_t(tensor1, tensor2):\n",
    "        correct_prob_tiled = tensor1.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)\n",
    "        incorrect_prob_tiled = tensor2.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)\n",
    "        correct_prob_pairs = torch.cat([correct_prob_tiled, correct_prob_tiled.permute(1, 0, 2)], dim=2)\n",
    "        incorrect_prob_pairs = torch.cat([incorrect_prob_tiled, incorrect_prob_tiled.permute(1, 0, 2)], dim=2)\n",
    "        correct_prob_tiled_1 = tensor1.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)\n",
    "        incorrect_prob_tiled_1 = tensor2.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)\n",
    "        correct_incorrect_pairs = torch.cat([correct_prob_tiled_1, incorrect_prob_tiled_1.permute(1, 0, 2)],\n",
    "                                            dim=2)\n",
    "        return correct_prob_pairs, incorrect_prob_pairs, correct_incorrect_pairs\n",
    "\n",
    "    correct_prob_pairs, incorrect_prob_pairs,\\\n",
    "               correct_incorrect_pairs = get_pairs_t(correct_prob, incorrect_prob)\n",
    "    correct_kernel = torch_kernel(correct_prob_pairs)\n",
    "    incorrect_kernel = torch_kernel(incorrect_prob_pairs)\n",
    "    correct_incorrect_kernel = torch_kernel(correct_incorrect_pairs)  \n",
    "    sampling_weights_correct = torch.matmul((1.0 - correct_prob).view(-1, 1),\n",
    "                                            (1.0 - correct_prob).view(-1, 1).permute(1, 0))\n",
    "    correct_correct_vals = get_out_tensor_torch(correct_kernel,\n",
    "                                                sampling_weights_correct)\n",
    "    sampling_weights_incorrect = torch.matmul(incorrect_prob.view(-1, 1),\n",
    "                                              incorrect_prob.view(-1, 1).permute(1, 0))\n",
    "    incorrect_incorrect_vals = get_out_tensor_torch(incorrect_kernel,\n",
    "                                                    sampling_weights_incorrect)\n",
    "    sampling_correct_incorrect = torch.matmul((1.0 - correct_prob).view(-1, 1),\n",
    "                                              incorrect_prob.view(-1, 1).permute(1, 0))\n",
    "    correct_incorrect_vals = get_out_tensor_torch(correct_incorrect_kernel,\n",
    "                                                  sampling_correct_incorrect)\n",
    "    correct_denom = torch.sum(1.0 - correct_prob)\n",
    "    incorrect_denom = torch.sum(incorrect_prob)\n",
    "    m = torch.sum(correct_mask)\n",
    "    n = torch.sum(1.0 - correct_mask)\n",
    "    mmd_error = 1.0/(m*m + 1e-5) * torch.sum(correct_correct_vals) \n",
    "    mmd_error += 1.0/(n*n + 1e-5) * torch.sum(incorrect_incorrect_vals)\n",
    "    mmd_error -= 2.0/(m*n + 1e-5) * torch.sum(correct_incorrect_vals)\n",
    "    return torch.max(torch.tensor([(cond_k*cond_k_p).float().detach()*torch.sqrt(mmd_error + 1e-10), 0.0], device=d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calibration_mmce_w_loss_torch2(softmaxes, correct_labels):\n",
    "    \"\"\"Function to compute the MMCE_w loss.\"\"\"\n",
    "    d = softmaxes.device\n",
    "    predicted_probs = softmaxes\n",
    "    range_index = torch.arange(0, predicted_probs.size(0)).view(-1, 1).long()\n",
    "    predicted_labels = torch.argmax(predicted_probs, 1).detach()\n",
    "    \n",
    "    gather_index = torch.cat([range_index, predicted_labels.view(-1, 1)], axis=1)\n",
    "    predicted_probs = nentr(predicted_probs, base=softmaxes.size(1))\n",
    "    predicted_probs = torch.ones_like(predicted_probs) - predicted_probs\n",
    "\n",
    "    correct_mask = torch.where(predicted_labels == correct_labels,\n",
    "                               torch.ones(correct_labels.size()),\n",
    "                               torch.zeros(correct_labels.size()))\n",
    "\n",
    "    k = torch.sum(correct_mask)\n",
    "    k_p = torch.sum(1.0 - correct_mask)\n",
    "    \n",
    "    cond_k = torch.tensor(0, device=d) if k == 0 else torch.tensor(1, device=d)\n",
    "    cond_k_p = torch.tensor(0, device=d) if k_p == 0 else torch.tensor(1, device=d)\n",
    "    \n",
    "    k = torch.max(torch.tensor([k, 1], device=d))*cond_k*cond_k_p + (1 - cond_k*cond_k_p)*2\n",
    "    k_p = torch.max(torch.tensor([k_p, 1], device=d))*cond_k_p*cond_k + ((1 - cond_k_p*cond_k)*(correct_mask.size(0) - 2))\n",
    "    correct_prob, _ = torch.topk(predicted_probs*correct_mask, k.long().item())\n",
    "    incorrect_prob, _ = torch.topk(predicted_probs*(1 - correct_mask), k_p.long().item())\n",
    "\n",
    "    def get_pairs_t(tensor1, tensor2):\n",
    "        correct_prob_tiled = tensor1.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)\n",
    "        incorrect_prob_tiled = tensor2.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)\n",
    "        correct_prob_pairs = torch.cat([correct_prob_tiled, correct_prob_tiled.permute(1, 0, 2)], dim=2)\n",
    "        incorrect_prob_pairs = torch.cat([incorrect_prob_tiled, incorrect_prob_tiled.permute(1, 0, 2)], dim=2)\n",
    "        correct_prob_tiled_1 = tensor1.view(-1, 1).repeat([1, tensor2.size(0)]).unsqueeze(2)\n",
    "        incorrect_prob_tiled_1 = tensor2.view(-1, 1).repeat([1, tensor1.size(0)]).unsqueeze(2)\n",
    "        correct_incorrect_pairs = torch.cat([correct_prob_tiled_1, incorrect_prob_tiled_1.permute(1, 0, 2)],\n",
    "                                            dim=2)\n",
    "        return correct_prob_pairs, incorrect_prob_pairs, correct_incorrect_pairs\n",
    "\n",
    "    correct_prob_pairs, incorrect_prob_pairs,\\\n",
    "               correct_incorrect_pairs = get_pairs_t(correct_prob, incorrect_prob)\n",
    "    correct_kernel = torch_kernel(correct_prob_pairs)\n",
    "    incorrect_kernel = torch_kernel(incorrect_prob_pairs)\n",
    "    correct_incorrect_kernel = torch_kernel(correct_incorrect_pairs)  \n",
    "    sampling_weights_correct = torch.matmul((1.0 - correct_prob).view(-1, 1),\n",
    "                                            (1.0 - correct_prob).view(-1, 1).permute(1, 0))\n",
    "    correct_correct_vals = get_out_tensor_torch(correct_kernel,\n",
    "                                                sampling_weights_correct)\n",
    "    sampling_weights_incorrect = torch.matmul(incorrect_prob.view(-1, 1),\n",
    "                                              incorrect_prob.view(-1, 1).permute(1, 0))\n",
    "    incorrect_incorrect_vals = get_out_tensor_torch(incorrect_kernel,\n",
    "                                                    sampling_weights_incorrect)\n",
    "    sampling_correct_incorrect = torch.matmul((1.0 - correct_prob).view(-1, 1),\n",
    "                                              incorrect_prob.view(-1, 1).permute(1, 0))\n",
    "    correct_incorrect_vals = get_out_tensor_torch(correct_incorrect_kernel,\n",
    "                                                  sampling_correct_incorrect)\n",
    "    correct_denom = torch.sum(1.0 - correct_prob)\n",
    "    incorrect_denom = torch.sum(incorrect_prob)\n",
    "    m = torch.sum(correct_mask)\n",
    "    n = torch.sum(1.0 - correct_mask)\n",
    "    mmd_error = 1.0/(m*m + 1e-5) * torch.sum(correct_correct_vals) \n",
    "    mmd_error += 1.0/(n*n + 1e-5) * torch.sum(incorrect_incorrect_vals)\n",
    "    mmd_error -= 2.0/(m*n + 1e-5) * torch.sum(correct_incorrect_vals)\n",
    "    return torch.max(torch.tensor([(cond_k*cond_k_p).float().detach()*torch.sqrt(mmd_error + 1e-10), 0.0], device=d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7Y9gQx6V-L9_"
   },
   "outputs": [],
   "source": [
    "def uceloss(softmaxes, labels, n_bins=15):\n",
    "    d = softmaxes.device\n",
    "    bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)\n",
    "    bin_lowers = bin_boundaries[:-1]\n",
    "    bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    _, predictions = torch.max(softmaxes, 1)\n",
    "    errors = predictions.ne(labels)\n",
    "    uncertainties = nentr(softmaxes, base=softmaxes.size(1))\n",
    "    errors_in_bin_list = []\n",
    "    avg_entropy_in_bin_list = []\n",
    "\n",
    "    uce = torch.zeros(1, device=d)\n",
    "    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n",
    "        # Calculate |uncert - err| in each bin\n",
    "        in_bin = uncertainties.gt(bin_lower.item()) * uncertainties.le(bin_upper.item())\n",
    "        prop_in_bin = in_bin.float().mean()  # |Bm| / n\n",
    "        if prop_in_bin.item() > 0.0:\n",
    "            errors_in_bin = errors[in_bin].float().mean()  # err()\n",
    "            avg_entropy_in_bin = uncertainties[in_bin].mean()  # uncert()\n",
    "            uce += torch.abs(avg_entropy_in_bin - errors_in_bin) * prop_in_bin\n",
    "\n",
    "            errors_in_bin_list.append(errors_in_bin)\n",
    "            avg_entropy_in_bin_list.append(avg_entropy_in_bin)\n",
    "\n",
    "    err_in_bin = torch.tensor(errors_in_bin_list, device=d)\n",
    "    avg_entropy_in_bin = torch.tensor(avg_entropy_in_bin_list, device=d)\n",
    "\n",
    "    return uce, err_in_bin, avg_entropy_in_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def uceloss_weighted(softmaxes, labels, n_bins=15):\n",
    "    d = softmaxes.device\n",
    "    bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)\n",
    "    bin_lowers = bin_boundaries[:-1]\n",
    "    bin_uppers = bin_boundaries[1:]\n",
    "\n",
    "    _, predictions = torch.max(softmaxes, 1)\n",
    "    errors = predictions.ne(labels)\n",
    "    uncertainties = nentr(softmaxes, base=softmaxes.size(1))\n",
    "    errors_in_bin_list = []\n",
    "    avg_entropy_in_bin_list = []\n",
    "\n",
    "    uce = torch.zeros(1, device=d)\n",
    "    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n",
    "        # Calculate |uncert - err| in each bin\n",
    "        in_bin = uncertainties.gt(bin_lower.item()) * uncertainties.le(bin_upper.item())\n",
    "        prop_in_bin = in_bin.float().mean()  # |Bm| / n\n",
    "        if prop_in_bin.item() > 0.0:\n",
    "            errors_in_bin = errors[in_bin].float().mean()  # err()\n",
    "            avg_entropy_in_bin = uncertainties[in_bin].mean()  # uncert()\n",
    "            uce += torch.abs(avg_entropy_in_bin - errors_in_bin) * prop_in_bin\n",
    "\n",
    "            errors_in_bin_list.append(errors_in_bin)\n",
    "            avg_entropy_in_bin_list.append(avg_entropy_in_bin)\n",
    "\n",
    "    err_in_bin = torch.tensor(errors_in_bin_list, device=d)\n",
    "    avg_entropy_in_bin = torch.tensor(avg_entropy_in_bin_list, device=d)\n",
    "\n",
    "    return uce, err_in_bin, avg_entropy_in_bin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LW1t5pxp4583"
   },
   "outputs": [],
   "source": [
    "labels_np = np.array([0.]*50 + [1.]*50)\n",
    "softmaxes_np = np.array([0.5, 0.5]*100).reshape(100,2)\n",
    "\n",
    "labels = tf.constant(labels_np, tf.int64)\n",
    "softmaxes = tf.constant(softmaxes_np, tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "Jacobian mismatch for output 0 with respect to input 0,\nnumerical:tensor([[ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999]], dtype=torch.float64)\nanalytical:tensor([[0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19]], dtype=torch.float64)\n",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-20-0abd1a5a6e34>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m gradcheck(calibration_unbiased_loss_torch,\n\u001b[1;32m      4\u001b[0m           (torch.tensor(softmaxes_np, requires_grad=True),\n\u001b[0;32m----> 5\u001b[0;31m            torch.tensor(labels_np, requires_grad=True)))\n\u001b[0m",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py\u001b[0m in \u001b[0;36mgradcheck\u001b[0;34m(func, inputs, eps, atol, rtol, raise_exception, check_sparse_nnz, nondet_tol)\u001b[0m\n\u001b[1;32m    289\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallclose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrtol\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0matol\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    290\u001b[0m                     return fail_test('Jacobian mismatch for output %d with respect to input %d,\\n'\n\u001b[0;32m--> 291\u001b[0;31m                                      'numerical:%s\\nanalytical:%s\\n' % (i, j, n, a))\n\u001b[0m\u001b[1;32m    292\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    293\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mreentrant\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py\u001b[0m in \u001b[0;36mfail_test\u001b[0;34m(msg)\u001b[0m\n\u001b[1;32m    227\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mfail_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    228\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mraise_exception\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 229\u001b[0;31m             \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    230\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Jacobian mismatch for output 0 with respect to input 0,\nnumerical:tensor([[ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 49.9998],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999],\n        [ 50.0000],\n        [-49.9999]], dtype=torch.float64)\nanalytical:tensor([[0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19],\n        [0.0000e+00],\n        [2.3039e-19]], dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "from torch.autograd import gradcheck\n",
    "\n",
    "gradcheck(calibration_unbiased_loss_torch,\n",
    "          (torch.tensor(softmaxes_np, requires_grad=True),\n",
    "           torch.tensor(labels_np, requires_grad=True)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 88
    },
    "colab_type": "code",
    "id": "g7-IMvPh5Ir8",
    "outputId": "c8fecce2-70ce-425c-c451-38a1ca3bc018"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.0\n",
      "0.24999999725216435\n"
     ]
    }
   ],
   "source": [
    "a = calibration_unbiased_loss(softmaxes, labels)\n",
    "print(a.numpy())\n",
    "b = calibration_unbiased_loss_torch(torch.tensor(softmaxes_np), torch.tensor(labels_np)).numpy()\n",
    "print(b)\n",
    "c = calibration_unbiased_loss_torch2(torch.tensor(softmaxes_np), torch.tensor(labels_np)).numpy()\n",
    "print(c)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "ewDR_Cab6XeQ",
    "outputId": "e4182dfb-3879-42ce-cdd7-ba133fd538fc"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.0114])"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "uceloss(torch.tensor(softmaxes_np), torch.tensor(labels_np))[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "Ut3vxrgX-58C",
    "outputId": "5b6c082a-f90d-4a0d-cecc-35e189c9ba13"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n"
     ]
    }
   ],
   "source": [
    "print(np.allclose(a, b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "a = calibration_mmce_w_loss(softmaxes, labels).numpy()\n",
    "b = calibration_mmce_w_loss_torch(torch.tensor(softmaxes_np), torch.tensor(labels_np)).numpy()\n",
    "print(np.allclose(a, b))\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "0.0\n"
     ]
    }
   ],
   "source": [
    "a = calibration_mmce_w_loss_torch(torch.tensor(softmaxes_np), torch.tensor(labels_np)).numpy()\n",
    "b = calibration_mmce_w_loss_torch2(torch.tensor(softmaxes_np), torch.tensor(labels_np)).numpy()\n",
    "print(a)\n",
    "print(b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "id": "xnQUauRvDkQf",
    "outputId": "35f7d38b-1298-4bc7-935f-c47a164de015"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "for i in np.arange(1, 99, step=1):\n",
    "    labels_np = np.array([0]*int(i) + [1]*(100-i))\n",
    "    logits_np = np.array([i/100, 1-i/100]*100).reshape(100,2)\n",
    "\n",
    "    labels = tf.constant(labels_np, tf.int64)\n",
    "    logits = tf.constant(logits_np, tf.float32)\n",
    "\n",
    "    a = calibration_unbiased_loss(logits, labels).numpy()\n",
    "    b = calibration_unbiased_loss_torch(torch.tensor(logits_np), torch.tensor(labels_np)).numpy()\n",
    "    print(np.allclose(a, b, rtol=1e-5, atol=1e-6))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "GhyJJSxlGcDp",
    "outputId": "0ea95245-5129-4c5a-c424-f1a49e19c40a"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(0.)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels_np = np.array([0]*75 + [1]*25)\n",
    "logits_np = np.array([0.75, 0.25]*100).reshape(100,2)\n",
    "\n",
    "labels = tf.constant(labels_np, tf.int64)\n",
    "logits = tf.constant(logits_np, tf.float32)\n",
    "\n",
    "calibration_unbiased_loss_torch(torch.tensor(logits_np), torch.tensor(labels_np)).numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "b2BP3rl5azpo"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.9899 0.0708\n",
      "0.4898 0.1214\n",
      "0.3230 0.1644\n",
      "0.2396 0.2023\n",
      "0.1895 0.2364\n",
      "0.1560 0.2674\n",
      "0.1321 0.2959\n",
      "0.1141 0.3222\n",
      "0.1001 0.3465\n",
      "0.0889 0.3690\n",
      "0.0797 0.3899\n",
      "0.0720 0.4094\n",
      "0.0654 0.4274\n",
      "0.0598 0.4442\n",
      "0.0549 0.4598\n",
      "0.0506 0.4743\n",
      "0.0468 0.4877\n",
      "0.0434 0.5001\n",
      "0.0403 0.5115\n",
      "0.0375 0.5219\n",
      "0.0350 0.5315\n",
      "0.0326 0.5402\n",
      "0.0305 0.5480\n",
      "0.0285 0.5550\n",
      "0.0267 0.5613\n",
      "0.0249 0.5667\n",
      "0.0233 0.5715\n",
      "0.0218 0.5755\n",
      "0.0204 0.5787\n",
      "0.0190 0.5813\n",
      "0.0178 0.5832\n",
      "0.0165 0.5844\n",
      "0.0154 0.5849\n",
      "0.0143 0.5848\n",
      "0.0132 0.5841\n",
      "0.0122 0.5827\n",
      "0.0112 0.5807\n",
      "0.0102 0.5780\n",
      "0.0092 0.5748\n",
      "0.0083 0.5710\n",
      "0.0074 0.5665\n",
      "0.0066 0.5615\n",
      "0.0057 0.5558\n",
      "0.0049 0.5496\n",
      "0.0040 0.5428\n",
      "0.0032 0.5354\n",
      "0.0024 0.5274\n",
      "0.0016 0.5188\n",
      "0.0008 0.5097\n",
      "0.0000 0.5000\n",
      "0.0008 0.5097\n",
      "0.0016 0.5188\n",
      "0.0024 0.5274\n",
      "0.0032 0.5354\n",
      "0.0040 0.5428\n",
      "0.0049 0.5496\n",
      "0.0057 0.5558\n",
      "0.0066 0.5615\n",
      "0.0074 0.5665\n",
      "0.0083 0.5710\n",
      "0.0092 0.5748\n",
      "0.0102 0.5780\n",
      "0.0112 0.5807\n",
      "0.0122 0.5827\n",
      "0.0132 0.5841\n",
      "0.0143 0.5848\n",
      "0.0154 0.5849\n",
      "0.0165 0.5844\n",
      "0.0178 0.5832\n",
      "0.0190 0.5813\n",
      "0.0204 0.5787\n",
      "0.0218 0.5755\n",
      "0.0233 0.5715\n",
      "0.0249 0.5667\n",
      "0.0267 0.5613\n",
      "0.0285 0.5550\n",
      "0.0305 0.5480\n",
      "0.0326 0.5402\n",
      "0.0350 0.5315\n",
      "0.0375 0.5219\n",
      "0.0403 0.5115\n",
      "0.0434 0.5001\n",
      "0.0468 0.4877\n",
      "0.0506 0.4743\n",
      "0.0549 0.4598\n",
      "0.0598 0.4442\n",
      "0.0654 0.4274\n",
      "0.0720 0.4094\n",
      "0.0797 0.3899\n",
      "0.0889 0.3690\n",
      "0.1001 0.3465\n",
      "0.1141 0.3222\n",
      "0.1321 0.2959\n",
      "0.1560 0.2674\n",
      "0.1895 0.2364\n",
      "0.2396 0.2023\n",
      "0.3230 0.1644\n",
      "0.4898 0.1214\n"
     ]
    }
   ],
   "source": [
    "for i in np.arange(1, 99, step=1):\n",
    "    labels_np = np.array([0]*int(i) + [1]*(100-i))\n",
    "    logits_np = np.array([i/100, 1-i/100]*100).reshape(100,2)\n",
    "\n",
    "    labels = tf.constant(labels_np, tf.int64)\n",
    "    logits = tf.constant(logits_np, tf.float32)\n",
    "\n",
    "    a = calibration_mmce_w_loss(logits, labels).numpy()\n",
    "    b = uceloss(torch.tensor(logits_np), torch.tensor(labels_np))[0].numpy()[0]\n",
    "    #print(np.allclose(a, b, rtol=1e-5, atol=1e-6))\n",
    "    print(f\"{a:.4f} {b:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "MMCE_pytorch.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
