{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "dc9f687e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n",
    "\n",
    "\n",
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pathlib\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../deep-learning-base/datasets')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper\n",
    "from data_modules import DATA_MODULES\n",
    "import dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH\n",
    "from functools import partial\n",
    "import stir.model.tools.helpers as helpers\n",
    "import stir\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "73d30c85",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaskedModel(nn.Module):\n",
    "    def __init__(self, model, mask):\n",
    "        super().__init__()\n",
    "        self.mask = mask\n",
    "        self.model = model\n",
    "    \n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        out, latent = self.model(x, *args, **kwargs)\n",
    "        return out, latent[:,self.mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "1c84c9e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_chosen_neurons(m1, mode, seeds, numbers):\n",
    "    frac_to_chosen_neurons = {}\n",
    "    name, param = list(m1.model.named_modules())[-1]\n",
    "    in_fts = param.in_features\n",
    "    if mode == 'random':\n",
    "        for partial_seed in seeds:\n",
    "            for num_neurons in numbers:\n",
    "                linear = nn.Linear(num_neurons, dsmd.DATASET_PARAMS[SOURCE_DATASET]['num_classes'])\n",
    "                torch.manual_seed(partial_seed)\n",
    "                chosen_neurons = torch.randperm(in_fts)[:num_neurons]\n",
    "                if num_neurons in frac_to_chosen_neurons:\n",
    "                    frac_to_chosen_neurons[num_neurons].append(chosen_neurons)\n",
    "                else:\n",
    "                    frac_to_chosen_neurons[num_neurons] = [chosen_neurons]\n",
    "    elif mode == 'first':\n",
    "        for num_neurons in numbers:\n",
    "            chosen_neurons = torch.arange(in_fts)[:num_neurons]\n",
    "            if num_neurons in frac_to_chosen_neurons:\n",
    "                frac_to_chosen_neurons[num_neurons].append(chosen_neurons)\n",
    "            else:\n",
    "                frac_to_chosen_neurons[num_neurons] = [chosen_neurons]\n",
    "    return frac_to_chosen_neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "f35c2422",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, NUMBERS, MODE):\n",
    "    results = []\n",
    "    for eval_ds in EVAL_DATASETS:\n",
    "        print (eval_ds)\n",
    "        append_to_frac_ckas = {}\n",
    "        dm = DATA_MODULES[eval_ds](\n",
    "            data_dir=DATA_PATH_IMAGENET if 'imagenet' in eval_ds else DATA_PATH,\n",
    "            transform_train=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "            transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "            batch_size=BATCH_SIZE)\n",
    "        dm.init_remaining_attrs(eval_ds)\n",
    "        for append in APPENDS:\n",
    "            m1 = arch.create_model(MODEL, SOURCE_DATASET, pretrained=True,\n",
    "                                   checkpoint_path=CHECKPOINT_PATHS[MODEL][append], seed=SEED, \n",
    "                                   callback=partial(LightningWrapper, \n",
    "                                                    dataset_name=SOURCE_DATASET,\n",
    "                                                    inference_kwargs={'with_latent': True}),\n",
    "                                  loading_function_kwargs={'strict': False} \\\n",
    "                                   if 'resnet50_mrl' in MODEL or 'resnet50_ff' in MODEL else {})\n",
    "            frac_to_chosen_neurons = find_chosen_neurons(m1, MODE, PARTIAL_CHOICE_SEEDS, NUMBERS)\n",
    "            frac_to_ckas = {}\n",
    "            for num_neurons in NUMBERS[:-1]: # last one is full layer\n",
    "                full_mask = frac_to_chosen_neurons[NUMBERS[-1]]\n",
    "                for mask1, mask2 in itertools.product(frac_to_chosen_neurons[num_neurons], full_mask):\n",
    "                    stir_score = stir.STIR(MaskedModel(m1, mask1), MaskedModel(m1, mask2), \n",
    "                        helpers.InputNormalize(dsmd.STANDARD_MEAN, dsmd.STANDARD_STD), \n",
    "                        helpers.InputNormalize(dsmd.STANDARD_MEAN, dsmd.STANDARD_STD),\n",
    "                        (dm.test_dataloader(), 1000), verbose=False, layer1_num=None, \n",
    "                        layer2_num=None, no_opt=True, cka_only=True)\n",
    "                    if num_neurons in frac_to_ckas:\n",
    "                        frac_to_ckas[num_neurons].append(stir_score.rsm)\n",
    "                    else:\n",
    "                        frac_to_ckas[num_neurons] = [stir_score.rsm]\n",
    "\n",
    "            append_to_frac_ckas[append] = frac_to_ckas\n",
    "        results.append(append_to_frac_ckas)\n",
    "    \n",
    "    plt_str = '== CKA Analysis between part and whole layer ==\\n\\n'\n",
    "    for idx, eval_ds in enumerate(EVAL_DATASETS):\n",
    "        plt_str += f'=== {eval_ds} ===\\n\\n'\n",
    "        for append in APPENDS:\n",
    "            frac_to_ckas = results[idx][append]\n",
    "            x_vals, y_vals = list(zip(*sorted(frac_to_ckas.items(), key=lambda t: t[0])))\n",
    "            plt_str += '== {} ==\\n\\n{}\\n\\n'.format(\n",
    "                append,\n",
    "                plt_hp.get_wiki_link(plt_hp.line_plot(\n",
    "                    [[np.mean(_y) for _y in y_vals]], 'Number of Neurons', \n",
    "                    'CKA between fraction and full', f'Eval On {eval_ds}, mode = {MODE}', \n",
    "                    subfolder=SOURCE_DATASET, filename=f'part-whole-{MODEL}-{append}-{eval_ds}-{MODE}', \n",
    "                    extension='png', x_vals=x_vals, \n",
    "                    legend_vals=['CKA (fraction, full)'], vertical_line=None, \n",
    "                    colors=None, linestyles=['-', '--'],\n",
    "                    y_lims=(0.,1.1), root_dir='.', paper_friendly_plots=False, \n",
    "                    plot_inside=False, legend_location='best', savefig=True, figsize=(10,6), \n",
    "                    marker=[True], results_subfolder_name='cka_analysis_part_whole', \n",
    "                    grid_spacing=None, y_err=[[np.std(_y) for _y in y_vals]], legend_ncol=None), \n",
    "                                     SERVER_PROJECT_PATH, size=700))\n",
    "    \n",
    "    with open(f'./results/cka_analysis_part_whole/{SOURCE_DATASET}/'\n",
    "              f'wiki_results-{MODEL}-{MODE}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, \n",
    "                                          'cka_analysis_part_whole', SOURCE_DATASET)], \n",
    "            'results', SERVER_PROJECT_PATH, '.png')\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b508cc2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50_mrl'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "SEEDS = list(range(1,6))\n",
    "\n",
    "resnet50_mrl_first = results(SOURCE_DATASET, MODEL, APPENDS, [], NUMBERS, 'first')\n",
    "resnet50_mrl_random = results(SOURCE_DATASET, MODEL, APPENDS, SEEDS, NUMBERS, 'random')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40dfdabf",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "FRACTIONS = [0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.8, 0.9, 1.]\n",
    "SEEDS = list(range(1,6))\n",
    "\n",
    "resnet50_numbers = results(SOURCE_DATASET, MODEL, APPENDS, SEEDS, NUMBERS, 'random')\n",
    "resnet50_fracs = results(SOURCE_DATASET, MODEL, APPENDS, SEEDS, \n",
    "                         [int(x * NUMBERS[-1]) for x in FRACTIONS], 'random')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79cc9325",
   "metadata": {},
   "outputs": [],
   "source": [
    "MODELS = ['resnet50_ff1024', 'resnet50_ff512', 'resnet50_ff256', 'resnet50_ff128',\n",
    "          'resnet50_ff64', 'resnet50_ff32', 'resnet50_ff16', 'resnet50_ff8']\n",
    "NUMBERS = [\n",
    "    [16,32,48,64,128,256,512,700,1024],\n",
    "    [16,32,48,64,128,256,400,512],    \n",
    "    [5,20,50,100,150,200,256],\n",
    "    [4,10,16,32,40,75,100,128],    \n",
    "    [4,10,16,32,45,50,58,64],\n",
    "    [2,6,8,10,16,24,32],\n",
    "    [2,6,8,10,14,16],\n",
    "    [1,2,4,6,8]\n",
    "]\n",
    "SEEDS = list(range(1,6))\n",
    "APPENDS = ['nonrob']\n",
    "SOURCE_DATASET = 'imagenet'\n",
    "\n",
    "resnet50_width_results = []\n",
    "for model, NUMS in zip(MODELS, NUMBERS):\n",
    "    resnet50_width_results.append(results(SOURCE_DATASET, model, APPENDS, SEEDS, NUMS, 'random'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0fe3729",
   "metadata": {},
   "outputs": [],
   "source": [
    "for append in CHECKPOINT_PATHS['resnet50'].keys():\n",
    "    all_x_vals, all_y_vals, all_full_accs = [], [], []\n",
    "    for idx, ft_ds in enumerate(EVAL_DATASETS[1:]):\n",
    "        x, y = list(zip(*sorted(\n",
    "            resnet50_fracs_test[idx+1][append].items(), key=lambda i:i[0])))\n",
    "        all_x_vals.append([x_/NUMBERS[-1] for x_ in x])\n",
    "        all_y_vals.append(y)\n",
    "    plt_hp.line_plot(\n",
    "        [[np.mean(_y) for _y in y_vals] for y_vals in all_y_vals], \n",
    "        'Fraction of Neurons (Total=2048)', 'CKA', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'part-whole-{MODEL}-{append}-allds-random_test', \n",
    "        extension='png', x_vals=all_x_vals, legend_vals=EVAL_DATASETS[1:], vertical_line=None, \n",
    "        colors=plt_hp.COLORS[:len(EVAL_DATASETS)], \n",
    "        linestyles=['-']*len(EVAL_DATASETS), \n",
    "        y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='lower right', savefig=True, figsize=(10,6), \n",
    "        marker=[True] * len(EVAL_DATASETS), \n",
    "        results_subfolder_name='cka_analysis_part_whole', grid_spacing=None, \n",
    "        y_err=[[np.std(_y) for _y in y_vals] for y_vals in all_y_vals], legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8f0afe7",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, ft_ds in enumerate(EVAL_DATASETS):\n",
    "    all_x_vals, all_y_vals, all_full_accs = [], [], []\n",
    "    legend_vals = []\n",
    "    for model_name, model_info, model_res in zip(['resnet50', 'resnet50_mrl', 'resnet50_mrl'], \n",
    "                                     ['random', 'first', 'random'],\n",
    "                                     [resnet50_numbers, resnet50_mrl_first, resnet50_mrl_random]):\n",
    "        for append in CHECKPOINT_PATHS[model_name].keys():\n",
    "            x, y = list(zip(*sorted(\n",
    "                model_res[idx][append].items(), key=lambda i:i[0])))\n",
    "            all_x_vals.append([x_/MODEL_TO_TOTAL[model_name] for x_ in x])\n",
    "            all_y_vals.append(y)\n",
    "            legend_vals.append(f'{model_name}-{append}-{model_info}')\n",
    "    plt_hp.line_plot(\n",
    "        [[np.mean(_y) for _y in y_vals] for y_vals in all_y_vals], \n",
    "        'Fraction of Neurons (Total=2048)', 'CKA', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'part-whole-{MODEL}-{append}-{ft_ds}-random', \n",
    "        extension='png', x_vals=all_x_vals, legend_vals=legend_vals, vertical_line=None, \n",
    "        colors=plt_hp.COLORS[:len(EVAL_DATASETS)], \n",
    "        linestyles=['-']*len(EVAL_DATASETS), \n",
    "        y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=[True] * len(EVAL_DATASETS), \n",
    "        results_subfolder_name='cka_analysis_part_whole', grid_spacing=None, \n",
    "        y_err=[[np.std(_y) for _y in y_vals] for y_vals in all_y_vals], legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e79d0c57",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch16_224'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "NUMBERS = [2,5,11,23,47,95,191,382,764]\n",
    "SEEDS = list(range(1,6))\n",
    "\n",
    "results(SOURCE_DATASET, MODEL, APPENDS, SEEDS, NUMBERS, 'random')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1349577",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch32_224'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "NUMBERS = [2,5,11,23,47,95,191,382,764]\n",
    "SEEDS = list(range(1,6))\n",
    "\n",
    "results(SOURCE_DATASET, MODEL, APPENDS, SEEDS, NUMBERS, 'random')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c19f6bd1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
