{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6df5c715",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import pathlib\n",
    "from functools import partial\n",
    "import sys, glob, os, copy\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "from training import LitProgressBar, NicerModelCheckpointing\n",
    "import training.finetuning as ft\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper, LinearEvalWrapper\n",
    "from attack.callbacks import AdvCallback\n",
    "from datasets.data_modules import DATA_MODULES\n",
    "import datasets.dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "11882de5",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 2\n",
    "NUM_NODES = 1\n",
    "DEVICES = 1\n",
    "BASE_DIR = f\"{pathlib.Path('.').parent.resolve()}/checkpoints\"\n",
    "\n",
    "FINETUNING_DATASETS = ['cifar10', 'cifar100', 'flowers', 'oxford-iiit-pets']\n",
    "FINETUNE_BS = 256\n",
    "EVAL_BATCH_SIZE = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "47a7f7a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def save_predictions(model_path, out):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    torch.save({'pred': out[0].detach().cpu(), 'gt': out[2].detach().cpu()}, preds_path)\n",
    "\n",
    "def load_predictions(model_path):\n",
    "    preds_path = f'{model_path.split(\".ckpt\")[0]}.pred'\n",
    "    if os.path.exists(preds_path):\n",
    "        return torch.load(preds_path)\n",
    "\n",
    "def accuracy(gt, pred):\n",
    "    pred = torch.argmax(pred, 1)\n",
    "    return torch.sum(gt == pred) / len(gt)\n",
    "\n",
    "\n",
    "def get_test_acc(model, source_dataset, finetuning_dataset, checkpoint_path, model_path, seed, fraction):\n",
    "    state_dict = torch.load(model_path)\n",
    "    dm = DATA_MODULES[finetuning_dataset](\n",
    "        data_dir=DATA_PATH_IMAGENET if 'imagenet' in finetuning_dataset else DATA_PATH,\n",
    "        transform_train=dsmd.TRAIN_TRANSFORMS_TRANSFER_DEFAULT(224),\n",
    "        transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "        batch_size=EVAL_BATCH_SIZE)\n",
    "    dm.init_remaining_attrs(source_dataset)\n",
    "\n",
    "    ## assign mean and std from source dataset\n",
    "    m1 = arch.create_model(model, source_dataset, pretrained=True,\n",
    "                           checkpoint_path=checkpoint_path, seed=SEED, \n",
    "                           num_classes=dsmd.DATASET_PARAMS[source_dataset]['num_classes'],\n",
    "                           callback=partial(LightningWrapper, \n",
    "                                            dataset_name=source_dataset,\n",
    "                                            inference_kwargs={'with_latent': True}),\n",
    "                           loading_function_kwargs={'strict': False})\n",
    "    new_layer = ft.setup_model_for_finetuning(\n",
    "        m1.model, \n",
    "        dsmd.DATASET_PARAMS[finetuning_dataset]['num_classes'],\n",
    "        FINETUNE_MODE, fraction, seed)\n",
    "    print (new_layer.__dict__)\n",
    "    linear_layer = list(m1.model.named_modules())[-1][1]\n",
    "    linear_layer.load_state_dict({k.split('.')[-1]:v \\\n",
    "                                  for k,v in state_dict['state_dict'].items()}, strict=True)\n",
    "    if hasattr(new_layer, 'neuron_indices') and 'neuron_indices' in state_dict:\n",
    "        assert torch.all(new_layer.neuron_indices == state_dict['neuron_indices'])\n",
    "    pl_utils.seed.seed_everything(seed, workers=True)\n",
    "\n",
    "    trainer = Trainer(accelerator='gpu', \n",
    "                      devices=DEVICES,\n",
    "                      num_nodes=NUM_NODES,\n",
    "                      log_every_n_steps=1,\n",
    "                      auto_select_gpus=True, \n",
    "                      deterministic=True,\n",
    "                      check_val_every_n_epoch=1,\n",
    "                      num_sanity_val_steps=0,\n",
    "                      callbacks=[\n",
    "                        LitProgressBar(['loss', \n",
    "                                        'running_test_acc'])])\n",
    "\n",
    "    out = trainer.predict(m1, dataloaders=[dm.test_dataloader()])\n",
    "    save_predictions(model_path, out)\n",
    "    gt, pred = out[2], out[0]\n",
    "    return accuracy(gt, pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6c14ddd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(model, source_dataset, FRACTIONS_OR_NUMBERS=None):\n",
    "    plt_str = '== Finetuning results ==\\n\\n'\n",
    "    model_to_frac_wise_test = {}\n",
    "    for append in CHECKPOINT_PATHS[model].keys():\n",
    "        plt_str += f'=== {model} - {append} ===\\n\\n'\n",
    "        for ft_ds in FINETUNING_DATASETS:\n",
    "            plt_str += f'==== {ft_ds} ====\\n\\n'\n",
    "            frac_wise_val_accs, frac_wise_test_accs = {}, {}\n",
    "            if FRACTIONS_OR_NUMBERS is None:\n",
    "                FRACTIONS_OR_NUMBERS = sorted(\n",
    "                    list(set(\n",
    "                        [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                            glob.glob(f'./checkpoints/{model}-base-'\n",
    "                                  f'{source_dataset}-ft-{ft_ds}/'\n",
    "                                  f'*-bs-256-{append}') \\\n",
    "                         if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                            'full-feature' not in x and \\\n",
    "                            'layer' not in x and \\\n",
    "                            'pool' not in x]\n",
    "                        )))\n",
    "            for frac in FRACTIONS_OR_NUMBERS:\n",
    "                if isinstance(frac, int):\n",
    "                    actual_fraction = frac / FRACTIONS_OR_NUMBERS[-1]\n",
    "                else:\n",
    "                    actual_fraction = frac\n",
    "                PARTIAL_CHOICE_SEEDS = sorted(\n",
    "                    list(set(\n",
    "                        [int(x.split('-seed-')[1].split('-')[0]) for x in \\\n",
    "                            glob.glob(f'./checkpoints/{model}-base-'\n",
    "                                      f'{source_dataset}-ft-{ft_ds}/'\n",
    "                                      f'frac-{actual_fraction:.5f}-*-bs-256-{append}') \\\n",
    "                        if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                           'full-feature' not in x and \\\n",
    "                           'layer' not in x and \\\n",
    "                           'pool' not in x]\n",
    "                        )))\n",
    "                for seed in PARTIAL_CHOICE_SEEDS:\n",
    "                    model_path = [x for x in glob.glob(f'{BASE_DIR}/{model}-base-{source_dataset}-ft-{ft_ds}/'\n",
    "                                           f'frac-{actual_fraction:.5f}-mode-{FINETUNE_MODE}-seed-{seed}-'\n",
    "                                           f'ftmode-linear-lr-*-bs*-{append}/'\n",
    "                                           '*-topk=1.ckpt') \\\n",
    "                                  if len(x.split('/frac-')[1].split('-')[0]) == 7 and \\\n",
    "                                     'full-feature' not in x and \\\n",
    "                                     'layer' not in x and \\\n",
    "                                     'pool' not in x]\n",
    "                    if len(model_path) == 0:\n",
    "                        continue\n",
    "                    model_path = model_path[0]\n",
    "                    \n",
    "                    pickled_preds = load_predictions(model_path)\n",
    "                    if pickled_preds is not None:\n",
    "                        test_acc = accuracy(pickled_preds['gt'], pickled_preds['pred']).item()\n",
    "                    else:\n",
    "                        test_acc = get_test_acc(model, source_dataset, ft_ds, \n",
    "                                                CHECKPOINT_PATHS[model][append], \n",
    "                                                model_path, seed, actual_fraction)\n",
    "                    sd = torch.load(model_path)\n",
    "                    val_acc = list(sd['callbacks'].values())[0]['best_model_score'].item()\n",
    "                    if frac in frac_wise_val_accs:\n",
    "                        frac_wise_val_accs[frac].append(val_acc)\n",
    "                        frac_wise_test_accs[frac].append(test_acc)\n",
    "                    else:\n",
    "                        frac_wise_val_accs[frac] = [val_acc]\n",
    "                        frac_wise_test_accs[frac] = [test_acc]\n",
    "            \n",
    "            model_to_frac_wise_test[f'{model}-{append}-{ft_ds}'] = copy.deepcopy(frac_wise_test_accs)\n",
    "            \n",
    "            full_acc_val = frac_wise_val_accs.pop(FRACTIONS_OR_NUMBERS[-1])\n",
    "            full_acc_test = frac_wise_test_accs.pop(FRACTIONS_OR_NUMBERS[-1])\n",
    "            x_vals, y_vals = list(zip(*sorted(frac_wise_val_accs.items(), key=lambda t: t[0])))\n",
    "            _, y_tests = list(zip(*sorted(frac_wise_test_accs.items(), key=lambda t: t[0])))\n",
    "            \n",
    "            print (model, append, ft_ds)\n",
    "            print (f'Full Acc: {full_acc_test}')\n",
    "            for x, y in zip(x_vals, y_tests):\n",
    "                print (f'For {x}, acc: {np.nanmean(y)} +/- {np.nanstd(y)}')\n",
    "            print ()\n",
    "            \n",
    "            plt_str += '{}\\n\\n'.format(plt_hp.get_wiki_link(plt_hp.line_plot(\n",
    "                [[np.mean(_y) for _y in y_vals], [np.mean(_y) for _y in y_tests]], \n",
    "                'Fraction/number of neurons', 'Transfer Accuracy', ft_ds, \n",
    "                subfolder=source_dataset, filename=f'{model}_{ft_ds}_bs_{FINETUNE_BS}_{append}_{FINETUNE_MODE}', \n",
    "                extension='png', x_vals=x_vals, \n",
    "                legend_vals=['Val', 'Test', 'Full Layer (Val)', 'Full Layer (Test)'], \n",
    "                vertical_line=None, horizontal_lines=[np.mean(full_acc_val), np.mean(full_acc_test)], \n",
    "                horizontal_lines_err=[np.std(full_acc_val), np.std(full_acc_test)], \n",
    "                colors=[plt_hp.COLORS[0], plt_hp.COLORS[1], plt_hp.COLORS[0], plt_hp.COLORS[1]], \n",
    "                linestyles=['-', '-', ':', ':'], y_lims=(0.,1.), root_dir='.', \n",
    "                paper_friendly_plots=False, plot_inside=False, legend_location='best', \n",
    "                savefig=True, figsize=(10,6), marker=[True, True, False, False], \n",
    "                results_subfolder_name='transfer_analysis', grid_spacing=None, \n",
    "                y_err=[[np.std(_y) for _y in y_vals], [np.std(_y) for _y in y_tests]], \n",
    "                legend_ncol=None), SERVER_PROJECT_PATH, size=700))\n",
    "            print (ft_ds, frac_wise_val_accs)\n",
    "    with open(f'./results/transfer_analysis/{source_dataset}/'\n",
    "              f'wiki_results-{model}-{FINETUNE_MODE}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, 'transfer_analysis', source_dataset)], \n",
    "        'results', SERVER_PROJECT_PATH, '.png')\n",
    "    return model_to_frac_wise_test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7c085e80",
   "metadata": {},
   "outputs": [],
   "source": [
    "def estimate_dr(fractions, corresponding_accs, delta):\n",
    "    '''\n",
    "    fractions: a list of fractions (will be sorted)\n",
    "    corresponding_accs: a list of lists where each list contains accuracy along multiple random seeds\n",
    "    delta: what fraction of full layer accuracy do you want from the fraction\n",
    "    '''\n",
    "    fractions, corresponding_accs = list(zip(*sorted(zip(fractions, corresponding_accs), \n",
    "                                                     key=lambda x: x[0])))\n",
    "    full_acc = np.mean(corresponding_accs[-1])\n",
    "    ratios = np.array([])\n",
    "    for fr, accs in zip(fractions, corresponding_accs):\n",
    "        ratios = np.append(ratios, np.mean(accs) / full_acc)\n",
    "    return 1 - fractions[np.argmin(np.abs(ratios - delta))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7dded8bf",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50'\n",
    "FINETUNE_MODE = 'random'\n",
    "resnet50_to_fracwise_accs = results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7be992e7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d97a495",
   "metadata": {},
   "outputs": [],
   "source": [
    "deltas = np.concatenate((np.linspace(0.,0.8,20), np.linspace(0.8,1.,10)))\n",
    "for append in CHECKPOINT_PATHS['resnet50'].keys():\n",
    "    all_x_vals, all_y_vals, all_full_accs = [], [], []\n",
    "    all_dr_vals = []\n",
    "    for ft_ds in FINETUNING_DATASETS:\n",
    "        x, y = list(zip(*sorted(\n",
    "            resnet50_to_fracwise_accs[f'resnet50-{append}-{ft_ds}'].items(), key=lambda i:i[0])))\n",
    "        all_full_accs.append(y[-1])\n",
    "        all_x_vals.append(x[:-1])\n",
    "        all_y_vals.append(y[:-1])\n",
    "        all_dr_vals.append([estimate_dr(x, y, delta) for delta in deltas])\n",
    "        \n",
    "    all_y_means = [[np.mean(_y) for _y in y_vals] for y_vals in all_y_vals]\n",
    "    print ([len(all_y_means[i]) == len(all_x_vals[i]) for i in range(len(all_x_vals))])\n",
    "    plt_hp.line_plot(\n",
    "        [[np.mean(_y) for _y in y_vals] for y_vals in all_y_vals], \n",
    "        'Fraction of Neurons (Total=2048)', 'Transfer Accuracy', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'{MODEL}_allds_bs_{FINETUNE_BS}_{append}_{FINETUNE_MODE}', \n",
    "        extension='png', x_vals=all_x_vals, legend_vals=FINETUNING_DATASETS, vertical_line=None, \n",
    "        horizontal_lines=[np.mean(full_acc) for full_acc in all_full_accs], \n",
    "        horizontal_lines_err=[np.std(full_acc) for full_acc in all_full_accs], \n",
    "        colors=plt_hp.COLORS[:len(FINETUNING_DATASETS)] * 2, \n",
    "        linestyles=['-']*len(FINETUNING_DATASETS) + [':']*len(FINETUNING_DATASETS), \n",
    "        y_lims=(0.,1.), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=[True] * len(FINETUNING_DATASETS) + [False] * len(FINETUNING_DATASETS), \n",
    "        results_subfolder_name='transfer_analysis', grid_spacing=None, \n",
    "        y_err=[[np.std(_y) for _y in y_vals] for y_vals in all_y_vals], legend_ncol=None)\n",
    "    plt_hp.line_plot(\n",
    "        all_dr_vals, \n",
    "        r'$\\delta$', 'Diffused Redundancy (DR)', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'{MODEL}_allds_{append}_{FINETUNE_MODE}_diffused_redundancy', \n",
    "        extension='png', x_vals=[deltas] * len(all_dr_vals), \n",
    "        legend_vals=FINETUNING_DATASETS, vertical_line=None, \n",
    "        colors=plt_hp.COLORS[:len(FINETUNING_DATASETS)], \n",
    "        linestyles=['-']*len(FINETUNING_DATASETS), \n",
    "        y_lims=(0.,1.), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=[True] * len(FINETUNING_DATASETS) + [False] * len(FINETUNING_DATASETS), \n",
    "        results_subfolder_name='transfer_analysis', grid_spacing=None, \n",
    "        y_err=None, legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ddc21f30",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdaa1bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch16_224'\n",
    "FINETUNE_MODE = 'random'\n",
    "results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "668c0cd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch32_224'\n",
    "FINETUNE_MODE = 'random'\n",
    "results(MODEL, SOURCE_DATASET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4900fa93",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50_mrl'\n",
    "FINETUNE_MODE = 'first'\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "results(MODEL, SOURCE_DATASET, NUMBERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb965f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50_mrl'\n",
    "FINETUNE_MODE = 'random'\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "results(MODEL, SOURCE_DATASET, NUMBERS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8c1f8c6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
