{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2f843da",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-24T05:47:51.395110Z",
     "iopub.status.busy": "2024-02-24T05:47:51.394900Z",
     "iopub.status.idle": "2024-02-24T05:47:54.580552Z",
     "shell.execute_reply": "2024-02-24T05:47:54.580019Z",
     "shell.execute_reply.started": "2024-02-24T05:47:51.395063Z"
    }
   },
   "outputs": [],
   "source": [
    "from mpl_toolkits import axes_grid1\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from functools import partial\n",
    "from warnings import warn\n",
    "from typing import List, Dict\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch import optim\n",
    "from torch.utils.data import DataLoader, RandomSampler\n",
    "from torchvision import models\n",
    "\n",
    "import os\n",
    "import time\n",
    "import numpy as np\n",
    "import copy\n",
    "\n",
    "from network import mnist_net, res_net, alex_net, cifar_net, vit_net, reg_net, generator\n",
    "from network.modules import get_resnet, get_generator, freeze, unfreeze, freeze_, unfreeze_, LARS\n",
    "from tools.miro_utils import *\n",
    "from tools.farmer import *\n",
    "import data_loader\n",
    "from main_base import evaluate\n",
    "from main_test import evaluate_digit, evaluate_image, evaluate_pacs, evaluate_officehome, evaluate_vlcs\n",
    "\n",
    "from tools.cka import CKA"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cdbac1e-c214-4141-a607-1013ff1fba5f",
   "metadata": {},
   "source": [
    "# Digits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6bb6663e-6d37-44f3-b2f1-d0261d619ec6",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-23T08:29:25.806182Z",
     "iopub.status.busy": "2024-02-23T08:29:25.805896Z",
     "iopub.status.idle": "2024-02-23T08:29:25.891841Z",
     "shell.execute_reply": "2024-02-23T08:29:25.891375Z",
     "shell.execute_reply.started": "2024-02-23T08:29:25.806159Z"
    }
   },
   "outputs": [],
   "source": [
    "random.seed(1111)\n",
    "torch.manual_seed(1111)\n",
    "np.random.seed(1111)\n",
    "trset = data_loader.load_mnist(split='train')\n",
    "dataloader  = DataLoader(trset, batch_size=32, num_workers=8, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d491fe3-adca-41d3-b6a1-a6e734bb6b01",
   "metadata": {},
   "source": [
    "## PEER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5ed8cd8-57f5-4507-99be-b72f2413c9e8",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-23T08:29:40.073664Z",
     "iopub.status.busy": "2024-02-23T08:29:40.073388Z",
     "iopub.status.idle": "2024-02-23T08:29:40.124807Z",
     "shell.execute_reply": "2024-02-23T08:29:40.124389Z",
     "shell.execute_reply.started": "2024-02-23T08:29:40.073642Z"
    }
   },
   "outputs": [],
   "source": [
    "src_net = mnist_net.ConvNet(128)\n",
    "ckpt= \"./saved-model/peer/mnist/cnn_img_custom_mdar_False_128_lmda0.0051_oracleTrue_self_digits_2.0_run0/99-best.pkl\"\n",
    "saved_weight = torch.load(ckpt)\n",
    "src_net.load_state_dict(saved_weight['cls_net'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e5ddb22-22e8-413b-b256-0f321235b76d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-23T08:29:41.888940Z",
     "iopub.status.busy": "2024-02-23T08:29:41.888610Z",
     "iopub.status.idle": "2024-02-23T08:29:41.945741Z",
     "shell.execute_reply": "2024-02-23T08:29:41.945288Z",
     "shell.execute_reply.started": "2024-02-23T08:29:41.888912Z"
    }
   },
   "outputs": [],
   "source": [
    "oracleckpt= \"./saved-model/worship/mnist/cnn_img_custom_mdar_False_128_lmda0.0051_oracleTrue_self_digits_2.0_run0/peer-best.pkl\"\n",
    "oracle_net = mnist_net.ConvNet(128)\n",
    "o_saved_weight = torch.load(oracleckpt)\n",
    "oracle_net.load_state_dict(o_saved_weight['cls_net'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b2eda36-f44c-4f84-af26-8800ff014151",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-23T08:29:44.167361Z",
     "iopub.status.busy": "2024-02-23T08:29:44.167103Z",
     "iopub.status.idle": "2024-02-23T08:41:45.464668Z",
     "shell.execute_reply": "2024-02-23T08:41:45.464064Z",
     "shell.execute_reply.started": "2024-02-23T08:29:44.167340Z"
    }
   },
   "outputs": [],
   "source": [
    "layers= ['','conv1', 'mp', 'relu1', 'fc1', 'relu3', 'cls_head']\n",
    "cka = CKA(oracle_net, src_net,\n",
    "              model1_name=\"Task Model\",   # good idea to provide names to avoid confusion\n",
    "              model2_name=\"Auxiliary Model\",   \n",
    "              model1_layers=layers, # List of layers to extract features from\n",
    "              model2_layers=layers, # extracts all layer features by default\n",
    "              device='cuda:0')\n",
    "    \n",
    "cka.compare(dataloader) # secondary dataloader is optional\n",
    "    \n",
    "results = cka.export()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6a72ee-60e3-4ba3-9b3e-80f4cc11aa2e",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-02-23T08:41:45.465882Z",
     "iopub.status.busy": "2024-02-23T08:41:45.465665Z",
     "iopub.status.idle": "2024-02-23T08:41:45.910900Z",
     "shell.execute_reply": "2024-02-23T08:41:45.910462Z",
     "shell.execute_reply.started": "2024-02-23T08:41:45.465860Z"
    }
   },
   "outputs": [],
   "source": [
    "cka.plot_results(save_path=\"./CKA/cka_peer_task_99.pdf\")\n",
    "#torch.save(results, \"./cka_peer.json\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "peer",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
