{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard Library Imports\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Third-Party Library Imports\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# Local Imports\n",
    "root_dir = \"../\"\n",
    "sys.path.append(root_dir)\n",
    "from clf import classifier\n",
    "import sim_dataset\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "sns.set_theme()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "The following variable cannot be assigned with wide-form data: `hue`",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[34], line 8\u001b[0m\n\u001b[1;32m      6\u001b[0m plt\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m8\u001b[39m, \u001b[38;5;241m6\u001b[39m))\n\u001b[1;32m      7\u001b[0m high_scores \u001b[38;5;241m=\u001b[39m suff_scores[suff_scores \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0.1\u001b[39m]\n\u001b[0;32m----> 8\u001b[0m \u001b[43msns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstripplot\u001b[49m\u001b[43m(\u001b[49m\u001b[43mhigh_scores\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mskyblue\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msize\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m \u001b[38;5;66;03m# Overlay a violin plot for scores <= 0.05\u001b[39;00m\n\u001b[1;32m     11\u001b[0m sns\u001b[38;5;241m.\u001b[39mswarmplot(y\u001b[38;5;241m=\u001b[39msuff_scores[suff_scores \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.1\u001b[39m], color\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124morange\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[0;32m~/conda/envs/cuda121/lib/python3.12/site-packages/seaborn/categorical.py:2082\u001b[0m, in \u001b[0;36mstripplot\u001b[0;34m(data, x, y, hue, order, hue_order, jitter, dodge, orient, color, palette, size, edgecolor, linewidth, hue_norm, log_scale, native_scale, formatter, legend, ax, **kwargs)\u001b[0m\n\u001b[1;32m   2074\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstripplot\u001b[39m(\n\u001b[1;32m   2075\u001b[0m     data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m, x\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, y\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, hue\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, order\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, hue_order\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m   2076\u001b[0m     jitter\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, dodge\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m, orient\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, color\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, palette\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   2079\u001b[0m     ax\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs\n\u001b[1;32m   2080\u001b[0m ):\n\u001b[0;32m-> 2082\u001b[0m     p \u001b[38;5;241m=\u001b[39m \u001b[43m_CategoricalPlotter\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2083\u001b[0m \u001b[43m        \u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2084\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvariables\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mdict\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhue\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mhue\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2085\u001b[0m \u001b[43m        \u001b[49m\u001b[43morder\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morder\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2086\u001b[0m \u001b[43m        \u001b[49m\u001b[43morient\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morient\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2087\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2088\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlegend\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlegend\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2089\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2091\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m ax \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2092\u001b[0m         ax \u001b[38;5;241m=\u001b[39m plt\u001b[38;5;241m.\u001b[39mgca()\n",
      "File \u001b[0;32m~/conda/envs/cuda121/lib/python3.12/site-packages/seaborn/categorical.py:67\u001b[0m, in \u001b[0;36m_CategoricalPlotter.__init__\u001b[0;34m(self, data, variables, order, orient, require_numeric, color, legend)\u001b[0m\n\u001b[1;32m     56\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\n\u001b[1;32m     57\u001b[0m     \u001b[38;5;28mself\u001b[39m,\n\u001b[1;32m     58\u001b[0m     data\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     64\u001b[0m     legend\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mauto\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m     65\u001b[0m ):\n\u001b[0;32m---> 67\u001b[0m     \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvariables\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvariables\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     69\u001b[0m     \u001b[38;5;66;03m# This method takes care of some bookkeeping that is necessary because the\u001b[39;00m\n\u001b[1;32m     70\u001b[0m     \u001b[38;5;66;03m# original categorical plots (prior to the 2021 refactor) had some rules that\u001b[39;00m\n\u001b[1;32m     71\u001b[0m     \u001b[38;5;66;03m# don't fit exactly into VectorPlotter logic. It may be wise to have a second\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[38;5;66;03m# default VectorPlotter rules. If we do decide to make orient part of the\u001b[39;00m\n\u001b[1;32m     77\u001b[0m     \u001b[38;5;66;03m# _base variable assignment, we'll want to figure out how to express that.\u001b[39;00m\n\u001b[1;32m     78\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_format \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwide\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m orient \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mh\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n",
      "File \u001b[0;32m~/conda/envs/cuda121/lib/python3.12/site-packages/seaborn/_base.py:634\u001b[0m, in \u001b[0;36mVectorPlotter.__init__\u001b[0;34m(self, data, variables)\u001b[0m\n\u001b[1;32m    629\u001b[0m \u001b[38;5;66;03m# var_ordered is relevant only for categorical axis variables, and may\u001b[39;00m\n\u001b[1;32m    630\u001b[0m \u001b[38;5;66;03m# be better handled by an internal axis information object that tracks\u001b[39;00m\n\u001b[1;32m    631\u001b[0m \u001b[38;5;66;03m# such information and is set up by the scale_* methods. The analogous\u001b[39;00m\n\u001b[1;32m    632\u001b[0m \u001b[38;5;66;03m# information for numeric axes would be information about log scales.\u001b[39;00m\n\u001b[1;32m    633\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_var_ordered \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mx\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m: \u001b[38;5;28;01mFalse\u001b[39;00m}  \u001b[38;5;66;03m# alt., used DefaultDict\u001b[39;00m\n\u001b[0;32m--> 634\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43massign_variables\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvariables\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    636\u001b[0m \u001b[38;5;66;03m# TODO Lots of tests assume that these are called to initialize the\u001b[39;00m\n\u001b[1;32m    637\u001b[0m \u001b[38;5;66;03m# mappings to default values on class initialization. I'd prefer to\u001b[39;00m\n\u001b[1;32m    638\u001b[0m \u001b[38;5;66;03m# move away from that and only have a mapping when explicitly called.\u001b[39;00m\n\u001b[1;32m    639\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m var \u001b[38;5;129;01min\u001b[39;00m [\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mhue\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msize\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mstyle\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n",
      "File \u001b[0;32m~/conda/envs/cuda121/lib/python3.12/site-packages/seaborn/_base.py:673\u001b[0m, in \u001b[0;36mVectorPlotter.assign_variables\u001b[0;34m(self, data, variables)\u001b[0m\n\u001b[1;32m    671\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m x \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m y \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m    672\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_format \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwide\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 673\u001b[0m     frame, names \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_assign_variables_wideform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mvariables\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    674\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    675\u001b[0m     \u001b[38;5;66;03m# When dealing with long-form input, use the newer PlotData\u001b[39;00m\n\u001b[1;32m    676\u001b[0m     \u001b[38;5;66;03m# object (internal but introduced for the objects interface)\u001b[39;00m\n\u001b[1;32m    677\u001b[0m     \u001b[38;5;66;03m# to centralize / standardize data consumption logic.\u001b[39;00m\n\u001b[1;32m    678\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39minput_format \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlong\u001b[39m\u001b[38;5;124m\"\u001b[39m\n",
      "File \u001b[0;32m~/conda/envs/cuda121/lib/python3.12/site-packages/seaborn/_base.py:723\u001b[0m, in \u001b[0;36mVectorPlotter._assign_variables_wideform\u001b[0;34m(self, data, **kwargs)\u001b[0m\n\u001b[1;32m    721\u001b[0m     err \u001b[38;5;241m=\u001b[39m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe following variable\u001b[39m\u001b[38;5;132;01m{\u001b[39;00ms\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m cannot be assigned with wide-form data: \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    722\u001b[0m     err \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m, \u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mjoin(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mv\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m`\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m v \u001b[38;5;129;01min\u001b[39;00m assigned)\n\u001b[0;32m--> 723\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(err)\n\u001b[1;32m    725\u001b[0m \u001b[38;5;66;03m# Determine if the data object actually has any data in it\u001b[39;00m\n\u001b[1;32m    726\u001b[0m empty \u001b[38;5;241m=\u001b[39m data \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(data)\n",
      "\u001b[0;31mValueError\u001b[0m: The following variable cannot be assigned with wide-form data: `hue`"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 800x600 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Example data (replace with your actual data)\n",
    "np.random.seed(0)\n",
    "suff_scores = np.random.rand(100)  # Example array of scores between 0 and 1\n",
    "\n",
    "# Create a violin plot with all scores\n",
    "plt.figure(figsize=(8, 6))\n",
    "high_scores = suff_scores[suff_scores > 0.1]\n",
    "sns.stripplot(high_scores, color='skyblue', hue=\"size\")\n",
    "\n",
    "# Overlay a violin plot for scores <= 0.05\n",
    "sns.swarmplot(y=suff_scores[suff_scores <= 0.1], color='orange')\n",
    "\n",
    "# Customize the plot\n",
    "plt.xlabel('Score Distribution')\n",
    "plt.ylabel('Scores')\n",
    "plt.title('Violin Plot of suff_scores')\n",
    "plt.legend(['All Scores', '<= 0.05'])\n",
    "\n",
    "# Show the plot\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ,\n",
       "       0.64589411, 0.43758721, 0.891773  , 0.96366276, 0.38344152,\n",
       "       0.79172504, 0.52889492, 0.56804456, 0.92559664, 0.07103606,\n",
       "       0.0871293 , 0.0202184 , 0.83261985, 0.77815675, 0.87001215,\n",
       "       0.97861834, 0.79915856, 0.46147936, 0.78052918, 0.11827443,\n",
       "       0.63992102, 0.14335329, 0.94466892, 0.52184832, 0.41466194,\n",
       "       0.26455561, 0.77423369, 0.45615033, 0.56843395, 0.0187898 ,\n",
       "       0.6176355 , 0.61209572, 0.616934  , 0.94374808, 0.6818203 ,\n",
       "       0.3595079 , 0.43703195, 0.6976312 , 0.06022547, 0.66676672,\n",
       "       0.67063787, 0.21038256, 0.1289263 , 0.31542835, 0.36371077,\n",
       "       0.57019677, 0.43860151, 0.98837384, 0.10204481, 0.20887676,\n",
       "       0.16130952, 0.65310833, 0.2532916 , 0.46631077, 0.24442559,\n",
       "       0.15896958, 0.11037514, 0.65632959, 0.13818295, 0.19658236,\n",
       "       0.36872517, 0.82099323, 0.09710128, 0.83794491, 0.09609841,\n",
       "       0.97645947, 0.4686512 , 0.97676109, 0.60484552, 0.73926358,\n",
       "       0.03918779, 0.28280696, 0.12019656, 0.2961402 , 0.11872772,\n",
       "       0.31798318, 0.41426299, 0.0641475 , 0.69247212, 0.56660145,\n",
       "       0.26538949, 0.52324805, 0.09394051, 0.5759465 , 0.9292962 ,\n",
       "       0.31856895, 0.66741038, 0.13179786, 0.7163272 , 0.28940609,\n",
       "       0.18319136, 0.58651293, 0.02010755, 0.82894003, 0.00469548])"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "suff_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.64589411,\n",
       "       0.891773  , 0.96366276, 0.79172504, 0.52889492, 0.56804456,\n",
       "       0.92559664, 0.83261985, 0.77815675, 0.87001215, 0.97861834,\n",
       "       0.79915856, 0.78052918, 0.63992102, 0.94466892, 0.52184832,\n",
       "       0.77423369, 0.56843395, 0.6176355 , 0.61209572, 0.616934  ,\n",
       "       0.94374808, 0.6818203 , 0.6976312 , 0.66676672, 0.67063787,\n",
       "       0.57019677, 0.98837384, 0.65310833, 0.65632959, 0.82099323,\n",
       "       0.83794491, 0.97645947, 0.97676109, 0.60484552, 0.73926358,\n",
       "       0.69247212, 0.56660145, 0.52324805, 0.5759465 , 0.9292962 ,\n",
       "       0.66741038, 0.7163272 , 0.58651293, 0.82894003])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "suff_scores[suff_scores > 0.5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Classifier loaded successfully\n"
     ]
    }
   ],
   "source": [
    "# Experiment name\n",
    "exp_name = \"cooperative\"\n",
    "\n",
    "# Set torch device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load untrained classifier\n",
    "clf = classifier().to(device)\n",
    "print(\"Classifier loaded successfully\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training parameters\n",
    "num_epochs = 10\n",
    "learning_rate = 1e-3\n",
    "optimizer = torch.optim.Adam(clf.parameters(), lr=learning_rate)\n",
    "\n",
    "# Path to data configs\n",
    "sim_pos_config_path = os.path.join(exp_name, \"data\", \"spi1_ctcf_exp_config.json\")\n",
    "sim_neg_config_path = os.path.join(exp_name, \"data\", \"spi1_ctcf_exp_neg_config.json\")\n",
    "\n",
    "# Dataloaders\n",
    "batch_size = 4\n",
    "num_train_batches = 100\n",
    "num_val_batches = 100\n",
    "clf_train_loader = sim_dataset.create_data_loader(\n",
    "    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path, \n",
    "    batch_size=batch_size, num_batches=num_train_batches, return_configs=True\n",
    ")\n",
    "clf_val_loader = sim_dataset.create_data_loader(\n",
    "    sim_pos_config_path, neg_motif_config_path=sim_neg_config_path,\n",
    "    batch_size=batch_size, num_batches=num_val_batches, return_configs=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 142.31it/s]\n"
     ]
    }
   ],
   "source": [
    "train_configs = []\n",
    "for x, y, configs in tqdm(clf_train_loader):\n",
    "    train_configs.append(configs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 143.71it/s]\n"
     ]
    }
   ],
   "source": [
    "val_configs = []\n",
    "for x, y, configs in tqdm(clf_val_loader):\n",
    "    val_configs.append(configs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 99,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(195, ['SPI1', 150, 'CTCF'], False),\n",
       " (258, [], False),\n",
       " (195, ['SPI1', 150, 'CTCF'], True),\n",
       " (258, [], True)]"
      ]
     },
     "execution_count": 99,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_configs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(212, ['SPI1', 150, 'CTCF'], False),\n",
       " (227, [], False),\n",
       " (212, ['SPI1', 150, 'CTCF'], True),\n",
       " (227, [], True)]"
      ]
     },
     "execution_count": 84,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "cuda121",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
