{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3e94090",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'\n",
    "\n",
    "from pytorch_lightning import utilities as pl_utils\n",
    "from pytorch_lightning.trainer.trainer import Trainer\n",
    "from pytorch_lightning.plugins import DDPPlugin\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pathlib\n",
    "import sys\n",
    "import numpy as np\n",
    "\n",
    "sys.path.append('..')\n",
    "sys.path.append('../deep-learning-base')\n",
    "sys.path.append('../deep-learning-base/training')\n",
    "sys.path.append('../partially_inverted_reps')\n",
    "\n",
    "import plot_helper as plt_hp\n",
    "import output as out\n",
    "import architectures as arch\n",
    "from architectures.callbacks import LightningWrapper\n",
    "from datasets.data_modules import DATA_MODULES\n",
    "import datasets.dataset_metadata as dsmd\n",
    "from partial_loss import PartialInversionLoss, PartialInversionRegularizedLoss\n",
    "from __init__ import DATA_PATH_IMAGENET, DATA_PATH, SERVER_PROJECT_PATH\n",
    "from functools import partial\n",
    "import stir.model.tools.helpers as helpers\n",
    "import stir\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdb551e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MaskedModel(nn.Module):\n",
    "    def __init__(self, model, mask):\n",
    "        super().__init__()\n",
    "        self.mask = mask\n",
    "        self.model = model\n",
    "    \n",
    "    def forward(self, x, *args, **kwargs):\n",
    "        out, latent = self.model(x, *args, **kwargs)\n",
    "        return out, latent[:,self.mask]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4e9619c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_chosen_neurons(m1):\n",
    "    frac_to_chosen_neurons = {}\n",
    "    for partial_seed in PARTIAL_CHOICE_SEEDS:\n",
    "        for frac in PARTIAL_FRACTIONS:\n",
    "            name, param = list(m1.model.named_modules())[-1]\n",
    "            in_fts = param.in_features\n",
    "            num_neurons = int(frac * in_fts)\n",
    "            linear = nn.Linear(num_neurons, dsmd.DATASET_PARAMS[SOURCE_DATASET]['num_classes'])\n",
    "            torch.manual_seed(partial_seed)\n",
    "            chosen_neurons = torch.randperm(in_fts)[:num_neurons]\n",
    "\n",
    "            if frac in frac_to_chosen_neurons:\n",
    "                frac_to_chosen_neurons[frac].append(chosen_neurons)\n",
    "            else:\n",
    "                frac_to_chosen_neurons[frac] = [chosen_neurons]\n",
    "    return frac_to_chosen_neurons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "61f9908b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, PARTIAL_FRACTIONS):\n",
    "    results = []\n",
    "    for eval_ds in EVAL_DATASETS:\n",
    "        print (eval_ds)\n",
    "        append_to_frac_ckas = {}\n",
    "        dm = DATA_MODULES[eval_ds](\n",
    "            data_dir=DATA_PATH_IMAGENET if 'imagenet' in eval_ds else DATA_PATH,\n",
    "            transform_train=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "            transform_test=dsmd.TEST_TRANSFORMS_DEFAULT(224),\n",
    "            batch_size=BATCH_SIZE)\n",
    "        dm.init_remaining_attrs(eval_ds)\n",
    "        for append in APPENDS:\n",
    "            m1 = arch.create_model(MODEL, SOURCE_DATASET, pretrained=True,\n",
    "                                   checkpoint_path=CHECKPOINT_PATHS[MODEL][append], seed=SEED, \n",
    "                                   callback=partial(LightningWrapper, \n",
    "                                                    dataset_name=SOURCE_DATASET,\n",
    "                                                    inference_kwargs={'with_latent': True}))\n",
    "            frac_to_chosen_neurons = find_chosen_neurons(m1)\n",
    "            frac_to_ckas = {}\n",
    "            for frac in PARTIAL_FRACTIONS:\n",
    "                for mask1, mask2 in itertools.combinations(frac_to_chosen_neurons[frac], 2):\n",
    "                    stir_score = stir.STIR(MaskedModel(m1, mask1), MaskedModel(m1, mask2), \n",
    "                        helpers.InputNormalize(dsmd.STANDARD_MEAN, dsmd.STANDARD_STD), \n",
    "                        helpers.InputNormalize(dsmd.STANDARD_MEAN, dsmd.STANDARD_STD),\n",
    "                        (dm.test_dataloader(), 1000), verbose=False, layer1_num=None, \n",
    "                        layer2_num=None, no_opt=True, cka_only=True)\n",
    "                    if frac in frac_to_ckas:\n",
    "                        frac_to_ckas[frac].append(stir_score.rsm)\n",
    "                    else:\n",
    "                        frac_to_ckas[frac] = [stir_score.rsm]\n",
    "\n",
    "            append_to_frac_ckas[append] = frac_to_ckas\n",
    "        results.append(append_to_frac_ckas)\n",
    "    \n",
    "    plt_str = '== CKA Analysis ==\\n\\n'\n",
    "    for idx, eval_ds in enumerate(EVAL_DATASETS):\n",
    "        plt_str += f'=== {eval_ds} ===\\n\\n'\n",
    "        for append in APPENDS:\n",
    "            frac_to_ckas = results[idx][append]\n",
    "            full_cka = frac_to_ckas[1.]\n",
    "            remaining_cka = {k:v for k,v in frac_to_ckas.items() if k!=1}\n",
    "            x_vals, y_vals = list(zip(*sorted(remaining_cka.items(), key=lambda t: t[0])))\n",
    "            plt_str += '== {} ==\\n\\n{}\\n\\n'.format(\n",
    "                append,\n",
    "                plt_hp.get_wiki_link(plt_hp.line_plot(\n",
    "                    [[np.mean(_y) for _y in y_vals]], 'Fraction of neurons', 'CKA', f'Eval On {eval_ds}', \n",
    "                    subfolder=SOURCE_DATASET, filename=f'{MODEL}-{append}-{eval_ds}', extension='png', \n",
    "                    x_vals=x_vals, \n",
    "                    legend_vals=['', 'Full Layer'], vertical_line=None, \n",
    "                    horizontal_lines=[np.mean(full_cka)], horizontal_lines_err=[np.std(full_cka)], \n",
    "                    colors=None, linestyles=['-', '--'],\n",
    "                    y_lims=(0.,1.1), root_dir='.', paper_friendly_plots=False, \n",
    "                    plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "                    marker=[True], results_subfolder_name='cka_analysis', \n",
    "                    grid_spacing=None, y_err=[[np.std(_y) for _y in y_vals]], legend_ncol=None), \n",
    "                                     SERVER_PROJECT_PATH, size=1000))\n",
    "    \n",
    "    with open(f'./results/cka_analysis/{SOURCE_DATASET}/wiki_results-{MODEL}.txt', 'w') as fp:\n",
    "        fp.write(plt_str)\n",
    "\n",
    "    out.upload_results(['{}/{}/{}'.format(plt_hp.RESULTS_FOLDER_NAME, 'cka_analysis', SOURCE_DATASET)], \n",
    "            'results', SERVER_PROJECT_PATH, '.png')\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0323d8f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet'\n",
    "MODEL = 'resnet50'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "\n",
    "PARTIAL_CHOICE_SEEDS = list(range(1,6))\n",
    "PARTIAL_FRACTIONS = sorted(\n",
    "        list(set(\n",
    "            [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                glob.glob(f'./checkpoints/{MODEL}-base-'\n",
    "                          f'{SOURCE_DATASET}-ft-{EVAL_DATASETS[-1]}/'\n",
    "                          f'*-bs-256-{APPENDS[0]}')]\n",
    "            )))\n",
    "NUMBERS = [8,16,32,64,128,256,512,1024,2048]\n",
    "\n",
    "results_fracs = results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, PARTIAL_FRACTIONS)\n",
    "results_numbers = results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, \n",
    "                          [x/NUMBERS[-1] for x in NUMBERS])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0d73145",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83afb735",
   "metadata": {},
   "outputs": [],
   "source": [
    "for append in CHECKPOINT_PATHS['resnet50'].keys():\n",
    "    all_x_vals, all_y_vals= [], []\n",
    "    for idx, ft_ds in enumerate(EVAL_DATASETS):\n",
    "        x, y = list(zip(*sorted(\n",
    "            results_fracs[idx][append].items(), key=lambda i:i[0])))\n",
    "        all_x_vals.append(x[:-1])\n",
    "        all_y_vals.append(y[:-1])\n",
    "    plt_hp.line_plot(\n",
    "        [[np.nanmean(_y) for _y in y_vals] for y_vals in all_y_vals], \n",
    "        'Fraction of Neurons (Total=2048)', 'CKA', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'{MODEL}-{append}-allds-random', \n",
    "        extension='png', x_vals=all_x_vals, legend_vals=EVAL_DATASETS, vertical_line=None, \n",
    "        colors=plt_hp.COLORS[:len(EVAL_DATASETS)], \n",
    "        linestyles=['']*len(EVAL_DATASETS), \n",
    "        y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=['o'] * len(EVAL_DATASETS), \n",
    "        results_subfolder_name='cka_analysis', grid_spacing=None, \n",
    "        y_err=[[np.nanstd(_y) for _y in y_vals] for y_vals in all_y_vals], legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1604cb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for append in CHECKPOINT_PATHS['resnet50'].keys():\n",
    "    all_x_vals, all_y_vals= [], []\n",
    "    for idx, ft_ds in enumerate(EVAL_DATASETS):\n",
    "        x, y = list(zip(*sorted(\n",
    "            results_fracs[idx+1][append].items(), key=lambda i:i[0])))\n",
    "        all_x_vals.append(x[:-1])\n",
    "        all_y_vals.append(y[:-1])\n",
    "    plt_hp.line_plot(\n",
    "        [[np.nanmean(_y) for _y in y_vals] for y_vals in all_y_vals], \n",
    "        'Fraction of Neurons (Total=2048)', 'CKA', '', \n",
    "        subfolder=SOURCE_DATASET, filename=f'{MODEL}-{append}-allds-random', \n",
    "        extension='png', x_vals=all_x_vals, legend_vals=EVAL_DATASETS, vertical_line=None, \n",
    "        colors=plt_hp.COLORS[:len(EVAL_DATASETS)], \n",
    "        linestyles=['']*len(EVAL_DATASETS), \n",
    "        y_lims=(0.,1.02), root_dir='.', paper_friendly_plots=True, \n",
    "        plot_inside=True, legend_location='best', savefig=True, figsize=(10,6), \n",
    "        marker=['*'] * len(EVAL_DATASETS), \n",
    "        results_subfolder_name='cka_analysis', grid_spacing=None, \n",
    "        y_err=[[np.nanstd(_y) for _y in y_vals] for y_vals in all_y_vals], legend_ncol=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75ddf09b",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch16_224'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "\n",
    "PARTIAL_CHOICE_SEEDS = list(range(1,6))\n",
    "PARTIAL_FRACTIONS = sorted(\n",
    "        list(set(\n",
    "            [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                glob.glob(f'./checkpoints/{MODEL}-base-'\n",
    "                          f'{SOURCE_DATASET}-ft-{EVAL_DATASETS[-1]}/'\n",
    "                          f'*-bs-256-{APPENDS[0]}')]\n",
    "            )))\n",
    "print (PARTIAL_FRACTIONS)\n",
    "results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, PARTIAL_FRACTIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40662688",
   "metadata": {},
   "outputs": [],
   "source": [
    "SOURCE_DATASET = 'imagenet21k'\n",
    "MODEL = 'vit_small_patch32_224'\n",
    "APPENDS = list(CHECKPOINT_PATHS[MODEL].keys())\n",
    "\n",
    "PARTIAL_CHOICE_SEEDS = list(range(1,6))\n",
    "PARTIAL_FRACTIONS = sorted(\n",
    "        list(set(\n",
    "            [float(x.split('/frac-')[1].split('-')[0]) for x in \\\n",
    "                glob.glob(f'./checkpoints/{MODEL}-base-'\n",
    "                          f'{SOURCE_DATASET}-ft-{EVAL_DATASETS[-1]}/'\n",
    "                          f'*-bs-256-{APPENDS[0]}')]\n",
    "            )))\n",
    "print (PARTIAL_FRACTIONS)\n",
    "results(SOURCE_DATASET, MODEL, APPENDS, PARTIAL_CHOICE_SEEDS, PARTIAL_FRACTIONS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9f47464",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
