{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====== vgg16 =====\n",
      "All loss: 0.61\n",
      "Top5k loss: 1.00\n",
      "All loss_grad: 0.60\n",
      "Top5k loss_grad: 0.98\n",
      "====== mobilenetv2 =====\n",
      "All loss: 0.73\n",
      "Top5k loss: 0.95\n",
      "All loss_grad: 0.65\n",
      "Top5k loss_grad: 0.94\n",
      "====== fz_inception =====\n",
      "All loss: 0.70\n",
      "Top5k loss: 0.97\n",
      "All loss_grad: 0.64\n",
      "Top5k loss_grad: 0.95\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import json\n",
    "from minio_obj_storage import get_numpy_from_cloud\n",
    "\n",
    "# Convert dataset order from tensorflow to pytorch. Needed because FZ use tensorflow, while this repo uses Pytorch\n",
    "tf_2_torch_idx = pd.read_pickle(\"./mislabelled_exps/tf_id_2_torch_id2.pkl\")\n",
    "tf_2_torch_idx = tf_2_torch_idx['torch_id'].to_numpy()\n",
    "\n",
    "container_name = 'learning-dynamics-scores'\n",
    "dataset = 'cifar100'\n",
    "container_dir = dataset\n",
    "for arch in ['vgg16', 'mobilenetv2', 'fz_inception']:\n",
    "    loss_grad = []\n",
    "    loss = []\n",
    "    loss_curvature = []\n",
    "    for epoch in range(10):\n",
    "        loss_grad_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_grad_{dataset}_{arch}_seed_1_epoch_{epoch}_tid0.pt\")\n",
    "        loss_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"losses_{dataset}_{arch}_seed_1_epoch_{epoch}_tid0.pt\")\n",
    "        loss_curvature_4_eph = get_numpy_from_cloud(container_name, container_dir, f\"loss_curvature_{dataset}_{arch}_seed_1_epoch_{epoch}_h0.001_tid0.pt\")\n",
    "        loss_grad.append(loss_grad_4_eph)\n",
    "        loss.append(loss_4_eph)\n",
    "        loss_curvature.append(loss_curvature_4_eph)\n",
    "\n",
    "    loss_grads = np.array(loss_grad)\n",
    "    losses = np.array(loss)\n",
    "    loss_curvatures = np.array(loss_curvature)\n",
    "    npz = np.load('./analysis_checkpoints/cifar100/cifar100_infl_matrix.npz', allow_pickle=True)\n",
    "    fz_scores = pd.DataFrame.from_dict({item: npz[item] for item in ['tr_labels', 'tr_mem']})\n",
    "    fz_scores['loss'] = losses.mean(0)\n",
    "    fz_scores['loss_grad'] = loss_grads.mean(0)\n",
    "    fz_scores = fz_scores.sort_values(by='tr_mem', ascending=False)\n",
    "\n",
    "    print(f\"====== {arch} =====\")\n",
    "    for metric in ['loss', 'loss_grad']:\n",
    "        all_cs_loss = np.dot(fz_scores['tr_mem'].values, fz_scores[metric].values) / (np.linalg.norm(fz_scores['tr_mem'].values) * np.linalg.norm(fz_scores[metric].values))\n",
    "        top5k_cs_loss = np.dot(fz_scores['tr_mem'].values[:5000], fz_scores[metric].values[:5000]) / (np.linalg.norm(fz_scores['tr_mem'].values[:5000]) * np.linalg.norm(fz_scores[metric].values[:5000]))\n",
    "        print(f\"All {metric}: {all_cs_loss:.2f}\")\n",
    "        print(f\"Top5k {metric}: {top5k_cs_loss:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py3.11_tf",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
