{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import sklearn\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm import tqdm\n",
    "import copy\n",
    "import seaborn as sns\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format='retina'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SMALL_SIZE = 15\n",
    "MEDIUM_SIZE = 15\n",
    "BIGGER_SIZE = 22\n",
    "\n",
    "plt.rcParams.update(\n",
    "    {\n",
    "        \"text.usetex\": True,\n",
    "        \"font.family\": \"serif\",\n",
    "        \"font.sans-serif\": \"Times\",\n",
    "        # \"font.sans-serif\": \"Computer Modern\",\n",
    "        \"text.latex.preamble\": r\"\\usepackage{amsfonts}\",\n",
    "    }\n",
    ")\n",
    "\n",
    "plt.rc(\"font\", size=MEDIUM_SIZE)  # controls default text sizes\n",
    "plt.rc(\"axes\", titlesize=SMALL_SIZE)  # fontsize of the axes title\n",
    "plt.rc(\"axes\", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels\n",
    "plt.rc(\"xtick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"ytick\", labelsize=SMALL_SIZE)  # fontsize of the tick labels\n",
    "plt.rc(\"legend\", fontsize=SMALL_SIZE)  # legend fontsize\n",
    "plt.rc(\"figure\", titlesize=BIGGER_SIZE)  # fontsize of the figure title"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_gaussians(dist, samples_per_class):\n",
    "    covariance = np.identity(2)\n",
    "    \n",
    "    mean_1 = [-dist, 0]\n",
    "    x_1 = np.random.multivariate_normal(mean_1, covariance, size=samples_per_class)\n",
    "    y_1 = np.ones(len(x_1))\n",
    "    \n",
    "    mean_2 = [dist, 0]\n",
    "    x_2 = np.random.multivariate_normal(mean_2, covariance, size=samples_per_class)\n",
    "    y_2 = np.zeros(len(x_2))\n",
    "    \n",
    "    X = np.concatenate([x_1, x_2])\n",
    "    y = np.concatenate([y_1, y_2]).flatten()\n",
    "    \n",
    "    return X, y\n",
    "\n",
    "def twospirals(samples_per_class, noise=.5):\n",
    "    n = np.sqrt(np.random.rand(samples_per_class,1)) * 780 * (2*np.pi)/360\n",
    "    d1x = -np.cos(n)*n + np.random.rand(samples_per_class,1) * noise\n",
    "    d1y = np.sin(n)*n + np.random.rand(samples_per_class,1) * noise\n",
    "    return (np.vstack((np.hstack((d1x,d1y)),np.hstack((-d1x,-d1y)))), \n",
    "            np.hstack((np.zeros(samples_per_class),np.ones(samples_per_class))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = 5\n",
    "n = 500 #500\n",
    "bs = 64\n",
    "\n",
    "dataset = \"gaussian\"\n",
    "# dataset = \"spiral\"\n",
    "\n",
    "if dataset == \"gaussian\":\n",
    "    X_tr, y_tr = make_gaussians(a,n)\n",
    "    X_te, y_te = make_gaussians(a,n)\n",
    "elif dataset == \"spiral\":\n",
    "    X_tr, y_tr = twospirals(n)\n",
    "    X_te, y_te = twospirals(n)\n",
    "else:\n",
    "    print(\"Dataset not found\")\n",
    "\n",
    "train_set = torch.utils.data.TensorDataset(torch.tensor(X_tr), torch.tensor(y_tr))\n",
    "test_set = torch.utils.data.TensorDataset(torch.tensor(X_te), torch.tensor(y_te))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size=bs, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size=n*2, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(X_tr[y_tr==0,0], X_tr[y_tr==0,1], '.', label='class 1')\n",
    "plt.plot(X_tr[y_tr==1,0], X_tr[y_tr==1,1], '.', label='class 2')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(torch.nn.Module):\n",
    "    \n",
    "     def __init__(self, input_dim, output_dim):\n",
    "         super(LogisticRegression, self).__init__()\n",
    "         self.linear = torch.nn.Linear(input_dim, output_dim)\n",
    "         \n",
    "     def forward(self, x):\n",
    "         outputs = torch.sigmoid(self.linear(x))\n",
    "         return outputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class SimpleNet(nn.Module):\n",
    "    def __init__(self, dims, mnist_like=False, is_selectivenet=False):\n",
    "        super(SimpleNet, self).__init__()\n",
    "        dims_tup = list(zip(dims, dims[1:]))\n",
    "        self.dims = dims\n",
    "        self.mnist_like = mnist_like\n",
    "        self.layers = nn.ModuleList()\n",
    "        self.is_selectivenet = is_selectivenet\n",
    "        for i, o in dims_tup[:-1]:\n",
    "            self.layers.append(nn.Linear(i, o))\n",
    "\n",
    "        i, o = dims_tup[-1]\n",
    "        if is_selectivenet:\n",
    "            # represented as f() in the original paper\n",
    "            self.classifier = nn.Sequential(nn.Linear(i, o))\n",
    "            # represented as g() in the original paper\n",
    "            self.selector = nn.Sequential(\n",
    "                nn.Linear(i, i),\n",
    "                nn.ReLU(True),\n",
    "                nn.BatchNorm1d(i),\n",
    "                nn.Linear(i, 1),\n",
    "                nn.Sigmoid(),\n",
    "            )\n",
    "            # represented as h() in the original paper\n",
    "            self.aux_classifier = nn.Sequential(nn.Linear(i, o))\n",
    "        else:\n",
    "            self.fc = nn.Linear(i, o)\n",
    "\n",
    "    def forward(self, x):\n",
    "        if self.mnist_like:\n",
    "            x = x.view(-1, 784)\n",
    "        activations = []\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            if i == 0:\n",
    "                activations.append(F.relu(layer(x)))\n",
    "            else:\n",
    "                activations.append(F.relu(layer(activations[i - 1])))\n",
    "\n",
    "        if self.is_selectivenet:\n",
    "            if len(self.layers) > 0:\n",
    "                prediction_out = self.classifier(activations[-1])\n",
    "                selection_out = self.selector(activations[-1])\n",
    "                auxiliary_out = self.aux_classifier(activations[-1])\n",
    "                activations.append(F.log_softmax(prediction_out, -1))\n",
    "            else:\n",
    "                prediction_out = self.classifier(x)\n",
    "                selection_out = self.selector(x)\n",
    "                auxiliary_out = self.aux_classifier(x)\n",
    "                activations.append(F.log_softmax(prediction_out, -1))\n",
    "\n",
    "            return activations[-1], selection_out, auxiliary_out\n",
    "        else:\n",
    "            if len(self.layers) > 0:\n",
    "                if self.dims[-1] == 1:\n",
    "                    activations.append(F.sigmoid(self.fc(activations[-1])))\n",
    "                else: \n",
    "                    activations.append(F.log_softmax(self.fc(activations[-1]), -1))\n",
    "            else:\n",
    "                    activations.append(F.log_softmax(self.fc(x), -1))\n",
    "            return activations[-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 1000\n",
    "input_dim = 2\n",
    "output_dim = 1 \n",
    "learning_rate = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = LogisticRegression(input_dim,output_dim)\n",
    "# model = SimpleNet([2,10,10,1])\n",
    "criterion = torch.nn.BCELoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses = []\n",
    "pbar = tqdm(range(epochs))\n",
    "models = []\n",
    "for epoch in pbar:\n",
    "    loss = 0\n",
    "    iter = 0\n",
    "    for x, y in train_loader:\n",
    "        x = x.float()\n",
    "        y = y.float()\n",
    "        optimizer.zero_grad()\n",
    "        output = model(x)\n",
    "        l = criterion(torch.squeeze(output), y)\n",
    "        l.backward()\n",
    "        loss += l\n",
    "        optimizer.step()\n",
    "        iter+=1\n",
    "    loss /= iter\n",
    "    pbar.set_postfix({'loss': loss.item()})\n",
    "    models.append(copy.deepcopy(model))\n",
    "    losses.append(loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculate_nntd_sum_score_noargs(predictions, start, step, k):\n",
    "    num_checkpoints = predictions.shape[-1]\n",
    "\n",
    "    linear_weights = torch.linspace(0, 1, num_checkpoints)\n",
    "    weighting = linear_weights ** (k)\n",
    "\n",
    "    last_predictions = predictions[:, -1]\n",
    "    last_predictions_rep = last_predictions.unsqueeze(-1).repeat(1, num_checkpoints)\n",
    "    a_ts = 1 - (last_predictions_rep == predictions).int()\n",
    "\n",
    "    weighting_exp = weighting.unsqueeze(0).repeat(len(predictions), 1)\n",
    "    weight_vals = weighting_exp * a_ts\n",
    "    weight_vals = weight_vals[:, start::step]\n",
    "    avg_score = torch.sum(weight_vals, dim=1)\n",
    "    return avg_score\n",
    "\n",
    "def weighted_varaince(values, k):\n",
    "    values = np.array(values)\n",
    "    linear_weights = np.linspace(0, 1, values.shape[1])\n",
    "    weights = linear_weights ** (k)\n",
    "    # average = np.expand_dims(np.average(values, weights=weights, axis=1), axis=-1)\n",
    "    # print(values[:,-1])\n",
    "    average = 0 #np.expand_dims(values[:,-1], axis=-1)\n",
    "    variance = np.average((values-average)**2, weights=weights, axis=1 )\n",
    "    return variance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = torch.ones((len(y_te), len(models))) * (-1)\n",
    "confidences = torch.ones_like(predictions)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for x, y in test_loader:\n",
    "        for i, m in enumerate(models):\n",
    "            x = x.float()\n",
    "            y = y.float()\n",
    "            output = m(x)\n",
    "            confidences[:,i] = output.flatten()\n",
    "            predictions[:,i] = output.round().int().flatten()\n",
    "\n",
    "last_conf_exp = torch.repeat_interleave(torch.unsqueeze(confidences[:,-1], -1), len(models), 1)\n",
    "confidences_norm = torch.abs(confidences - last_conf_exp)\n",
    "\n",
    "acc = torch.eq(torch.tensor(y_te), predictions[:,-1]).sum() / len(y_te)\n",
    "\n",
    "print(f\"Accuracy {acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1000):\n",
    "    plt.plot(predictions[i], color=\"blue\", alpha=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(1000):\n",
    "    plt.plot(confidences_norm[i], color=\"blue\", alpha=0.25)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calculate_nntd_sum_score_noargs(predictions, 0, 1, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weighted_varaince(confidences_norm, 3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "padding = 1\n",
    "res = 100\n",
    "\n",
    "x_min, x_max = X_te[:, 0].min() - padding, X_te[:, 0].max() + padding\n",
    "y_min, y_max = X_te[:, 1].min() - padding, X_te[:, 1].max() + padding\n",
    "xx, yy = np.meshgrid(np.linspace(x_min, x_max, res),\n",
    "                     np.linspace(y_min, y_max, res))\n",
    "plt.xlim(xx.min(), xx.max())\n",
    "plt.ylim(yy.min(), yy.max())\n",
    "\n",
    "ax = plt.gca()\n",
    "with torch.no_grad():\n",
    "    output = model(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float())\n",
    "    output = output.reshape(xx.shape)\n",
    "    cs = ax.contourf(xx, yy, output, cmap='RdBu', alpha=.5)\n",
    "    plt.colorbar(cs)\n",
    "    # cs2 = ax.contour(xx, yy, output, cmap='RdBu', alpha=.5)\n",
    "    # plt.clabel(cs2, fmt = '%2.1f', colors = 'k', fontsize=14)\n",
    "    ax.scatter(X_te[y_te == 0, 0], X_te[y_te == 0, 1], color=\"tab:red\", label='Class 1')\n",
    "    ax.scatter(X_te[y_te == 1, 0], X_te[y_te == 1, 1], color=\"tab:blue\", label='Class 2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.xlim(xx.min(), xx.max())\n",
    "plt.ylim(yy.min(), yy.max())\n",
    "\n",
    "predictions = torch.ones((res*res, len(models))) * (-1)\n",
    "\n",
    "ax = plt.gca()\n",
    "with torch.no_grad():\n",
    "    for i, m in enumerate(models):\n",
    "        output = m(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float())\n",
    "        predictions[:,i] = output.round().int().flatten()\n",
    "    print(predictions)\n",
    "    scores = calculate_nntd_sum_score_noargs(predictions, 0, 1, 3)\n",
    "    print(scores[:100])\n",
    "    output = scores.reshape(xx.shape)\n",
    "    # cs = ax.contourf(xx, yy, output, cmap='OrRd', alpha=.5)\n",
    "    # cs2 = ax.contour(xx, yy, output, cmap='OrRd', alpha=.5)\n",
    "    # plt.clabel(cs2, fmt = '%2.1f', colors = 'k', fontsize=14)\n",
    "    cs = ax.contourf(xx, yy, output, cmap='binary', alpha=.75)\n",
    "    plt.colorbar(cs)\n",
    "    ax.scatter(X_te[y_te == 0, 0], X_te[y_te == 0, 1], color=\"tab:red\", label='Class 1')\n",
    "    ax.scatter(X_te[y_te == 1, 0], X_te[y_te == 1, 1], color=\"tab:blue\", label='Class 2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.xlim(xx.min(), xx.max())\n",
    "plt.ylim(yy.min(), yy.max())\n",
    "\n",
    "confidences = torch.ones((res*res, len(models))) * (-1)\n",
    "\n",
    "ax = plt.gca()\n",
    "with torch.no_grad():\n",
    "    for i, m in enumerate(models):\n",
    "        output = m(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float())\n",
    "        confidences[:,i] = output.flatten()\n",
    "    print(confidences)\n",
    "    last_conf_exp = torch.repeat_interleave(torch.unsqueeze(confidences[:,-1], -1), len(models), 1)\n",
    "    confidences_norm = torch.abs(confidences - last_conf_exp)\n",
    "    scores = weighted_varaince(confidences_norm, 3)\n",
    "    output = scores.reshape(xx.shape)\n",
    "    cs = ax.contourf(xx, yy, output, cmap='binary', alpha=.5)\n",
    "    plt.colorbar(cs)\n",
    "    ax.scatter(X_te[y_te == 0, 0], X_te[y_te == 0, 1], color=\"tab:red\", label='Class 1')\n",
    "    ax.scatter(X_te[y_te == 1, 0], X_te[y_te == 1, 1], color=\"tab:blue\", label='Class 2')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Complete Gaussian Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alist = [5,1,0.5,0]\n",
    "n = 500 #500\n",
    "bs = 64\n",
    "    \n",
    "epochs = 200\n",
    "input_dim = 2\n",
    "output_dim = 1 \n",
    "learning_rate = 20\n",
    "\n",
    "padding = 1\n",
    "res = 100\n",
    "\n",
    "dataset = \"gaussian\"\n",
    "\n",
    "fig, axes = plt.subplots(3, len(alist), figsize=(10, 4.75))\n",
    "plt.subplots_adjust(wspace=0.33)\n",
    "plt.subplots_adjust(hspace=0.15)\n",
    "\n",
    "X_tr, y_tr = make_gaussians(0,n)\n",
    "X_te, y_te = make_gaussians(0,n)\n",
    "\n",
    "nntd_levels = np.linspace(0,1,11)\n",
    "conf_levels = np.linspace(0,1,11)\n",
    "ticks = np.linspace(0,1,5)\n",
    "    \n",
    "for a_num, a in enumerate(alist):\n",
    "    \n",
    "    X_tr_a = np.copy(X_tr)\n",
    "    X_tr_a[:n,0] = X_tr_a[:n,0] - a\n",
    "    X_tr_a[n:,0] = X_tr_a[n:,0] + a\n",
    "    \n",
    "    X_te_a = np.copy(X_te)\n",
    "    X_te_a[:n,0] = X_te_a[:n,0] - a\n",
    "    X_te_a[n:,0] = X_te_a[n:,0] + a\n",
    "\n",
    "    train_set = torch.utils.data.TensorDataset(torch.tensor(X_tr_a), torch.tensor(y_tr))\n",
    "    test_set = torch.utils.data.TensorDataset(torch.tensor(X_te_a), torch.tensor(y_te))\n",
    "\n",
    "    train_loader = torch.utils.data.DataLoader(train_set, batch_size=bs, shuffle=True)\n",
    "    test_loader = torch.utils.data.DataLoader(test_set, batch_size=n*2, shuffle=False)\n",
    "    \n",
    "    model = LogisticRegression(input_dim,output_dim)\n",
    "\n",
    "    criterion = torch.nn.BCELoss()\n",
    "    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)\n",
    "    \n",
    "    losses = []\n",
    "    pbar = tqdm(range(epochs))\n",
    "    models = []\n",
    "    for epoch in pbar:\n",
    "        loss = 0\n",
    "        iter = 0\n",
    "        for x, y in train_loader:\n",
    "            x = x.float()\n",
    "            y = y.float()\n",
    "            optimizer.zero_grad()\n",
    "            output = model(x)\n",
    "            l = criterion(torch.squeeze(output), y)\n",
    "            l.backward()\n",
    "            loss += l\n",
    "            optimizer.step()\n",
    "            iter+=1\n",
    "        loss /= iter\n",
    "        pbar.set_postfix({'loss': loss.item()})\n",
    "        models.append(copy.deepcopy(model))\n",
    "        losses.append(loss.item())\n",
    "        \n",
    "    predictions = torch.ones((len(y_te), len(models))) * (-1)\n",
    "    confidences = torch.ones_like(predictions)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            for i, m in enumerate(models):\n",
    "                x = x.float()\n",
    "                y = y.float()\n",
    "                output = m(x)\n",
    "                confidences[:,i] = output.flatten()\n",
    "                predictions[:,i] = output.round().int().flatten()\n",
    "                \n",
    "    nntd_scores = calculate_nntd_sum_score_noargs(predictions, 0, 1, 3)\n",
    "         \n",
    "    confidences_rescaled = 1 - (2 * torch.abs(confidences[:,-1] - 0.5))\n",
    "    nntd_scores_rescaled = nntd_scores / torch.max(nntd_scores)\n",
    "    nntd_scores_rescaled = torch.nan_to_num(nntd_scores_rescaled, 0)\n",
    "                \n",
    "    if len(torch.unique(confidences_rescaled)) == 1:\n",
    "        axes[2, a_num].axvline(confidences_rescaled[0], label=\"\\\\texttt{SR}\", color=\"tab:blue\")\n",
    "    else:\n",
    "        sns.kdeplot(confidences_rescaled, fill=True, label=\"\\\\texttt{SR}\", ax=axes[2, a_num], color=\"tab:blue\")\n",
    "    \n",
    "    if len(torch.unique(nntd_scores_rescaled)) == 1:\n",
    "        axes[2, a_num].axvline(nntd_scores_rescaled[0], label=\"\\\\texttt{SPTD}\", color=\"tab:orange\")\n",
    "    else:\n",
    "        sns.kdeplot(nntd_scores_rescaled, fill=True, label=\"\\\\texttt{SPTD}\", ax=axes[2, a_num], color=\"tab:orange\")\n",
    "    axes[2, a_num].set_ylabel(\"\")\n",
    "    axes[2, a_num].set_xlim(-0.2, 1.2)\n",
    "    if a_num == 0:\n",
    "        axes[2, a_num].legend()\n",
    "\n",
    "    last_conf_exp = torch.repeat_interleave(torch.unsqueeze(confidences[:,-1], -1), len(models), 1)\n",
    "    confidences_norm = torch.abs(confidences - last_conf_exp)\n",
    "\n",
    "    acc = torch.eq(torch.tensor(y_te), predictions[:,-1]).sum() / len(y_te)\n",
    "\n",
    "    print(f\"Accuracy {acc}\")\n",
    "\n",
    "    x_min, x_max = X_te_a[:, 0].min() - padding, X_te_a[:, 0].max() + padding\n",
    "    y_min, y_max = X_te_a[:, 1].min() - padding, X_te_a[:, 1].max() + padding\n",
    "    xx, yy = np.meshgrid(np.linspace(x_min, x_max, res),\n",
    "                        np.linspace(y_min, y_max, res))\n",
    "    axes[0, a_num].set_xlim(xx.min(), xx.max())\n",
    "    axes[0, a_num].set_ylim(yy.min(), yy.max())\n",
    "\n",
    "    with torch.no_grad():\n",
    "        output = model(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float())\n",
    "        output = output.reshape(xx.shape)\n",
    "        cs = axes[0, a_num].contourf(xx, yy, output, cmap='RdBu', alpha=.5, levels=conf_levels)\n",
    "        if a_num == len(alist) - 1:\n",
    "            cbar = fig.colorbar(cs, ticks=ticks, ax=axes[0, a_num])\n",
    "        axes[0, a_num].scatter(X_te_a[y_te == 0, 0], X_te_a[y_te == 0, 1], color=\"tab:red\", label='Class 1', alpha=0.33)\n",
    "        axes[0, a_num].scatter(X_te_a[y_te == 1, 0], X_te_a[y_te == 1, 1], color=\"tab:blue\", label='Class 2', alpha=0.33)\n",
    "        axes[0, a_num].set_title(f\"$a = {a}$\")\n",
    "        \n",
    "    predictions = torch.ones((res*res, len(models))) * (-1)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, m in enumerate(models):\n",
    "            output = m(torch.tensor(np.c_[xx.ravel(), yy.ravel()]).float())\n",
    "            predictions[:,i] = output.round().int().flatten()\n",
    "        scores = calculate_nntd_sum_score_noargs(predictions, 0, 1, 3)\n",
    "        scores_rescaled = scores / torch.max(scores)\n",
    "        scores_rescaled = torch.nan_to_num(scores_rescaled, 0)\n",
    "        output = scores_rescaled.reshape(xx.shape)\n",
    "        cs = axes[1, a_num].contourf(xx, yy, output, cmap='Greens', alpha=.5, levels=nntd_levels)\n",
    "        if a_num == len(alist) - 1:\n",
    "            fig.colorbar(cs, ticks=ticks, ax=axes[1, a_num])\n",
    "        axes[1, a_num].scatter(X_te_a[y_te == 0, 0], X_te_a[y_te == 0, 1], color=\"tab:red\", label='Class 1', alpha=0.33)\n",
    "        axes[1, a_num].scatter(X_te_a[y_te == 1, 0], X_te_a[y_te == 1, 1], color=\"tab:blue\", label='Class 2', alpha=0.33)\n",
    "     \n",
    "plt.tight_layout()  \n",
    "plt.savefig(\"gaussians.png\") \n",
    "plt.savefig(\"gaussians.pdf\") \n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "rl_nntd",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
