{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "666c038a",
   "metadata": {},
   "source": [
    "# imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b76b5379",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:00.430902Z",
     "start_time": "2024-05-17T21:30:59.073902Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "\n",
    "from matplotlib.pyplot import cm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68f59991",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:00.437327Z",
     "start_time": "2024-05-17T21:31:00.433137Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def check_key_name(key):\n",
    "    return ('.running_var' not in key) and \\\n",
    "        ('.num_batches_tracked' not in key) and \\\n",
    "        ('.running_mean' not in key) and \\\n",
    "        ('linear.weight' not in key) and \\\n",
    "        ('n_averaged' not in key)\n",
    "\n",
    "def make_flatten_vec(state_dict, layer=None):\n",
    "    values = []\n",
    "    if layer is None:\n",
    "        for key, value in state_dict.items():\n",
    "            if check_key_name(key):\n",
    "                values.append(torch.flatten(value))\n",
    "    else:\n",
    "        values.append(torch.flatten(state_dict[layer]))\n",
    "#             print('adding ', value.shape)\n",
    "    vec = torch.cat(values, 0).to(torch.float64)\n",
    "    return vec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dcaf4548",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:00.445422Z",
     "start_time": "2024-05-17T21:31:00.438613Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def unit_vector(vector):\n",
    "    \"\"\" Returns the unit vector of the vector.  \"\"\"\n",
    "    return vector / np.linalg.norm(vector)\n",
    "\n",
    "def angle_between(v1, v2):\n",
    "    \"\"\" Returns the angle in radians between vectors 'v1' and 'v2'::\n",
    "\n",
    "            >>> angle_between((1, 0, 0), (0, 1, 0))\n",
    "            1.5707963267948966\n",
    "            >>> angle_between((1, 0, 0), (1, 0, 0))\n",
    "            0.0\n",
    "            >>> angle_between((1, 0, 0), (-1, 0, 0))\n",
    "            3.141592653589793\n",
    "    \"\"\"\n",
    "    v1_u = unit_vector(v1)\n",
    "    v2_u = unit_vector(v2)\n",
    "    return np.arccos(np.clip(np.dot(v1_u, v2_u), -1.0, 1.0))\n",
    "\n",
    "def get_init_angle_dist(point_a:str, point_b:str):\n",
    "    sd1 = torch.load(point_a)['state_dict']\n",
    "    sd2 = torch.load(point_b)['state_dict']\n",
    "    \n",
    "#     print(sd1.keys()) \n",
    "#     print('-'*10)\n",
    "#     print(sd2.keys())\n",
    "    \n",
    "    vec1 = make_flatten_vec(sd1).detach().cpu()\n",
    "    vec2 = make_flatten_vec(sd2).detach().cpu()\n",
    "    cdist = angle_between(vec1, vec2)\n",
    "    return cdist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0e0ebfa",
   "metadata": {},
   "source": [
    "# LI with barrier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47a7ed49",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:00.450134Z",
     "start_time": "2024-05-17T21:31:00.446985Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def read_angle_dist_from_track(checkpoint_dir: str, n_steps: int=20):\n",
    "    base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')\n",
    "    pt_path0 = base.format(0.0, int(0.0))\n",
    "    pt_path1 = base.format(1.0, int(1.0))\n",
    "\n",
    "    angle = get_init_angle_dist(point_a=pt_path0, point_b=pt_path1)   \n",
    "       \n",
    "    return angle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25af25f0",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:00.745947Z",
     "start_time": "2024-05-17T21:31:00.742233Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def read_li_track(checkpoint_dir: str, n_steps: int=20):\n",
    "    track = []\n",
    "    for alpha in np.linspace(0.0, 1.0, n_steps + 1):\n",
    "        base = os.path.join(checkpoint_dir, 'interp_result_{:5.4f}-{}.pt')\n",
    "        pt_path = base.format(alpha, int(alpha))\n",
    "        data = torch.load(pt_path)\n",
    "\n",
    "        data['pnorm'] = np.linalg.norm(make_flatten_vec(data['state_dict']).cpu())\n",
    "\n",
    "        del data['state_dict']\n",
    "\n",
    "        data['alpha'] = alpha\n",
    "\n",
    "        track.append(data)\n",
    "    return track"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fccdd0db",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:01.167324Z",
     "start_time": "2024-05-17T21:31:01.163869Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "def track_to_barrier(track_values: list, barrier_is_higher: bool=True):\n",
    "    track_values = np.array(track_values)\n",
    "    A = track_values[0]\n",
    "    B = track_values[-1]\n",
    "    \n",
    "    alpha = np.linspace(0.0, 1.0, len(track_values))\n",
    "    li = (1.0 - alpha) * A + alpha * B \n",
    "    if barrier_is_higher:\n",
    "        return (track_values - li).clip(min=0.0).max()\n",
    "    return (li - track_values).clip(min=0.0).max()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f23b724",
   "metadata": {},
   "source": [
    "# general"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "15d04d47",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:22.014096Z",
     "start_time": "2024-05-17T21:31:22.008291Z"
    },
    "deletable": false,
    "editable": false
   },
   "outputs": [],
   "source": [
    "USUAL_ELRS = [\n",
    "    1e-6, 2e-6, 5e-6, \n",
    "    1e-5, 1.4e-5, 2e-5, 3e-5, 5e-5, 7e-5,\n",
    "    1e-4, 1.4e-4, 2e-4, 3e-4, 5e-4, 7e-4, \n",
    "    1e-3, 1.4e-3, 2e-3, 3e-3, 5e-3, 7e-3,\n",
    "    1e-2, 1.4e-2, 2e-2, 3e-2, 5e-2, 7e-2,\n",
    "    1e-1, 2e-1, 5e-1,\n",
    "    1e+0, 2e+0\n",
    "]\n",
    "\n",
    "USUAL_ESEEDS = [\n",
    "    2000, 2001, 2002,\n",
    "    2003, 2004, 2005, 2006, 2007, 2008,\n",
    "    2009, 2010, 2011, 2012, 2013, 2014,\n",
    "    2015, 2016, 2017, 2018, 2019, 2020,\n",
    "    2021, 2022, 2023, 2024, 2025, 2026,\n",
    "    2027, 2028, 2029,\n",
    "    2030, 2031\n",
    "]\n",
    "\n",
    "EDLRS = [1e-5, 3e-5, 5e-5, 7e-5, 1e-4, 1.4e-4, 2e-4, 2.5e-4, 3e-4]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd12021e",
   "metadata": {},
   "source": [
    "# DROP(HIGH) -> DROP(LOW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e986581",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:31:24.618538Z",
     "start_time": "2024-05-17T21:31:24.612044Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = (i // 6)\n",
    "    \n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_0.0003_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\\\\n",
    "    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_1e-05_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, seed, \n",
    "                elr, seed, \n",
    "                elr, seed))\n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr, seed))\n",
    "    if i % 6 == 5:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "086ca6f5",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:19:26.905492Z",
     "start_time": "2024-05-18T08:19:06.283965Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_drophigh_droplow = dict()\n",
    "for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): \n",
    "    interp_drophigh_droplow [elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_drop_0.0003_to_drop_1e-05_seed_{}'\n",
    "    pt_path = base.format(elr, seed)\n",
    "    \n",
    "    interp_drophigh_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_drophigh_droplow[elr]['track'] = track\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_drophigh_droplow[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_drophigh_droplow[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_drophigh_droplow[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_drophigh_droplow[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df28e6cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T15:11:07.841246Z",
     "start_time": "2024-05-17T15:11:07.832352Z"
    }
   },
   "outputs": [],
   "source": [
    "# interp_drophigh_droplow[1e-6]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c2af237",
   "metadata": {},
   "source": [
    "# SWA(5) -> DROP(HIGH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76c1cb89",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T21:48:13.217842Z",
     "start_time": "2024-05-17T21:48:13.210781Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = (i // 6)\n",
    "    \n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../drops_with_clean_cifar10/Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_200_k_001/checkpoint-204.pt \\\\\n",
    "    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_0.0003_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, elr, \n",
    "                elr, seed, \n",
    "                elr, seed))\n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr, seed))\n",
    "    if i % 6 == 5:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa63b2ab",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:19:53.195551Z",
     "start_time": "2024-05-18T08:19:34.159224Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_swa5_drophigh = dict()\n",
    "for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): \n",
    "    interp_swa5_drophigh[elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_0.0003_seed_{}'\n",
    "    pt_path = base.format(elr, seed)\n",
    "    \n",
    "    interp_swa5_drophigh[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_swa5_drophigh[elr]['track'] = track\n",
    "        \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_drophigh[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_drophigh[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_drophigh[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_drophigh[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f14c964d",
   "metadata": {},
   "source": [
    "# SWA(5) -> DROP(LOW)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c345483",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-17T22:25:22.240414Z",
     "start_time": "2024-05-17T22:25:22.232209Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for i, (elr, seed) in enumerate(zip(USUAL_ELRS, USUAL_ESEEDS)):\n",
    "    gp = (i // 6)\n",
    "    \n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/linear_interpolation_resnet18si_cifar10_clean.py \\\\\n",
    "    --gpu {} --elr {} --n_interp 20 --recalc_bn 1 \\\\\n",
    "    --point_a ./../drops_with_clean_cifar10/Experiments/SWA_K_100_stride_1_ResNet18SI_CIFAR10_elri_{}_elrd_{}_dropepoch_1001_wd_0.0/swa_start_200_k_001/checkpoint-204.pt \\\\\n",
    "    --point_b ./../drops_with_clean_cifar10/Experiments/FIXEDINIT_DROP_ResNet18SI_CIFAR10_elri_{}_elrd_1e-05_dropepochfrom_200_wd_0.0_seed_{}_noaug_True/checkpoint-400.pt \\\\\n",
    "    --save ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}/ && \\\\\"\"\".\\\n",
    "         format(gp, elr, \n",
    "                elr, elr, \n",
    "                elr, seed, \n",
    "                elr, seed))\n",
    "    print(\"\"\"python ./../drops_with_clean_cifar10/calc_grad_norms_resnet18si_cifar_clean.py \\\\\n",
    "    --gpu {} \\\\\n",
    "    --directory_with_checkpoints ./Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}/ \\\\\n",
    "    --train_mode 1 && \\\\\"\"\".\\\n",
    "         format(gp, \n",
    "                elr, seed))\n",
    "    if i % 6 == 5:\n",
    "        print('\\n\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9d7efe3",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:20:19.639013Z",
     "start_time": "2024-05-18T08:20:00.393643Z"
    }
   },
   "outputs": [],
   "source": [
    "interp_swa5_droplow = dict()\n",
    "for elr, seed in tqdm(zip(USUAL_ELRS, USUAL_ESEEDS)): \n",
    "    interp_swa5_droplow[elr] = dict()\n",
    "    base = './Experiments/CONNECTIVITY_RN18C10_lri_{}_from_swa_5_to_drop_1e-05_seed_{}'\n",
    "    pt_path = base.format(elr, seed)\n",
    "    \n",
    "    interp_swa5_droplow[elr]['angle'] = read_angle_dist_from_track(pt_path)\n",
    "    \n",
    "    track = read_li_track(pt_path)\n",
    "    interp_swa5_droplow[elr]['track'] = track\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['loss_trainmode_train'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_droplow[elr]['loss_barrier'] = loss_barrier\n",
    "    \n",
    "    lossts_barrier = track_to_barrier([x['test_res']['loss'] for x in track], barrier_is_higher=True)\n",
    "    interp_swa5_droplow[elr]['lossts_barrier'] = lossts_barrier\n",
    "    \n",
    "    testacc_barrier = track_to_barrier([x['test_res']['accuracy'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_droplow[elr]['testacc_barrier'] = testacc_barrier\n",
    "    \n",
    "    loss_barrier = track_to_barrier([x['acc_trainmode_train'] for x in track], barrier_is_higher=False)\n",
    "    interp_swa5_droplow[elr]['trainacc_barrier'] = loss_barrier"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f286e85e",
   "metadata": {},
   "source": [
    "# dump to the disk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c4bdbec",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:20:21.561615Z",
     "start_time": "2024-05-18T08:20:21.554507Z"
    }
   },
   "outputs": [],
   "source": [
    "barrier_setups = dict()\n",
    "barrier_setups['interp_drophigh_droplow'] = interp_drophigh_droplow\n",
    "barrier_setups['interp_swa5_drophigh'] = interp_swa5_drophigh\n",
    "barrier_setups['interp_swa5_droplow'] = interp_swa5_droplow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39cb9ae9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-18T08:20:21.947634Z",
     "start_time": "2024-05-18T08:20:21.915481Z"
    }
   },
   "outputs": [],
   "source": [
    "with open('./resnet18si_cifar10_barrier_setups.pkl', 'wb') as f:\n",
    "    pickle.dump(barrier_setups, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdbb4d8e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c814db2b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (venv_gpu)",
   "language": "python",
   "name": "venv_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
