{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare random resnet with PT-FT resnet "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from similarity_metrics import CKA, deltaCKA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, random\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torchvision\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found device: cpu\n"
     ]
    }
   ],
   "source": [
    "def seed_everything(seed):\n",
    "    random.seed(seed)\n",
    "    os.environ[\"PYTHONHASHSEED\"] = str(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Found device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResNet-18 model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Resnet18_pt(nn.Module):\n",
    "    def __init__(self, num_classes):\n",
    "        super().__init__()\n",
    "        self.encoder = models.resnet18(pretrained=True)\n",
    "        self.encoder.fc = nn.Linear(512, num_classes)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # return representations for the \n",
    "        # intermediate layers\n",
    "        x = x.float()\n",
    "        inter_x = self.encoder.maxpool(self.encoder.relu(\n",
    "            self.encoder.bn1(self.encoder.conv1(x))))\n",
    "        intermediate_outputs = []\n",
    "        for layer in [self.encoder.layer1, self.encoder.layer2, self.encoder.layer3, self.encoder.layer4]:\n",
    "            for layer_i in layer:\n",
    "                inter_x = layer_i(inter_x)\n",
    "                intermediate_outputs.append(inter_x)\n",
    "\n",
    "        #forward pass\n",
    "        x = self.encoder(x)\n",
    "\n",
    "        return x, intermediate_outputs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def ResNet_Gaussian_noise(noise_std, seed, device = 'cpu'):\n",
    "    seed_everything(seed)\n",
    "    model =  Resnet18_pt(num_classes=4).to(device)\n",
    "    with torch.no_grad():\n",
    "        for param in model.parameters():\n",
    "            param.copy_(param+torch.randn_like(param) * noise_std)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "## get all embeddings\n",
    "def get_embedding_all_conv(resnet, input_x, device):\n",
    "    all_layers = []\n",
    "    def remove_sequential(network):\n",
    "        for layer in network.children():\n",
    "            if isinstance(layer, nn.Sequential): # if sequential layer, apply recursively to layers in sequential layer\n",
    "                remove_sequential(layer)\n",
    "            if isinstance(layer, models.resnet.BasicBlock): # if block layer, apply recursively to layers in block layer\n",
    "                remove_sequential(layer)\n",
    "            if list(layer.children()) == []: # if leaf node, add it to list\n",
    "                all_layers.append(layer)\n",
    "    remove_sequential(resnet.encoder)\n",
    "    \n",
    "    all_conv = [i for i in all_layers if isinstance(i, nn.Conv2d)]\n",
    "    name_conv = ['conv'+str(i+1) for i in range(len(all_conv))]\n",
    "    \n",
    "    \n",
    "    activation = {}\n",
    "    def get_activation(name):\n",
    "        def hook(model, input, output):\n",
    "            activation[name] = output.detach()\n",
    "        return hook\n",
    "    hook_list = []\n",
    "    for i in range(len(all_conv)):\n",
    "        hook_i = all_conv[i].register_forward_hook(get_activation(name_conv[i]))\n",
    "        hook_list.append(hook_i)\n",
    "    y_pred = resnet(input_x.to(device))\n",
    "    \n",
    "    for hook_j in hook_list: ## remove the hook\n",
    "        hook_j.remove()\n",
    "    \n",
    "    activation_list = [activation[i] for i in name_conv]\n",
    "    \n",
    "    return activation_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the CIFAR-10 data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose(\n",
    "    [\n",
    "     transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))\n",
    "     ])\n",
    "\n",
    "\n",
    "batch_size = 20 # select 20 datapoints for demostration\n",
    "cifar_testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "seed_everything(50)\n",
    "cifar_testloader = torch.utils.data.DataLoader(cifar_testset, batch_size=batch_size,\n",
    "                                         shuffle=True, num_workers=1)\n",
    "cka_examples_batch = next(iter(cifar_testloader))\n",
    "input_batch, cifar_sample_target = cka_examples_batch "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load ResNets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "## generate random resnets\n",
    "noise_std = 1.; seed_1 = 44; seed_2 = 200;\n",
    "res_model_rnd1 = ResNet_Gaussian_noise(noise_std, seed_1, device = device)\n",
    "res_model_rnd2 = ResNet_Gaussian_noise(noise_std, seed_2, device = device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "## load pretrained and fine-tuned resnets\n",
    "res_model_pt = Resnet18_pt(num_classes=10).to(device)\n",
    "res_model_ft = Resnet18_pt(num_classes=10).to(device); res_model_ft.load_state_dict(torch.load('./cifar_resnet18.pth', map_location=torch.device(device)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "res_model_rnd1_all_embeddings = get_embedding_all_conv(res_model_rnd1,input_batch.to(device),device=device)\n",
    "res_model_rnd2_all_embeddings = get_embedding_all_conv(res_model_rnd2,input_batch.to(device),device=device)\n",
    "\n",
    "res_model_pt_all_embeddings = get_embedding_all_conv(res_model_pt,input_batch.to(device),device=device)\n",
    "res_model_ft_all_embeddings = get_embedding_all_conv(res_model_ft,input_batch.to(device),device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_embedding = input_batch.reshape(input_batch.shape[0], -1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Calculate the similarity between first-layer representations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PT-FT similarity: dCKA 0.710, CKA 0.959 \n",
      "rnd1-rnd2 similarity: dCKA 0.460, CKA 0.997 \n"
     ]
    }
   ],
   "source": [
    "layer = 0\n",
    "dcka = deltaCKA(device=device); cka = CKA(device=device)\n",
    "\n",
    "dcka_pt_ft = dcka.linear_CKA(res_model_pt_all_embeddings[layer].reshape(input_batch.shape[0], -1), res_model_ft_all_embeddings[layer].reshape(input_batch.shape[0], -1), input_embedding, input_embedding)\n",
    "cka_pt_ft = cka.linear_CKA(res_model_pt_all_embeddings[layer].reshape(input_batch.shape[0], -1), res_model_ft_all_embeddings[layer].reshape(input_batch.shape[0], -1))\n",
    "print('PT-FT similarity: dCKA %.3f, CKA %.3f ' %(dcka_pt_ft, cka_pt_ft))\n",
    "\n",
    "dcka_rnd1_rnd2 = dcka.linear_CKA(res_model_rnd1_all_embeddings[layer].reshape(input_batch.shape[0], -1), res_model_rnd2_all_embeddings[layer].reshape(input_batch.shape[0], -1), input_embedding, input_embedding)\n",
    "cka_rnd1_rnd2 = cka.linear_CKA(res_model_rnd1_all_embeddings[layer].reshape(input_batch.shape[0], -1), res_model_rnd2_all_embeddings[layer].reshape(input_batch.shape[0], -1))\n",
    "print('rnd1-rnd2 similarity: dCKA %.3f, CKA %.3f ' %(dcka_rnd1_rnd2, cka_rnd1_rnd2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
