{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "24a86aca",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '../Results')\n",
    "sys.path.insert(0, '../../../src')\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "import os\n",
    "from visualization import perc, SetPlotRC, ApplyFont"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78f3b3aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"Figures\"):\n",
    "    os.mkdir(\"Figures\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6f3f81ab",
   "metadata": {},
   "source": [
    "# MNIST Hidden Layer Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ccf9ac97",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(16, 7)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Index(['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list',\n",
       "       'Tst_ACC_list', 'forward_backward_weight_angle_list'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results = pd.read_pickle(r\"../Results/simulation_results_CorInfoMax_MNIST_HiddenLayerSize_Ablation.pkl\")\n",
    "print(df_results.shape)\n",
    "df_results.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8b0dcace",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>setting_number</th>\n",
       "      <th>seed</th>\n",
       "      <th>Model</th>\n",
       "      <th>Hyperparams</th>\n",
       "      <th>Trn_ACC_list</th>\n",
       "      <th>Tst_ACC_list</th>\n",
       "      <th>forward_backward_weight_angle_list</th>\n",
       "      <th>Trn_ACC</th>\n",
       "      <th>Tst_ACC</th>\n",
       "      <th>hidden_layer_size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9295, 0.94345, 0.9562833333333334, 0.961533...</td>\n",
       "      <td>[0.93, 0.9413, 0.953, 0.9558, 0.9539, 0.961, 0...</td>\n",
       "      <td>[[89.69414520263672, 89.41947174072266], [88.1...</td>\n",
       "      <td>0.986383</td>\n",
       "      <td>0.9732</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.92725, 0.9481166666666667, 0.95445, 0.9669,...</td>\n",
       "      <td>[0.9273, 0.9472, 0.9508, 0.9611, 0.9601, 0.963...</td>\n",
       "      <td>[[90.03373718261719, 89.5990219116211], [88.27...</td>\n",
       "      <td>0.988100</td>\n",
       "      <td>0.9737</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9250166666666667, 0.9432, 0.9572, 0.9647, 0...</td>\n",
       "      <td>[0.9273, 0.9417, 0.9532, 0.9591, 0.9637, 0.962...</td>\n",
       "      <td>[[90.01834106445312, 89.76544189453125], [89.2...</td>\n",
       "      <td>0.990217</td>\n",
       "      <td>0.9756</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.91335, 0.9494, 0.9568, 0.9659, 0.9667, 0.97...</td>\n",
       "      <td>[0.9127, 0.946, 0.9514, 0.9598, 0.9602, 0.9652...</td>\n",
       "      <td>[[90.15558624267578, 87.85167694091797], [89.3...</td>\n",
       "      <td>0.989850</td>\n",
       "      <td>0.9741</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9267666666666666, 0.94815, 0.95826666666666...</td>\n",
       "      <td>[0.9264, 0.9448, 0.9548, 0.9607, 0.961, 0.9673...</td>\n",
       "      <td>[[89.92867279052734, 90.13568115234375], [89.5...</td>\n",
       "      <td>0.991300</td>\n",
       "      <td>0.9786</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9185333333333333, 0.9467166666666667, 0.959...</td>\n",
       "      <td>[0.922, 0.9445, 0.9526, 0.9617, 0.9614, 0.9663...</td>\n",
       "      <td>[[90.042724609375, 88.67874908447266], [89.657...</td>\n",
       "      <td>0.991200</td>\n",
       "      <td>0.9786</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.928, 0.9471833333333334, 0.9569333333333333...</td>\n",
       "      <td>[0.9304, 0.9426, 0.9518, 0.9617, 0.9634, 0.967...</td>\n",
       "      <td>[[89.99951934814453, 90.03453826904297], [89.8...</td>\n",
       "      <td>0.992433</td>\n",
       "      <td>0.9787</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9283666666666667, 0.94775, 0.95603333333333...</td>\n",
       "      <td>[0.9296, 0.944, 0.9517, 0.9622, 0.9649, 0.9636...</td>\n",
       "      <td>[[90.03175354003906, 89.64662170410156], [89.8...</td>\n",
       "      <td>0.992050</td>\n",
       "      <td>0.9794</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9272333333333334, 0.9417166666666666, 0.955...</td>\n",
       "      <td>[0.9296, 0.9395, 0.9525, 0.9548, 0.9542, 0.961...</td>\n",
       "      <td>[[89.69414520263672, 89.4194564819336], [88.14...</td>\n",
       "      <td>0.986000</td>\n",
       "      <td>0.9725</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9273333333333333, 0.94745, 0.95261666666666...</td>\n",
       "      <td>[0.9273, 0.9476, 0.9494, 0.9599, 0.9594, 0.964...</td>\n",
       "      <td>[[90.03373718261719, 89.59906768798828], [88.2...</td>\n",
       "      <td>0.986950</td>\n",
       "      <td>0.9740</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9249, 0.9439666666666666, 0.95745, 0.963333...</td>\n",
       "      <td>[0.9273, 0.9425, 0.9535, 0.9576, 0.9635, 0.961...</td>\n",
       "      <td>[[90.01834106445312, 89.76543426513672], [89.2...</td>\n",
       "      <td>0.989583</td>\n",
       "      <td>0.9741</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>6</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9138, 0.9487666666666666, 0.954283333333333...</td>\n",
       "      <td>[0.9147, 0.9446, 0.9483, 0.9602, 0.9596, 0.963...</td>\n",
       "      <td>[[90.15558624267578, 87.85172271728516], [89.3...</td>\n",
       "      <td>0.989367</td>\n",
       "      <td>0.9741</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>7</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9259666666666667, 0.9470833333333334, 0.956...</td>\n",
       "      <td>[0.926, 0.943, 0.9521, 0.9611, 0.9603, 0.9669,...</td>\n",
       "      <td>[[89.92867279052734, 90.13567352294922], [89.5...</td>\n",
       "      <td>0.990617</td>\n",
       "      <td>0.9774</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>7</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.91825, 0.9462166666666667, 0.95956666666666...</td>\n",
       "      <td>[0.921, 0.9444, 0.9532, 0.9609, 0.9603, 0.9663...</td>\n",
       "      <td>[[90.042724609375, 88.67878723144531], [89.661...</td>\n",
       "      <td>0.990767</td>\n",
       "      <td>0.9780</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>8</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9275666666666667, 0.94495, 0.95506666666666...</td>\n",
       "      <td>[0.9292, 0.9415, 0.9511, 0.9612, 0.9627, 0.967...</td>\n",
       "      <td>[[89.99951171875, 90.03453063964844], [89.8161...</td>\n",
       "      <td>0.991750</td>\n",
       "      <td>0.9780</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>8</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...</td>\n",
       "      <td>[0.9284666666666667, 0.9469833333333333, 0.954...</td>\n",
       "      <td>[0.9294, 0.9436, 0.9509, 0.9612, 0.9634, 0.963...</td>\n",
       "      <td>[[90.03175354003906, 89.64667510986328], [89.8...</td>\n",
       "      <td>0.991767</td>\n",
       "      <td>0.9788</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   setting_number seed       Model  \\\n",
       "0               1    0  CorInfoMax   \n",
       "1               1   10  CorInfoMax   \n",
       "2               2    0  CorInfoMax   \n",
       "3               2   10  CorInfoMax   \n",
       "4               3    0  CorInfoMax   \n",
       "5               3   10  CorInfoMax   \n",
       "6               4    0  CorInfoMax   \n",
       "7               4   10  CorInfoMax   \n",
       "8               5    0  CorInfoMax   \n",
       "9               5   10  CorInfoMax   \n",
       "10              6    0  CorInfoMax   \n",
       "11              6   10  CorInfoMax   \n",
       "12              7    0  CorInfoMax   \n",
       "13              7   10  CorInfoMax   \n",
       "14              8    0  CorInfoMax   \n",
       "15              8   10  CorInfoMax   \n",
       "\n",
       "                                          Hyperparams  \\\n",
       "0   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "1   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "2   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "3   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "4   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "5   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "6   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "7   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "8   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "9   {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "10  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "11  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "12  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "13  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "14  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "15  {'lr_start': {'ff': [1.1, 0.75, 0.6], 'fb': [n...   \n",
       "\n",
       "                                         Trn_ACC_list  \\\n",
       "0   [0.9295, 0.94345, 0.9562833333333334, 0.961533...   \n",
       "1   [0.92725, 0.9481166666666667, 0.95445, 0.9669,...   \n",
       "2   [0.9250166666666667, 0.9432, 0.9572, 0.9647, 0...   \n",
       "3   [0.91335, 0.9494, 0.9568, 0.9659, 0.9667, 0.97...   \n",
       "4   [0.9267666666666666, 0.94815, 0.95826666666666...   \n",
       "5   [0.9185333333333333, 0.9467166666666667, 0.959...   \n",
       "6   [0.928, 0.9471833333333334, 0.9569333333333333...   \n",
       "7   [0.9283666666666667, 0.94775, 0.95603333333333...   \n",
       "8   [0.9272333333333334, 0.9417166666666666, 0.955...   \n",
       "9   [0.9273333333333333, 0.94745, 0.95261666666666...   \n",
       "10  [0.9249, 0.9439666666666666, 0.95745, 0.963333...   \n",
       "11  [0.9138, 0.9487666666666666, 0.954283333333333...   \n",
       "12  [0.9259666666666667, 0.9470833333333334, 0.956...   \n",
       "13  [0.91825, 0.9462166666666667, 0.95956666666666...   \n",
       "14  [0.9275666666666667, 0.94495, 0.95506666666666...   \n",
       "15  [0.9284666666666667, 0.9469833333333333, 0.954...   \n",
       "\n",
       "                                         Tst_ACC_list  \\\n",
       "0   [0.93, 0.9413, 0.953, 0.9558, 0.9539, 0.961, 0...   \n",
       "1   [0.9273, 0.9472, 0.9508, 0.9611, 0.9601, 0.963...   \n",
       "2   [0.9273, 0.9417, 0.9532, 0.9591, 0.9637, 0.962...   \n",
       "3   [0.9127, 0.946, 0.9514, 0.9598, 0.9602, 0.9652...   \n",
       "4   [0.9264, 0.9448, 0.9548, 0.9607, 0.961, 0.9673...   \n",
       "5   [0.922, 0.9445, 0.9526, 0.9617, 0.9614, 0.9663...   \n",
       "6   [0.9304, 0.9426, 0.9518, 0.9617, 0.9634, 0.967...   \n",
       "7   [0.9296, 0.944, 0.9517, 0.9622, 0.9649, 0.9636...   \n",
       "8   [0.9296, 0.9395, 0.9525, 0.9548, 0.9542, 0.961...   \n",
       "9   [0.9273, 0.9476, 0.9494, 0.9599, 0.9594, 0.964...   \n",
       "10  [0.9273, 0.9425, 0.9535, 0.9576, 0.9635, 0.961...   \n",
       "11  [0.9147, 0.9446, 0.9483, 0.9602, 0.9596, 0.963...   \n",
       "12  [0.926, 0.943, 0.9521, 0.9611, 0.9603, 0.9669,...   \n",
       "13  [0.921, 0.9444, 0.9532, 0.9609, 0.9603, 0.9663...   \n",
       "14  [0.9292, 0.9415, 0.9511, 0.9612, 0.9627, 0.967...   \n",
       "15  [0.9294, 0.9436, 0.9509, 0.9612, 0.9634, 0.963...   \n",
       "\n",
       "                   forward_backward_weight_angle_list   Trn_ACC  Tst_ACC  \\\n",
       "0   [[89.69414520263672, 89.41947174072266], [88.1...  0.986383   0.9732   \n",
       "1   [[90.03373718261719, 89.5990219116211], [88.27...  0.988100   0.9737   \n",
       "2   [[90.01834106445312, 89.76544189453125], [89.2...  0.990217   0.9756   \n",
       "3   [[90.15558624267578, 87.85167694091797], [89.3...  0.989850   0.9741   \n",
       "4   [[89.92867279052734, 90.13568115234375], [89.5...  0.991300   0.9786   \n",
       "5   [[90.042724609375, 88.67874908447266], [89.657...  0.991200   0.9786   \n",
       "6   [[89.99951934814453, 90.03453826904297], [89.8...  0.992433   0.9787   \n",
       "7   [[90.03175354003906, 89.64662170410156], [89.8...  0.992050   0.9794   \n",
       "8   [[89.69414520263672, 89.4194564819336], [88.14...  0.986000   0.9725   \n",
       "9   [[90.03373718261719, 89.59906768798828], [88.2...  0.986950   0.9740   \n",
       "10  [[90.01834106445312, 89.76543426513672], [89.2...  0.989583   0.9741   \n",
       "11  [[90.15558624267578, 87.85172271728516], [89.3...  0.989367   0.9741   \n",
       "12  [[89.92867279052734, 90.13567352294922], [89.5...  0.990617   0.9774   \n",
       "13  [[90.042724609375, 88.67878723144531], [89.661...  0.990767   0.9780   \n",
       "14  [[89.99951171875, 90.03453063964844], [89.8161...  0.991750   0.9780   \n",
       "15  [[90.03175354003906, 89.64667510986328], [89.8...  0.991767   0.9788   \n",
       "\n",
       "    hidden_layer_size  \n",
       "0                 256  \n",
       "1                 256  \n",
       "2                 512  \n",
       "3                 512  \n",
       "4                1024  \n",
       "5                1024  \n",
       "6                2048  \n",
       "7                2048  \n",
       "8                 256  \n",
       "9                 256  \n",
       "10                512  \n",
       "11                512  \n",
       "12               1024  \n",
       "13               1024  \n",
       "14               2048  \n",
       "15               2048  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results['Trn_ACC'] = df_results.apply(lambda row: row['Trn_ACC_list'][-1], axis = 1)\n",
    "df_results['Tst_ACC'] = df_results.apply(lambda row: row['Tst_ACC_list'][-1], axis = 1)\n",
    "df_results[\"hidden_layer_size\"] = df_results.apply(lambda row: row[\"Hyperparams\"][\"hidden_layer_size\"], axis = 1).values\n",
    "df_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e04020fb",
   "metadata": {},
   "source": [
    "# CIFAR10 Hidden Layer Size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "dad692db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(37, 7)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Index(['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list',\n",
       "       'Tst_ACC_list', 'forward_backward_weight_angle_list'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results = pd.read_pickle(r\"../Results/simulation_results_CorInfoMax_CIFAR10_HiddenLayerSize_Ablation.pkl\")\n",
    "print(df_results.shape)\n",
    "df_results.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d54ffb54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>setting_number</th>\n",
       "      <th>seed</th>\n",
       "      <th>Model</th>\n",
       "      <th>Hyperparams</th>\n",
       "      <th>Trn_ACC_list</th>\n",
       "      <th>Tst_ACC_list</th>\n",
       "      <th>forward_backward_weight_angle_list</th>\n",
       "      <th>Trn_ACC</th>\n",
       "      <th>Tst_ACC</th>\n",
       "      <th>hidden_layer_size</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.30378, 0.3602, 0.4009, 0.42444, 0.43846, 0....</td>\n",
       "      <td>[0.3055, 0.3531, 0.3921, 0.4091, 0.4186, 0.431...</td>\n",
       "      <td>[[90.66891479492188], [89.83344268798828], [88...</td>\n",
       "      <td>0.57240</td>\n",
       "      <td>0.4851</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.2945, 0.36132, 0.39258, 0.41726, 0.43206, 0...</td>\n",
       "      <td>[0.2932, 0.3612, 0.3854, 0.408, 0.4159, 0.4289...</td>\n",
       "      <td>[[91.4767837524414], [90.19556427001953], [89....</td>\n",
       "      <td>0.57090</td>\n",
       "      <td>0.4804</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>20</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.2737, 0.33824, 0.38382, 0.41152, 0.43064, 0...</td>\n",
       "      <td>[0.2751, 0.3304, 0.3752, 0.3999, 0.4158, 0.427...</td>\n",
       "      <td>[[90.09364318847656], [89.2579345703125], [88....</td>\n",
       "      <td>0.57776</td>\n",
       "      <td>0.4851</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>30</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.32736, 0.37874, 0.40892, 0.43212, 0.44806, ...</td>\n",
       "      <td>[0.3217, 0.3749, 0.4, 0.4155, 0.4245, 0.4405, ...</td>\n",
       "      <td>[[88.9329605102539], [87.81392669677734], [86....</td>\n",
       "      <td>0.58536</td>\n",
       "      <td>0.4893</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>40</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.2848, 0.3606, 0.39442, 0.42408, 0.44216, 0....</td>\n",
       "      <td>[0.2786, 0.3563, 0.3856, 0.4127, 0.4257, 0.434...</td>\n",
       "      <td>[[89.21763610839844], [88.2616195678711], [87....</td>\n",
       "      <td>0.57676</td>\n",
       "      <td>0.4846</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>50</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.30248, 0.3607, 0.39364, 0.41522, 0.43698, 0...</td>\n",
       "      <td>[0.3046, 0.3658, 0.3841, 0.4049, 0.415, 0.4234...</td>\n",
       "      <td>[[89.66824340820312], [88.52716064453125], [87...</td>\n",
       "      <td>0.57188</td>\n",
       "      <td>0.4836</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>60</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.30636, 0.35824, 0.3885, 0.408, 0.42856, 0.4...</td>\n",
       "      <td>[0.303, 0.3493, 0.38, 0.3945, 0.414, 0.4264, 0...</td>\n",
       "      <td>[[90.69612121582031], [89.47639465332031], [88...</td>\n",
       "      <td>0.57640</td>\n",
       "      <td>0.4894</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>70</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.28724, 0.34404, 0.38786, 0.41426, 0.43632, ...</td>\n",
       "      <td>[0.2966, 0.3421, 0.3853, 0.4048, 0.4193, 0.43,...</td>\n",
       "      <td>[[91.867919921875], [90.5802001953125], [89.42...</td>\n",
       "      <td>0.58434</td>\n",
       "      <td>0.4823</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>80</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.2998, 0.37126, 0.40296, 0.4225, 0.44524, 0....</td>\n",
       "      <td>[0.2942, 0.3674, 0.3936, 0.4092, 0.4219, 0.433...</td>\n",
       "      <td>[[87.8289794921875], [87.10165405273438], [86....</td>\n",
       "      <td>0.58418</td>\n",
       "      <td>0.4852</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>90</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.29176, 0.37032, 0.40556, 0.42488, 0.44386, ...</td>\n",
       "      <td>[0.288, 0.3682, 0.3972, 0.4125, 0.4245, 0.4335...</td>\n",
       "      <td>[[90.79267883300781], [90.06278991699219], [89...</td>\n",
       "      <td>0.57946</td>\n",
       "      <td>0.4814</td>\n",
       "      <td>256</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.29002, 0.37746, 0.41332, 0.44086, 0.4619, 0...</td>\n",
       "      <td>[0.2912, 0.3788, 0.4018, 0.4235, 0.4467, 0.457...</td>\n",
       "      <td>[[90.7146224975586], [89.9310302734375], [89.0...</td>\n",
       "      <td>0.62030</td>\n",
       "      <td>0.5022</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>2</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.33788, 0.38504, 0.42292, 0.45098, 0.46914, ...</td>\n",
       "      <td>[0.344, 0.3771, 0.4102, 0.4339, 0.4491, 0.4569...</td>\n",
       "      <td>[[89.55988311767578], [88.61637115478516], [87...</td>\n",
       "      <td>0.61694</td>\n",
       "      <td>0.5007</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>2</td>\n",
       "      <td>20</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.3364, 0.40234, 0.43416, 0.45748, 0.47716, 0...</td>\n",
       "      <td>[0.3318, 0.3898, 0.416, 0.4337, 0.4515, 0.4557...</td>\n",
       "      <td>[[88.46021270751953], [87.66310119628906], [86...</td>\n",
       "      <td>0.62080</td>\n",
       "      <td>0.5061</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>2</td>\n",
       "      <td>30</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.33342, 0.38922, 0.42622, 0.4472, 0.4677, 0....</td>\n",
       "      <td>[0.3271, 0.3886, 0.4201, 0.4307, 0.4486, 0.458...</td>\n",
       "      <td>[[90.3956069946289], [89.62022399902344], [88....</td>\n",
       "      <td>0.61636</td>\n",
       "      <td>0.5065</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>2</td>\n",
       "      <td>40</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.3289, 0.37998, 0.41484, 0.4423, 0.45954, 0....</td>\n",
       "      <td>[0.3291, 0.3796, 0.4089, 0.4275, 0.4431, 0.455...</td>\n",
       "      <td>[[89.07485961914062], [88.2500228881836], [87....</td>\n",
       "      <td>0.61682</td>\n",
       "      <td>0.5013</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>2</td>\n",
       "      <td>50</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.32892, 0.39038, 0.42296, 0.44928, 0.46794, ...</td>\n",
       "      <td>[0.3272, 0.3816, 0.4139, 0.4275, 0.4411, 0.453...</td>\n",
       "      <td>[[89.21969604492188], [88.49434661865234], [87...</td>\n",
       "      <td>0.62106</td>\n",
       "      <td>0.4981</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>2</td>\n",
       "      <td>60</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.34012, 0.39586, 0.42742, 0.45234, 0.4694, 0...</td>\n",
       "      <td>[0.3435, 0.3978, 0.4203, 0.4397, 0.4476, 0.455...</td>\n",
       "      <td>[[90.03176879882812], [89.20450592041016], [88...</td>\n",
       "      <td>0.62006</td>\n",
       "      <td>0.5017</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>2</td>\n",
       "      <td>70</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.31864, 0.38314, 0.41818, 0.44316, 0.4622, 0...</td>\n",
       "      <td>[0.315, 0.3797, 0.4054, 0.4272, 0.4421, 0.4488...</td>\n",
       "      <td>[[90.89095306396484], [89.91641998291016], [88...</td>\n",
       "      <td>0.61602</td>\n",
       "      <td>0.5095</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>2</td>\n",
       "      <td>80</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.2743, 0.33772, 0.39076, 0.42308, 0.44982, 0...</td>\n",
       "      <td>[0.2702, 0.3394, 0.3843, 0.4062, 0.43, 0.4449,...</td>\n",
       "      <td>[[91.49378204345703], [90.82295227050781], [90...</td>\n",
       "      <td>0.61492</td>\n",
       "      <td>0.4998</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>2</td>\n",
       "      <td>90</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.33272, 0.38792, 0.4242, 0.44582, 0.46166, 0...</td>\n",
       "      <td>[0.3312, 0.3852, 0.4181, 0.4327, 0.4451, 0.455...</td>\n",
       "      <td>[[90.84580993652344], [90.0239028930664], [89....</td>\n",
       "      <td>0.61896</td>\n",
       "      <td>0.5009</td>\n",
       "      <td>512</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.3456, 0.40536, 0.44106, 0.46498, 0.48564, 0...</td>\n",
       "      <td>[0.3478, 0.3998, 0.426, 0.4437, 0.4566, 0.466,...</td>\n",
       "      <td>[[90.40131378173828], [89.73482513427734], [88...</td>\n",
       "      <td>0.65268</td>\n",
       "      <td>0.5163</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35144, 0.40898, 0.4435, 0.46472, 0.48524, 0...</td>\n",
       "      <td>[0.349, 0.4029, 0.4335, 0.4508, 0.4587, 0.4698...</td>\n",
       "      <td>[[89.93730163574219], [89.17709350585938], [88...</td>\n",
       "      <td>0.65106</td>\n",
       "      <td>0.5192</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>3</td>\n",
       "      <td>20</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35048, 0.41064, 0.44528, 0.4731, 0.49088, 0...</td>\n",
       "      <td>[0.3519, 0.4036, 0.4312, 0.449, 0.46, 0.4698, ...</td>\n",
       "      <td>[[89.20269012451172], [88.5719985961914], [87....</td>\n",
       "      <td>0.65094</td>\n",
       "      <td>0.5190</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>3</td>\n",
       "      <td>30</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35204, 0.4209, 0.4492, 0.4702, 0.49394, 0.5...</td>\n",
       "      <td>[0.3452, 0.4076, 0.4344, 0.4507, 0.463, 0.4746...</td>\n",
       "      <td>[[88.78587341308594], [88.04069519042969], [87...</td>\n",
       "      <td>0.64556</td>\n",
       "      <td>0.5203</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>3</td>\n",
       "      <td>40</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.36244, 0.41818, 0.45122, 0.47152, 0.49294, ...</td>\n",
       "      <td>[0.3657, 0.4056, 0.433, 0.4451, 0.458, 0.4682,...</td>\n",
       "      <td>[[90.7271957397461], [89.76773071289062], [88....</td>\n",
       "      <td>0.65122</td>\n",
       "      <td>0.5200</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>3</td>\n",
       "      <td>50</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35784, 0.41364, 0.44548, 0.46842, 0.48804, ...</td>\n",
       "      <td>[0.361, 0.4057, 0.4325, 0.4486, 0.457, 0.4668,...</td>\n",
       "      <td>[[90.33991241455078], [89.550048828125], [88.6...</td>\n",
       "      <td>0.64872</td>\n",
       "      <td>0.5198</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>3</td>\n",
       "      <td>60</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35242, 0.4127, 0.44822, 0.4695, 0.49044, 0....</td>\n",
       "      <td>[0.3489, 0.3982, 0.4292, 0.4508, 0.4635, 0.474...</td>\n",
       "      <td>[[90.36842346191406], [89.26473236083984], [87...</td>\n",
       "      <td>0.65042</td>\n",
       "      <td>0.5143</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>3</td>\n",
       "      <td>70</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.3512, 0.41206, 0.44786, 0.46918, 0.49012, 0...</td>\n",
       "      <td>[0.3551, 0.4099, 0.4409, 0.4505, 0.4679, 0.473...</td>\n",
       "      <td>[[89.11036682128906], [88.51934814453125], [87...</td>\n",
       "      <td>0.65160</td>\n",
       "      <td>0.5216</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>3</td>\n",
       "      <td>80</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35246, 0.4134, 0.4452, 0.47106, 0.49046, 0....</td>\n",
       "      <td>[0.3596, 0.4073, 0.4367, 0.4492, 0.4647, 0.471...</td>\n",
       "      <td>[[90.0738754272461], [89.43641662597656], [88....</td>\n",
       "      <td>0.64834</td>\n",
       "      <td>0.5193</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>3</td>\n",
       "      <td>90</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.35454, 0.41416, 0.44186, 0.46896, 0.48726, ...</td>\n",
       "      <td>[0.3488, 0.4068, 0.4293, 0.4489, 0.4633, 0.47,...</td>\n",
       "      <td>[[90.54320526123047], [89.8916244506836], [88....</td>\n",
       "      <td>0.65062</td>\n",
       "      <td>0.5197</td>\n",
       "      <td>1024</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.37978, 0.43686, 0.46406, 0.48776, 0.51076, ...</td>\n",
       "      <td>[0.3782, 0.4258, 0.4477, 0.4622, 0.4774, 0.487...</td>\n",
       "      <td>[[90.22587585449219], [89.52689361572266], [88...</td>\n",
       "      <td>0.67238</td>\n",
       "      <td>0.5308</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>4</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.37874, 0.43044, 0.4653, 0.48942, 0.5049, 0....</td>\n",
       "      <td>[0.3743, 0.4211, 0.4525, 0.4609, 0.4715, 0.482...</td>\n",
       "      <td>[[90.28846740722656], [89.57877349853516], [88...</td>\n",
       "      <td>0.67266</td>\n",
       "      <td>0.5308</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>4</td>\n",
       "      <td>20</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.37806, 0.43278, 0.46106, 0.48768, 0.509, 0....</td>\n",
       "      <td>[0.3751, 0.4167, 0.4476, 0.4627, 0.4756, 0.484...</td>\n",
       "      <td>[[89.6426773071289], [89.13346099853516], [88....</td>\n",
       "      <td>0.67702</td>\n",
       "      <td>0.5340</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>4</td>\n",
       "      <td>30</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.36928, 0.4304, 0.45744, 0.481, 0.4999, 0.52...</td>\n",
       "      <td>[0.369, 0.4269, 0.4424, 0.4636, 0.4663, 0.4839...</td>\n",
       "      <td>[[89.55570220947266], [88.91480255126953], [87...</td>\n",
       "      <td>0.67570</td>\n",
       "      <td>0.5345</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>34</th>\n",
       "      <td>4</td>\n",
       "      <td>40</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.378, 0.4336, 0.46306, 0.48956, 0.50692, 0.5...</td>\n",
       "      <td>[0.3787, 0.4205, 0.4445, 0.4624, 0.4754, 0.486...</td>\n",
       "      <td>[[90.21064758300781], [89.49546813964844], [88...</td>\n",
       "      <td>0.67576</td>\n",
       "      <td>0.5261</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>35</th>\n",
       "      <td>4</td>\n",
       "      <td>50</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.38522, 0.44088, 0.4661, 0.48796, 0.51096, 0...</td>\n",
       "      <td>[0.3861, 0.434, 0.4473, 0.4647, 0.4802, 0.4883...</td>\n",
       "      <td>[[90.0096435546875], [89.27079772949219], [88....</td>\n",
       "      <td>0.67346</td>\n",
       "      <td>0.5308</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>36</th>\n",
       "      <td>4</td>\n",
       "      <td>60</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...</td>\n",
       "      <td>[0.38328, 0.43312, 0.46752, 0.4863, 0.50358, 0...</td>\n",
       "      <td>[0.3779, 0.4228, 0.4534, 0.4642, 0.4754, 0.486...</td>\n",
       "      <td>[[89.85319519042969], [89.20379638671875], [88...</td>\n",
       "      <td>0.67394</td>\n",
       "      <td>0.5215</td>\n",
       "      <td>2048</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   setting_number seed       Model  \\\n",
       "0               1    0  CorInfoMax   \n",
       "1               1   10  CorInfoMax   \n",
       "2               1   20  CorInfoMax   \n",
       "3               1   30  CorInfoMax   \n",
       "4               1   40  CorInfoMax   \n",
       "5               1   50  CorInfoMax   \n",
       "6               1   60  CorInfoMax   \n",
       "7               1   70  CorInfoMax   \n",
       "8               1   80  CorInfoMax   \n",
       "9               1   90  CorInfoMax   \n",
       "10              2    0  CorInfoMax   \n",
       "11              2   10  CorInfoMax   \n",
       "12              2   20  CorInfoMax   \n",
       "13              2   30  CorInfoMax   \n",
       "14              2   40  CorInfoMax   \n",
       "15              2   50  CorInfoMax   \n",
       "16              2   60  CorInfoMax   \n",
       "17              2   70  CorInfoMax   \n",
       "18              2   80  CorInfoMax   \n",
       "19              2   90  CorInfoMax   \n",
       "20              3    0  CorInfoMax   \n",
       "21              3   10  CorInfoMax   \n",
       "22              3   20  CorInfoMax   \n",
       "23              3   30  CorInfoMax   \n",
       "24              3   40  CorInfoMax   \n",
       "25              3   50  CorInfoMax   \n",
       "26              3   60  CorInfoMax   \n",
       "27              3   70  CorInfoMax   \n",
       "28              3   80  CorInfoMax   \n",
       "29              3   90  CorInfoMax   \n",
       "30              4    0  CorInfoMax   \n",
       "31              4   10  CorInfoMax   \n",
       "32              4   20  CorInfoMax   \n",
       "33              4   30  CorInfoMax   \n",
       "34              4   40  CorInfoMax   \n",
       "35              4   50  CorInfoMax   \n",
       "36              4   60  CorInfoMax   \n",
       "\n",
       "                                          Hyperparams  \\\n",
       "0   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "1   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "2   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "3   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "4   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "5   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "6   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "7   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "8   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "9   {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "10  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "11  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "12  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "13  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "14  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "15  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "16  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "17  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "18  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "19  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "20  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "21  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "22  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "23  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "24  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "25  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "26  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "27  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "28  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "29  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "30  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "31  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "32  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "33  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "34  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "35  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "36  {'lr_start': {'ff': [0.08, 0.04], 'fb': [nan, ...   \n",
       "\n",
       "                                         Trn_ACC_list  \\\n",
       "0   [0.30378, 0.3602, 0.4009, 0.42444, 0.43846, 0....   \n",
       "1   [0.2945, 0.36132, 0.39258, 0.41726, 0.43206, 0...   \n",
       "2   [0.2737, 0.33824, 0.38382, 0.41152, 0.43064, 0...   \n",
       "3   [0.32736, 0.37874, 0.40892, 0.43212, 0.44806, ...   \n",
       "4   [0.2848, 0.3606, 0.39442, 0.42408, 0.44216, 0....   \n",
       "5   [0.30248, 0.3607, 0.39364, 0.41522, 0.43698, 0...   \n",
       "6   [0.30636, 0.35824, 0.3885, 0.408, 0.42856, 0.4...   \n",
       "7   [0.28724, 0.34404, 0.38786, 0.41426, 0.43632, ...   \n",
       "8   [0.2998, 0.37126, 0.40296, 0.4225, 0.44524, 0....   \n",
       "9   [0.29176, 0.37032, 0.40556, 0.42488, 0.44386, ...   \n",
       "10  [0.29002, 0.37746, 0.41332, 0.44086, 0.4619, 0...   \n",
       "11  [0.33788, 0.38504, 0.42292, 0.45098, 0.46914, ...   \n",
       "12  [0.3364, 0.40234, 0.43416, 0.45748, 0.47716, 0...   \n",
       "13  [0.33342, 0.38922, 0.42622, 0.4472, 0.4677, 0....   \n",
       "14  [0.3289, 0.37998, 0.41484, 0.4423, 0.45954, 0....   \n",
       "15  [0.32892, 0.39038, 0.42296, 0.44928, 0.46794, ...   \n",
       "16  [0.34012, 0.39586, 0.42742, 0.45234, 0.4694, 0...   \n",
       "17  [0.31864, 0.38314, 0.41818, 0.44316, 0.4622, 0...   \n",
       "18  [0.2743, 0.33772, 0.39076, 0.42308, 0.44982, 0...   \n",
       "19  [0.33272, 0.38792, 0.4242, 0.44582, 0.46166, 0...   \n",
       "20  [0.3456, 0.40536, 0.44106, 0.46498, 0.48564, 0...   \n",
       "21  [0.35144, 0.40898, 0.4435, 0.46472, 0.48524, 0...   \n",
       "22  [0.35048, 0.41064, 0.44528, 0.4731, 0.49088, 0...   \n",
       "23  [0.35204, 0.4209, 0.4492, 0.4702, 0.49394, 0.5...   \n",
       "24  [0.36244, 0.41818, 0.45122, 0.47152, 0.49294, ...   \n",
       "25  [0.35784, 0.41364, 0.44548, 0.46842, 0.48804, ...   \n",
       "26  [0.35242, 0.4127, 0.44822, 0.4695, 0.49044, 0....   \n",
       "27  [0.3512, 0.41206, 0.44786, 0.46918, 0.49012, 0...   \n",
       "28  [0.35246, 0.4134, 0.4452, 0.47106, 0.49046, 0....   \n",
       "29  [0.35454, 0.41416, 0.44186, 0.46896, 0.48726, ...   \n",
       "30  [0.37978, 0.43686, 0.46406, 0.48776, 0.51076, ...   \n",
       "31  [0.37874, 0.43044, 0.4653, 0.48942, 0.5049, 0....   \n",
       "32  [0.37806, 0.43278, 0.46106, 0.48768, 0.509, 0....   \n",
       "33  [0.36928, 0.4304, 0.45744, 0.481, 0.4999, 0.52...   \n",
       "34  [0.378, 0.4336, 0.46306, 0.48956, 0.50692, 0.5...   \n",
       "35  [0.38522, 0.44088, 0.4661, 0.48796, 0.51096, 0...   \n",
       "36  [0.38328, 0.43312, 0.46752, 0.4863, 0.50358, 0...   \n",
       "\n",
       "                                         Tst_ACC_list  \\\n",
       "0   [0.3055, 0.3531, 0.3921, 0.4091, 0.4186, 0.431...   \n",
       "1   [0.2932, 0.3612, 0.3854, 0.408, 0.4159, 0.4289...   \n",
       "2   [0.2751, 0.3304, 0.3752, 0.3999, 0.4158, 0.427...   \n",
       "3   [0.3217, 0.3749, 0.4, 0.4155, 0.4245, 0.4405, ...   \n",
       "4   [0.2786, 0.3563, 0.3856, 0.4127, 0.4257, 0.434...   \n",
       "5   [0.3046, 0.3658, 0.3841, 0.4049, 0.415, 0.4234...   \n",
       "6   [0.303, 0.3493, 0.38, 0.3945, 0.414, 0.4264, 0...   \n",
       "7   [0.2966, 0.3421, 0.3853, 0.4048, 0.4193, 0.43,...   \n",
       "8   [0.2942, 0.3674, 0.3936, 0.4092, 0.4219, 0.433...   \n",
       "9   [0.288, 0.3682, 0.3972, 0.4125, 0.4245, 0.4335...   \n",
       "10  [0.2912, 0.3788, 0.4018, 0.4235, 0.4467, 0.457...   \n",
       "11  [0.344, 0.3771, 0.4102, 0.4339, 0.4491, 0.4569...   \n",
       "12  [0.3318, 0.3898, 0.416, 0.4337, 0.4515, 0.4557...   \n",
       "13  [0.3271, 0.3886, 0.4201, 0.4307, 0.4486, 0.458...   \n",
       "14  [0.3291, 0.3796, 0.4089, 0.4275, 0.4431, 0.455...   \n",
       "15  [0.3272, 0.3816, 0.4139, 0.4275, 0.4411, 0.453...   \n",
       "16  [0.3435, 0.3978, 0.4203, 0.4397, 0.4476, 0.455...   \n",
       "17  [0.315, 0.3797, 0.4054, 0.4272, 0.4421, 0.4488...   \n",
       "18  [0.2702, 0.3394, 0.3843, 0.4062, 0.43, 0.4449,...   \n",
       "19  [0.3312, 0.3852, 0.4181, 0.4327, 0.4451, 0.455...   \n",
       "20  [0.3478, 0.3998, 0.426, 0.4437, 0.4566, 0.466,...   \n",
       "21  [0.349, 0.4029, 0.4335, 0.4508, 0.4587, 0.4698...   \n",
       "22  [0.3519, 0.4036, 0.4312, 0.449, 0.46, 0.4698, ...   \n",
       "23  [0.3452, 0.4076, 0.4344, 0.4507, 0.463, 0.4746...   \n",
       "24  [0.3657, 0.4056, 0.433, 0.4451, 0.458, 0.4682,...   \n",
       "25  [0.361, 0.4057, 0.4325, 0.4486, 0.457, 0.4668,...   \n",
       "26  [0.3489, 0.3982, 0.4292, 0.4508, 0.4635, 0.474...   \n",
       "27  [0.3551, 0.4099, 0.4409, 0.4505, 0.4679, 0.473...   \n",
       "28  [0.3596, 0.4073, 0.4367, 0.4492, 0.4647, 0.471...   \n",
       "29  [0.3488, 0.4068, 0.4293, 0.4489, 0.4633, 0.47,...   \n",
       "30  [0.3782, 0.4258, 0.4477, 0.4622, 0.4774, 0.487...   \n",
       "31  [0.3743, 0.4211, 0.4525, 0.4609, 0.4715, 0.482...   \n",
       "32  [0.3751, 0.4167, 0.4476, 0.4627, 0.4756, 0.484...   \n",
       "33  [0.369, 0.4269, 0.4424, 0.4636, 0.4663, 0.4839...   \n",
       "34  [0.3787, 0.4205, 0.4445, 0.4624, 0.4754, 0.486...   \n",
       "35  [0.3861, 0.434, 0.4473, 0.4647, 0.4802, 0.4883...   \n",
       "36  [0.3779, 0.4228, 0.4534, 0.4642, 0.4754, 0.486...   \n",
       "\n",
       "                   forward_backward_weight_angle_list  Trn_ACC  Tst_ACC  \\\n",
       "0   [[90.66891479492188], [89.83344268798828], [88...  0.57240   0.4851   \n",
       "1   [[91.4767837524414], [90.19556427001953], [89....  0.57090   0.4804   \n",
       "2   [[90.09364318847656], [89.2579345703125], [88....  0.57776   0.4851   \n",
       "3   [[88.9329605102539], [87.81392669677734], [86....  0.58536   0.4893   \n",
       "4   [[89.21763610839844], [88.2616195678711], [87....  0.57676   0.4846   \n",
       "5   [[89.66824340820312], [88.52716064453125], [87...  0.57188   0.4836   \n",
       "6   [[90.69612121582031], [89.47639465332031], [88...  0.57640   0.4894   \n",
       "7   [[91.867919921875], [90.5802001953125], [89.42...  0.58434   0.4823   \n",
       "8   [[87.8289794921875], [87.10165405273438], [86....  0.58418   0.4852   \n",
       "9   [[90.79267883300781], [90.06278991699219], [89...  0.57946   0.4814   \n",
       "10  [[90.7146224975586], [89.9310302734375], [89.0...  0.62030   0.5022   \n",
       "11  [[89.55988311767578], [88.61637115478516], [87...  0.61694   0.5007   \n",
       "12  [[88.46021270751953], [87.66310119628906], [86...  0.62080   0.5061   \n",
       "13  [[90.3956069946289], [89.62022399902344], [88....  0.61636   0.5065   \n",
       "14  [[89.07485961914062], [88.2500228881836], [87....  0.61682   0.5013   \n",
       "15  [[89.21969604492188], [88.49434661865234], [87...  0.62106   0.4981   \n",
       "16  [[90.03176879882812], [89.20450592041016], [88...  0.62006   0.5017   \n",
       "17  [[90.89095306396484], [89.91641998291016], [88...  0.61602   0.5095   \n",
       "18  [[91.49378204345703], [90.82295227050781], [90...  0.61492   0.4998   \n",
       "19  [[90.84580993652344], [90.0239028930664], [89....  0.61896   0.5009   \n",
       "20  [[90.40131378173828], [89.73482513427734], [88...  0.65268   0.5163   \n",
       "21  [[89.93730163574219], [89.17709350585938], [88...  0.65106   0.5192   \n",
       "22  [[89.20269012451172], [88.5719985961914], [87....  0.65094   0.5190   \n",
       "23  [[88.78587341308594], [88.04069519042969], [87...  0.64556   0.5203   \n",
       "24  [[90.7271957397461], [89.76773071289062], [88....  0.65122   0.5200   \n",
       "25  [[90.33991241455078], [89.550048828125], [88.6...  0.64872   0.5198   \n",
       "26  [[90.36842346191406], [89.26473236083984], [87...  0.65042   0.5143   \n",
       "27  [[89.11036682128906], [88.51934814453125], [87...  0.65160   0.5216   \n",
       "28  [[90.0738754272461], [89.43641662597656], [88....  0.64834   0.5193   \n",
       "29  [[90.54320526123047], [89.8916244506836], [88....  0.65062   0.5197   \n",
       "30  [[90.22587585449219], [89.52689361572266], [88...  0.67238   0.5308   \n",
       "31  [[90.28846740722656], [89.57877349853516], [88...  0.67266   0.5308   \n",
       "32  [[89.6426773071289], [89.13346099853516], [88....  0.67702   0.5340   \n",
       "33  [[89.55570220947266], [88.91480255126953], [87...  0.67570   0.5345   \n",
       "34  [[90.21064758300781], [89.49546813964844], [88...  0.67576   0.5261   \n",
       "35  [[90.0096435546875], [89.27079772949219], [88....  0.67346   0.5308   \n",
       "36  [[89.85319519042969], [89.20379638671875], [88...  0.67394   0.5215   \n",
       "\n",
       "    hidden_layer_size  \n",
       "0                 256  \n",
       "1                 256  \n",
       "2                 256  \n",
       "3                 256  \n",
       "4                 256  \n",
       "5                 256  \n",
       "6                 256  \n",
       "7                 256  \n",
       "8                 256  \n",
       "9                 256  \n",
       "10                512  \n",
       "11                512  \n",
       "12                512  \n",
       "13                512  \n",
       "14                512  \n",
       "15                512  \n",
       "16                512  \n",
       "17                512  \n",
       "18                512  \n",
       "19                512  \n",
       "20               1024  \n",
       "21               1024  \n",
       "22               1024  \n",
       "23               1024  \n",
       "24               1024  \n",
       "25               1024  \n",
       "26               1024  \n",
       "27               1024  \n",
       "28               1024  \n",
       "29               1024  \n",
       "30               2048  \n",
       "31               2048  \n",
       "32               2048  \n",
       "33               2048  \n",
       "34               2048  \n",
       "35               2048  \n",
       "36               2048  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results['Trn_ACC'] = df_results.apply(lambda row: row['Trn_ACC_list'][-1], axis = 1)\n",
    "df_results['Tst_ACC'] = df_results.apply(lambda row: row['Tst_ACC_list'][-1], axis = 1)\n",
    "df_results[\"hidden_layer_size\"] = df_results.apply(lambda row: row[\"Hyperparams\"][\"hidden_layer_size\"], axis = 1).values\n",
    "df_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb71247c",
   "metadata": {},
   "source": [
    "# CIFAR10 4 Layers Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "0a3eb293",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(11, 7)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Index(['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list',\n",
       "       'Tst_ACC_list', 'forward_backward_weight_angle_list'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results = pd.read_pickle(r\"../Results/simulation_results_CorInfoMax_CIFAR10_4Layers_V1.pkl\")\n",
    "print(df_results.shape)\n",
    "df_results.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "becd6bdb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>setting_number</th>\n",
       "      <th>seed</th>\n",
       "      <th>Model</th>\n",
       "      <th>Hyperparams</th>\n",
       "      <th>Trn_ACC_list</th>\n",
       "      <th>Tst_ACC_list</th>\n",
       "      <th>forward_backward_weight_angle_list</th>\n",
       "      <th>Trn_ACC</th>\n",
       "      <th>Tst_ACC</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.29148, 0.3688, 0.41754, 0.4493, 0.47024, 0....</td>\n",
       "      <td>[0.2916, 0.3657, 0.4159, 0.4448, 0.4542, 0.464...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.729...</td>\n",
       "      <td>0.10000</td>\n",
       "      <td>0.1000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.26546, 0.35018, 0.39964, 0.43704, 0.45686, ...</td>\n",
       "      <td>[0.2652, 0.3514, 0.4026, 0.4359, 0.447, 0.3046...</td>\n",
       "      <td>[[89.94451904296875, 90.0085220336914, 91.0055...</td>\n",
       "      <td>0.46806</td>\n",
       "      <td>0.4463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.28902, 0.36752, 0.4172, 0.44906, 0.46802, 0...</td>\n",
       "      <td>[0.2895, 0.3663, 0.4141, 0.4426, 0.4558, 0.462...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.730...</td>\n",
       "      <td>0.44152</td>\n",
       "      <td>0.4305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.25866, 0.36692, 0.4133, 0.44262, 0.46476, 0...</td>\n",
       "      <td>[0.2621, 0.3711, 0.4101, 0.4347, 0.4518, 0.469...</td>\n",
       "      <td>[[89.94451904296875, 90.0085220336914, 91.0056...</td>\n",
       "      <td>0.47554</td>\n",
       "      <td>0.4586</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.27942, 0.35734, 0.40456, 0.43656, 0.45484, ...</td>\n",
       "      <td>[0.2791, 0.3562, 0.401, 0.4305, 0.4448, 0.4548...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.729...</td>\n",
       "      <td>0.52732</td>\n",
       "      <td>0.4899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>3</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.25752, 0.34186, 0.39188, 0.42218, 0.4455, 0...</td>\n",
       "      <td>[0.2596, 0.3414, 0.3871, 0.4225, 0.437, 0.455,...</td>\n",
       "      <td>[[89.94451904296875, 90.00852966308594, 91.005...</td>\n",
       "      <td>0.45694</td>\n",
       "      <td>0.4447</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>4</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.2764, 0.3565, 0.4017, 0.43422, 0.45158, 0.4...</td>\n",
       "      <td>[0.2786, 0.3577, 0.4023, 0.4286, 0.444, 0.4543...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.729...</td>\n",
       "      <td>0.52016</td>\n",
       "      <td>0.4874</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>4</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.24864, 0.35578, 0.4025, 0.42846, 0.44986, 0...</td>\n",
       "      <td>[0.2502, 0.3547, 0.401, 0.4262, 0.4395, 0.4549...</td>\n",
       "      <td>[[89.94451904296875, 90.0085220336914, 91.0053...</td>\n",
       "      <td>0.53012</td>\n",
       "      <td>0.4980</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.29148, 0.36722, 0.41864, 0.44966, 0.46876, ...</td>\n",
       "      <td>[0.2921, 0.3651, 0.418, 0.4431, 0.4539, 0.4635...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.729...</td>\n",
       "      <td>0.30580</td>\n",
       "      <td>0.2997</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>5</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.26568, 0.3509, 0.4044, 0.43848, 0.4585, 0.4...</td>\n",
       "      <td>[0.2673, 0.3526, 0.4058, 0.4333, 0.4503, 0.466...</td>\n",
       "      <td>[[89.94451904296875, 90.0085220336914, 91.0055...</td>\n",
       "      <td>0.50048</td>\n",
       "      <td>0.4708</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>6</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...</td>\n",
       "      <td>[0.28862, 0.36776, 0.41822, 0.44908, 0.4682, 0...</td>\n",
       "      <td>[0.2885, 0.3669, 0.4163, 0.4424, 0.4532, 0.461...</td>\n",
       "      <td>[[89.92984771728516, 89.90174102783203, 90.730...</td>\n",
       "      <td>0.52332</td>\n",
       "      <td>0.4861</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   setting_number seed       Model  \\\n",
       "0               1    0  CorInfoMax   \n",
       "1               1   10  CorInfoMax   \n",
       "2               2    0  CorInfoMax   \n",
       "3               2   10  CorInfoMax   \n",
       "4               3    0  CorInfoMax   \n",
       "5               3   10  CorInfoMax   \n",
       "6               4    0  CorInfoMax   \n",
       "7               4   10  CorInfoMax   \n",
       "8               5    0  CorInfoMax   \n",
       "9               5   10  CorInfoMax   \n",
       "10              6    0  CorInfoMax   \n",
       "\n",
       "                                          Hyperparams  \\\n",
       "0   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "1   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "2   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "3   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "4   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "5   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "6   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "7   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "8   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "9   {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "10  {'lr_start': {'ff': [0.25, 0.14, 0.075, 0.045]...   \n",
       "\n",
       "                                         Trn_ACC_list  \\\n",
       "0   [0.29148, 0.3688, 0.41754, 0.4493, 0.47024, 0....   \n",
       "1   [0.26546, 0.35018, 0.39964, 0.43704, 0.45686, ...   \n",
       "2   [0.28902, 0.36752, 0.4172, 0.44906, 0.46802, 0...   \n",
       "3   [0.25866, 0.36692, 0.4133, 0.44262, 0.46476, 0...   \n",
       "4   [0.27942, 0.35734, 0.40456, 0.43656, 0.45484, ...   \n",
       "5   [0.25752, 0.34186, 0.39188, 0.42218, 0.4455, 0...   \n",
       "6   [0.2764, 0.3565, 0.4017, 0.43422, 0.45158, 0.4...   \n",
       "7   [0.24864, 0.35578, 0.4025, 0.42846, 0.44986, 0...   \n",
       "8   [0.29148, 0.36722, 0.41864, 0.44966, 0.46876, ...   \n",
       "9   [0.26568, 0.3509, 0.4044, 0.43848, 0.4585, 0.4...   \n",
       "10  [0.28862, 0.36776, 0.41822, 0.44908, 0.4682, 0...   \n",
       "\n",
       "                                         Tst_ACC_list  \\\n",
       "0   [0.2916, 0.3657, 0.4159, 0.4448, 0.4542, 0.464...   \n",
       "1   [0.2652, 0.3514, 0.4026, 0.4359, 0.447, 0.3046...   \n",
       "2   [0.2895, 0.3663, 0.4141, 0.4426, 0.4558, 0.462...   \n",
       "3   [0.2621, 0.3711, 0.4101, 0.4347, 0.4518, 0.469...   \n",
       "4   [0.2791, 0.3562, 0.401, 0.4305, 0.4448, 0.4548...   \n",
       "5   [0.2596, 0.3414, 0.3871, 0.4225, 0.437, 0.455,...   \n",
       "6   [0.2786, 0.3577, 0.4023, 0.4286, 0.444, 0.4543...   \n",
       "7   [0.2502, 0.3547, 0.401, 0.4262, 0.4395, 0.4549...   \n",
       "8   [0.2921, 0.3651, 0.418, 0.4431, 0.4539, 0.4635...   \n",
       "9   [0.2673, 0.3526, 0.4058, 0.4333, 0.4503, 0.466...   \n",
       "10  [0.2885, 0.3669, 0.4163, 0.4424, 0.4532, 0.461...   \n",
       "\n",
       "                   forward_backward_weight_angle_list  Trn_ACC  Tst_ACC  \n",
       "0   [[89.92984771728516, 89.90174102783203, 90.729...  0.10000   0.1000  \n",
       "1   [[89.94451904296875, 90.0085220336914, 91.0055...  0.46806   0.4463  \n",
       "2   [[89.92984771728516, 89.90174102783203, 90.730...  0.44152   0.4305  \n",
       "3   [[89.94451904296875, 90.0085220336914, 91.0056...  0.47554   0.4586  \n",
       "4   [[89.92984771728516, 89.90174102783203, 90.729...  0.52732   0.4899  \n",
       "5   [[89.94451904296875, 90.00852966308594, 91.005...  0.45694   0.4447  \n",
       "6   [[89.92984771728516, 89.90174102783203, 90.729...  0.52016   0.4874  \n",
       "7   [[89.94451904296875, 90.0085220336914, 91.0053...  0.53012   0.4980  \n",
       "8   [[89.92984771728516, 89.90174102783203, 90.729...  0.30580   0.2997  \n",
       "9   [[89.94451904296875, 90.0085220336914, 91.0055...  0.50048   0.4708  \n",
       "10  [[89.92984771728516, 89.90174102783203, 90.730...  0.52332   0.4861  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results['Trn_ACC'] = df_results.apply(lambda row: row['Trn_ACC_list'][-1], axis = 1)\n",
    "df_results['Tst_ACC'] = df_results.apply(lambda row: row['Tst_ACC_list'][-1], axis = 1)\n",
    "df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c9d2e1d3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'lr_start': {'ff': array([0.25 , 0.14 , 0.075, 0.045]),\n",
       "  'fb': array([  nan, 0.095, 0.075, 0.04 ])},\n",
       " 'lr_decay_multiplier': 0.95,\n",
       " 'neural_dynamic_iterations_free': 80,\n",
       " 'neural_dynamic_iterations_nudged': 15,\n",
       " 'neural_lr_rule': 'constant',\n",
       " 'neural_lr': 0.03,\n",
       " 'epsilon': 0.15,\n",
       " 'lambda': 0.99999,\n",
       " 'beta': 0.5,\n",
       " 'architecture': [3072, 1000, 1000, 500, 10],\n",
       " 'three_phase': False}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results.iloc[-1]['Hyperparams']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a81c070f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "a4965df5",
   "metadata": {},
   "source": [
    "# CIFAR100 4 Layers Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a70ecf37",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3, 7)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Index(['setting_number', 'seed', 'Model', 'Hyperparams', 'Trn_ACC_list',\n",
       "       'Tst_ACC_list', 'forward_backward_weight_angle_list'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results = pd.read_pickle(r\"../Results/simulation_results_CorInfoMax_CIFAR100_4Layers_V1.pkl\")\n",
    "print(df_results.shape)\n",
    "df_results.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "37e93b12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>setting_number</th>\n",
       "      <th>seed</th>\n",
       "      <th>Model</th>\n",
       "      <th>Hyperparams</th>\n",
       "      <th>Trn_ACC_list</th>\n",
       "      <th>Tst_ACC_list</th>\n",
       "      <th>forward_backward_weight_angle_list</th>\n",
       "      <th>Trn_ACC_list_top1</th>\n",
       "      <th>Trn_ACC_list_top5</th>\n",
       "      <th>Tst_ACC_list_top1</th>\n",
       "      <th>Tst_ACC_list_top5</th>\n",
       "      <th>Trn_ACC_top1</th>\n",
       "      <th>Trn_ACC_top5</th>\n",
       "      <th>Tst_ACC_top1</th>\n",
       "      <th>Tst_ACC_top5</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...</td>\n",
       "      <td>[[0.02176, 0.07814], [0.0407, 0.13102], [0.060...</td>\n",
       "      <td>[[0.0209, 0.0804], [0.0399, 0.13], [0.0578, 0....</td>\n",
       "      <td>[[89.99908447265625, 89.99053955078125, 89.826...</td>\n",
       "      <td>[0.02176, 0.0407, 0.06074, 0.07514, 0.08744, 0...</td>\n",
       "      <td>[0.07814, 0.13102, 0.1734, 0.19794, 0.21902, 0...</td>\n",
       "      <td>[0.0209, 0.0399, 0.0578, 0.0705, 0.0844, 0.094...</td>\n",
       "      <td>[0.0804, 0.13, 0.1684, 0.1913, 0.2154, 0.2348,...</td>\n",
       "      <td>0.01738</td>\n",
       "      <td>0.05646</td>\n",
       "      <td>0.0174</td>\n",
       "      <td>0.0565</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>10</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...</td>\n",
       "      <td>[[0.02112, 0.08534], [0.04178, 0.12574], [0.06...</td>\n",
       "      <td>[[0.0192, 0.0815], [0.0419, 0.126], [0.0609, 0...</td>\n",
       "      <td>[[89.93399047851562, 90.04959106445312, 90.144...</td>\n",
       "      <td>[0.02112, 0.04178, 0.06238, 0.07584, 0.09112, ...</td>\n",
       "      <td>[0.08534, 0.12574, 0.18242, 0.2114, 0.23568, 0...</td>\n",
       "      <td>[0.0192, 0.0419, 0.0609, 0.075, 0.0868, 0.0957...</td>\n",
       "      <td>[0.0815, 0.126, 0.1768, 0.2084, 0.2351, 0.2457...</td>\n",
       "      <td>0.02314</td>\n",
       "      <td>0.06272</td>\n",
       "      <td>0.0223</td>\n",
       "      <td>0.0620</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0</td>\n",
       "      <td>CorInfoMax</td>\n",
       "      <td>{'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...</td>\n",
       "      <td>[[0.01948, 0.07388], [0.03424, 0.11044], [0.04...</td>\n",
       "      <td>[[0.0205, 0.0764], [0.0354, 0.1085], [0.0475, ...</td>\n",
       "      <td>[[89.99908447265625, 89.99053955078125, 89.826...</td>\n",
       "      <td>[0.01948, 0.03424, 0.04836, 0.0671, 0.07854, 0...</td>\n",
       "      <td>[0.07388, 0.11044, 0.14786, 0.16868, 0.18238, ...</td>\n",
       "      <td>[0.0205, 0.0354, 0.0475, 0.0649, 0.0749, 0.083...</td>\n",
       "      <td>[0.0764, 0.1085, 0.1437, 0.1638, 0.1776, 0.177...</td>\n",
       "      <td>0.02076</td>\n",
       "      <td>0.05970</td>\n",
       "      <td>0.0224</td>\n",
       "      <td>0.0596</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  setting_number seed       Model  \\\n",
       "0              1    0  CorInfoMax   \n",
       "1              1   10  CorInfoMax   \n",
       "2              2    0  CorInfoMax   \n",
       "\n",
       "                                         Hyperparams  \\\n",
       "0  {'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...   \n",
       "1  {'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...   \n",
       "2  {'lr_start': {'ff': [0.21, 0.17, 0.12, 0.07], ...   \n",
       "\n",
       "                                        Trn_ACC_list  \\\n",
       "0  [[0.02176, 0.07814], [0.0407, 0.13102], [0.060...   \n",
       "1  [[0.02112, 0.08534], [0.04178, 0.12574], [0.06...   \n",
       "2  [[0.01948, 0.07388], [0.03424, 0.11044], [0.04...   \n",
       "\n",
       "                                        Tst_ACC_list  \\\n",
       "0  [[0.0209, 0.0804], [0.0399, 0.13], [0.0578, 0....   \n",
       "1  [[0.0192, 0.0815], [0.0419, 0.126], [0.0609, 0...   \n",
       "2  [[0.0205, 0.0764], [0.0354, 0.1085], [0.0475, ...   \n",
       "\n",
       "                  forward_backward_weight_angle_list  \\\n",
       "0  [[89.99908447265625, 89.99053955078125, 89.826...   \n",
       "1  [[89.93399047851562, 90.04959106445312, 90.144...   \n",
       "2  [[89.99908447265625, 89.99053955078125, 89.826...   \n",
       "\n",
       "                                   Trn_ACC_list_top1  \\\n",
       "0  [0.02176, 0.0407, 0.06074, 0.07514, 0.08744, 0...   \n",
       "1  [0.02112, 0.04178, 0.06238, 0.07584, 0.09112, ...   \n",
       "2  [0.01948, 0.03424, 0.04836, 0.0671, 0.07854, 0...   \n",
       "\n",
       "                                   Trn_ACC_list_top5  \\\n",
       "0  [0.07814, 0.13102, 0.1734, 0.19794, 0.21902, 0...   \n",
       "1  [0.08534, 0.12574, 0.18242, 0.2114, 0.23568, 0...   \n",
       "2  [0.07388, 0.11044, 0.14786, 0.16868, 0.18238, ...   \n",
       "\n",
       "                                   Tst_ACC_list_top1  \\\n",
       "0  [0.0209, 0.0399, 0.0578, 0.0705, 0.0844, 0.094...   \n",
       "1  [0.0192, 0.0419, 0.0609, 0.075, 0.0868, 0.0957...   \n",
       "2  [0.0205, 0.0354, 0.0475, 0.0649, 0.0749, 0.083...   \n",
       "\n",
       "                                   Tst_ACC_list_top5  Trn_ACC_top1  \\\n",
       "0  [0.0804, 0.13, 0.1684, 0.1913, 0.2154, 0.2348,...       0.01738   \n",
       "1  [0.0815, 0.126, 0.1768, 0.2084, 0.2351, 0.2457...       0.02314   \n",
       "2  [0.0764, 0.1085, 0.1437, 0.1638, 0.1776, 0.177...       0.02076   \n",
       "\n",
       "   Trn_ACC_top5  Tst_ACC_top1  Tst_ACC_top5  \n",
       "0       0.05646        0.0174        0.0565  \n",
       "1       0.06272        0.0223        0.0620  \n",
       "2       0.05970        0.0224        0.0596  "
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_results['Trn_ACC_list_top1'] = df_results.apply(lambda row: np.array(row['Trn_ACC_list'])[:,0], axis = 1)\n",
    "df_results['Trn_ACC_list_top5'] = df_results.apply(lambda row: np.array(row['Trn_ACC_list'])[:,1], axis = 1)\n",
    "df_results['Tst_ACC_list_top1'] = df_results.apply(lambda row: np.array(row['Tst_ACC_list'])[:,0], axis = 1)\n",
    "df_results['Tst_ACC_list_top5'] = df_results.apply(lambda row: np.array(row['Tst_ACC_list'])[:,1], axis = 1)\n",
    "\n",
    "df_results['Trn_ACC_top1'] = df_results.apply(lambda row: row['Trn_ACC_list_top1'][-1], axis = 1)\n",
    "df_results['Trn_ACC_top5'] = df_results.apply(lambda row: row['Trn_ACC_list_top5'][-1], axis = 1)\n",
    "df_results['Tst_ACC_top1'] = df_results.apply(lambda row: row['Tst_ACC_list_top1'][-1], axis = 1)\n",
    "df_results['Tst_ACC_top5'] = df_results.apply(lambda row: row['Tst_ACC_list_top5'][-1], axis = 1)\n",
    "df_results"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7c429a4a",
   "metadata": {},
   "source": [
    "# CIFAR10 5 Layers Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fdf2c4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d082450",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "4e0e8033",
   "metadata": {},
   "source": [
    "# CIFAR100 5 Layers Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "92fb245a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fee7bce5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
