{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"UCE_pitfall_2.ipynb","provenance":[],"authorship_tag":"ABX9TyOtOZRwdo/wJULnOMUYAvYt"},"kernelspec":{"name":"python3","display_name":"Python 3"}},"cells":[{"cell_type":"code","metadata":{"id":"fQPQTJXePjde","colab_type":"code","colab":{}},"source":["import torch\n","import numpy as np\n","from scipy.stats import entropy\n","from scipy.special import softmax"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"BIpyjtMRCYiB","colab_type":"code","colab":{}},"source":["def nentr(p, base=None):\n","    \"\"\"\n","    Calculates entropy of p to the base b. If base is None, the natural logarithm is used.\n","    :param p: batches of class label probability distributions (softmax output)\n","    :param base: base b\n","    :return:\n","    \"\"\"\n","    eps = torch.tensor([1e-16], device=p.device)\n","    if base:\n","        base = torch.tensor([base], device=p.device, dtype=torch.float32)\n","        return (p.mul(p.add(eps).log().div(base.log()))).sum(dim=1).abs()\n","    else:\n","        return (p.mul(p.add(eps).log())).sum(dim=1).abs()"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"_GULFGbqCdWt","colab_type":"code","colab":{}},"source":["import torch\n","\n","\n","def classwise_uce(num_classes, softmaxes, labels, n_bins=15):\n","     uce_c = []\n","     err_in_bin_c = []\n","     avg_entropy_in_bin_c = []\n","     for c in range(num_classes):\n","         softmaxes_c = softmaxes[torch.where(labels == c)]\n","         labels_c = labels[torch.where(labels == c)]\n","         uce, err_in_bin, avg_entropy_in_bin = uceloss(softmaxes_c, labels_c, n_bins)\n","         uce_c.append(uce)\n","         err_in_bin_c.append(err_in_bin)\n","         avg_entropy_in_bin_c.append(avg_entropy_in_bin)\n","     return uce_c, err_in_bin_c, avg_entropy_in_bin_c\n","\n","\n","def classwise_ece(num_classes, softmaxes, labels, n_bins=15):\n","    ece_c = []\n","    acc_in_bin_c = []\n","    avg_confidence_in_bin_c = []\n","    for c in range(num_classes):\n","        softmaxes_c = softmaxes[torch.where(labels == c)]\n","        labels_c = labels[torch.where(labels == c)]\n","        ece, acc_in_bin, avg_confidence_in_bin = eceloss(softmaxes_c, labels_c, n_bins)\n","        ece_c.append(ece)\n","        acc_in_bin_c.append(acc_in_bin)\n","        avg_confidence_in_bin_c.append(avg_confidence_in_bin)\n","    return ece_c, acc_in_bin_c, avg_confidence_in_bin_c\n","\n","\n","def uceloss(softmaxes, labels, n_bins=15):\n","    d = softmaxes.device\n","    bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)\n","    bin_lowers = bin_boundaries[:-1]\n","    bin_uppers = bin_boundaries[1:]\n","\n","    _, predictions = torch.max(softmaxes, 1)\n","    errors = predictions.ne(labels)\n","    uncertainties = nentr(softmaxes, base=softmaxes.size(1))\n","    errors_in_bin_list = []\n","    avg_entropy_in_bin_list = []\n","\n","    uce = torch.zeros(1, device=d)\n","    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n","        # Calculate |uncert - err| in each bin\n","        in_bin = uncertainties.gt(bin_lower.item()) * uncertainties.le(bin_upper.item())\n","        prop_in_bin = in_bin.float().mean()  # |Bm| / n\n","        if prop_in_bin.item() > 0.0:\n","            errors_in_bin = errors[in_bin].float().mean()  # err()\n","            avg_entropy_in_bin = uncertainties[in_bin].mean()  # uncert()\n","            uce += torch.abs(avg_entropy_in_bin - errors_in_bin) * prop_in_bin\n","\n","            errors_in_bin_list.append(errors_in_bin)\n","            avg_entropy_in_bin_list.append(avg_entropy_in_bin)\n","\n","    err_in_bin = torch.tensor(errors_in_bin_list, device=d)\n","    avg_entropy_in_bin = torch.tensor(avg_entropy_in_bin_list, device=d)\n","\n","    return uce, err_in_bin, avg_entropy_in_bin\n","\n","\n","def eceloss(softmaxes, labels, n_bins=15):\n","    \"\"\"\n","    Modified from https://github.com/gpleiss/temperature_scaling/blob/master/temperature_scaling.py\n","    \"\"\"\n","    d = softmaxes.device\n","    bin_boundaries = torch.linspace(0, 1, n_bins + 1, device=d)\n","    bin_lowers = bin_boundaries[:-1]\n","    bin_uppers = bin_boundaries[1:]\n","\n","    confidences, predictions = torch.max(softmaxes, 1)\n","    accuracies = predictions.eq(labels)\n","    accuracy_in_bin_list = []\n","    avg_confidence_in_bin_list = []\n","\n","    ece = torch.zeros(1, device=d)\n","    for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):\n","        in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())\n","        prop_in_bin = in_bin.float().mean()\n","        if prop_in_bin.item() > 0.0:\n","            accuracy_in_bin = accuracies[in_bin].float().mean()\n","            avg_confidence_in_bin = confidences[in_bin].mean()\n","            ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin\n","\n","            accuracy_in_bin_list.append(accuracy_in_bin)\n","            avg_confidence_in_bin_list.append(avg_confidence_in_bin)\n","\n","    acc_in_bin = torch.tensor(accuracy_in_bin_list, device=d)\n","    avg_conf_in_bin = torch.tensor(avg_confidence_in_bin_list, device=d)\n","\n","    return ece, acc_in_bin, avg_conf_in_bin\n"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"n96Av8UqDqhN","colab_type":"code","colab":{}},"source":["labels = np.concatenate([np.zeros(50), np.ones(50)])\n","predictions = np.ones((100,2))*0.5 #softmax(np.random.uniform(size=(1000,2)), axis=1)"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"od8plCe6C7Pm","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"outputId":"deb7a9ce-d345-402d-f5f7-27158814c226","executionInfo":{"status":"ok","timestamp":1587411183586,"user_tz":-120,"elapsed":697,"user":{"displayName":"Max-Heinrich Laves","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiZO29EE0NuSexkoxSu9dcypLiMmGH5kKsPZeVhdQ=s64","userId":"13229107609269580337"}}},"source":["ece, _, _ = eceloss(torch.tensor(predictions), torch.tensor(labels))\n","print(ece*100)"],"execution_count":31,"outputs":[{"output_type":"stream","text":["tensor([0.])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"lg8c8VhzDAlc","colab_type":"code","colab":{"base_uri":"https://localhost:8080/","height":34},"outputId":"5125db1a-02e9-4b9e-8fa2-6916b326c3ec","executionInfo":{"status":"ok","timestamp":1587411183587,"user_tz":-120,"elapsed":468,"user":{"displayName":"Max-Heinrich Laves","photoUrl":"https://lh3.googleusercontent.com/a-/AOh14GiZO29EE0NuSexkoxSu9dcypLiMmGH5kKsPZeVhdQ=s64","userId":"13229107609269580337"}}},"source":["uce, _, _ = uceloss(torch.tensor(predictions), torch.tensor(labels))\n","print(uce*100)"],"execution_count":32,"outputs":[{"output_type":"stream","text":["tensor([50.])\n"],"name":"stdout"}]},{"cell_type":"code","metadata":{"id":"GzhvS_NhEiSo","colab_type":"code","colab":{}},"source":[""],"execution_count":0,"outputs":[]}]}