{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6df5c715",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '3'\n",
    "\n",
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import pathlib\n",
    "from functools import partial\n",
    "import sys, glob\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "from training import LitProgressBar, NicerModelCheckpointing\n",
    "import training.finetuning as ft\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper, LinearEvalWrapper\n",
    "from attack.callbacks import AdvCallback\n",
    "from datasets.data_modules import DATA_MODULES\n",
    "import datasets.dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "11882de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2\n",
    "NUM_NODES = 1\n",
    "DEVICES = 1\n",
    "BASE_DIR = f\"{pathlib.Path('.').parent.resolve()}/checkpoints\"\n",
    "\n",
    "FINETUNE_MODE = 'random'\n",
    "FINETUNE_BS = 256\n",
    "EVAL_BATCH_SIZE = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "47a7f7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_predictions(model_path, out):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    torch.save({'pred': out[0].detach().cpu(), 'gt': out[2].detach().cpu()}, preds_path)\n",
    "\n",
    "def load_predictions(model_path):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    if os.path.exists(preds_path):\n",
    "        return torch.load(preds_path)\n",
    "\n",
    "def accuracy(gt, pred):\n",
    "    pred = torch.argmax(pred, 1)\n",
    "    return torch.sum(gt == pred) / len(gt)\n",
    "\n",
    "\n",
    "def get_test_acc(model, source_dataset, finetuning_dataset, checkpoint_path, model_path, seed, fraction):\n",
    "    state_dict = torch.load(model_path)\n",
    "    dm = DATA_MODULES[finetuning_dataset](\n",
    "        data_dir=DATA_PATH_IMAGENET if 'imagenet' in finetuning_dataset else DATA_PATH,\n",
    "        transform_train=dsmd.TRAIN_TRANSFORMS_TRANSFER_DEFAULT(224),\n",
    "        transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "        batch_size=EVAL_BATCH_SIZE)\n",
    "    dm.init_remaining_attrs(source_dataset)\n",
    "\n",
    "    ## assign mean and std from source dataset\n",
    "    m1 = arch.create_model(model, source_dataset, pretrained=True,\n",
    "                           checkpoint_path=checkpoint_path, seed=SEED, \n",
    "                           num_classes=dsmd.DATASET_PARAMS[source_dataset]['num_classes'],\n",
    "                           callback=partial(LightningWrapper, \n",
    "                                            dataset_name=source_dataset,\n",
    "                                            inference_kwargs={'with_latent': True}),\n",
    "                           loading_function_kwargs={'strict': False}) # some final layers are strange\n",
    "    new_layer = ft.setup_model_for_finetuning(\n",
    "        m1.model, \n",
    "        dsmd.DATASET_PARAMS[finetuning_dataset]['num_classes'],\n",
    "        FINETUNE_MODE, fraction, seed)\n",
    "    print (new_layer.__dict__)\n",
    "    linear_layer = list(m1.model.named_modules())[-1][1]\n",
    "    linear_layer.load_state_dict({k.split('.')[-1]:v \\\n",
    "                                  for k,v in state_dict['state_dict'].items()}, strict=True)\n",
    "    if hasattr(new_layer, 'neuron_indices') and 'neuron_indices' in state_dict:\n",
    "        assert torch.all(new_layer.neuron_indices == state_dict['neuron_indices'])\n",
    "    pl_utils.seed.seed_everything(seed, workers=True)\n",
    "\n",
    "    trainer = Trainer(accelerator='gpu', \n",
    "                      devices=DEVICES,\n",
    "                      num_nodes=NUM_NODES,\n",
    "                      log_every_n_steps=1,\n",
    "                      auto_select_gpus=True, \n",
    "                      deterministic=True,\n",
    "                      check_val_every_n_epoch=1,\n",
    "                      num_sanity_val_steps=0,\n",
    "                      callbacks=[\n",
    "                        LitProgressBar(['loss', \n",
    "                                        'running_test_acc'])])\n",
    "\n",
    "    out = trainer.predict(m1, dataloaders=[dm.test_dataloader()])\n",
    "    save_predictions(model_path, out)\n",
    "    return accuracy(out[2], out[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6c14ddd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(models_to_compare, source_dataset, ft_ds, append, FRACTIONS_OR_NUMBERS):\n",
    "    plt_str = '== Finetuning results ==\\n\\n'\n",
    "    model_to_fracwise_accs = {}\n",
    "    for idx, model in enumerate(models_to_compare):\n",
    "        plt_str += f'=== {model} - {append} - {ft_ds} ===\\n\\n'\n",
    "        frac_wise_test_accs = {}\n",
    "        if FRACTIONS_OR_NUMBERS[idx] is None:\n",
    "            PARTIAL_FRACTIONS = sorted(\n",
    "                list(set(\n",
    "                    [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                            glob.glob(f'./checkpoints/{model}-base-'\n",
    "                                      f'{source_dataset}-ft-{ft_ds}/'\n",
    "                                      f'*-bs-256-{append}/*-topk=1.ckpt') \\\n",
    "                     if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                        'layer' not in x.split('/')[-1] and \\\n",
    "                        'pool' not in x.split('/')[-1] and \\\n",
    "                        'full-feature' not in x.split('/')[-1]]\n",
    "                )))\n",
    "        else:\n",
    "            PARTIAL_FRACTIONS = FRACTIONS_OR_NUMBERS[idx]\n",
    "\n",
    "        for frac in PARTIAL_FRACTIONS:\n",
    "            if isinstance(frac, int):\n",
    "                actual_fraction = frac / PARTIAL_FRACTIONS[-1]\n",
    "            else:\n",
    "                actual_fraction = frac\n",
    "            model_paths = [x for x in glob.glob(f'{BASE_DIR}/{model}-base-{source_dataset}-ft-{ft_ds}/'\n",
    "                                           f'frac-{actual_fraction:.5f}-mode-{FINETUNE_MODE}-seed-*-'\n",
    "                                           f'ftmode-linear-lr-*-bs*-{append}/'\n",
    "                                           '*-topk=1.ckpt') \\\n",
    "                           if 'layer' not in x.split('/')[-1] and \\\n",
    "                              'pool' not in x.split('/')[-1] and \\\n",
    "                              'full-feature' not in x.split('/')[-1]]\n",
    "            for path in model_paths:\n",
    "                pickled_preds = load_predictions(path)\n",
    "                if pickled_preds is None:\n",
    "                    acc = get_test_acc(model, source_dataset, ft_ds, \n",
    "                                       CHECKPOINT_PATHS[model][append], path, \n",
    "                                       int(path.split('-seed-')[1].split('-')[0]), actual_fraction).item()\n",
    "                else:\n",
    "                    acc = accuracy(pickled_preds['gt'], pickled_preds['pred'])\n",
    "                \n",
    "                if actual_fraction in frac_wise_test_accs:\n",
    "                    frac_wise_test_accs[actual_fraction].append(acc)\n",
    "                else:\n",
    "                    frac_wise_test_accs[actual_fraction] = [acc]\n",
    "        model_to_fracwise_accs[model] = frac_wise_test_accs\n",
    "\n",
    "    x_vals = [sorted(list(model_to_fracwise_accs[model].keys())) for model in models_to_compare]\n",
    "    plt_str += '{}\\n\\n'.format(plt_hp.get_wiki_link(plt_hp.line_plot(\n",
    "        [[np.nanmean(model_to_fracwise_accs[model][f]) for f in sorted(model_to_fracwise_accs[model].keys())] \\\n",
    "             for model in models_to_compare], \n",
    "        'Fraction of neurons', 'Transfer Accuracy', ft_ds, \n",
    "        subfolder=source_dataset, \n",
    "        filename=f'{\"-\".join(models_to_compare)}_{ft_ds}_bs_{FINETUNE_BS}_{append}', \n",
    "        extension='png', x_vals=x_vals, \n",
    "        legend_vals=[m for m in models_to_compare], \n",
    "        vertical_line=None, colors=plt_hp.COLORS, \n",
    "        linestyles=['-'] * len(models_to_compare), \n",
    "        y_lims=(0.,1.), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=False, legend_location='best', savefig=True, \n",
    "        figsize=(10,6), marker=[True] * len(models_to_compare), \n",
    "        results_subfolder_name='width_comparison', grid_spacing=None, \n",
    "        y_err=[[np.nanstd(model_to_fracwise_accs[model][f]) \\\n",
    "                for f in sorted(model_to_fracwise_accs[model].keys())] for model in models_to_compare], \n",
    "        legend_ncol=3), SERVER_PROJECT_PATH, size=700))\n",
    "\n",
    "    with open(f'./results/width_comparison/{source_dataset}/wiki_results-{model}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, 'width_comparison', source_dataset)], \n",
    "        'results', SERVER_PROJECT_PATH, '.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dded8bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODELS = ['resnet50', 'resnet50_ff1024', 'resnet50_ff512', 'resnet50_ff256', 'resnet50_ff128',\n",
    "          'resnet50_ff64', 'resnet50_ff32', 'resnet50_ff16', 'resnet50_ff8']\n",
    "FRACTIONS_OR_NUMBERS = [\n",
    "    None,\n",
    "    [4,8,16,32,48,64,90,128,200,256,300,400,512,600,700,900,1024],\n",
    "    [8,16,32,48,64,90,128,200,256,300,350,400,450,500,512],    \n",
    "    [5,10,20,30,50,100,128,150,200,220,240,256],\n",
    "    [4,8,10,16,20,32,40,60,65,75,90,100,120,128],    \n",
    "    [4,8,10,16,20,28,32,40,45,50,58,64],\n",
    "    [2,4,6,8,10,12,16,20,24,28,32],\n",
    "    [2,4,6,8,10,12,14,16],\n",
    "    [1,2,4,6,8]\n",
    "]\n",
    "APPEND = 'nonrob'\n",
    "\n",
    "FT_DATASET = 'cifar10'\n",
    "results(MODELS, SOURCE_DATASET, FT_DATASET, APPEND, FRACTIONS_OR_NUMBERS)\n",
    "\n",
    "FT_DATASET = 'cifar100'\n",
    "results(MODELS, SOURCE_DATASET, FT_DATASET, APPEND, FRACTIONS_OR_NUMBERS)\n",
    "\n",
    "FT_DATASET = 'flowers'\n",
    "results(MODELS, SOURCE_DATASET, FT_DATASET, APPEND, FRACTIONS_OR_NUMBERS)\n",
    "\n",
    "FT_DATASET = 'oxford-iiit-pets'\n",
    "results(MODELS, SOURCE_DATASET, FT_DATASET, APPEND, FRACTIONS_OR_NUMBERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76fd90ab",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
