{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "6df5c715",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n",
    "\n",
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import pathlib\n",
    "from functools import partial\n",
    "import sys, glob\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../deep-learning-base/datasets')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "from training import LitProgressBar, NicerModelCheckpointing\n",
    "import training.finetuning as ft\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper, LinearEvalWrapper\n",
    "from attack.callbacks import AdvCallback\n",
    "from data_modules import DATA_MODULES\n",
    "import dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "11882de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2\n",
    "NUM_NODES = 1\n",
    "DEVICES = 1\n",
    "BASE_DIR = f\"{pathlib.Path('.').parent.resolve()}/checkpoints\"\n",
    "\n",
    "FINETUNE_BS = 256\n",
    "EVAL_BATCH_SIZE = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "47a7f7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_predictions(model_path, out):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    torch.save({'pred': out[0].detach().cpu(), 'gt': out[2].detach().cpu()}, preds_path)\n",
    "\n",
    "def load_predictions(model_path):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    if os.path.exists(preds_path):\n",
    "        return torch.load(preds_path)\n",
    "\n",
    "def accuracy(gt, pred):\n",
    "    pred = torch.argmax(pred, 1)\n",
    "    return torch.sum(gt == pred) / len(gt)\n",
    "\n",
    "def get_test_acc(model, source_dataset, finetuning_dataset, checkpoint_path, \n",
    "                 model_path, seed, fraction, finetune_mode):\n",
    "    state_dict = torch.load(model_path)\n",
    "    dm = DATA_MODULES[finetuning_dataset](\n",
    "        data_dir=DATA_PATH_IMAGENET if 'imagenet' in finetuning_dataset else DATA_PATH,\n",
    "        transform_train=dsmd.TRAIN_TRANSFORMS_TRANSFER_DEFAULT(224),\n",
    "        transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "        batch_size=EVAL_BATCH_SIZE)\n",
    "    dm.init_remaining_attrs(source_dataset)\n",
    "\n",
    "    ## assign mean and std from source dataset\n",
    "    m1 = arch.create_model(model, source_dataset, pretrained=True,\n",
    "                           checkpoint_path=checkpoint_path, seed=SEED, \n",
    "                           num_classes=dsmd.DATASET_PARAMS[source_dataset]['num_classes'],\n",
    "                           callback=partial(LightningWrapper, \n",
    "                                            dataset_name=source_dataset,\n",
    "                                            inference_kwargs={'with_latent': True}),\n",
    "                           loading_function_kwargs={'strict': False}) # some final layers are strange\n",
    "    new_layer = ft.setup_model_for_finetuning(\n",
    "        m1.model, \n",
    "        dsmd.DATASET_PARAMS[finetuning_dataset]['num_classes'],\n",
    "        finetune_mode, fraction, seed)\n",
    "    new_layer.linear.load_state_dict({k.split('.')[-1]:v \\\n",
    "                                  for k,v in state_dict['state_dict'].items()}, strict=True)\n",
    "    if hasattr(new_layer, 'neuron_indices') and 'neuron_indices' in state_dict:\n",
    "        assert torch.all(new_layer.neuron_indices == state_dict['neuron_indices'])\n",
    "    pl_utils.seed.seed_everything(seed, workers=True)\n",
    "\n",
    "    trainer = Trainer(accelerator='gpu', \n",
    "                      devices=DEVICES,\n",
    "                      num_nodes=NUM_NODES,\n",
    "                      log_every_n_steps=1,\n",
    "                      auto_select_gpus=True, \n",
    "                      deterministic=True,\n",
    "                      check_val_every_n_epoch=1,\n",
    "                      num_sanity_val_steps=0,\n",
    "                      callbacks=[\n",
    "                        LitProgressBar(['loss', \n",
    "                                        'running_test_acc'])])\n",
    "\n",
    "    out = trainer.predict(m1, dataloaders=[dm.test_dataloader()])\n",
    "    save_predictions(model_path, out)\n",
    "    gt, pred = out[2], out[0]\n",
    "    return accuracy(gt, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6c14ddd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(models_to_compare, source_datasets, finetune_modes, ft_ds, FRACTIONS_OR_NUMBERS, analysis_type):\n",
    "    plt_str = '== Finetuning results ==\\n\\n'\n",
    "    model_to_fracwise_accs = {}\n",
    "    model_keys = []\n",
    "    for idx, model in enumerate(models_to_compare):\n",
    "        for source_dataset in source_datasets:\n",
    "            for append in CHECKPOINT_PATHS[model][source_dataset].keys():\n",
    "                for finetune_mode in finetune_modes:\n",
    "                    print (f'{model}-{append}-{source_dataset}-{finetune_mode}')\n",
    "                    frac_wise_test_accs = {}\n",
    "                    if FRACTIONS_OR_NUMBERS[idx] is None:\n",
    "                        PARTIAL_FRACTIONS = sorted(\n",
    "                            list(set(\n",
    "                                [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                                        glob.glob(f'./checkpoints/{model}-base-'\n",
    "                                                  f'{source_dataset}-ft-{ft_ds}/'\n",
    "                                                  f'*-bs-256-{append}/*-topk=1.ckpt') \\\n",
    "                                 if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                                    'layer' not in x.split('/')[-1] and \\\n",
    "                                    'pool' not in x.split('/')[-1] and \\\n",
    "                                    'full-feature' not in x.split('/')[-1]]\n",
    "                            )))\n",
    "                        FRACTIONS_OR_NUMBERS[idx] = PARTIAL_FRACTIONS\n",
    "                    else:\n",
    "                        PARTIAL_FRACTIONS = FRACTIONS_OR_NUMBERS[idx]\n",
    "\n",
    "                    for frac in PARTIAL_FRACTIONS:\n",
    "                        if isinstance(frac, int):\n",
    "                            actual_fraction = frac / PARTIAL_FRACTIONS[-1]\n",
    "                        else:\n",
    "                            actual_fraction = frac\n",
    "                        model_paths = [x for x in glob.glob(\n",
    "                            f'{BASE_DIR}/{model}-base-{source_dataset}-ft-{ft_ds}/'\n",
    "                            f'frac-{actual_fraction:.5f}-mode-{finetune_mode}-seed-*-'\n",
    "                            f'ftmode-linear-lr-*-bs*-{append}/*-topk=1.ckpt') \\\n",
    "                                       if 'layer' not in x.split('/')[-1] and \\\n",
    "                                          'pool' not in x.split('/')[-1] and \\\n",
    "                                          'full-feature' not in x.split('/')[-1]]\n",
    "                        print (f'{frac}: {len(model_paths)}')\n",
    "                        for path in model_paths:\n",
    "                            pickled_preds = load_predictions(path)\n",
    "                            if pickled_preds is not None:\n",
    "                                acc = accuracy(pickled_preds['gt'], pickled_preds['pred']).item()\n",
    "                            else:\n",
    "                                acc = get_test_acc(model, source_dataset, ft_ds, \n",
    "                                                   CHECKPOINT_PATHS[model][source_dataset][append], path, \n",
    "                                                   int(path.split('-seed-')[1].split('-')[0]), \n",
    "                                                   actual_fraction, finetune_mode).item()\n",
    "                            if actual_fraction in frac_wise_test_accs:\n",
    "                                frac_wise_test_accs[actual_fraction].append(acc)\n",
    "                            else:\n",
    "                                frac_wise_test_accs[actual_fraction] = [acc]\n",
    "                    if len(frac_wise_test_accs) > 0:\n",
    "                        model_keys.append(f'{model}-{append}-{source_dataset}-{finetune_mode}')\n",
    "                        model_to_fracwise_accs[f'{model}-{append}-{source_dataset}-{finetune_mode}'] = \\\n",
    "                            frac_wise_test_accs\n",
    "\n",
    "    x_vals_fracs = [sorted(list(model_to_fracwise_accs[mkey].keys())) for mkey in model_keys]\n",
    "    plt_str += '{}\\n\\n'.format(\n",
    "        plt_hp.get_wiki_link(plt_hp.line_plot(\n",
    "            [[np.nanmean(model_to_fracwise_accs[mkey][f]) \n",
    "              for f in sorted(model_to_fracwise_accs[mkey].keys())] \\\n",
    "                 for mkey in model_keys], \n",
    "            'Fraction of neurons', 'Transfer Accuracy', ft_ds, \n",
    "            subfolder=analysis_type, \n",
    "            filename=f'{\"-\".join(models_to_compare)}_{ft_ds}_{append}', \n",
    "            extension='png', x_vals=x_vals_fracs, \n",
    "            legend_vals=[f\"{m.split('-')[0]}-{m.split('-')[1]}-{m.split('-')[3]}\" for m in model_keys], \n",
    "            vertical_line=None, colors=plt_hp.COLORS, \n",
    "            linestyles=['-'] * len(model_keys), \n",
    "            y_lims=(0.,1.), root_dir='.', paper_friendly_plots=True, \n",
    "            plot_inside=False, legend_location='best', savefig=True, \n",
    "            figsize=(10,6), marker=[True] * len(model_keys), \n",
    "            results_subfolder_name='archs_loss_datasets', grid_spacing=None, \n",
    "            y_err=[[np.nanstd(model_to_fracwise_accs[mkey][f]) \\\n",
    "                    for f in sorted(model_to_fracwise_accs[mkey].keys())] for mkey in model_keys], \n",
    "            legend_ncol=4), \n",
    "                             SERVER_PROJECT_PATH, size=700)\n",
    "    )\n",
    "\n",
    "    with open(f'./results/archs_loss_datasets/{analysis_type}/'\n",
    "              f'wiki_results-{\"-\".join(source_datasets)}-{ft_ds}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, 'archs_loss_datasets', analysis_type)], \n",
    "        'results', SERVER_PROJECT_PATH, '.png')\n",
    "    \n",
    "    return model_to_fracwise_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "01168f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_dr(fractions, corresponding_accs, delta):\n",
    "    '''\n",
    "    fractions: a list of fractions (will be sorted)\n",
    "    corresponding_accs: a list of lists where each list contains accuracy along multiple random seeds\n",
    "    delta: what fraction of full layer accuracy do you want from the fraction\n",
    "    '''\n",
    "    fractions, corresponding_accs = list(zip(*sorted(zip(fractions, corresponding_accs), \n",
    "                                                     key=lambda x: x[0])))\n",
    "    full_acc = np.mean(corresponding_accs[-1])\n",
    "    ratios = np.array([])\n",
    "    for fr, accs in zip(fractions, corresponding_accs):\n",
    "        ratios = np.append(ratios, np.mean(accs) / full_acc)\n",
    "    return 1 - fractions[np.argmin(np.abs(ratios - delta))]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04f8b0cd",
   "metadata": {},
   "source": [
    "### Compare different architectures with both nonrob and rob losses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dded8bf",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "SOURCE_DATASETS = ['imagenet']\n",
    "MODELS = ['vgg16_bn', 'resnet18', 'resnet50', 'wide_resnet50_2', \n",
    "#           'wide_resnet50_4', \n",
    "          'vit_small_patch16_224', 'vit_small_patch32_224']\n",
    "FINETUNE_MODES = ['random']\n",
    "FRACTIONS_OR_NUMBERS = [None] * len(MODELS)\n",
    "ANALYSIS_TYPE = 'architectures'\n",
    "\n",
    "datasetwise_model_to_accs = {}\n",
    "for FT_DATASET in ['cifar10','cifar100','flowers','oxford-iiit-pets']:\n",
    "    datasetwise_model_to_accs[FT_DATASET] = \\\n",
    "        results(MODELS, SOURCE_DATASETS, FINETUNE_MODES, \n",
    "                FT_DATASET, FRACTIONS_OR_NUMBERS, ANALYSIS_TYPE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f242e61b",
   "metadata": {},
   "source": [
    "### Table 1 measures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fbabbb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "for ft_ds in ['cifar10','cifar100','flowers','oxford-iiit-pets']:\n",
    "    for model in MODELS:\n",
    "        dr = estimate_dr(\n",
    "            *list(zip(*datasetwise_model_to_accs[ft_ds][f'{model}-nonrob-imagenet-random'].items())), 0.9)\n",
    "        print (f'{ft_ds}, {model}, {dr}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d149e46",
   "metadata": {},
   "source": [
    "### Compare different upstream training datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76fd90ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASETS = ['imagenet', 'imagenet21k']\n",
    "MODELS = ['vit_small_patch16_224', 'vit_small_patch32_224']\n",
    "FINETUNE_MODES = ['random']\n",
    "FRACTIONS_OR_NUMBERS = [None] * len(MODELS)\n",
    "ANALYSIS_TYPE = 'upstream_datasets'\n",
    "\n",
    "for FT_DATASET in ['cifar10','cifar100','flowers','oxford-iiit-pets']:\n",
    "    results(MODELS, SOURCE_DATASETS, FINETUNE_MODES, FT_DATASET, FRACTIONS_OR_NUMBERS, ANALYSIS_TYPE)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75918528",
   "metadata": {},
   "source": [
    "### Compare ResNet50-MRL to ResNet50-nonrob to ResNet50-rob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad2d962e",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASETS = ['imagenet']\n",
    "MODELS = ['resnet50', 'resnet50_mrl']\n",
    "FINETUNE_MODES = ['random', 'first']\n",
    "FRACTIONS_OR_NUMBERS = [None, [8,16,32,64,128,256,512,1024,2048]]\n",
    "ANALYSIS_TYPE = 'efficient_representations'\n",
    "\n",
    "for FT_DATASET in ['cifar10','cifar100','flowers','oxford-iiit-pets']:\n",
    "    results(MODELS, SOURCE_DATASETS, FINETUNE_MODES, FT_DATASET, FRACTIONS_OR_NUMBERS, ANALYSIS_TYPE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c062b60",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
