{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "from scipy import stats\n",
    "\n",
    "container_name = 'learning-dynamics-scores'\n",
    "dataset = 'cifar100'\n",
    "container_dir = dataset\n",
    "seed = 1\n",
    "arch = 'fz_inception'\n",
    "model_name = f\"{dataset}_{arch}_seed_{seed}_epoch_{0}\"\n",
    "targets = get_numpy_from_cloud(container_name, container_dir, f\"targets_{model_name}.npy\")\n",
    "\n",
    "loss_grad = []\n",
    "loss = []\n",
    "loss_curvature = []\n",
    "preds = []\n",
    "for epoch in range(1, 199):\n",
    "    model_name = f\"{dataset}_{arch}_seed_{seed}_epoch_{epoch}\"\n",
    "    \n",
    "    loss_grad_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_grad_{model_name}_tid0.pt\")\n",
    "    loss_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_{model_name}.npy\")\n",
    "    loss_grad.append(get_numpy_from_cloud(container_name, container_dir, f\"loss_grad_{model_name}.npy\") )\n",
    "    loss_curvature_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_curvature_{model_name}_h0.001_tid0.pt\")\n",
    "    pred_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"pred_{model_name}.npy\")\n",
    "    \n",
    "    preds.append(pred_4_eph == targets)\n",
    "    loss_grad.append(loss_grad_4_eph)\n",
    "    loss.append(loss_4_eph)\n",
    "    loss_curvature.append(loss_curvature_4_eph)\n",
    "\n",
    "pred = np.array(preds)\n",
    "pred = pred.astype(np.int32)\n",
    "shifted = np.roll(pred, -1, axis=0)\n",
    "forget_freq = ((pred - shifted) > 0).sum(0)\n",
    "\n",
    "loss_grads = np.array(loss_grad)\n",
    "losses = np.array(loss)\n",
    "loss_curvatures = np.array(loss_curvature)\n",
    "loss_grads = np.array(loss_grad)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "npz = np.load('./analysis_checkpoints/cifar100/cifar100_infl_matrix.npz', allow_pickle=True)\n",
    "fz_scores = pd.DataFrame.from_dict({item: npz[item] for item in ['tr_labels', 'tr_mem']})\n",
    "fz_scores['loss_curv'] = loss_curvatures.mean(0)\n",
    "fz_scores['loss_sen'] = losses.std(0)\n",
    "fz_scores['loss_last'] = losses[-1, :]\n",
    "fz_scores['forget_freq'] = forget_freq\n",
    "fz_scores['loss_grad'] = loss_grads.mean(0)\n",
    "fz_scores = fz_scores.sort_values(by='tr_mem', ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Method</th>\n",
       "      <th>Is Top 5k</th>\n",
       "      <th>PC</th>\n",
       "      <th>CS</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>loss_last</td>\n",
       "      <td>False</td>\n",
       "      <td>0.17</td>\n",
       "      <td>0.24</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>loss_sen</td>\n",
       "      <td>False</td>\n",
       "      <td>0.76</td>\n",
       "      <td>0.81</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>forget_freq</td>\n",
       "      <td>False</td>\n",
       "      <td>0.59</td>\n",
       "      <td>0.76</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>loss_curv</td>\n",
       "      <td>False</td>\n",
       "      <td>0.49</td>\n",
       "      <td>0.69</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>loss_grad</td>\n",
       "      <td>False</td>\n",
       "      <td>0.72</td>\n",
       "      <td>0.84</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        Method  Is Top 5k   PC   CS\n",
       "0    loss_last      False 0.17 0.24\n",
       "2     loss_sen      False 0.76 0.81\n",
       "4  forget_freq      False 0.59 0.76\n",
       "6    loss_curv      False 0.49 0.69\n",
       "8    loss_grad      False 0.72 0.84"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "methods = []\n",
    "top_k = []\n",
    "cs_values = []\n",
    "pc_values = []\n",
    "\n",
    "for idx, metric in enumerate(['loss_last', 'loss_sen', 'forget_freq', 'loss_curv', 'loss_grad']):\n",
    "    all_cs_loss = np.dot(fz_scores['tr_mem'].values, fz_scores[metric].values) / (np.linalg.norm(fz_scores['tr_mem'].values) * np.linalg.norm(fz_scores[metric].values))\n",
    "    top5k_cs_loss = np.dot(fz_scores['tr_mem'].values[:5000], fz_scores[metric].values[:5000]) / (np.linalg.norm(fz_scores['tr_mem'].values[:5000]) * np.linalg.norm(fz_scores[metric].values[:5000]))\n",
    "    all_pc = stats.pearsonr(fz_scores['tr_mem'].values, fz_scores[metric].values)\n",
    "    top5k_pc = stats.pearsonr(fz_scores['tr_mem'].values[:5000], fz_scores[metric].values[:5000])\n",
    "    methods.append(metric)\n",
    "    cs_values.append(all_cs_loss)\n",
    "    top_k.append(False)\n",
    "    pc_values.append(abs(all_pc.statistic))\n",
    "    methods.append(metric)\n",
    "    cs_values.append(top5k_cs_loss)\n",
    "    top_k.append(True)\n",
    "    pc_values.append(abs(top5k_pc.statistic))\n",
    "\n",
    "simData = pd.DataFrame({\n",
    "    'Method': methods,\n",
    "    'Is Top 5k': top_k,\n",
    "    'PC': pc_values,\n",
    "    'CS': cs_values\n",
    "})\n",
    "\n",
    "pd.options.display.float_format = '{:.2f}'.format\n",
    "simData[simData['Is Top 5k'] == False]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
