{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load & plot the data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = torch.load(\"data.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set the global parameters for the plot\n",
    "plt.rc(\"axes\", axisbelow=True)\n",
    "plt.rc(\"font\", **{\"family\": \"serif\", \"serif\": [\"Cambria Math\"]})\n",
    "plt.rc(\"text\", usetex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8, 8))\n",
    "\n",
    "# Visuals\n",
    "sns.despine(bottom=True, left=True)\n",
    "ax = fig.gca()\n",
    "ax.tick_params(axis=\"both\", labelsize=20)\n",
    "ax.grid(True, linestyle=\"--\", alpha=0.7)\n",
    "ax.set_xticks([-6, -3, 0, 3, 6])\n",
    "ax.set_yticks([-6, -3, 0, 3, 6])\n",
    "plt.xlim(-6.02, 6.02)\n",
    "plt.ylim(-6.02, 6.02)\n",
    "plt.xlabel(\"Feature X\", fontsize=30)\n",
    "plt.ylabel(\"Feature Y\", fontsize=30)\n",
    "ax.set_aspect(\"equal\")\n",
    "\n",
    "# Plot the data\n",
    "sns.scatterplot(\n",
    "    x=data[0, :, 0],\n",
    "    y=data[0, :, 1],\n",
    "    alpha=0.9,\n",
    "    s=50,\n",
    "    color=\"#c96004\",\n",
    "    label=\"Class A\",\n",
    "    linewidth=0.0,\n",
    "    edgecolor=\"grey\",\n",
    ")\n",
    "sns.scatterplot(\n",
    "    x=data[1, :, 0],\n",
    "    y=data[1, :, 1],\n",
    "    alpha=0.9,\n",
    "    s=50,\n",
    "    color=\"#e9bf9a\",\n",
    "    label=\"Class B\",\n",
    "    linewidth=0.0,\n",
    "    edgecolor=\"grey\",\n",
    ")\n",
    "\n",
    "\n",
    "ax.legend(fontsize=30)\n",
    "plt.tight_layout()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "uncertainty",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
