{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26043d37",
   "metadata": {},
   "outputs": [],
   "source": [
    "import  torch, os\n",
    "import  numpy as np\n",
    "from    MiniImagenet import MiniImagenet\n",
    "from    torch.utils.data import DataLoader\n",
    "import  random\n",
    "from    mini_meta_feature import Meta_mini\n",
    "from    mini_utils import get_config, save_model, name_path, load_model\n",
    "import matplotlib.pyplot as plt\n",
    "import pickle\n",
    "\n",
    "from visualization_utils import metric2_cos, get_cross_covariance, shuffle, get_averaged_matrix, get_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b431a9da",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_way = 5\n",
    "k_shot = 20\n",
    "k_qry = 20\n",
    "maml_order = \"first\"\n",
    "init_var = 1\n",
    "seed_start = 222\n",
    "seed_end = 232\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
    "num_epoch = 500\n",
    "task_num = 1\n",
    "batchsz = 10\n",
    "outer_lr, inner_lr = 0.001, 0.01\n",
    "train_update_steps, test_update_steps = 5, 10\n",
    "\n",
    "root = \"./results/\"\n",
    "\"\"\"Please set the data_root here\"\"\"\n",
    "data_root = \"./data/miniimagenet/\"\n",
    "\n",
    "mini = MiniImagenet(data_root, mode='test', n_way=n_way, k_shot=k_shot, k_query=k_qry, batchsz=400, resize=84)\n",
    "maml = Meta_mini(n_way, k_shot, k_qry, task_num, \n",
    "                train_update_steps, test_update_steps, \n",
    "                inner_lr, outer_lr, get_config(n_way), device).to(device)\n",
    "\n",
    "maml.set_last_layer_variance(init_var)\n",
    "if init_var == 0:\n",
    "    maml.set_last_layer_to_zero()\n",
    "\n",
    "db = DataLoader(mini, task_num, shuffle=True, num_workers=8, pin_memory=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2db7c0ad",
   "metadata": {},
   "source": [
    "# Original FOMAML"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5b1dc44",
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in range(seed_start, seed_end):\n",
    "    \"\"\"\n",
    "    memory_all: used to collect the computed cosine similarity along training\n",
    "    get_cross_covariance: used to compute the cosine similarity between features.\n",
    "    \"\"\"\n",
    "    # Set the random seed\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    \n",
    "    # Initialize the meta-learning model\n",
    "    maml = Meta_mini(n_way, k_shot, k_qry, task_num, \n",
    "                    train_update_steps, test_update_steps, \n",
    "                    inner_lr, outer_lr, get_config(n_way), device).to(device)\n",
    "    \n",
    "    # Initialize the data loader\n",
    "    db = DataLoader(mini, task_num, shuffle=True, num_workers=8, pin_memory=True)\n",
    "    \n",
    "    memory_all = list()\n",
    "    for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n",
    "        # We consider a simplified scenario where there is only one task. \n",
    "        # So this for-loop does not iterate but break quickly\n",
    "        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n",
    "        # The model undergoes 101 outer loop updates\n",
    "        for i in range(101):\n",
    "            # To avoid channel-memorization problem. we explicitly perform channel shuffling.\n",
    "            # Please refer to visualization_utils.py for more details\n",
    "            y_spt_s, y_qry_s =shuffle(y_spt, y_qry)\n",
    "            # Forward the data\n",
    "            accs = maml.forward_FOMAML(x_spt, y_spt_s, x_qry, y_qry_s)\n",
    "            if i % 10 == 0:\n",
    "                # The cosine similarity is computed. \n",
    "                # Please refer to visualization_utils.py for more details\n",
    "                memory = get_cross_covariance(maml, x_spt, x_qry, y_spt, y_qry)\n",
    "                \n",
    "                memory_all.append(memory)   \n",
    "        break\n",
    "    with open('./pickles/RandInit_{}.pickle'.format(seed), 'wb') as handle:\n",
    "        pickle.dump(memory_all, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3032e808",
   "metadata": {},
   "source": [
    "# FOMAML with zero-initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "374710c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in range(seed_start, seed_end):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    maml = Meta_mini(n_way, k_shot, k_qry, task_num, \n",
    "                    train_update_steps, test_update_steps, \n",
    "                    inner_lr, outer_lr, get_config(n_way), device).to(device)\n",
    "    db = DataLoader(mini, task_num, shuffle=True, num_workers=8, pin_memory=True)\n",
    "    \n",
    "    maml.set_last_layer_to_zero()\n",
    "    memory_all = list()\n",
    "    for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n",
    "        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n",
    "        for i in range(101):\n",
    "            y_spt_s, y_qry_s =shuffle(y_spt, y_qry)\n",
    "            accs = maml.forward_FOMAML(x_spt, y_spt_s, x_qry, y_qry_s)\n",
    "            if i % 10 == 0:\n",
    "                memory = get_cross_covariance(maml, x_spt, x_qry, y_spt, y_qry)\n",
    "                memory_all.append(memory)   \n",
    "        break\n",
    "    with open('./pickles/ZeroInit_{}.pickle'.format(seed), 'wb') as handle:\n",
    "        pickle.dump(memory_all, handle)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1be83ff5",
   "metadata": {},
   "source": [
    "# Zeroing trick: zeroing the final linear layer every outer loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b73c04c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in range(seed_start, seed_end):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    maml = Meta_mini(n_way, k_shot, k_qry, task_num, \n",
    "                    train_update_steps, test_update_steps, \n",
    "                    inner_lr, outer_lr, get_config(n_way), device).to(device)\n",
    "    db = DataLoader(mini, task_num, shuffle=True, num_workers=8, pin_memory=True)\n",
    "\n",
    "    memory_all = list()\n",
    "    maml.set_last_layer_to_zero()\n",
    "    for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):\n",
    "        x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)\n",
    "        for i in range(101):\n",
    "            \n",
    "            y_spt_s, y_qry_s =shuffle(y_spt, y_qry)\n",
    "            accs = maml.forward_FOMAML(x_spt, y_spt_s, x_qry, y_qry_s)\n",
    "            if i % 10 == 0:\n",
    "                memory = get_cross_covariance(maml, x_spt, x_qry, y_spt, y_qry)\n",
    "                memory_all.append(memory)   \n",
    "            if i % 1 == 0:\n",
    "                maml.set_last_layer_to_zero()\n",
    "        break\n",
    "    with open('./pickles/ZeroIter1_{}.pickle'.format(seed), 'wb') as handle:\n",
    "        pickle.dump(memory_all, handle)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef91ed58",
   "metadata": {},
   "outputs": [],
   "source": [
    "cmap = \"Spectral_r\"\n",
    "vmin, vmax = 0, 0.6\n",
    "\n",
    "c = 3\n",
    "print(\"Show the main results\")\n",
    "print(\"Note that the mid column is used to seperate the results. Please ignore it\")\n",
    "for task in [\"RandInit\", \"ZeroInit\",\"ZeroIter1\"]:\n",
    "    fig, axes = plt.subplots(1,c, figsize=(c*8+(c-1),4))\n",
    "    \n",
    "    matrix = np.abs(get_averaged_matrix(task))\n",
    "    ax = axes[0]\n",
    "    ax.set_title(\"\\n1 outer loop updates\", fontsize=(20))\n",
    "    ax.pcolormesh(get_map(matrix[0]), vmin=vmin, vmax=vmax, cmap=cmap, edgecolors=\"white\",linewidth=1.5)\n",
    "    ax.axis('off')\n",
    "\n",
    "    matrix = get_averaged_matrix(task)\n",
    "    ax = axes[1]\n",
    "    ax.set_title(\"\\n10 outer loop updates\", fontsize=(20))\n",
    "    ax.pcolormesh(get_map(matrix[1]), vmin=vmin, vmax=vmax, cmap=cmap, edgecolors=\"white\",linewidth=1.5)\n",
    "    ax.axis('off')\n",
    "\n",
    "    matrix = get_averaged_matrix(task)\n",
    "    ax = axes[2]\n",
    "    ax.set_title(\"\\n100 outer loop updates\", fontsize=(20))\n",
    "    ax.pcolormesh(get_map(matrix[10]), vmin=vmin, vmax=vmax, cmap=cmap, edgecolors=\"white\",linewidth=1.5)\n",
    "    ax.axis('off')\n",
    "        \n",
    "    plt.pause(0.1)\n",
    "\n",
    "print(\"Show the color bar\")\n",
    "a = np.array([[0,0.6]])\n",
    "plt.figure(figsize=(8, 0.4))\n",
    "img = plt.imshow(a, cmap=\"Spectral_r\")\n",
    "plt.gca().set_visible(False)\n",
    "cax = plt.axes([0.1, 0.2, 0.8, 0.6])\n",
    "plt.colorbar(orientation=\"horizontal\", cax=cax)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88db4117",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PT17",
   "language": "python",
   "name": "pt17"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
