{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f632efbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "\n",
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import pathlib, itertools\n",
    "from functools import partial\n",
    "import sys, glob, copy\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../deep-learning-base/datasets')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "from training import LitProgressBar, NicerModelCheckpointing\n",
    "import training.finetuning as ft\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper, LinearEvalWrapper\n",
    "from attack.callbacks import AdvCallback\n",
    "from data_modules import DATA_MODULES\n",
    "import dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH\n",
    "from human_nn_alignment.save_inverted_reps import save_batched_images, get_classes_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d23c47af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_linear_weights(model, state_dict):\n",
    "    linear_layer = list(model.named_modules())[-1][1]\n",
    "    linear_layer.load_state_dict({'.'.join(k.split('.')[-1:]):v \\\n",
    "                                  for k,v in state_dict['state_dict'].items()}, strict=True)\n",
    "    if hasattr(linear_layer, 'neuron_indices') and 'neuron_indices' in state_dict:\n",
    "        assert torch.all(linear_layer.neuron_indices == state_dict['neuron_indices'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "77d54f94",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Source: https://github.com/oliviaguest/gini/blob/master/gini.py\n",
    "def gini(array):\n",
    "    \"\"\"Calculate the Gini coefficient of a numpy array.\"\"\"\n",
    "    # based on bottom eq:\n",
    "    # http://www.statsdirect.com/help/generatedimages/equations/equation154.svg\n",
    "    # from:\n",
    "    # http://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm\n",
    "    # All values are treated equally, arrays must be 1d:\n",
    "    array = array.flatten()\n",
    "    if np.amin(array) < 0:\n",
    "        # Values cannot be negative:\n",
    "        array -= np.amin(array)\n",
    "    # Values cannot be 0:\n",
    "    array += 0.0000001\n",
    "    # Values must be sorted:\n",
    "    array = np.sort(array)\n",
    "    # Index per array element:\n",
    "    index = np.arange(1,array.shape[0]+1)\n",
    "    # Number of array elements:\n",
    "    n = array.shape[0]\n",
    "    # Gini coefficient:\n",
    "    return ((np.sum((2 * index - n  - 1) * array)) / (n * np.sum(array)))\n",
    "\n",
    "def maxmin(array):\n",
    "    return np.max(array) - np.min(array)\n",
    "\n",
    "def tv(array):\n",
    "    return np.var(array) ** 0.5 / np.mean(array)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3b309f3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_predictions(model_path):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    if os.path.exists(preds_path):\n",
    "        return torch.load(preds_path)\n",
    "\n",
    "def calc_accs(m1, dm, checkpoints, frac, finetuning_dataset):    \n",
    "    class_to_acc, overall = {}, []\n",
    "    for model_path in checkpoints:\n",
    "        loaded_preds = load_predictions(model_path)\n",
    "        if loaded_preds is not None:\n",
    "            pred, y = torch.argmax(loaded_preds['pred'], 1), loaded_preds['gt']\n",
    "        else:\n",
    "            seed = int(model_path.split('-seed-')[1].split('-')[0])\n",
    "            sd = torch.load(model_path)\n",
    "            layer = ft.setup_model_for_finetuning(\n",
    "                m1.model, \n",
    "                dsmd.DATASET_PARAMS[finetuning_dataset]['num_classes'],\n",
    "                FINETUNE_MODE, frac, seed, inplace=True)\n",
    "            load_linear_weights(m1.model, sd)        \n",
    "            pl_utils.seed.seed_everything(SEED, workers=True)\n",
    "            trainer = Trainer(accelerator='gpu', \n",
    "                              devices=DEVICES,\n",
    "                              num_nodes=NUM_NODES,\n",
    "                              log_every_n_steps=1,\n",
    "                              auto_select_gpus=True, \n",
    "                              deterministic=True,\n",
    "                              check_val_every_n_epoch=1,\n",
    "                              num_sanity_val_steps=0,\n",
    "                              callbacks=[LitProgressBar(['loss', 'running_test_acc'])])\n",
    "            out = trainer.predict(m1, dataloaders=[dm.test_dataloader()])\n",
    "            pred, y = torch.argmax(out[0], 1), out[1]\n",
    "        \n",
    "        overall.append(torch.sum(pred == y).item()/len(y))\n",
    "        unique_classes = list(set(y.numpy()))\n",
    "        for c in unique_classes:\n",
    "            mask = y == c\n",
    "            if c in class_to_acc:\n",
    "                class_to_acc[c].append(torch.sum(pred[mask] == y[mask])/torch.sum(mask))\n",
    "            else:\n",
    "                class_to_acc[c] = [torch.sum(pred[mask] == y[mask])/torch.sum(mask)]\n",
    "    return class_to_acc, overall\n",
    "\n",
    "def get_classwise_errors(model, source_dataset, finetuning_dataset, checkpoint_path, \n",
    "                         append, FRACTIONS_OR_NUMBERS=None):\n",
    "    dm = DATA_MODULES[finetuning_dataset](\n",
    "        data_dir=DATA_PATH_IMAGENET if 'imagenet' in finetuning_dataset else DATA_PATH,\n",
    "        transform_train=dsmd.TRAIN_TRANSFORMS_TRANSFER_DEFAULT(224),\n",
    "        transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "        batch_size=EVAL_BATCH_SIZE)\n",
    "    dm.init_remaining_attrs(source_dataset)\n",
    "\n",
    "    ## assign mean and std from source dataset\n",
    "    m1 = arch.create_model(model, source_dataset, pretrained=True,\n",
    "                           checkpoint_path=checkpoint_path, seed=SEED, \n",
    "                           num_classes=dsmd.DATASET_PARAMS[source_dataset]['num_classes'],\n",
    "                           callback=partial(LightningWrapper, \n",
    "                                            dataset_name=source_dataset),\n",
    "                           loading_function_kwargs={'strict': False})\n",
    "    og_model = copy.deepcopy(m1.model)\n",
    "    \n",
    "    if FRACTIONS_OR_NUMBERS is None:\n",
    "        FRACTIONS_OR_NUMBERS = sorted(list(set(\n",
    "            [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "             glob.glob(f'./checkpoints/{model}-base-'\n",
    "                       f'{source_dataset}-ft-{finetuning_dataset}/'\n",
    "                       f'*-bs-256-{append}/*') \\\n",
    "             if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                'full-feature' not in x.split('/')[-1] and \\\n",
    "                'layer' not in x.split('/')[-1] and \\\n",
    "                'pool' not in x.split('/')[-1]])))\n",
    "    \n",
    "    classwise_accs = {} # each value is a dict from class_index to a list of accuracies\n",
    "    overall_accs = {}\n",
    "    for frac in FRACTIONS_OR_NUMBERS:\n",
    "        if isinstance(frac, int):\n",
    "            actual_fraction = frac / FRACTIONS_OR_NUMBERS[-1]\n",
    "        else:\n",
    "            actual_fraction = frac\n",
    "        checkpoint_paths = [x for x in glob.glob(\n",
    "            f'{BASE_DIR}/{model}-base-{source_dataset}-ft-{finetuning_dataset}/'\n",
    "            f'frac-{actual_fraction:.5f}-mode-{FINETUNE_MODE}-seed-*-'\n",
    "            f'ftmode-linear-lr-*-bs*-{append}/*-topk=1.ckpt') \\\n",
    "                if 'layer' not in x.split('/')[-1] and \\\n",
    "                   'pool' not in x.split('/')[-1] and \\\n",
    "                   'full-feature' not in x.split('/')[-1]]\n",
    "        m1.model = copy.deepcopy(og_model)\n",
    "        classwise, overall = calc_accs(m1, dm, checkpoint_paths, actual_fraction, finetuning_dataset)\n",
    "        classwise_accs[frac] = classwise\n",
    "        overall_accs[frac] = overall\n",
    "    \n",
    "    class_names = get_classes_names(finetuning_dataset, \n",
    "                                    DATA_PATH_IMAGENET if 'imagenet' in finetuning_dataset else DATA_PATH)\n",
    "    classwise_heatmap = np.full((len(class_names), len(FRACTIONS_OR_NUMBERS)), np.nan)\n",
    "    for i, frac in enumerate(FRACTIONS_OR_NUMBERS):\n",
    "        for c, accs in sorted(classwise_accs[frac].items(), key=lambda x: x[0]):\n",
    "            classwise_heatmap[c, i] = np.nanmean(accs)\n",
    "    \n",
    "    return classwise_heatmap, overall_accs, plt_hp.get_wiki_link(plt_hp.plot_heatmaps(\n",
    "        [classwise_heatmap], x_labels=FRACTIONS_OR_NUMBERS, \n",
    "        y_labels=[class_names[i] for i in range(len(class_names))], \n",
    "        plot_title=finetuning_dataset, subplot_titles=None, subfolder=source_dataset, \n",
    "        filename=f'{model}-{finetuning_dataset}-{append}-{FINETUNE_MODE}', file_format='png', vmin=0, vmax=1, \n",
    "        show_fig=True, cols=None, x_title='Fraction/number of neurons', y_title='Classwise Accuracies (mean)', \n",
    "        annotate=True, types=None, paper_friendly_plots=False, annotation_fontsize=7, \n",
    "        root_dir='.', figsize=(10,10) if len(class_names) < 50 else (15,30), \n",
    "        results_subfolder_name='class_wise_errors'), SERVER_PROJECT_PATH, size=700)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "27acb660",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(model, source_dataset, FRACTIONS_OR_NUMBERS=None):\n",
    "    plt_str = '== Classwise Errors on Final Layer ==\\n\\n'\n",
    "    model_to_heatmap, model_to_overall = {}, {}\n",
    "    for append in CHECKPOINT_PATHS[model].keys():\n",
    "        plt_str += f'=== {model} - {append} - {FINETUNE_MODE} ===\\n\\n'\n",
    "        for ft_ds in FINETUNING_DATASETS:\n",
    "            plt_str += f'==== {ft_ds} ====\\n\\n'\n",
    "            classwise_heatmap, overall_accs, error_plt = get_classwise_errors(model, source_dataset, ft_ds, \n",
    "                                             CHECKPOINT_PATHS[model][append], append, FRACTIONS_OR_NUMBERS)\n",
    "            plt_str += f'{error_plt}\\n\\n'\n",
    "            model_to_heatmap[f'{model}-{append}-{ft_ds}'] = classwise_heatmap\n",
    "            model_to_overall[f'{model}-{append}-{ft_ds}'] = overall_accs\n",
    "    with open(f'./results/class_wise_errors/{source_dataset}/'\n",
    "              f'wiki_results-{model}-{FINETUNE_MODE}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, 'class_wise_errors', source_dataset)], \n",
    "        'results', SERVER_PROJECT_PATH, '.png')\n",
    "    return model_to_heatmap, model_to_overall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa8df371",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'resnet50'\n",
    "SOURCE_DATASET = 'imagenet'\n",
    "FINETUNE_MODE = 'random'\n",
    "resnet50_classwise_accs, resnet50_overall = results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1761eb6d",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'vit_small_patch32_224'\n",
    "SOURCE_DATASET = 'imagenet21k'\n",
    "FINETUNE_MODE = 'random'\n",
    "results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea14d5f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODEL = 'vit_small_patch16_224'\n",
    "SOURCE_DATASET = 'imagenet21k'\n",
    "FINETUNE_MODE = 'random'\n",
    "results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d9bd14f",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "MODEL = 'resnet50_mrl'\n",
    "SOURCE_DATASET = 'imagenet'\n",
    "FINETUNE_MODE = 'first'\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "resnet50mrl_classwise_accs, resnet50mrl_overall = results(MODEL, SOURCE_DATASET, FRACTIONS_OR_NUMBERS=NUMBERS)\n",
    "\n",
    "FINETUNE_MODE = 'random'\n",
    "resnet50mrl_rand_classwise_accs, \\\n",
    "resnet50mrl_rand_overall = results(MODEL, SOURCE_DATASET, FRACTIONS_OR_NUMBERS=NUMBERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "5d1ad396",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'plot_helper' from '../plot_helper.py'>"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import importlib\n",
    "importlib.reload(plt_hp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f979197",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_to_compare = ['resnet50-nonrob', 'resnet50-robustl2eps3', 'resnet50_mrl-nonrob']\n",
    "model_results = [(resnet50_classwise_accs, resnet50_overall),\n",
    "                 (resnet50_classwise_accs, resnet50_overall),\n",
    "                 (resnet50mrl_classwise_accs, resnet50mrl_overall)]\n",
    "#                  (resnet50mrl_rand_classwise_accs, resnet50mrl_rand_overall)]\n",
    "appends = ['', '', '-first'] #, '-rand']\n",
    "model_str = 'resnet50_vs_mrl'\n",
    "\n",
    "inset_vals = [\n",
    "    {'xlim': (0.4,0.8), 'loc': 'center right', 'loc1': 2, 'loc2': 4, 'zoom': 1.},\n",
    "    {'xlim': (0.15,0.35), 'loc': 'lower left', 'loc1': 2, 'loc2': 4, 'zoom': 1.25},\n",
    "    {'xlim': (0.05,0.2), 'loc': 'lower left', 'loc1': 1, 'loc2': 2, 'zoom': 1.5},\n",
    "    {'xlim': (0.7,0.85), 'loc': 'center', 'loc1': 1, 'loc2': 3, 'zoom': 1.5}\n",
    "]\n",
    "\n",
    "for ds_idx, ft_ds in enumerate(FINETUNING_DATASETS):\n",
    "    model_wise_x, model_wise_x_frac = [], []\n",
    "    model_wise_gini, model_wise_maxmin, model_wise_tv = [], [], []\n",
    "    legend_vals = []\n",
    "    for model, append, (model_results_classwise, model_results_overall) in \\\n",
    "        zip(models_to_compare, appends, model_results):\n",
    "        x_vals, x_vals_frac = [], []\n",
    "        ginis, maxmins, tvs = [], [], []\n",
    "        if f'{model}-{ft_ds}' in model_results_overall:\n",
    "            for idx, key in enumerate(sorted(model_results_overall[f'{model}-{ft_ds}'].keys())):\n",
    "                x_vals.append(np.nanmean(model_results_overall[f'{model}-{ft_ds}'][key]))\n",
    "                x_vals_frac.append(key if isinstance(key, float) else key/2048.)\n",
    "                ginis.append(gini(model_results_classwise[f'{model}-{ft_ds}'][:,idx]))\n",
    "                maxmins.append(maxmin(model_results_classwise[f'{model}-{ft_ds}'][:,idx]))\n",
    "                tvs.append(tv(model_results_classwise[f'{model}-{ft_ds}'][:,idx]))\n",
    "            model_wise_x.append(x_vals)\n",
    "            model_wise_x_frac.append(x_vals_frac)\n",
    "            model_wise_gini.append(ginis)\n",
    "            model_wise_tv.append(tvs)\n",
    "            model_wise_maxmin.append(maxmins)\n",
    "            legend_vals.append(f'{model}{append}')\n",
    "    print (ft_ds)\n",
    "    plt_hp.line_plot(\n",
    "        model_wise_gini, \n",
    "        'Accuracy', 'Gini', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'gini_vs_acc_{model_str}-{ft_ds}', \n",
    "        extension='png', x_vals=model_wise_x, legend_vals=legend_vals, vertical_line=None, \n",
    "        linestyles=['']*len(model_wise_gini), \n",
    "        y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=['*', '+', '^', 'o'], results_subfolder_name='class_wise_errors', \n",
    "        grid_spacing=None, y_err=None, legend_ncol=None, inset=inset_vals[ds_idx])\n",
    "    \n",
    "#     plt_hp.line_plot(\n",
    "#         model_wise_tv, \n",
    "#         'Accuracy', 'Coeff. of Variation', '', \n",
    "#         subfolder=SOURCE_DATASET, filename=f'tv_vs_acc_{model_str}-{ft_ds}', \n",
    "#         extension='png', x_vals=model_wise_x, legend_vals=legend_vals, vertical_line=None, \n",
    "#         linestyles=['']*len(model_wise_tv), \n",
    "#         y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "#         plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "#         marker=['*', '+', '^', 'o'], results_subfolder_name='class_wise_errors', \n",
    "#         grid_spacing=None, y_err=None, legend_ncol=None)\n",
    "    \n",
    "#     plt_hp.line_plot(\n",
    "#         model_wise_tv, \n",
    "#         'Accuracy', 'Coeff. of Variation', '', \n",
    "#         subfolder=SOURCE_DATASET, filename=f'tv_vs_frac_{model_str}-{ft_ds}', \n",
    "#         extension='png', x_vals=model_wise_x_frac, legend_vals=legend_vals, vertical_line=None, \n",
    "#         linestyles=['']*len(model_wise_tv), \n",
    "#         y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "#         plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "#         marker=['*', '+', '^', 'o'], results_subfolder_name='class_wise_errors', \n",
    "#         grid_spacing=None, y_err=None, legend_ncol=None)\n",
    "\n",
    "\n",
    "#     plt_hp.line_plot(\n",
    "#         model_wise_gini, \n",
    "#         'Fraction', 'Gini', '', \n",
    "#         subfolder=SOURCE_DATASET, filename=f'gini_vs_frac_{model_str}-{ft_ds}', \n",
    "#         extension='png', x_vals=model_wise_x_frac, legend_vals=legend_vals, vertical_line=None, \n",
    "#         linestyles=['']*len(model_wise_gini), \n",
    "#         y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "#         plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "#         marker=['*', '+', '^', 'o'], results_subfolder_name='class_wise_errors', \n",
    "#         grid_spacing=None, y_err=None, legend_ncol=None)\n",
    "\n",
    "#     plt_hp.line_plot(\n",
    "#         model_wise_maxmin, \n",
    "#         'Accuracy', 'Maxmin', '', \n",
    "#         subfolder=SOURCE_DATASET, filename=f'maxmin_vs_acc_{model_str}-{ft_ds}', \n",
    "#         extension='png', x_vals=model_wise_x, legend_vals=legend_vals, vertical_line=None, \n",
    "#         linestyles=['']*len(model_wise_maxmin), \n",
    "#         y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "#         plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "#         marker=['*', '+', '^', 'o'], results_subfolder_name='class_wise_errors', \n",
    "#         grid_spacing=None, y_err=None, legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00e70984",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
