{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "\n",
    "container_name = 'learning-dynamics-scores'\n",
    "container_dir = 'imagenet'\n",
    "loss_curv = []\n",
    "loss_grad = []\n",
    "loss = []\n",
    "preds = []\n",
    "\n",
    "epochs = 90\n",
    "seed = 3\n",
    "model_name = f\"resnet50_wd1_seed_{seed}_epoch\"\n",
    "\n",
    "targets = get_numpy_from_cloud(container_name, container_dir, f\"targets_{model_name}_{0}.npy\")\n",
    "\n",
    "for epoch in range(0, epochs):\n",
    "    try:\n",
    "        pred = get_numpy_from_cloud(container_name, container_dir, f\"pred_{model_name}_{epoch}.npy\")\n",
    "        preds.append(pred == targets)\n",
    "\n",
    "        loss_curv_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"curvature_{model_name}_{epoch}.npy\")\n",
    "        loss_curv.append(loss_curv_4_eph)\n",
    "\n",
    "        loss_grad_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_grad_{model_name}_{epoch}.npy\")\n",
    "        loss_grad.append(loss_grad_4_eph)\n",
    "\n",
    "        loss_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_{model_name}_{epoch}.npy\")\n",
    "        loss.append(loss_4_eph)\n",
    "    except Exception as e:\n",
    "        print(f\"Epoch {epoch} not found, {e.args}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = np.array(preds)\n",
    "pred = pred.astype(np.int32)\n",
    "shifted = np.roll(pred, -1, axis=0)\n",
    "forget_freq = ((pred - shifted) > 0).sum(0)\n",
    "loss_curv_score = np.array(loss_curv).mean(0)\n",
    "loss = np.array(loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from scipy import stats\n",
    "\n",
    "# Convert dataset order from tensorflow to pytorch. Needed because FZ use tensorflow, while this repo uses Pytorch\n",
    "npz = np.load('./analysis_checkpoints/imagenet/imagenet_index.npz', allow_pickle=True)\n",
    "fz_scores = pd.DataFrame.from_dict({item: npz[item] for item in ['tr_labels', 'tr_filenames', 'tr_mem']})\n",
    "fz_scores['tr_filenames'] = fz_scores['tr_filenames'].astype(str)\n",
    "fz_scores['loss_last'] = loss[-1, :]\n",
    "fz_scores['loss_sen'] = loss.std(0)\n",
    "fz_scores['forget_freq'] = forget_freq\n",
    "fz_scores['curv'] = loss_curv_score\n",
    "fz_scores['csg'] = np.array(loss_grad).mean(0)\n",
    "fz_scores = fz_scores.sort_values(by='tr_mem', ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Method</th>\n",
       "      <th>Is Top 50k</th>\n",
       "      <th>PC</th>\n",
       "      <th>CS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>loss_last</td>\n",
       "      <td>False</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.63</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>loss_sen</td>\n",
       "      <td>False</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>forget_freq</td>\n",
       "      <td>False</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.49</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>curv</td>\n",
       "      <td>False</td>\n",
       "      <td>0.33</td>\n",
       "      <td>0.62</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>csg</td>\n",
       "      <td>False</td>\n",
       "      <td>0.52</td>\n",
       "      <td>0.72</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Method  Is Top 50k   PC   CS\n",
       "0    loss_last       False 0.49 0.63\n",
       "2     loss_sen       False 0.17 0.49\n",
       "4  forget_freq       False 0.04 0.49\n",
       "6         curv       False 0.33 0.62\n",
       "8          csg       False 0.52 0.72"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "methods = []\n",
    "top_k = []\n",
    "cs_values = []\n",
    "pc_values = []\n",
    "\n",
    "for idx, metric in enumerate(['loss_last', 'loss_sen', 'forget_freq', 'curv', 'csg']):\n",
    "    all_cs_loss = np.dot(fz_scores['tr_mem'].values, fz_scores[metric].values) / (np.linalg.norm(fz_scores['tr_mem'].values) * np.linalg.norm(fz_scores[metric].values))\n",
    "    top50k_cs_loss = np.dot(fz_scores['tr_mem'].values[:50000], fz_scores[metric].values[:50000]) / (np.linalg.norm(fz_scores['tr_mem'].values[:50000]) * np.linalg.norm(fz_scores[metric].values[:50000]))\n",
    "    all_pc = stats.pearsonr(fz_scores['tr_mem'].values, fz_scores[metric].values)\n",
    "    top50k_pc = stats.pearsonr(fz_scores['tr_mem'].values[:50000], fz_scores[metric].values[:50000])\n",
    "    methods.append(metric)\n",
    "    cs_values.append(all_cs_loss)\n",
    "    top_k.append(False)\n",
    "    pc_values.append(abs(all_pc.statistic))\n",
    "    methods.append(metric)\n",
    "    cs_values.append(top50k_cs_loss)\n",
    "    top_k.append(True)\n",
    "    pc_values.append(abs(top50k_pc.statistic))\n",
    "\n",
    "simData = pd.DataFrame({\n",
    "    'Method': methods,\n",
    "    'Is Top 50k': top_k,\n",
    "    'PC': pc_values,\n",
    "    'CS': cs_values\n",
    "})\n",
    "\n",
    "pd.options.display.float_format = '{:.2f}'.format\n",
    "simData[simData['Is Top 50k'] == False]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
