{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bbcf6f7a-76e2-4d51-aacd-51073028ce18",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.utils.data as data\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import datasets\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "import wandb\n",
    "\n",
    "from tqdm import tqdm\n",
    "from types import SimpleNamespace"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43466bdb-516b-4e7f-831e-d09186d9c069",
   "metadata": {},
   "source": [
    "## ImageNet Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b75060a1-426c-4a83-b976-2c7fad746f99",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageNetDataset(data.Dataset):\n",
    "    def __init__(self, split = 'train'):\n",
    "        dataset = datasets.load_dataset('ILSVRC/imagenet-1k')\n",
    "        self.data = dataset[split]\n",
    "        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                        std=[0.229, 0.224, 0.225])\n",
    "        if split == 'train':\n",
    "            self.transform = transforms.Compose([\n",
    "                transforms.RandomResizedCrop(224),\n",
    "                transforms.RandomHorizontalFlip(),\n",
    "                transforms.ToTensor(),\n",
    "                normalize,\n",
    "            ])\n",
    "        else:\n",
    "            self.transform = transforms.Compose([\n",
    "                transforms.Resize(256),\n",
    "                transforms.CenterCrop(224),\n",
    "                transforms.ToTensor(),\n",
    "                normalize,\n",
    "            ])\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        entry = self.data[idx]\n",
    "        img = entry['image'].convert('RGB')\n",
    "        label = entry['label']\n",
    "\n",
    "        img = self.transform(img)\n",
    "        return {'img': img, 'label': label}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d41e7f5-5513-41a0-b03e-f9ede3279803",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataloaders(batch_size = 64, num_workers = 4):\n",
    "    train_dataset, val_dataset, test_dataset = ImageNetDataset('train'), ImageNetDataset('validation'), ImageNetDataset('test')\n",
    "    train_loader = data.DataLoader(train_dataset, batch_size, shuffle = True, num_workers = num_workers)\n",
    "    val_loader = data.DataLoader(val_dataset, batch_size = batch_size, num_workers = num_workers, shuffle = False)\n",
    "    test_loader = data.DataLoader(test_dataset, batch_size = 1)\n",
    "    return train_loader, val_loader, test_loader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfdbaf41-00b4-42d3-87fe-671e2c6e67a1",
   "metadata": {},
   "source": [
    "## Architectures"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "82f4ed24-45a4-45b2-82d3-494a9bde0fe5",
   "metadata": {},
   "source": [
    "### MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a38fc70-192c-4ad5-b895-fb458e477c34",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ImageNetNarrowMLP(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(ImageNetNarrowMLP, self).__init__()\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.initial_layer = nn.Sequential(\n",
    "            nn.Linear(150528, 2048),\n",
    "            nn.BatchNorm1d(2048),\n",
    "            nn.ReLU()\n",
    "        )\n",
    "        self.intermediate_layers = nn.Sequential(\n",
    "            *[nn.Sequential(\n",
    "                nn.Linear(2048 if i == 0 else 1024, 1024),\n",
    "                nn.BatchNorm1d(1024),\n",
    "                nn.ReLU()\n",
    "              ) for i in range(47)]\n",
    "        )\n",
    "        self.output_layer = nn.Sequential(\n",
    "            nn.Linear(1024, 1000),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.flatten(x)\n",
    "        x = self.initial_layer(x)\n",
    "        x = self.intermediate_layers(x)\n",
    "        x = self.output_layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "37368a67-9e72-4fdb-8c1c-64c4cff31298",
   "metadata": {},
   "source": [
    "## Representational Similarity"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "52ab3309-301b-4530-8904-9390171057c3",
   "metadata": {},
   "source": [
    "### Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e8d0f06-ad3a-4302-a614-d97c600405ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CKA(object):\n",
    "    def __init__(self, device):\n",
    "        self.device = device\n",
    "\n",
    "    def centering(self, K):\n",
    "        n = K.shape[0]\n",
    "        I = torch.eye(n, device=self.device)\n",
    "        H = I - torch.ones([n, n], device=self.device) / n\n",
    "        return H @ K @ H \n",
    "\n",
    "    def linear_HSIC(self, X, Y):\n",
    "        L_X = X @ X.T\n",
    "        L_Y = Y @ Y.T\n",
    "        return torch.sum(self.centering(L_X) * self.centering(L_Y))\n",
    "\n",
    "    def linear_CKA(self, X, Y):\n",
    "        hsic = self.linear_HSIC(X, Y)\n",
    "        var1 = torch.sqrt(self.linear_HSIC(X, X))\n",
    "        var2 = torch.sqrt(self.linear_HSIC(Y, Y))\n",
    "\n",
    "        return hsic / (var1 * var2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7be1bb9e-736c-4e9c-86d3-b2f927943f05",
   "metadata": {},
   "source": [
    "### Layer-wise Computation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d980b7b-267c-40f2-9162-cfd33fbfb45b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def hook_fn(module, input, output, layer_outputs, layer_name):\n",
    "    layer_outputs[layer_name] = output\n",
    "\n",
    "def register_hooks(model, layer_outputs):\n",
    "    # NOTE: Just an example, I include many more in the actual code.\n",
    "    hooks = []\n",
    "    for name, module in model.named_modules():\n",
    "        if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) or isinstance(module, nn.AdaptiveAvgPool2d) or isinstance(module, nn.LSTM) or isinstance(module, nn.RNN): \n",
    "            hooks.append(module.register_forward_hook(lambda m, i, o, name=name: hook_fn(m, i, o, layer_outputs, name)))\n",
    "    return hooks\n",
    "\n",
    "def get_layer_outputs(model, inputs, eval = False):\n",
    "    layer_outputs = OrderedDict()\n",
    "    hooks = register_hooks(model, layer_outputs)\n",
    "    if eval:\n",
    "        model = model.eval()\n",
    "        with torch.no_grad():\n",
    "            model(inputs)\n",
    "    else:\n",
    "        model(inputs)\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    return layer_outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b149e43-35f0-4d33-94f3-38e70d2df5a7",
   "metadata": {},
   "source": [
    "To map from one architecture to another, we can spread out layers as follows. The following functions are just some examples. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f37d9d-10e6-4fee-9a2e-5aa0038ccd78",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_layers():\n",
    "    rn18_rn50_mapping = {\n",
    "        'conv1': 'conv1',\n",
    "        'layer1.0.conv1': 'layer1.0.conv1',\n",
    "        'layer1.0.conv2': 'layer1.0.conv3',\n",
    "        'layer1.1.conv1': 'layer1.2.conv1',\n",
    "        'layer1.1.conv2': 'layer1.2.conv3',\n",
    "        'layer2.0.conv1': 'layer2.0.conv1',\n",
    "        'layer2.0.conv2': 'layer2.0.conv3',\n",
    "        'layer2.0.downsample.0': 'layer2.0.downsample.0',\n",
    "        'layer2.1.conv1': 'layer2.2.conv1',\n",
    "        'layer2.1.conv2': 'layer2.2.conv3',\n",
    "        'layer3.0.conv1': 'layer3.0.conv1',\n",
    "        'layer3.0.conv2': 'layer3.0.conv3',\n",
    "        'layer3.0.downsample.0': 'layer3.0.downsample.0',\n",
    "        'layer3.1.conv1': 'layer3.3.conv1',\n",
    "        'layer3.1.conv2': 'layer3.3.conv3',\n",
    "        'layer4.0.conv1': 'layer4.0.conv1',\n",
    "        'layer4.0.conv2': 'layer4.0.conv3',\n",
    "        'layer4.0.downsample.0': 'layer4.0.downsample.0',\n",
    "        'layer4.1.conv1': 'layer4.2.conv1',\n",
    "        'layer4.1.conv2': 'layer4.2.conv3',\n",
    "    }\n",
    "\n",
    "    return rn18_rn50_mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27174106-b510-4bb9-bad6-8183bea05baa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_rn_mlp():\n",
    "    rn18_mlp_mapping = {\n",
    "        'conv1': 'initial_layer.0',\n",
    "        'layer1.0.conv1': 'intermediate_layers.1.0',\n",
    "        'layer1.0.conv2': 'intermediate_layers.3.0',\n",
    "        'layer1.1.conv1': 'intermediate_layers.5.0',\n",
    "        'layer1.1.conv2': 'intermediate_layers.7.0',\n",
    "        'layer2.0.conv1': 'intermediate_layers.9.0',\n",
    "        'layer2.0.conv2': 'intermediate_layers.11.0',\n",
    "        'layer2.0.downsample.0': 'intermediate_layers.13.0',\n",
    "        'layer2.1.conv1': 'intermediate_layers.15.0',\n",
    "        'layer2.1.conv2': 'intermediate_layers.17.0',\n",
    "        'layer3.0.conv1': 'intermediate_layers.20.0',\n",
    "        'layer3.0.conv2': 'intermediate_layers.22.0',\n",
    "        'layer3.0.downsample.0': 'intermediate_layers.24.0',\n",
    "        'layer3.1.conv1': 'intermediate_layers.26.0',\n",
    "        'layer3.1.conv2': 'intermediate_layers.28.0',\n",
    "        'layer4.0.conv1': 'intermediate_layers.31.0',\n",
    "        'layer4.0.conv2': 'intermediate_layers.33.0',\n",
    "        'layer4.0.downsample.0': 'intermediate_layers.35.0',\n",
    "        'layer4.1.conv1': 'intermediate_layers.37.0',\n",
    "        'layer4.1.conv2': 'intermediate_layers.39.0',\n",
    "        'avgpool': 'intermediate_layers.42.0',\n",
    "        'fc': 'output_layer.0'\n",
    "    }\n",
    "    \n",
    "    return rn18_mlp_mapping"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d388f3ac-780e-44b9-bb5d-2090182046b0",
   "metadata": {},
   "source": [
    "Here is a generic layer supervision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b51ae6-9d13-4a1b-924a-dc5775e04fbc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def layer_supervision(target_model_layers, student_model_layers):\n",
    "    source_count = len(target_model_layers)\n",
    "    target_count = len(student_model_layers)\n",
    "    step = (target_count - 1) / (source_count - 1) if source_count > 1 else 1\n",
    "\n",
    "    mapping = {}\n",
    "    for i, source_layer in enumerate(target_model_layers):\n",
    "        target_index = min(round(i * step), target_count - 1)\n",
    "        mapping[source_layer] = student_model_layers[target_index]\n",
    "    return mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9d9e9cb-c867-4ac6-8143-e20569e811a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def layermap_sim(train_model, target_model, student_model, rep_sim, inputs, device):\n",
    "    cka = CKA(device)\n",
    "    pretrained_outputs = get_layer_outputs(target_model, inputs, eval = True)\n",
    "    training_outputs = get_layer_outputs(train_model, inputs)\n",
    "    if student_model == 'ResNet-50':\n",
    "        model_mapping = map_layers()\n",
    "    elif student_model == 'DeepMLP':\n",
    "        model_mapping = map_rn_mlp()\n",
    "    else:\n",
    "        teacher_layers = list(pretrained_outputs.keys())\n",
    "        student_layers = list(training_outputs.keys())\n",
    "        if len(teacher_layers) <= len(student_layers):\n",
    "            model_mapping = layer_supervision(teacher_layers, student_layers)\n",
    "        else:\n",
    "            model_mapping = layer_supervision(student_layers, teacher_layers)\n",
    "            model_mapping = {v : k for k, v in model_mapping.items()}\n",
    "    sim_scores = {}\n",
    "    for layer in model_mapping:\n",
    "        assert layer in pretrained_outputs, 'Layer is not in ResNet-18'\n",
    "        tr_layer = model_mapping[layer]\n",
    "        assert tr_layer in training_outputs, f'Layer is not in {student_model}'\n",
    "\n",
    "        pretrained_output = pretrained_outputs[layer]\n",
    "        training_output = training_outputs[tr_layer]\n",
    "        if isinstance(pretrained_output, tuple):\n",
    "            pretrained_output = pretrained_output[0]\n",
    "\n",
    "        if isinstance(training_output, tuple):\n",
    "            training_output = training_output[0]\n",
    "        pretrained_output = pretrained_output.contiguous().view(inputs.size(0), -1)\n",
    "        training_output = training_output.contiguous().view(inputs.size(0), -1)\n",
    "        # NOTE: This is where I am adding the CKA similarity\n",
    "        if rep_sim == 'CKA': \n",
    "            sim = 1 - cka.linear_CKA(training_output.to(torch.float32), pretrained_output.to(torch.float32))\n",
    "        else:\n",
    "            raise NotImplementedError()\n",
    "        sim_scores[tr_layer] = sim\n",
    "    \n",
    "    return sim_scores"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68add029-97de-4fb2-8da3-0e778e0c38be",
   "metadata": {},
   "source": [
    "### Final wrapper to get loss!"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72017972-967e-4c96-a3c0-0a4c66501d5e",
   "metadata": {},
   "source": [
    "LSTMs are pretty funky so we need to fix this up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd0c399e-858d-4399-bda8-bcd696e6d5ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTMWrapper(nn.Module):\n",
    "    def __init__(self, lstm):\n",
    "        super(LSTMWrapper, self).__init__()\n",
    "        self.lstm = lstm\n",
    "\n",
    "    def forward(self, x):\n",
    "        output, _ = self.lstm(x)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "778df9dd-41d1-431e-919d-90cd096c7411",
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrap_and_flatten_model(model):\n",
    "    new_modules = []\n",
    "    \n",
    "    for module in model.children():\n",
    "        if isinstance(module, nn.ModuleList) or isinstance(module, nn.Sequential):\n",
    "            for sub_module in module:\n",
    "                if isinstance(sub_module, nn.LSTM):\n",
    "                    new_modules.append(LSTMWrapper(sub_module))\n",
    "                elif isinstance(sub_module, nn.RNN):\n",
    "                    new_modules.append(LSTMWrapper(sub_module))\n",
    "                else:\n",
    "                    new_modules.append(sub_module)\n",
    "        elif isinstance(module, nn.LSTM) or isinstance(module, nn.RNN):\n",
    "            new_modules.append(LSTMWrapper(module))\n",
    "        else:\n",
    "            new_modules.append(module)\n",
    "    \n",
    "    return nn.Sequential(*new_modules[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68b50375-0fab-4ba4-9413-0e7a36201c9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rep_similarity_loss(exp_name, train_model, target_model, rep_sim, inputs, device, layerwise, student_model = 'ResNet-50', hf_base = False, one_layer = False, batch_idx = None):\n",
    "    if not hf_base:\n",
    "        target_model_fe = wrap_and_flatten_model(target_model)\n",
    "        with torch.no_grad():\n",
    "            target_features = target_model_fe(inputs).squeeze()\n",
    "    else:\n",
    "        target_model_fe = target_model\n",
    "        with torch.no_grad():\n",
    "            target_features = target_model_fe(inputs, output_hidden_states = True).last_hidden_state\n",
    "    train_model_fe = wrap_and_flatten_model(train_model)\n",
    "    batch_inputs = train_model_fe(inputs).squeeze()\n",
    "    if rep_sim == 'CKA':\n",
    "        if one_layer:\n",
    "            cka = CKA(device)\n",
    "            sim = 1 - cka.linear_CKA(batch_inputs.to(torch.float32), target_features.to(torch.float32))\n",
    "        else:\n",
    "            if layerwise:\n",
    "                sims = torch.stack(list(layerwise_sim(train_model_fe, target_model_fe, rep_sim, inputs, device).values()))\n",
    "            else:\n",
    "                sims = torch.stack(list(layermap_sim(train_model_fe, target_model_fe, student_model, rep_sim, inputs, device).values()))\n",
    "            sim = torch.sum(sims)\n",
    "    elif rep_sim == 'Procrustes':\n",
    "        if one_layer:\n",
    "            pro = Procrustes(device)\n",
    "            sim = pro.orthogonal_procrustes_distance(batch_inputs.to(torch.float32), target_features.to(torch.float32), normalize = True)\n",
    "        else:\n",
    "            sims = torch.stack(list(layerwise_sim(train_model_fe, target_model_fe, rep_sim, inputs, device).values()))\n",
    "            sim = torch.sum(sims)\n",
    "    elif rep_sim == 'Ridge':\n",
    "        fit = False \n",
    "        if not fit:\n",
    "            if one_layer:\n",
    "                if not os.path.exists(f'saved_models/{exp_name}'):\n",
    "                    os.makedirs(f'saved_models/{exp_name}')\n",
    "                prev_state_dict = None\n",
    "                if os.path.exists(f'saved_models/{exp_name}/prev_state_dict_{batch_idx}.pt'):\n",
    "                    prev_state_dict = torch.load(f'saved_models/prev_state_dict_{batch_idx}.pt')\n",
    "                sim, state_dict = train_ridge(batch_inputs.to(torch.float32).detach(), target_features.to(torch.float32).detach(), device, prev_state_dict = prev_state_dict)\n",
    "                torch.save(state_dict, f'saved_models/{exp_name}/prev_state_dict_{batch_idx}.pt')\n",
    "            else:\n",
    "                sim = ridge_layerwise_sim(train_model, target_model, inputs, 200, device)\n",
    "        else:\n",
    "            if one_layer:\n",
    "                sim = fit_ridge(batch_inputs.to(torch.float32), batch_inputs.to(torch.float32), device)\n",
    "            else:\n",
    "                sims = torch.stack(list(layerwise_sim(train_model_fe, target_model_fe, rep_sim, inputs, device).values()))\n",
    "                sim = torch.sum(sims)\n",
    "    elif 'CCA' in rep_sim:\n",
    "        cca = CCA(device)\n",
    "        if rep_sim == 'SVCCA':\n",
    "            sim = cca.svcca_distance(batch_inputs.to(torch.float32), target_features.to(torch.float32))\n",
    "        elif rep_sim == 'PWCCA':\n",
    "            sim = cca.pwcca_distance(batch_inputs.to(torch.float32), target_features.to(torch.float32))\n",
    "    else:\n",
    "        raise NotImplementedError\n",
    "    return sim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16d34dae-d787-4863-89b7-1090c79c75cf",
   "metadata": {},
   "source": [
    "## Training!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6df2e38a-e3c0-4558-88b9-683d36fb930d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def total_loss(exp_name, train_model, target_model, rep_sim, loss_fn, preds, imgs, labels, rep_sim_alpha, device, layerwise, student_model = 'ResNet-50', one_layer = False, batch_idx = None):\n",
    "    rep_sim = rep_similarity_loss(exp_name, train_model, target_model, rep_sim, imgs, device, layerwise, student_model = student_model, one_layer = one_layer, batch_idx = batch_idx)\n",
    "    ce_loss = loss_fn(preds, labels)\n",
    "    return ce_loss + rep_sim_alpha * rep_sim, rep_sim, ce_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "035a6f25-39c4-46f9-8ff7-d839f24569f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def adjust_learning_rate(lr, optimizer, epoch):\n",
    "    lr = lr * (0.1 ** (epoch // 30))\n",
    "    for param_group in optimizer.param_groups:\n",
    "        param_group['lr'] = lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58f5cb81-a2d1-424f-9fa8-02cfec4f24af",
   "metadata": {},
   "outputs": [],
   "source": [
    "def validate(model, val_loader, loss_fn, device):\n",
    "    model = model.eval()\n",
    "    val_loss = 0.0\n",
    "    for batch in tqdm(val_loader, desc = 'Iterating over validation batches...'):\n",
    "        imgs, labels = batch['img'].to(device), batch['label'].to(device)\n",
    "        preds = model(imgs)\n",
    "        loss = loss_fn(preds, labels)\n",
    "        val_loss += loss.item()\n",
    "    return val_loss/len(val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b6d7bd0-63d6-4971-a480-11d5f5518dfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_step_size(model, before_state_dict):\n",
    "    sum_changes = 0\n",
    "    count = 0\n",
    "    with torch.no_grad():\n",
    "        after_state_dict = model.state_dict()\n",
    "        for key in before_state_dict:\n",
    "            change = (after_state_dict[key] - before_state_dict[key]).abs().mean().item()\n",
    "            sum_changes += change\n",
    "            count += 1\n",
    "    return sum_changes / count"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c216235b-40fa-4232-ac07-0ddf37f385f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_resnet(args, exp_name, rep_sim, num_epochs, student_model = 'ResNet-50', target_model = 'rn50', lr = 1e-3, batch_size = 64, num_workers = 4, pretrained = False, rep_dist = None, rep_sim_alpha = 1.0, one_layer = False):\n",
    "    wandb.init(\n",
    "        project = exp_name,\n",
    "        config = {\n",
    "            'model': student_model,\n",
    "            'target_model': target_model,\n",
    "            'rep-sim': rep_sim,\n",
    "            'dist-func': rep_dist,\n",
    "            'lr': lr,\n",
    "            'batch_size': batch_size,\n",
    "            'epochs': num_epochs\n",
    "        }\n",
    "    )\n",
    "    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "    train_loader, val_loader, _ = get_dataloaders(batch_size, num_workers)\n",
    "\n",
    "    if student_model == 'ResNet-50':\n",
    "        model = torchvision.models.resnet50(pretrained = pretrained).to(device)\n",
    "    else:\n",
    "        model = ImageNetNarrowMLP()\n",
    "        model = model.to(device)\n",
    "\n",
    "    if rep_sim:\n",
    "        if target_model == 'rn50':\n",
    "            diff = not (student_model == 'NoResNet-50')\n",
    "            target_model = torchvision.models.resnet50(pretrained = False).to(device)\n",
    "        elif target_model == 'rn18':\n",
    "            diff = True\n",
    "            target_model = torchvision.models.resnet18(pretrained = True).to(device)\n",
    "        elif target_model == 'vitb': \n",
    "            diff = True\n",
    "            target_model = torchvision.models.vit_b_16(pretrained = True).to(device)\n",
    "        else:\n",
    "            raise NotImplementedError\n",
    "\n",
    "    loss_fn = nn.CrossEntropyLoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr = lr)\n",
    "\n",
    "    epoch_train_losses = []\n",
    "    step_train_losses = []\n",
    "    step_sizes = []\n",
    "    val_losses = []\n",
    "    step_ce_loss = []\n",
    "    step_rep_sim_loss = []\n",
    "\n",
    "    total_steps = len(train_loader) * num_epochs\n",
    "\n",
    "    for epoch in range(num_epochs):\n",
    "        adjust_learning_rate(lr, optimizer, epoch)\n",
    "        avg_val_loss = validate(model, val_loader, loss_fn, device)\n",
    "        wandb.log({'val_loss': avg_val_loss})\n",
    "        print(f'Epoch {epoch}, Validation Loss: {avg_val_loss}')\n",
    "        val_losses.append(avg_val_loss)\n",
    "\n",
    "        model = model.train()\n",
    "        train_loss = 0.0\n",
    "        for i, batch in enumerate(tqdm(train_loader, desc = 'Iterating over training batches...')):\n",
    "            imgs, labels = batch['img'].to(device), batch['label'].to(device)\n",
    "            preds = model(imgs)\n",
    "\n",
    "            if not rep_sim:\n",
    "                loss = loss_fn(preds, labels)\n",
    "                ce_loss = None\n",
    "            else:\n",
    "                loss, rep_sim, ce_loss = total_loss(exp_name, model, target_model, rep_dist, loss_fn, preds, imgs, labels, rep_sim_alpha, device, layerwise = (not diff), student_model = student_model, one_layer = one_layer, batch_idx = i)\n",
    "                step_ce_loss.append(ce_loss.item())\n",
    "                step_rep_sim_loss.append(rep_sim.item())\n",
    "\n",
    "                if i % 20 == 0:\n",
    "                    avg_ce_loss = np.mean(step_ce_loss[-20:])\n",
    "                    avg_rep_sim_loss = np.mean(step_rep_sim_loss[-20:])\n",
    "                    wandb.log({'ce_loss': avg_ce_loss, 'rep_sim_loss': avg_rep_sim_loss})\n",
    "\n",
    "            before_update_params = {name: param.clone() for name, param in model.named_parameters()}\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            if ce_loss == None:\n",
    "                train_loss += loss.item()\n",
    "            else:\n",
    "                train_loss += ce_loss.item()\n",
    "            step_train_losses.append(loss.item())\n",
    "            \n",
    "            step_size = avg_step_size(model, before_update_params)\n",
    "            step_sizes.append(step_size)\n",
    "            if i % 20 == 0:\n",
    "                avg_train_loss = np.mean(step_train_losses[-20:])\n",
    "                wandb.log({'train_loss': avg_train_loss, 'step_size': step_size})\n",
    "        \n",
    "        avg_train_loss = train_loss/len(train_loader)\n",
    "        print(f'Epoch {epoch + 1}, Training Loss: {avg_train_loss}')\n",
    "        epoch_train_losses.append(avg_train_loss)\n",
    "\n",
    "    final_avg_val_loss = validate(model, val_loader, loss_fn, device)\n",
    "    print(f'Epoch {epoch+1}, Validation Loss: {final_avg_val_loss}')\n",
    "\n",
    "    assert len(step_train_losses) == len(step_sizes) == total_steps\n",
    "    torch.save(model.state_dict(), f'saved_models/{exp_name}.pt') \n",
    "\n",
    "    if not os.path.exists(f'{args.logging}/{args.exp_name}'):\n",
    "        os.makedirs(f'{args.logging}/{args.exp_name}')\n",
    "    \n",
    "    with open(f'{args.logging}/{exp_name}/args.json', 'w') as f:\n",
    "        json.dump(args.__dict__, f, indent=2)\n",
    "\n",
    "    loss_info = {'step_train_losses': step_train_losses, 'step_sizes': step_sizes, 'val_losses': val_losses, 'epoch_train_losses': epoch_train_losses, 'step_ce_loss': step_ce_loss, 'step_rep_sim_loss': step_rep_sim_loss}\n",
    "    loss_info = {key: value for key, value in loss_info.items() if value != []}\n",
    "    with open(f'{args.logging}/{exp_name}/info.json', 'w') as f:\n",
    "        json.dump(loss_info, f)\n",
    "    wandb.finish()\n",
    "    return model, step_train_losses, step_sizes, val_losses, epoch_train_losses, step_ce_loss, step_rep_sim_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f437904f-13ec-4e7f-88d0-4135e632cecf",
   "metadata": {},
   "outputs": [],
   "source": [
    "def accuracy(output, target, topk=(1,)):\n",
    "    with torch.no_grad():\n",
    "        maxk = max(topk)\n",
    "        batch_size = target.size(0)\n",
    "        _, pred = torch.topk(output, maxk, dim = 1, largest = True, sorted = True)\n",
    "        pred = pred.t()\n",
    "        correct = pred.eq(target.view(1, -1).expand_as(pred))\n",
    "        res = []\n",
    "        for k in topk:\n",
    "            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)\n",
    "            res.append(correct_k.mul_(100.0 / batch_size))\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "295c3e87-4e58-4854-9ee7-ed78c528fa79",
   "metadata": {},
   "outputs": [],
   "source": [
    "def eval_loop(model, device):\n",
    "    _, val_loader, _ = get_dataloaders(1, 4)\n",
    "    model = model.eval()\n",
    "    top1_acc = 0\n",
    "    top5_acc = 0\n",
    "    total_samples = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(val_loader, desc = 'Iterating over test batches...'):\n",
    "            img, label = batch['img'].to(device), batch['label'].to(device)\n",
    "            outputs = model(img)\n",
    "            acc1, acc5 = accuracy(outputs, label, topk=(1, 5))\n",
    "            top1_acc += acc1.item() * img.size(0)\n",
    "            top5_acc += acc5.item() * img.size(0)\n",
    "            total_samples += img.size(0)\n",
    "    top1_acc /= total_samples\n",
    "    top5_acc /= total_samples\n",
    "    return top1_acc, top5_acc\n",
    "\n",
    "def eval_resnet(exp_name, pretrained = False):\n",
    "    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "    if pretrained:\n",
    "        model = torchvision.models.resnet50(pretrained = True)\n",
    "    else:\n",
    "        model = torchvision.models.resnet50(pretrained = False)\n",
    "        model.load_state_dict(torch.load(f'../saved_models/{exp_name}.pt'))\n",
    "    model = model.to(device)\n",
    "\n",
    "    top1_acc, top5_acc = eval_loop(model, device)\n",
    "    print(f'Top-1 Accuracy: {top1_acc:.2f}%')\n",
    "    print(f'Top-5 Accuracy: {top5_acc:.2f}%')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50025f6b-3341-40b5-a287-e59ca1332bc4",
   "metadata": {},
   "source": [
    "## Main Running Cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb3cff31-2fdf-424b-a351-9a4094803c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "np.random.seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9989e72c-b632-4c09-ab4f-da795818e9e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "args = SimpleNamespace()\n",
    "args.exp_name = 'DeepMLP'\n",
    "args.eval = False\n",
    "args.student_model = 'DeepMLP'\n",
    "\n",
    "args.rep_sim = False ##Set this to add the representational similarity loss\n",
    "args.repdist = 'CKA'\n",
    "args.target_model = 'rn18'\n",
    "\n",
    "args.num_workers = 4\n",
    "args.batch_size = 256\n",
    "args.lr = 1e-3\n",
    "args.num_epochs = 2\n",
    "args.alpha = 1.0 ## How much should come from the representational similarity\n",
    "args.one_layer = False\n",
    "args.pretrained = False #Whether the student model is pretrained (set to False)\n",
    "args.logging = '../logs'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d34068ac-893a-41d9-a661-324664596485",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_resnet(args = args, \n",
    "             exp_name = args.exp_name, \n",
    "             rep_sim = args.rep_sim, \n",
    "             num_epochs = args.num_epochs, \n",
    "             student_model = args.student_model, \n",
    "             target_model = args.target_model, \n",
    "             lr = args.lr,\n",
    "             batch_size = args.batch_size, \n",
    "             num_workers = args.num_workers, \n",
    "             pretrained = args.pretrained, \n",
    "             rep_dist = args.repdist, \n",
    "             rep_sim_alpha = args.alpha, \n",
    "             one_layer = args.one_layer)"
   ]
  },
  {
   "cell_type": "raw",
   "id": "cba32758-2d9a-4127-9f77-c391cd814314",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
