{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import json\n",
    "import torch\n",
    "import joblib\n",
    "import numpy as np\n",
    "from types import SimpleNamespace\n",
    "from torch.nn import functional as F\n",
    "from scripts import launch_pretraining, launch_finetuning\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.notebook import tqdm\n",
    "from IPython.display import clear_output\n",
    "from scipy.special import xlogy\n",
    "from utils import load_stats\n",
    "from model import *\n",
    "from dataset import *\n",
    "\n",
    "from matplotlib.colors import LogNorm, LinearSegmentedColormap\n",
    "from matplotlib.cm import ScalarMappable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "dirs = sorted(glob.glob('experiments-final/f_all=tick/PT-*'))\n",
    "configs = []\n",
    "for dir_name in dirs:\n",
    "    with open(os.path.join(dir_name, 'config.json')) as file:\n",
    "        configs.append(SimpleNamespace(**json.load(file)))\n",
    "configs = sorted(configs, key=lambda x: (x.data_seed, x.pt_seed))\n",
    "lrs = configs[0].lrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = configs[0]\n",
    "X_group, y_group = [], []\n",
    "num_groups = config.num_features // 2\n",
    "samples_per_group = 1000\n",
    "\n",
    "for i in range(num_groups):\n",
    "    multiview_probs = [0] * num_groups\n",
    "    multiview_probs[i] = 1.0\n",
    "\n",
    "    _, X, _, y, _ = \\\n",
    "        generate_data(\n",
    "            config.data_protocol, multiview_probs,\n",
    "            config.data_seed + i, config.num_features,\n",
    "            config.train_samples, samples_per_group, 'cpu'\n",
    "        )\n",
    "    \n",
    "    X_group += [X]\n",
    "    y_group += [y]\n",
    "\n",
    "X_group = torch.cat(X_group, dim=0)\n",
    "y_group = torch.cat(y_group, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_metrics(grad):\n",
    "        s = torch.linalg.svdvals(grad)\n",
    "        s = s / s.sum()\n",
    "        erank = torch.exp(-torch.xlogy(s, s).sum()).item()\n",
    "\n",
    "        grad_sq = grad ** 2\n",
    "        ratios = (grad_sq / (grad_sq.sum(dim=1, keepdims=True) + 1e-7)).mean(0)\n",
    "        s = (ratios + torch.roll(ratios, 1))[1::2]\n",
    "        ent = -torch.xlogy(s, s).sum().item()\n",
    "    \n",
    "        mean_grad_sq = grad_sq.mean(0)\n",
    "        group_grad = (mean_grad_sq + torch.roll(mean_grad_sq, 1))[1::2].numpy()\n",
    "        return erank, ent, group_grad\n",
    "\n",
    "def get_erank_and_ent(model, X, y):\n",
    "    X.grad = None\n",
    "    preds = model(X)[:, 0]\n",
    "    torch.autograd.backward(list(preds), retain_graph=True)\n",
    "    output_erank, output_ent, output_group_grad = get_metrics(X.grad)\n",
    "\n",
    "    X.grad = None\n",
    "    l = F.binary_cross_entropy_with_logits(preds, y.to(torch.float))\n",
    "    l.backward()\n",
    "    loss_erank, loss_ent, loss_group_grad = get_metrics(X.grad)\n",
    "\n",
    "    return (\n",
    "        output_erank, output_ent,\n",
    "        loss_erank, loss_ent,\n",
    "        output_group_grad, loss_group_grad\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbf8fbeaf6e4479abf39cfd7d6adc584",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "loss_eranks, loss_ents, output_eranks, output_ents = [np.zeros((len(configs), len(lrs), 2)) for _ in range(4)]\n",
    "group_output_grads, group_loss_grads = [np.zeros((len(configs), len(lrs), 2, 16)) for _ in range(2)]\n",
    "group_accs = np.zeros((len(configs), len(lrs), 16))\n",
    "\n",
    "flr = None #lrs[10]\n",
    "num_swa = None #5\n",
    "for i, config in enumerate(tqdm(configs)):\n",
    "    model = init_model(\n",
    "        config.num_layers, config.num_hidden, config.num_features,\n",
    "        config.last_layer_norm, config.activation\n",
    "    )\n",
    "    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]\n",
    "    X_train.requires_grad = True\n",
    "    X_test.requires_grad = True\n",
    "\n",
    "    for j, lr in enumerate(lrs):\n",
    "        if flr is None:\n",
    "            ckpt = torch.load(f'{config.savedir}/pt_lr={lr:.3e}.pt')\n",
    "        else:\n",
    "            ckpt = torch.load(f'{config.ft_savedir}/pt_lr={lr:.3e}-ft_lr={flr:.3e}.pt')\n",
    "        model.load_state_dict(ckpt['model'])\n",
    "    \n",
    "        if num_swa is not None:\n",
    "            w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)\n",
    "            set_weights(model, w_swa)\n",
    "\n",
    "        '''\n",
    "        output_eranks[i, j, 0], output_ents[i, j, 0], \\\n",
    "        loss_eranks[i, j, 0], loss_ents[i, j, 0], \\\n",
    "        group_output_grads[i, j, 0], group_loss_grads[i, j, 0] = \\\n",
    "            get_erank_and_ent(model, X_train, y_train)\n",
    "\n",
    "        output_eranks[i, j, 1], output_ents[i, j, 1], \\\n",
    "        loss_eranks[i, j, 1], loss_ents[i, j, 1], \\\n",
    "        group_output_grads[i, j, 1], group_loss_grads[i, j, 1] = \\\n",
    "            get_erank_and_ent(model, X_test, y_test)\n",
    "        '''\n",
    "\n",
    "        with torch.no_grad():\n",
    "            group_preds = (model(X_group)[:, 0] > 0).to(torch.long)\n",
    "        correct = group_preds == y_group\n",
    "\n",
    "        for k in range(num_groups):\n",
    "            group_accs[i, j, k] = correct[k * samples_per_group: (k + 1) * samples_per_group].to(torch.float).mean().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['experiments-final/f_multi=tick@0/swa_group_accs.pickle']"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "joblib.dump((\n",
    "    (loss_eranks, loss_ents, output_eranks, output_ents,\n",
    "     group_output_grads, group_loss_grads, group_accs)\n",
    "), f'experiments-final/f_all=tick/pt_group_accs.pickle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "alphas = np.linspace(0, 1, 11)\n",
    "\n",
    "@torch.no_grad()\n",
    "def get_error(model, X, y):\n",
    "    preds = (model(X)[:, 0] > 0).to(torch.long)\n",
    "    err = (preds != y).to(torch.float).mean().item()\n",
    "    return 100 * err\n",
    "\n",
    "@torch.no_grad()\n",
    "def get_accuracy(model, X, y):\n",
    "    preds = (model(X)[:, 0] > 0).to(torch.long)\n",
    "    acc = (preds == y).to(torch.float).mean().item()\n",
    "    return 100 * acc\n",
    "\n",
    "def get_barrier(model, w1, w2, X, y, alphas):\n",
    "    errors = np.zeros_like(alphas)\n",
    "    for i, alpha in enumerate(alphas):\n",
    "        #w = np.cos(alpha) * w1 + np.sin(alpha) * w2\n",
    "        w = (1 - alpha) * w1 + alpha * w2\n",
    "        set_weights(model, w)\n",
    "        errors[i] = get_error(model, X, y)\n",
    "\n",
    "    barrier = np.max(errors - (1 - alphas) * errors[0] - alphas * errors[-1])\n",
    "    return barrier\n",
    "\n",
    "def get_angle(w1, w2):\n",
    "    return torch.clip(w1 @ w2, -1, 1).arccos().item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b4c5232192c04fa09e425c6c71108b0f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from joblib import Parallel, delayed\n",
    "\n",
    "config = configs[0]\n",
    "flr1, flr2 = lrs[0], lrs[10]\n",
    "num_swa = 5\n",
    "alphas = np.linspace(0, 1, 11)\n",
    "angles = np.zeros((3, len(configs), len(lrs)))\n",
    "train_barriers = np.zeros((3, len(configs), len(lrs)))\n",
    "train_errors = np.zeros((3, len(configs), len(lrs), len(alphas)))\n",
    "test_barriers = np.zeros((3, len(configs), len(lrs)))\n",
    "test_errors = np.zeros((3, len(configs), len(lrs), len(alphas)))\n",
    "\n",
    "model = init_model(\n",
    "    config.num_layers, config.num_hidden, config.num_features,\n",
    "    config.last_layer_norm, config.activation\n",
    ")\n",
    "\n",
    "def process_plr(config, plr):\n",
    "    ckpt = torch.load(f'{config.savedir}/pt_lr={plr:.3e}.pt')\n",
    "    ckpt1 = torch.load(f'{config.ft_savedir}/pt_lr={plr:.3e}-ft_lr={flr1:.3e}.pt')\n",
    "    ckpt2 = torch.load(f'{config.ft_savedir}/pt_lr={plr:.3e}-ft_lr={flr2:.3e}.pt')\n",
    "    model.load_state_dict(ckpt['model'])\n",
    "\n",
    "    w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)\n",
    "    w_low = ckpt1['trace']['weight'][-1]\n",
    "    w_high = ckpt2['trace']['weight'][-1]\n",
    "\n",
    "    angles, train_barriers, test_barriers = np.zeros(3), np.zeros(3), np.zeros(3)\n",
    "    for k, (w1, w2) in enumerate([(w_low, w_high), (w_high, w_swa), (w_low, w_swa)]):\n",
    "        angles[k] = get_angle(w1, w2)\n",
    "\n",
    "        barrier = get_barrier(model, w1, w2, X_train, y_train, alphas)\n",
    "        train_barriers[k] = barrier\n",
    "        barrier = get_barrier(model, w1, w2, X_test, y_test, alphas)\n",
    "        test_barriers[k] = barrier\n",
    "    \n",
    "    return angles, train_barriers, test_barriers\n",
    "\n",
    "for i, config in enumerate(tqdm(configs)):\n",
    "    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]\n",
    "    results = Parallel(n_jobs=8)(\n",
    "        delayed(process_plr)(config, plr) for plr in lrs\n",
    "    )\n",
    "    for j, (a, b, c) in enumerate(results):\n",
    "        angles[:, i, j] = a\n",
    "        train_barriers[:, i, j] = b\n",
    "        test_barriers[:, i, j] = c"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['experiments-final/f_multi=tick@0/barriers-low_flr=3.162e-05-high_flr=7.499e-03.pickle']"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "joblib.dump((angles, train_barriers, train_errors, test_barriers, test_errors),\n",
    "            f'experiments-final/f_multi=tick@0/barriers-low_flr={flr1:.3e}-high_flr={flr2:.3e}.pickle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09a3ccef2dc74b28b75d7d7058d53e9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "config = configs[0]\n",
    "num_swa_list = [2, 5, 10, 20, 50]\n",
    "swa_accs = np.zeros((len(configs), len(lrs), len(num_swa_list)))\n",
    "model = init_model(\n",
    "    config.num_layers, config.num_hidden, config.num_features,\n",
    "    config.last_layer_norm, config.activation\n",
    ")\n",
    "\n",
    "for i, config in enumerate(tqdm(configs)):\n",
    "    X_train, X_test, y_train, y_test = torch.load(f'{config.savedir}/data.pt')[:4]\n",
    "\n",
    "    for j, plr in enumerate(lrs):\n",
    "        ckpt = torch.load(f'{config.savedir}/pt_lr={plr:.3e}.pt')\n",
    "        model.load_state_dict(ckpt['model'])\n",
    "\n",
    "        for k, num_swa in enumerate(num_swa_list):\n",
    "            w_swa = torch.stack(ckpt['trace']['weight'][-num_swa:], dim=0).mean(0)\n",
    "            set_weights(model, w_swa)\n",
    "            swa_accs[i, j, k] = get_accuracy(model, X_test, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['experiments-final/f_multi=tick@0/swa.pickle']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "joblib.dump(swa_accs, 'experiments-final/f_all=tick/swa.pickle')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
