{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "666c038a",
   "metadata": {},
   "source": [
    "# imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76b5379",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:39.334366Z",
     "start_time": "2024-05-18T08:26:37.904402Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "\n",
    "from matplotlib.pyplot import cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f59991",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:40.452157Z",
     "start_time": "2024-05-18T08:26:40.446725Z"
    }
   },
   "outputs": [],
   "source": [
    "def check_key_name(key):\n",
    "    return ('.running_var' not in key) and \\\n",
    "        ('.num_batches_tracked' not in key) and \\\n",
    "        ('.running_mean' not in key) and \\\n",
    "        ('linear.weight' not in key) and \\\n",
    "        ('n_averaged' not in key)\n",
    "\n",
    "def make_flatten_vec(state_dict, layer=None):\n",
    "    values = []\n",
    "    if layer is None:\n",
    "        for key, value in state_dict.items():\n",
    "            if check_key_name(key):\n",
    "                values.append(torch.flatten(value))\n",
    "    else:\n",
    "        values.append(torch.flatten(state_dict[layer]))\n",
    "#             print('adding ', value.shape)\n",
    "    vec = torch.cat(values, 0).to(torch.float64)\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcaf4548",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:41.397248Z",
     "start_time": "2024-05-18T08:26:41.391963Z"
    }
   },
   "outputs": [],
   "source": [
    "def unit_vector(vector):\n",
    "    \"\"\" Returns the unit vector of the vector.  \"\"\"\n",
    "    return vector / np.linalg.norm(vector)\n",
    "\n",
    "def angle_between(v1, v2):\n",
    "    \"\"\" Returns the angle in radians between vectors 'v1' and 'v2'::\n",
    "\n",
    "            >>> angle_between((1, 0, 0), (0, 1, 0))\n",
    "            1.5707963267948966\n",
    "            >>> angle_between((1, 0, 0), (1, 0, 0))\n",
    "            0.0\n",
    "            >>> angle_between((1, 0, 0), (-1, 0, 0))\n",
    "            3.141592653589793\n",
    "    \"\"\"\n",
    "    v1_u = unit_vector(v1)\n",
    "    v2_u = unit_vector(v2)\n",
    "    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))\n",
    "\n",
    "def get_init_angle_dist(point_a:str, point_b:str):\n",
    "    sd1 = torch.load(point_a)['state_dict']\n",
    "    sd2 = torch.load(point_b)['state_dict']\n",
    "    \n",
    "    vec1 = make_flatten_vec(sd1).detach().cpu()\n",
    "    vec2 = make_flatten_vec(sd2).detach().cpu()\n",
    "    cdist = angle_between(vec1, vec2)\n",
    "    return cdist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0e0ebfa",
   "metadata": {},
   "source": [
    "# LI with barrier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a7ed49",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:44.199051Z",
     "start_time": "2024-05-18T08:26:44.194875Z"
    }
   },
   "outputs": [],
   "source": [
    "def read_angle_dist_from_track(checkpoint_dir: str, n_steps: int=20):\n",
    "    base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')\n",
    "    pt_path0 = base.format(0.0, int(0.0))\n",
    "    pt_path1 = base.format(1.0, int(1.0))\n",
    "\n",
    "    angle = get_init_angle_dist(point_a=pt_path0, point_b=pt_path1)   \n",
    "       \n",
    "    return angle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25af25f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:44.566545Z",
     "start_time": "2024-05-18T08:26:44.562122Z"
    }
   },
   "outputs": [],
   "source": [
    "def read_li_track(checkpoint_dir: str, n_steps: int=20):\n",
    "    track = []\n",
    "    for alpha in np.linspace(0.0, 1.0, n_steps + 1):\n",
    "        base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')\n",
    "        pt_path = base.format(alpha, int(alpha))\n",
    "        data = torch.load(pt_path)\n",
    "\n",
    "        data['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())\n",
    "\n",
    "        del data['state_dict']\n",
    "\n",
    "        data['alpha'] = alpha\n",
    "\n",
    "        track.append(data)\n",
    "    return track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fccdd0db",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:47.274873Z",
     "start_time": "2024-05-18T08:26:47.269568Z"
    }
   },
   "outputs": [],
   "source": [
    "def track_to_barrier(track_values: list, barrier_is_higher: bool=True):\n",
    "    track_values = np.array(track_values)\n",
    "    A = track_values[0]\n",
    "    B = track_values[-1]\n",
    "    \n",
    "    alpha = np.linspace(0.0, 1.0, len(track_values))\n",
    "    li = (1.0 - alpha) * A + alpha * B \n",
    "    if barrier_is_higher:\n",
    "        return (track_values - li).clip(min=0.0).max()\n",
    "    return (li - track_values).clip(min=0.0).max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f23b724",
   "metadata": {},
   "source": [
    "# general"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15d04d47",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:26:58.617572Z",
     "start_time": "2024-05-18T08:26:58.613195Z"
    }
   },
   "outputs": [],
   "source": [
    "USUAL_ELRS = [\n",
    "    1e-5, 2e-5, 5e-5,\n",
    "    1e-4, 2e-4, 5e-4,\n",
    "    1e-3, 2e-3, 5e-3,\n",
    "    1e-2, 2e-2, 5e-2,\n",
    "    1e-1, 2e-1, 5e-1,\n",
    "    1\n",
    "]\n",
    "\n",
    "EDLRS = [1e-5, 1e-4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd12021e",
   "metadata": {},
   "source": [
    "# DROP(HIGH) -> DROP(LOW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ddb2a63",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T01:36:44.563961Z",
     "start_time": "2024-05-21T01:36:44.559253Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "spt = 3\n",
    "for i, elr in enumerate(USUAL_ELRS):\n",
    "    gp = (i // spt)\n",
    "    \n",
    "    print(\"\"\"python ./linear_interpolation_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../for_eduard_oct_25/ResNet18SI_CIFAR100_size32/finetune_high_dlr_1e-04/elr{}_checkpoint.pt \\\\\n",
    "    --point_b ./../for_eduard_may_19/ResNet18SI_CIFAR100_size32_finetune_low_dlr_2e-05/elr{}_checkpoint.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_drop_1e-04_to_drop_2e-05/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, \n",
    "                elr, \n",
    "                elr))\n",
    "    print(\"\"\"python ./calc_grad_norms_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_drop_1e-04_to_drop_2e-05/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr))\n",
    "    if i % spt == spt - 1:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c31ef43",
   "metadata": {},
   "source": [
    "### readings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086ca6f5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T03:19:59.554781Z",
     "start_time": "2024-05-21T03:19:44.010569Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_drophigh_droplow = dict()\n",
    "for elr in tqdm(USUAL_ELRS): \n",
    "    interp_drophigh_droplow [elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C100_lri_{}_from_drop_1e-04_to_drop_2e-05'\n",
    "    \n",
    "    pt_path = base.format(elr)\n",
    "    \n",
    "    interp_drophigh_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_drophigh_droplow[elr]['track'] = track\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_drophigh_droplow[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_drophigh_droplow[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_drophigh_droplow[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_drophigh_droplow[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c2af237",
   "metadata": {},
   "source": [
    "# SWA(5 or 2) -> DROP(HIGH)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9f231e4",
   "metadata": {},
   "source": [
    "### with old checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76c1cb89",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:27:13.850981Z",
     "start_time": "2024-05-18T08:27:13.845939Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "spt = 3\n",
    "for i, elr in enumerate(USUAL_ELRS):\n",
    "    gp = (i // spt)\n",
    "    \n",
    "    print(\"\"\"python ./linear_interpolation_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../for_eduard_oct_25/ResNet18SI_CIFAR100_size32/swa_from_200_for_5/elr{}_checkpoint.pt \\\\\n",
    "    --point_b ./../for_eduard_oct_25/ResNet18SI_CIFAR100_size32/finetune_high_dlr_1e-04/elr{}_checkpoint.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_1e-04/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, \n",
    "                elr, \n",
    "                elr))\n",
    "    print(\"\"\"python ./calc_grad_norms_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_1e-04/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr))\n",
    "    if i % spt == spt - 1:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20b6cc29",
   "metadata": {},
   "source": [
    "### readings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa63b2ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T03:20:12.451577Z",
     "start_time": "2024-05-21T03:19:59.557186Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_swa5_drophigh = dict()\n",
    "for elr in tqdm(USUAL_ELRS): \n",
    "    interp_swa5_drophigh[elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_1e-04'\n",
    "\n",
    "    pt_path = base.format(elr)\n",
    "    \n",
    "    interp_swa5_drophigh[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_swa5_drophigh[elr]['track'] = track\n",
    "        \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_drophigh[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_drophigh[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_drophigh[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_drophigh[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f14c964d",
   "metadata": {},
   "source": [
    "# SWA(5) -> DROP(LOW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f127954",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T02:00:14.013968Z",
     "start_time": "2024-05-21T02:00:14.009046Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "spt = 3\n",
    "for i, elr in enumerate(USUAL_ELRS):\n",
    "    gp = (i // spt)\n",
    "    \n",
    "    print(\"\"\"python ./linear_interpolation_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../for_eduard_oct_25/ResNet18SI_CIFAR100_size32/swa_from_200_for_5/elr{}_checkpoint.pt \\\\\n",
    "    --point_b ./../for_eduard_may_19/ResNet18SI_CIFAR100_size32_finetune_low_dlr_2e-05/elr{}_checkpoint.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_2e-05/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, \n",
    "                elr, \n",
    "                elr))\n",
    "    print(\"\"\"python ./calc_grad_norms_resnet18si_cifar100_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_2e-05/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr))\n",
    "    if i % spt == spt - 1:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddc6634f",
   "metadata": {},
   "source": [
    "### readings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9d7efe3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T03:20:26.127780Z",
     "start_time": "2024-05-21T03:20:12.453513Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_swa5_droplow = dict()\n",
    "for elr in tqdm(USUAL_ELRS): \n",
    "    interp_swa5_droplow[elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C100_lri_{}_from_swa_5_to_drop_2e-05'\n",
    "    \n",
    "    pt_path = base.format(elr)\n",
    "    \n",
    "    interp_swa5_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_swa5_droplow[elr]['track'] = track\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_droplow[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_droplow[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_droplow[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_droplow[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f286e85e",
   "metadata": {},
   "source": [
    "# dump to the disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4bdbec",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-21T03:20:26.869757Z",
     "start_time": "2024-05-21T03:20:26.865772Z"
    }
   },
   "outputs": [],
   "source": [
    "barrier_setups = dict()\n",
    "barrier_setups['interp_drophigh_droplow'] = interp_drophigh_droplow\n",
    "barrier_setups['interp_swa5_drophigh'] = interp_swa5_drophigh\n",
    "barrier_setups['interp_swa5_droplow'] = interp_swa5_droplow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39cb9ae9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T09:41:37.610475Z",
     "start_time": "2024-05-18T09:41:37.592032Z"
    }
   },
   "outputs": [],
   "source": [
    "with open('./resnet18si_cifar100_barrier_setups.pkl', 'wb') as f:\n",
    "    pickle.dump(barrier_setups, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a32a848",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (venv_gpu)",
   "language": "python",
   "name": "venv_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
