{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import copy\n",
    "import easydict\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "import matplotlib\n",
    "from tqdm import tqdm\n",
    "from maml.utils import load_dataset, load_model, update_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "matplotlib.rcParams.update({'font.size': 20})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1000 episodes test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_sample_task(dataset):\n",
    "    sample_task = dataset.sample_task()\n",
    "    for idx, (image, label) in enumerate(sample_task['train']):\n",
    "        if idx == 0:\n",
    "            s_images = image.unsqueeze(0)\n",
    "            s_labels = [label]\n",
    "            s_real_labels = [sample_task['train'].index[label]]\n",
    "        else:\n",
    "            s_images = torch.cat([s_images, image.unsqueeze(0)], dim=0)\n",
    "            s_labels.append(label)\n",
    "            s_real_labels.append(sample_task['train'].index[label])\n",
    "    \n",
    "    for idx, (image, label) in enumerate(sample_task['test']):\n",
    "        if idx == 0:\n",
    "            q_images = image.unsqueeze(0)\n",
    "            q_labels = [label]\n",
    "            q_real_labels = [sample_task['test'].index[label]]\n",
    "        else:\n",
    "            q_images = torch.cat([q_images, image.unsqueeze(0)], dim=0)\n",
    "            q_labels.append(label)\n",
    "            q_real_labels.append(sample_task['test'].index[label])\n",
    "    \n",
    "    s_labels = torch.tensor(s_labels).type(torch.LongTensor)\n",
    "    s_real_labels = torch.tensor(s_real_labels).type(torch.LongTensor)\n",
    "    q_labels = torch.tensor(q_labels).type(torch.LongTensor)\n",
    "    q_real_labels = torch.tensor(q_real_labels).type(torch.LongTensor)\n",
    "    return [s_images, s_labels, s_real_labels, q_images, q_labels, q_real_labels]\n",
    "\n",
    "def isfloat(value):\n",
    "    try:\n",
    "        float(value)\n",
    "        return True\n",
    "    except ValueError:\n",
    "        return False\n",
    "\n",
    "def get_arguments(path, dataset, save_name):\n",
    "    filename = '{}/{}_{}/logs/arguments.txt'.format(path, dataset, save_name)\n",
    "\n",
    "    args = easydict.EasyDict()\n",
    "    with open(filename) as f:\n",
    "        for line in f:\n",
    "            key, val = line.split(\": \")\n",
    "            if '\\n' in val:\n",
    "                val = val[:-1]\n",
    "            if isfloat(val):\n",
    "                if val.isdigit():\n",
    "                    val = int(val)\n",
    "                else:\n",
    "                    val = float(val)\n",
    "            if val == 'True' or val == 'False':\n",
    "                val = val == 'True'\n",
    "            args[key] = val\n",
    "    return args\n",
    "\n",
    "def print_accuracy(args, test_dataset, sample_tasks, iteration, NIL_testing=False):\n",
    "    device = torch.device(args.device)\n",
    "    sample_number = len(sample_tasks)\n",
    "    \n",
    "    index = ['task{}'.format(str(i+1)) for i in range(sample_number)]\n",
    "    columns = []\n",
    "    columns += ['Accuracy on support set (before adaptation)', 'Accuracy on query set (before adaptation)']\n",
    "    columns += ['Accuracy on support set (after adaptation)', 'Accuracy on query set (after adaptation)']\n",
    "    \n",
    "    if NIL_testing:\n",
    "        filename_pd = '{}/{}_{}/logs/{}_nil_results_{}.csv'.format(args.output_folder, args.dataset, args.save_name, test_dataset, iteration)\n",
    "    else:\n",
    "        filename_pd = '{}/{}_{}/logs/{}_results_{}.csv'.format(args.output_folder, args.dataset, args.save_name, test_dataset, iteration)\n",
    "    test_pd = pd.DataFrame(np.zeros([sample_number, len(columns)]), index=index, columns=columns)\n",
    "    \n",
    "    model = load_model(args)\n",
    "    \n",
    "    if args.model == '4conv':\n",
    "        checkpoint = '{}/{}_{}/models/epochs_30000.pt'.format(args.output_folder, args.dataset, args.save_name)\n",
    "    elif args.model == 'resnet':\n",
    "        checkpoint = '{}/{}_{}/models/epochs_10000.pt'.format(args.output_folder, args.dataset, args.save_name)\n",
    "        \n",
    "    checkpoint = torch.load(checkpoint, map_location=device)\n",
    "    \n",
    "    for idx in tqdm(range(sample_number)):\n",
    "        task_log = []\n",
    "               \n",
    "        model.load_state_dict(checkpoint, strict=True)\n",
    "        model.to(device)\n",
    "\n",
    "        support_input = sample_tasks[idx][0].to(device)\n",
    "        support_target = sample_tasks[idx][1].to(device)\n",
    "        support_real_target = sample_tasks[idx][2]\n",
    "        query_input = sample_tasks[idx][3].to(device)\n",
    "        query_target = sample_tasks[idx][4].to(device)\n",
    "        query_real_target = sample_tasks[idx][5]\n",
    "        \n",
    "        model.train()\n",
    "        \n",
    "        # before adaptation\n",
    "        support_features, support_logit = model(support_input)\n",
    "        _, support_pred_target = torch.max(support_logit, dim=1)\n",
    "                \n",
    "        query_features, query_logit = model(query_input)\n",
    "        _, query_pred_target = torch.max(query_logit, dim=1)\n",
    "        \n",
    "        if NIL_testing:\n",
    "            cos = nn.CosineSimilarity()\n",
    "            support_features_mean = torch.zeros([args.num_ways, support_features.shape[1]]).to(device)\n",
    "            support_target_mean = torch.zeros([args.num_ways]).to(device)\n",
    "            for label in range(args.num_ways):\n",
    "                support_features_mean[label] = torch.mean(support_features[torch.where(support_target==label)], dim=0)\n",
    "                support_target_mean[label] = label\n",
    "\n",
    "            distance = torch.zeros([len(query_features), len(support_features_mean)])\n",
    "            for i, query_feature in enumerate(query_features):\n",
    "                distance[i] = cos(torch.cat([query_feature.unsqueeze(0)]*len(support_features_mean)), support_features_mean)\n",
    "            top_similar_idx = torch.argmax(distance, dim=1)\n",
    "        \n",
    "        task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())\n",
    "        task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())\n",
    "        \n",
    "        # after adaptation\n",
    "        inner_loss = F.cross_entropy(support_logit, support_target)\n",
    "        model.zero_grad()\n",
    "        \n",
    "        params = update_parameters(model, inner_loss, extractor_step_size=args.extractor_step_size, classifier_step_size=args.classifier_step_size, first_order=args.first_order)\n",
    "        \n",
    "        support_features, support_logit = model(support_input, params=params)\n",
    "        _, support_pred_target = torch.max(support_logit, dim=1)\n",
    "        \n",
    "        query_features, query_logit = model(query_input, params=params)\n",
    "        _, query_pred_target = torch.max(query_logit, dim=1)\n",
    "        \n",
    "        if NIL_testing:            \n",
    "            cos = nn.CosineSimilarity()\n",
    "            support_features_mean = torch.zeros([args.num_ways, support_features.shape[1]]).to(device)\n",
    "            support_target_mean = torch.zeros([args.num_ways]).to(device)\n",
    "            for label in range(args.num_ways):\n",
    "                support_features_mean[label] = torch.mean(support_features[torch.where(support_target==label)], dim=0)\n",
    "                support_target_mean[label] = label\n",
    "\n",
    "            distance = torch.zeros([len(query_features), len(support_features_mean)])\n",
    "            for i, query_feature in enumerate(query_features):\n",
    "                distance[i] = cos(torch.cat([query_feature.unsqueeze(0)]*len(support_features_mean)), support_features_mean)\n",
    "            top_similar_idx = torch.argmax(distance, dim=1)\n",
    "\n",
    "            query_pred_target = support_target_mean[top_similar_idx]\n",
    "            \n",
    "        task_log.append((sum(support_target==support_pred_target)/float(len(support_target))).item())\n",
    "        task_log.append((sum(query_target==query_pred_target)/float(len(query_target))).item())\n",
    "\n",
    "        test_pd.iloc[idx] = task_log\n",
    "    test_pd.loc[sample_number+1], test_pd.loc[sample_number+2] = test_pd.mean(axis=0), test_pd.std(axis=0)\n",
    "    test_pd.index = list(test_pd.index[:sample_number]) + ['mean', 'std']\n",
    "    test_pd.to_csv(filename_pd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_number = 1000\n",
    "dataset = [('miniimagenet', 'miniimagenet')]\n",
    "\n",
    "model = '4conv'\n",
    "path = './output'\n",
    "\n",
    "for train_dataset, test_dataset in dataset:\n",
    "    for num_shots in [1, 5]:\n",
    "        for iteration in [1,2,3,4,5]:\n",
    "            dataset_args = easydict.EasyDict({'folder': './data',\n",
    "                                              'dataset': test_dataset,\n",
    "                                              'num_ways': 5,\n",
    "                                              'num_shots': num_shots,\n",
    "                                              'download': True})\n",
    "\n",
    "            sample_tasks = [make_sample_task(load_dataset(dataset_args, 'meta_test')) for _ in tqdm(range(sample_number))]\n",
    "                        \n",
    "            for algorithm in ['MAML', 'BOIL']:\n",
    "                save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)\n",
    "                args = get_arguments(path, train_dataset, save_name)\n",
    "                \n",
    "                print_accuracy(args, test_dataset, sample_tasks, iteration=iteration, NIL_testing=False)\n",
    "                print_accuracy(args, test_dataset, sample_tasks, iteration=iteration, NIL_testing=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# feature space, logit space (Cosine Similarity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_features_logits(args, sample_task, pos):\n",
    "    device = torch.device(args.device)\n",
    "    \n",
    "    model = load_model(args)\n",
    "    if args.model == '4conv':\n",
    "        checkpoint = '{}/{}_{}/models/epochs_30000.pt'.format(args.output_folder, args.dataset, args.save_name)\n",
    "    if args.model == 'resnet':\n",
    "        checkpoint = '{}/{}_{}/models/epochs_10000.pt'.format(args.output_folder, args.dataset, args.save_name)\n",
    "    checkpoint = torch.load(checkpoint, map_location=device)\n",
    "               \n",
    "    model.load_state_dict(checkpoint, strict=True)\n",
    "    model.to(device)\n",
    "\n",
    "    support_input = sample_task[0].to(device)\n",
    "    support_target = sample_task[1].to(device)\n",
    "    support_real_target = sample_task[2]\n",
    "    query_input = sample_task[3].to(device)\n",
    "    query_target = sample_task[4].to(device)\n",
    "    query_real_target = sample_task[5]\n",
    "    \n",
    "    model.train()\n",
    "\n",
    "    # before adaptation\n",
    "    before_support_features, before_support_logits = model(support_input)\n",
    "    if args.model == '4conv':\n",
    "        if pos == 1:\n",
    "            before_query_features = model.features[0](query_input)\n",
    "        elif pos == 2:\n",
    "            before_query_features = model.features[1](model.features[0](query_input))\n",
    "        elif pos == 3:\n",
    "            before_query_features = model.features[2](model.features[1](model.features[0](query_input)))\n",
    "        elif pos == 4:\n",
    "            before_query_features = model.features[3](model.features[2](model.features[1](model.features[0](query_input))))\n",
    "        elif pos == 5:\n",
    "            before_query_features = model.classifier(model.features[3](model.features[2](model.features[1](model.features[0](query_input)))).view(75, -1))\n",
    "    elif args.model == 'resnet':\n",
    "        if pos == 1:\n",
    "            before_query_features = model.layer1(query_input)\n",
    "        elif pos == 2:\n",
    "            before_query_features = model.layer2(model.layer1(query_input))\n",
    "        elif pos == 3:\n",
    "            before_query_features = model.layer3(model.layer2(model.layer1(query_input)))\n",
    "        elif pos == 4:\n",
    "            before_query_features = model.layer4(model.layer3(model.layer2(model.layer1(query_input))))\n",
    "        elif pos == 5:\n",
    "            before_query_features = model.classifier(F.avg_pool2d(model.layer4(model.layer3(model.layer2(model.layer1(query_input)))), 5).view(75, -1))\n",
    "    \n",
    "    before_query_features = before_query_features.view(75, -1)\n",
    "    \n",
    "    # after adaptation\n",
    "    inner_loss = F.cross_entropy(before_support_logits, support_target)\n",
    "    \n",
    "    model.zero_grad()\n",
    "    params = update_parameters(model, inner_loss, extractor_step_size=args.extractor_step_size, classifier_step_size=args.classifier_step_size, first_order=args.first_order)\n",
    "    model.load_state_dict(params, strict=True)\n",
    "    \n",
    "    if args.model == '4conv':\n",
    "        if pos == 1:\n",
    "            after_query_features = model.features[0](query_input)\n",
    "        elif pos == 2:\n",
    "            after_query_features = model.features[1](model.features[0](query_input))\n",
    "        elif pos == 3:\n",
    "            after_query_features = model.features[2](model.features[1](model.features[0](query_input)))\n",
    "        elif pos == 4:\n",
    "            after_query_features = model.features[3](model.features[2](model.features[1](model.features[0](query_input))))\n",
    "        elif pos == 5:\n",
    "            after_query_features = model.classifier(model.features[3](model.features[2](model.features[1](model.features[0](query_input)))).view(75, -1))\n",
    "    elif args.model == 'resnet':\n",
    "        if pos == 1:\n",
    "            after_query_features = model.layer1(query_input)\n",
    "        elif pos == 2:\n",
    "            after_query_features = model.layer2(model.layer1(query_input))\n",
    "        elif pos == 3:\n",
    "            after_query_features = model.layer3(model.layer2(model.layer1(query_input)))\n",
    "        elif pos == 4:\n",
    "            after_query_features = model.layer4(model.layer3(model.layer2(model.layer1(query_input))))\n",
    "        elif pos == 5:\n",
    "            after_query_features = model.classifier(F.avg_pool2d(model.layer4(model.layer3(model.layer2(model.layer1(query_input)))), 5).view(75, -1))\n",
    "    \n",
    "    after_query_features = after_query_features.view(75, -1)\n",
    "    \n",
    "    return (before_query_features.unsqueeze(0).detach().cpu(), after_query_features.unsqueeze(0).detach().cpu())\n",
    "\n",
    "def get_similarity(outputs):\n",
    "    distance = torch.tensor([])\n",
    "    cos = nn.CosineSimilarity()\n",
    "    for i in range(len(outputs)):\n",
    "        tmp_distance = torch.zeros([len(outputs[i]), len(outputs[i])])\n",
    "        for j, output in enumerate(outputs[i]):\n",
    "            tmp_distance[j] = cos(torch.cat([output.unsqueeze(0)]*len(outputs[i])), outputs[i])\n",
    "        distance = torch.cat([distance, tmp_distance.unsqueeze(0)], dim=0)\n",
    "    return distance\n",
    "\n",
    "def get_mean(similarity_matrices):\n",
    "    num_images = 15\n",
    "    same_class = []\n",
    "    different_class = []\n",
    "    \n",
    "    for i in range(len(similarity_matrices)):\n",
    "        similarity_matrices[i][range(5*num_images), range(5*num_images)] = 0\n",
    "        \n",
    "        same_class_distance = torch.zeros([5*num_images, 5*num_images])\n",
    "        same_class_distance[0*num_images:1*num_images, 0*num_images:1*num_images] = similarity_matrices[i][0*num_images:1*num_images, 0*num_images:1*num_images]\n",
    "        same_class_distance[1*num_images:2*num_images, 1*num_images:2*num_images] = similarity_matrices[i][1*num_images:2*num_images, 1*num_images:2*num_images]\n",
    "        same_class_distance[2*num_images:3*num_images, 2*num_images:3*num_images] = similarity_matrices[i][2*num_images:3*num_images, 2*num_images:3*num_images]\n",
    "        same_class_distance[3*num_images:4*num_images, 3*num_images:4*num_images] = similarity_matrices[i][3*num_images:4*num_images, 3*num_images:4*num_images]\n",
    "        same_class_distance[4*num_images:5*num_images, 4*num_images:5*num_images] = similarity_matrices[i][4*num_images:5*num_images, 4*num_images:5*num_images]\n",
    "\n",
    "        different_class_distance = similarity_matrices[i] - same_class_distance\n",
    "        \n",
    "        same_class.append((torch.sum(same_class_distance) / len(same_class_distance.nonzero())).item())\n",
    "        different_class.append((torch.sum(different_class_distance) / len(different_class_distance.nonzero())).item())\n",
    "        \n",
    "    return same_class, different_class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'miniimagenet'\n",
    "num_shots = 5\n",
    "dataset_args = easydict.EasyDict({'folder': './data',\n",
    "                                  'dataset': dataset,\n",
    "                                  'num_ways': 5,\n",
    "                                  'num_shots': num_shots,\n",
    "                                  'download': True})\n",
    "\n",
    "sample_task = make_sample_task(load_dataset(dataset_args, 'meta_train'))\n",
    "\n",
    "model = '4conv'\n",
    "path = './output'\n",
    "algorithms = ['MAML', 'BOIL']\n",
    "\n",
    "for algorithm in algorithms:\n",
    "    save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)\n",
    "    args = get_arguments(path, dataset, save_name)\n",
    "    \n",
    "    fig, axes = plt.subplots(1, 2, sharey=True, figsize=(10, 3))\n",
    "    \n",
    "    axes[0].set_title('Before inner loop adatation', fontsize=16)\n",
    "    axes[0].set_ylim([0.0-0.05, 1.0+0.05])\n",
    "    axes[0].tick_params(axis='both', which='major', labelsize=16)\n",
    "    axes[0].grid(True)\n",
    "\n",
    "    axes[1].set_title('After inner loop adatation', fontsize=16)\n",
    "    axes[1].set_ylim([0.0-0.05, 1.0+0.05])\n",
    "    axes[1].tick_params(axis='both', which='major', labelsize=16)\n",
    "    axes[1].grid(True)\n",
    "    \n",
    "    before_different_list = []\n",
    "    before_same_list = []\n",
    "    \n",
    "    after_different_list = []\n",
    "    after_same_list = []\n",
    "    \n",
    "    if model == '4conv':\n",
    "        xrange = ['conv1', 'conv2', 'conv3', 'conv4']\n",
    "        pos_list = [1,2,3,4]\n",
    "    elif model == 'resnet':\n",
    "        xrange = ['block1', 'block2', 'block3', 'block4']\n",
    "        pos_list = [1,2,3,4]\n",
    "    \n",
    "    for pos in pos_list:\n",
    "        before_f, after_f = get_features_logits(args, sample_task, pos=pos)\n",
    "\n",
    "        before = get_similarity(outputs=before_f)\n",
    "        after = get_similarity(outputs=after_f)\n",
    "\n",
    "        before_same_class, before_different_class = get_mean(before)\n",
    "        after_same_class, after_different_class = get_mean(after)\n",
    "        \n",
    "        before_different_list.append(before_different_class)\n",
    "        before_same_list.append(before_same_class)\n",
    "        \n",
    "        after_different_list.append(after_different_class)\n",
    "        after_same_list.append(after_same_class)\n",
    "            \n",
    "    axes[0].plot(xrange, before_different_list, marker='o')\n",
    "    axes[0].plot(xrange, before_same_list, marker='o')\n",
    "\n",
    "    axes[1].plot(xrange, after_different_list, marker='o')\n",
    "    axes[1].plot(xrange, after_same_list, marker='o')\n",
    "        \n",
    "    plt.show()\n",
    "    plt.subplots_adjust(wspace=0.2)\n",
    "#     plt.savefig('./src/{}_cosine.pdf'.format(algorithm), bbox_inches='tight', format='pdf')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CCA/CKA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gram_linear(x):\n",
    "    \"\"\"Compute Gram (kernel) matrix for a linear kernel.\n",
    "\n",
    "    Args:\n",
    "    x: A num_examples x num_features matrix of features.\n",
    "\n",
    "    Returns:\n",
    "    A num_examples x num_examples Gram matrix of examples.\n",
    "    \"\"\"\n",
    "    return x.dot(x.T)\n",
    "\n",
    "\n",
    "def gram_rbf(x, threshold=1.0):\n",
    "    \"\"\"Compute Gram (kernel) matrix for an RBF kernel.\n",
    "\n",
    "    Args:\n",
    "    x: A num_examples x num_features matrix of features.\n",
    "    threshold: Fraction of median Euclidean distance to use as RBF kernel\n",
    "      bandwidth. (This is the heuristic we use in the paper. There are other\n",
    "      possible ways to set the bandwidth; we didn't try them.)\n",
    "\n",
    "    Returns:\n",
    "    A num_examples x num_examples Gram matrix of examples.\n",
    "    \"\"\"\n",
    "    dot_products = x.dot(x.T)\n",
    "    sq_norms = np.diag(dot_products)\n",
    "    sq_distances = -2 * dot_products + sq_norms[:, None] + sq_norms[None, :]\n",
    "    sq_median_distance = np.median(sq_distances)\n",
    "    return np.exp(-sq_distances / (2 * threshold ** 2 * sq_median_distance))\n",
    "\n",
    "\n",
    "def center_gram(gram, unbiased=False):\n",
    "    \"\"\"Center a symmetric Gram matrix.\n",
    "\n",
    "    This is equvialent to centering the (possibly infinite-dimensional) features\n",
    "    induced by the kernel before computing the Gram matrix.\n",
    "\n",
    "    Args:\n",
    "    gram: A num_examples x num_examples symmetric matrix.\n",
    "    unbiased: Whether to adjust the Gram matrix in order to compute an unbiased\n",
    "      estimate of HSIC. Note that this estimator may be negative.\n",
    "\n",
    "    Returns:\n",
    "    A symmetric matrix with centered columns and rows.\n",
    "    \"\"\"\n",
    "    if not np.allclose(gram, gram.T):\n",
    "        raise ValueError('Input must be a symmetric matrix.')\n",
    "    gram = gram.copy()\n",
    "\n",
    "    if unbiased:\n",
    "        # This formulation of the U-statistic, from Szekely, G. J., & Rizzo, M.\n",
    "        # L. (2014). Partial distance correlation with methods for dissimilarities.\n",
    "        # The Annals of Statistics, 42(6), 2382-2412, seems to be more numerically\n",
    "        # stable than the alternative from Song et al. (2007).\n",
    "        n = gram.shape[0]\n",
    "        np.fill_diagonal(gram, 0)\n",
    "        means = np.sum(gram, 0, dtype=np.float64) / (n - 2)\n",
    "        means -= np.sum(means) / (2 * (n - 1))\n",
    "        gram -= means[:, None]\n",
    "        gram -= means[None, :]\n",
    "        np.fill_diagonal(gram, 0)\n",
    "    else:\n",
    "        means = np.mean(gram, 0, dtype=np.float64)\n",
    "        means -= np.mean(means) / 2\n",
    "        gram -= means[:, None]\n",
    "        gram -= means[None, :]\n",
    "\n",
    "    return gram\n",
    "\n",
    "def cka(gram_x, gram_y, debiased=False):\n",
    "    \"\"\"Compute CKA.\n",
    "\n",
    "    Args:\n",
    "    gram_x: A num_examples x num_examples Gram matrix.\n",
    "    gram_y: A num_examples x num_examples Gram matrix.\n",
    "    debiased: Use unbiased estimator of HSIC. CKA may still be biased.\n",
    "\n",
    "    Returns:\n",
    "    The value of CKA between X and Y.\n",
    "    \"\"\"\n",
    "    gram_x = center_gram(gram_x, unbiased=debiased)\n",
    "    gram_y = center_gram(gram_y, unbiased=debiased)\n",
    "\n",
    "    # Note: To obtain HSIC, this should be divided by (n-1)**2 (biased variant) or\n",
    "    # n*(n-3) (unbiased variant), but this cancels for CKA.\n",
    "    scaled_hsic = gram_x.ravel().dot(gram_y.ravel())\n",
    "\n",
    "    normalization_x = np.linalg.norm(gram_x)\n",
    "    normalization_y = np.linalg.norm(gram_y)\n",
    "    return scaled_hsic / (normalization_x * normalization_y)\n",
    "\n",
    "\n",
    "def _debiased_dot_product_similarity_helper(xty, sum_squared_rows_x, sum_squared_rows_y, squared_norm_x, squared_norm_y, n):\n",
    "    \"\"\"Helper for computing debiased dot product similarity (i.e. linear HSIC).\"\"\"\n",
    "    # This formula can be derived by manipulating the unbiased estimator from\n",
    "    # Song et al. (2007).\n",
    "    return (\n",
    "      xty - n / (n - 2.) * sum_squared_rows_x.dot(sum_squared_rows_y)\n",
    "      + squared_norm_x * squared_norm_y / ((n - 1) * (n - 2)))\n",
    "\n",
    "\n",
    "def feature_space_linear_cka(features_x, features_y, debiased=False):\n",
    "    \"\"\"Compute CKA with a linear kernel, in feature space.\n",
    "\n",
    "    This is typically faster than computing the Gram matrix when there are fewer\n",
    "    features than examples.\n",
    "\n",
    "    Args:\n",
    "    features_x: A num_examples x num_features matrix of features.\n",
    "    features_y: A num_examples x num_features matrix of features.\n",
    "    debiased: Use unbiased estimator of dot product similarity. CKA may still be\n",
    "      biased. Note that this estimator may be negative.\n",
    "\n",
    "    Returns:\n",
    "    The value of CKA between X and Y.\n",
    "    \"\"\"\n",
    "    features_x = features_x - np.mean(features_x, 0, keepdims=True)\n",
    "    features_y = features_y - np.mean(features_y, 0, keepdims=True)\n",
    "\n",
    "    dot_product_similarity = np.linalg.norm(features_x.T.dot(features_y)) ** 2\n",
    "    normalization_x = np.linalg.norm(features_x.T.dot(features_x))\n",
    "    normalization_y = np.linalg.norm(features_y.T.dot(features_y))\n",
    "\n",
    "    if debiased:\n",
    "        n = features_x.shape[0]\n",
    "        # Equivalent to np.sum(features_x ** 2, 1) but avoids an intermediate array.\n",
    "        sum_squared_rows_x = np.einsum('ij,ij->i', features_x, features_x)\n",
    "        sum_squared_rows_y = np.einsum('ij,ij->i', features_y, features_y)\n",
    "        squared_norm_x = np.sum(sum_squared_rows_x)\n",
    "        squared_norm_y = np.sum(sum_squared_rows_y)\n",
    "\n",
    "        dot_product_similarity = _debiased_dot_product_similarity_helper(\n",
    "            dot_product_similarity, sum_squared_rows_x, sum_squared_rows_y,\n",
    "            squared_norm_x, squared_norm_y, n)\n",
    "        normalization_x = np.sqrt(_debiased_dot_product_similarity_helper(\n",
    "            normalization_x ** 2, sum_squared_rows_x, sum_squared_rows_x,\n",
    "            squared_norm_x, squared_norm_x, n))\n",
    "        normalization_y = np.sqrt(_debiased_dot_product_similarity_helper(\n",
    "            normalization_y ** 2, sum_squared_rows_y, sum_squared_rows_y,\n",
    "            squared_norm_y, squared_norm_y, n))\n",
    "\n",
    "    return dot_product_similarity / (normalization_x * normalization_y)\n",
    "\n",
    "def cca(features_x, features_y):\n",
    "    \"\"\"Compute the mean squared CCA correlation (R^2_{CCA}).\n",
    "\n",
    "    Args:\n",
    "    features_x: A num_examples x num_features matrix of features.\n",
    "    features_y: A num_examples x num_features matrix of features.\n",
    "\n",
    "    Returns:\n",
    "    The mean squared CCA correlations between X and Y.\n",
    "    \"\"\"\n",
    "    qx, _ = np.linalg.qr(features_x)  # Or use SVD with full_matrices=False.\n",
    "    qy, _ = np.linalg.qr(features_y)\n",
    "    return np.linalg.norm(qx.T.dot(qy)) ** 2 / min(\n",
    "      features_x.shape[1], features_y.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'miniimagenet'\n",
    "num_shots = 5\n",
    "dataset_args = easydict.EasyDict({'folder': './data',\n",
    "                                  'dataset': dataset,\n",
    "                                  'num_ways': 5,\n",
    "                                  'num_shots': num_shots,\n",
    "                                  'download': True})\n",
    "\n",
    "sample_task = make_sample_task(load_dataset(dataset_args, 'meta_train'))\n",
    "\n",
    "model = '4conv'\n",
    "path = './output'\n",
    "algorithms = ['MAML', 'BOIL']\n",
    "fig, ax = plt.subplots(1, 1, sharey=True, figsize=(8,6))\n",
    "\n",
    "for algorithm in algorithms:\n",
    "    save_name = '{}shot_{}_{}'.format(num_shots, model, algorithm)\n",
    "    args = get_arguments(path, dataset, save_name)\n",
    "    \n",
    "    \n",
    "    ax.set_title('CKA')\n",
    "    ax.set_ylim([0.0-0.05, 1.0+0.05])\n",
    "    ax.tick_params(axis='both', which='major')\n",
    "    ax.grid(True)\n",
    "    \n",
    "    all_before_f = torch.tensor([])\n",
    "    all_after_f = torch.tensor([])\n",
    "    \n",
    "    cka_list = []\n",
    "    \n",
    "    if model == '4conv':\n",
    "        xrange = ['conv1', 'conv2', 'conv3', 'conv4', 'head']\n",
    "        pos_list = [1,2,3,4,5]\n",
    "    elif model == 'resnet':\n",
    "        xrange = ['block1', 'block2', 'block3', 'block4']\n",
    "        pos_list = [1,2,3,4]\n",
    "    \n",
    "    for pos in pos_list:\n",
    "        before_f, after_f = get_features_logits(args, sample_task, pos=pos)\n",
    "        \n",
    "        before_f = before_f.squeeze(0).numpy()\n",
    "        after_f = after_f.squeeze(0).numpy()\n",
    "        \n",
    "        cka_from_features = feature_space_linear_cka(before_f, after_f)\n",
    "        cka_list.append(cka_from_features)\n",
    "    \n",
    "    if model == '4conv':\n",
    "        if algorithm == 'MAML':\n",
    "            ax.plot(xrange, cka_list, marker='o', label='MAML', color='#4F81BD')\n",
    "        elif algorithm == 'BOIL':\n",
    "            ax.plot(xrange, cka_list, marker='D', label='BOIL', color='#C0504D')\n",
    "    elif model == 'resnet':\n",
    "        if algorithm == 'block_a_extractor':\n",
    "            ax.plot(xrange, cka_list, marker='o', label='BOIL w/ last skip connection', color='#4F81BD')\n",
    "        elif algorithm == 'block_b_extractor':\n",
    "            ax.plot(xrange, cka_list, marker='D', label='BOIL w/o last skip connection', color='#C0504D')\n",
    "    \n",
    "plt.legend()\n",
    "plt.show()\n",
    "plt.subplots_adjust(wspace=0.2)\n",
    "# plt.savefig('./src/{}_cka.pdf'.format(model), bbox_inches='tight', format='pdf')\n",
    "plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
