{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "431cc272",
   "metadata": {},
   "source": [
    "# ResNet-18 SI + CIFAR10"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a99431b6",
   "metadata": {},
   "source": [
    "## usual training"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e48a85c5",
   "metadata": {},
   "source": [
    "### Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397db5b7",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-06T06:34:36.221603Z",
     "start_time": "2024-05-06T06:34:36.217470Z"
    }
   },
   "outputs": [],
   "source": [
    "USUAL_ELRS = [\n",
    "    1e-6, 2e-6, 5e-6, \n",
    "    1e-5, 1.4e-5, 2e-5, 3e-5, 5e-5, 7e-5,\n",
    "    1e-4, 1.4e-4, 2e-4, 3e-4, 5e-4, 7e-4, \n",
    "    1e-3, 1.4e-3, 2e-3, 3e-3, 5e-3, 7e-3,\n",
    "    1e-2, 1.4e-2, 2e-2, 3e-2, 5e-2, 7e-2,\n",
    "    1e-1, 2e-1, 5e-1,\n",
    "    1e+0, 2e+0\n",
    "]\n",
    "\n",
    "USUAL_ESEEDS = [\n",
    "    2000, 2001, 2002,\n",
    "    2003, 2004, 2005, 2006, 2007, 2008,\n",
    "    2009, 2010, 2011, 2012, 2013, 2014,\n",
    "    2015, 2016, 2017, 2018, 2019, 2020,\n",
    "    2021, 2022, 2023, 2024, 2025, 2026,\n",
    "    2027, 2028, 2029,\n",
    "    2030, 2031\n",
    "]\n",
    "\n",
    "EDLRS = [1e-5, 3e-5, 5e-5, 7e-5, 1e-4, 1.4e-4, 2e-4, 2.5e-4, 3e-4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bee2a5fa",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-06T06:34:37.021834Z",
     "start_time": "2024-05-06T06:34:37.019207Z"
    }
   },
   "outputs": [],
   "source": [
    "ELR2SEED = dict()\n",
    "for k, v in zip(USUAL_ELRS, USUAL_ESEEDS):\n",
    "    ELR2SEED[k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a68526b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-08T20:10:59.776300Z",
     "start_time": "2024-01-08T20:10:59.772143Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "split = 5\n",
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = i // split\n",
    "    txt = \"\"\"python train_drop_resnet18si_cifar10_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --init_elr {} --drop_elr {} \\\\\n",
    "    --drop_epoch {} \\\\\n",
    "    --seed {} && \\\\\"\"\"\n",
    "    print(txt.format(\n",
    "        gp,\n",
    "        elr, elr, \n",
    "        1000,\n",
    "        seed\n",
    "    ))\n",
    "    if i % split == split - 1:\n",
    "        print()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4c791f2",
   "metadata": {},
   "source": [
    "### Сalculate gradients"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c97e2574",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-04-04T07:46:10.070674Z",
     "start_time": "2024-04-04T07:46:10.064774Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "split = 6\n",
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = i // split + 1\n",
    "    txt = \"\"\"python calc_grad_norms_resnet18si_cifar_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/ \\\\\n",
    "    --loader train \\\\\n",
    "    --aug 0 \\\\\n",
    "    --train_mode 1 && \\\\\"\"\"\n",
    "    print(txt.format(\n",
    "        gp,\n",
    "        elr, elr, seed\n",
    "    ))\n",
    "    if i % split == split - 1:\n",
    "        print()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f1ba487",
   "metadata": {},
   "source": [
    "## Drops"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "66fd97ed",
   "metadata": {},
   "source": [
    "### Run training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a443194",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-09T19:58:55.228288Z",
     "start_time": "2024-01-09T19:58:55.217184Z"
    }
   },
   "outputs": [],
   "source": [
    "split = 5\n",
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = i // split\n",
    "    for edlr in EDLRS:\n",
    "        txt = \"\"\"python train_drop_resnet18si_cifar10_clean_from_starting_point.py \\\\\n",
    "        --gpu {} \\\\\n",
    "        --init_checkpoint ./Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/checkpoint-200.pt \\\\\n",
    "        --init_elr {} --drop_elr {} \\\\\n",
    "        --drop_epoch {} \\\\\n",
    "        --k_epoch {} \\\\\n",
    "        --seed {} && \\\\\"\"\"\n",
    "        print(txt.format(\n",
    "            gp,\n",
    "            elr, elr, seed,\n",
    "            elr, edlr,\n",
    "            200,\n",
    "            250,\n",
    "            seed\n",
    "            \n",
    "    ))\n",
    "    if i % split == split - 1:\n",
    "        print()\n",
    "        print()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69f4dfbb",
   "metadata": {},
   "source": [
    "## SWA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8160ade0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T20:06:23.803622Z",
     "start_time": "2024-05-17T20:06:23.796706Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, (blr, sd) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = i // 5\n",
    "    step = 200\n",
    "    k_epoch = 100\n",
    "    txt = \"python custom_swa_resnet18si_cifar_starting_point_clean.py --gpu {} --elr {} \\\\\\n\".format(gp, blr) + \\\n",
    "          \"   --k_epoch {} --seed {} \\\\\\n\".format(k_epoch, sd) + \\\n",
    "          \"   --stride 1 --start_swa_epoch 200\"\n",
    "    if i % 6 == 5:\n",
    "        print(txt)\n",
    "        print()\n",
    "    else:\n",
    "        print(txt + \" && \\\\\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b2b5354",
   "metadata": {},
   "source": [
    "# readings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "353551cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T04:17:08.519122Z",
     "start_time": "2024-05-17T04:17:07.108342Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "from glob import glob\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "from tqdm import tqdm\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib.pyplot import cm\n",
    "from matplotlib.lines import Line2D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b0e75db",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T04:17:10.707756Z",
     "start_time": "2024-05-17T04:17:10.703186Z"
    }
   },
   "outputs": [],
   "source": [
    "def check_key_name(key):\n",
    "    return ('.running_var' not in key) and \\\n",
    "        ('.num_batches_tracked' not in key) and \\\n",
    "        ('.running_mean' not in key) and \\\n",
    "        ('linear.weight' not in key) and \\\n",
    "        ('n_averaged' not in key)\n",
    "\n",
    "\n",
    "def make_flatten_vec(state_dict, layer=None):\n",
    "    values = []\n",
    "    if layer is None:\n",
    "        for key, value in state_dict.items():\n",
    "            if check_key_name(key):\n",
    "                values.append(torch.flatten(value))\n",
    "    else:\n",
    "        values.append(torch.flatten(state_dict[layer]))\n",
    "    vec = torch.cat(values, 0).to(torch.float64)\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "617a7739",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-11T08:06:51.224134Z",
     "start_time": "2024-01-11T07:42:28.764821Z"
    }
   },
   "outputs": [],
   "source": [
    "usual_tracks = dict()\n",
    "for elr, seed in zip(USUAL_ELRS, USUAL_ESEEDS):\n",
    "    pth = './Experiments/ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1000_wd_0.0_seed_{}_noaug_True/checkpoint-{}.pt'\n",
    "    usual_tracks[elr] = []\n",
    "    \n",
    "    for ckpt in tqdm(range(1001)):\n",
    "        ckptpth = pth.format(elr, elr, seed, ckpt)\n",
    "        data = torch.load(ckptpth)\n",
    "        \n",
    "        record = {}\n",
    "        \n",
    "        record['ep'] = ckpt\n",
    "        record['train_loss'] = data['train_res']['loss']\n",
    "        record['train_accuracy'] = data['train_res']['accuracy']\n",
    "        \n",
    "        record['test_loss'] = data['test_res']['loss']\n",
    "        record['test_accuracy'] = data['test_res']['accuracy']\n",
    "        \n",
    "        record['elr'] = elr\n",
    "        record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())\n",
    "        if 'gnorm_trainmode_m_train' in data:\n",
    "            record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']\n",
    "        if 'loss_trainmode_train' in data:\n",
    "            record['loss_trainmode_train'] = data['loss_trainmode_train']\n",
    "        if 'acc_trainmode_train' in data:\n",
    "            record['acc_trainmode_train']  = data['acc_trainmode_train']\n",
    "        \n",
    "        \n",
    "        if 'gnorm_evalmode_m_train' in data:\n",
    "            record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']\n",
    "        if 'loss_evalmode_train' in data:\n",
    "            record['loss_evalmode_train'] = data['loss_evalmode_train']\n",
    "        if 'acc_evalmode_train' in data:\n",
    "            record['acc_evalmode_train']  = data['acc_evalmode_train']\n",
    "        \n",
    "        \n",
    "        usual_tracks[elr].append(record)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3fac764",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-11T08:53:04.085432Z",
     "start_time": "2024-01-11T08:16:14.617307Z"
    }
   },
   "outputs": [],
   "source": [
    "drop_checkpoints = dict()\n",
    "for elr, seed in zip(USUAL_ELRS, USUAL_ESEEDS):     \n",
    "    print('-'*80)\n",
    "    print(elr, seed)\n",
    "    print('-'*80)\n",
    "    \n",
    "    pth = './Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepochfrom_{}_wd_0.0_seed_{}_noaug_True/checkpoint-{}.pt'\n",
    "    pthmsk = './Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepochfrom_{}_wd_0.0_seed_{}_noaug_True/'\n",
    "    if elr not in drop_checkpoints:\n",
    "        drop_checkpoints[elr] = dict()\n",
    "    \n",
    "    for drop_start in [200]:\n",
    "        if drop_start not in drop_checkpoints[elr]:\n",
    "            drop_checkpoints[elr][drop_start] = dict()\n",
    "\n",
    "        for edlr in EDLRS:\n",
    "        \n",
    "            drop_checkpoints[elr][drop_start][edlr] = []\n",
    "        \n",
    "            globmsk = glob(pthmsk.format(elr, edlr, drop_start, '*'))\n",
    "            globmsk = list([x for x in globmsk if 'noaug_False' not in x])\n",
    "            if globmsk:\n",
    "                globmsk = globmsk[-1]\n",
    "                print(glob(globmsk))\n",
    "\n",
    "                for ckpt in tqdm(range(drop_start + 1, drop_start + 201)):\n",
    "                    ckptpth = globmsk + 'checkpoint-{}.pt'.format(ckpt)\n",
    "                    data = torch.load(ckptpth)\n",
    "\n",
    "                    record = {}\n",
    "                    \n",
    "                    record['ep'] = ckpt\n",
    "                    record['train_loss'] = data['train_res']['loss']\n",
    "                    record['train_accuracy'] = data['train_res']['accuracy']\n",
    "\n",
    "                    record['test_loss'] = data['test_res']['loss']\n",
    "                    record['test_accuracy'] = data['test_res']['accuracy']\n",
    "\n",
    "                    record['elr'] = elr\n",
    "                    record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())\n",
    "                    \n",
    "                    if 'gnorm_trainmode_m_train' in data:\n",
    "                        record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']\n",
    "                    if 'loss_trainmode_train' in data:\n",
    "                        record['loss_trainmode_train'] = data['loss_trainmode_train']\n",
    "                    if 'acc_trainmode_train' in data:\n",
    "                        record['acc_trainmode_train']  = data['acc_trainmode_train']\n",
    "\n",
    "\n",
    "                    if 'gnorm_evalmode_m_train' in data:\n",
    "                        record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']\n",
    "                    if 'loss_evalmode_train' in data:\n",
    "                        record['loss_evalmode_train'] = data['loss_evalmode_train']\n",
    "                    if 'acc_evalmode_train' in data:\n",
    "                        record['acc_evalmode_train']  = data['acc_evalmode_train']\n",
    "                    \n",
    "                    drop_checkpoints[elr][drop_start][edlr].append(record)\n",
    "#     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42db0587",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:24:35.313415Z",
     "start_time": "2024-05-17T21:24:31.874539Z"
    }
   },
   "outputs": [],
   "source": [
    "swa_checkpoints = dict()\n",
    "\n",
    "for elr in tqdm(USUAL_ELRS):\n",
    "    swa_checkpoints[elr] = dict()\n",
    "    for start_epoch in [200]:\n",
    "        swa_checkpoints[elr][start_epoch] = dict()\n",
    "        for k in [2, 5, 10, 50, 100]:\n",
    "            \n",
    "            base = './Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_{:03d}_k_001/checkpoint-{}.pt'\n",
    "            \n",
    "            path = base.format(elr, elr, start_epoch, start_epoch + k - 1)\n",
    "            \n",
    "            data = torch.load(path)\n",
    "    \n",
    "            record = {}\n",
    "\n",
    "            record['ep'] = start_epoch + k\n",
    "            record['train_loss'] = data['train_res']['loss']\n",
    "            record['train_accuracy'] = data['train_res']['accuracy']\n",
    "\n",
    "            record['test_loss'] = data['test_res']['loss']\n",
    "            record['test_accuracy'] = data['test_res']['accuracy']\n",
    "\n",
    "            record['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())\n",
    "            if 'gnorm_trainmode_m_train' in data:\n",
    "                record['gnorm_trainmode'] = data['gnorm_trainmode_m_train']\n",
    "            if 'loss_trainmode_train' in data:\n",
    "                record['loss_trainmode_train'] = data['loss_trainmode_train']\n",
    "            if 'acc_trainmode_train' in data:\n",
    "                record['acc_trainmode_train']  = data['acc_trainmode_train']\n",
    "\n",
    "\n",
    "            if 'gnorm_evalmode_m_train' in data:\n",
    "                record['gnorm_evalmode'] = data['gnorm_evalmode_m_train']\n",
    "            if 'loss_evalmode_train' in data:\n",
    "                record['loss_evalmode_train'] = data['loss_evalmode_train']\n",
    "            if 'acc_evalmode_train' in data:\n",
    "                record['acc_evalmode_train']  = data['acc_evalmode_train']\n",
    "            \n",
    "            swa_checkpoints[elr][start_epoch][k] = record"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f756a2b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:24:36.538361Z",
     "start_time": "2024-05-17T21:24:36.532428Z"
    }
   },
   "outputs": [],
   "source": [
    "with open('resnet18si_cifar10_swa_metrics.pkl', 'wb') as f:\n",
    "    pickle.dump(swa_checkpoints, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67ab0976",
   "metadata": {},
   "source": [
    "# saving to pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06c6c06a",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-01-11T08:54:32.239842Z",
     "start_time": "2024-01-11T08:54:31.954792Z"
    }
   },
   "outputs": [],
   "source": [
    "with open('./resnet18si_cifar10_usual_metrics.pkl', 'wb') as f:\n",
    "    pickle.dump(usual_tracks, f)\n",
    "    \n",
    "with open('./resnet18si_cifar10_drop_metrics.pkl', 'wb') as f:\n",
    "    pickle.dump(drop_checkpoints, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcf82d83",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
