{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import numpy as np\n",
    "import time\n",
    "import os\n",
    "from omniglot import task_generator\n",
    "from proto_attn_train import train, validate, count_acc\n",
    "import torch.nn.functional as F\n",
    "from collections import defaultdict\n",
    "import torch_scatter\n",
    "from torch.distributions import Normal\n",
    "import torchvision.datasets\n",
    "\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class AttrDict(dict):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super(AttrDict, self).__init__(*args, **kwargs)\n",
    "        self.__dict__ = self\n",
    "        \n",
    "config = {'dataset': 'omniglot', 'split_name': 'default', 'shot': 5, 'query': 15, 'classes_per_task': [20], 'examples_per_class': 20,\n",
    "          'groups_per_class': 2, 'examples_per_group': 32, 'max_epoch': 10000,\n",
    "          'nb_val_tasks': 1000, 'train_way': 2, 'test_way': 2, 'prob_xor': None, 'beta': 0,\n",
    "          'out_dim': 64, 'iterations': 0, 'temp': 0.5, 'scale': 1, 'verbose': False}\n",
    "config = AttrDict(config)\n",
    "accs = defaultdict(list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hot_attn(Q, K, V, temp):\n",
    "    return torch.softmax(Q@K.T/(temp),-1)@V  # * np.sqrt(K.shape[-1])\n",
    "\n",
    "def euclidean_metric(a, b):\n",
    "    n = a.shape[0]\n",
    "    m = b.shape[0]\n",
    "    a = a.unsqueeze(1).expand(n, m, -1)\n",
    "    b = b.unsqueeze(0).expand(n, m, -1)\n",
    "    logits = -((a - b)**2).sum(dim=2)\n",
    "    return logits\n",
    "\n",
    "def z_norm(x, h=1e-7):\n",
    "    return (x - x.mean(0))/(x.std(0, unbiased=True) + h)\n",
    "\n",
    "def conv_block(in_channels, out_channels):\n",
    "    # bn = CustomBatchNorm()\n",
    "    bn = nn.BatchNorm2d(out_channels, momentum=0.01, track_running_stats = False)\n",
    "    # nn.init.uniform_(bn.weight) # for pytorch 1.2 or later\n",
    "    return nn.Sequential(\n",
    "        nn.Conv2d(in_channels, out_channels, 3, padding=1),\n",
    "        bn,\n",
    "        nn.ReLU(),\n",
    "        nn.MaxPool2d(2)\n",
    "    )\n",
    "\n",
    "class Convnet(nn.Module):\n",
    "    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):\n",
    "        super().__init__()\n",
    "        self.encoder = nn.Sequential(\n",
    "            conv_block(x_dim, hid_dim),\n",
    "            conv_block(hid_dim, hid_dim),\n",
    "            conv_block(hid_dim, hid_dim),\n",
    "            conv_block(hid_dim, hid_dim),\n",
    "        )\n",
    "        self.embeddings = nn.Linear(hid_dim, z_dim)\n",
    "        self.mean = nn.Linear(z_dim, z_dim)\n",
    "        self.logvar = nn.Linear(z_dim, z_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.encoder(x)\n",
    "        h = x.view(x.size(0), -1)\n",
    "        h = self.embeddings(h)\n",
    "        h = nn.ReLU()(h)\n",
    "        mean = self.mean(h)\n",
    "        logvar = self.logvar(h)\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        return Normal(mean, std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def z_norm(x, h=1e-7):\n",
    "    return (x - x.mean(0))/(x.std(0, unbiased=True) + h)\n",
    "\n",
    "def forward_euclid(train, test, train_labels, config, attn_fn=hot_attn):\n",
    "    iterations=config.iterations\n",
    "    temp=config.temp\n",
    "    scale=config.scale\n",
    "    train, test = z_norm(train), z_norm(test)\n",
    "    \n",
    "    num_classes = train_labels.max() + 1\n",
    "    \n",
    "    # Self-attention feature selection\n",
    "    for _ in range(iterations):\n",
    "        for c in range(num_classes):\n",
    "            t = train[train_labels==c]\n",
    "            train[train_labels==c] = hot_attn(t, t, t, temp)  \n",
    "    rescale = train.abs().mean(0)\n",
    "    rescale = scale * (rescale - rescale.min()) / (rescale.max() - rescale.min() + 1e-7)\n",
    "    \n",
    "    # Compute predictions and accuracy\n",
    "    distances = euclidean_metric(rescale*test, rescale*train)  # Shape=(nb_test, nb_train)\n",
    "    weights = torch.softmax(distances, axis=-1) # Shape=(nb_test, nb_train)\n",
    "    predictions = weights @ F.one_hot(train_labels, num_classes=train_labels.max()+1).float()\n",
    "    predictions = torch.clip(predictions, 0.01, 0.99)\n",
    "    \n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Protonets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classify_proto(train, test, train_labels, **kwargs):\n",
    "    proto = torch_scatter.scatter_mean(train, train_labels.type(torch.int64), dim=0)\n",
    "\n",
    "    # Compute predictions and accuracy\n",
    "    logits = euclidean_metric(test, proto)\n",
    "    # predictions = torch.softmax(logits, axis=-1)\n",
    "    # return predictions\n",
    "    return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "epoch 100, loss=2.2504, kl=0.0000, acc=0.6212\n",
      "epoch 200, loss=1.2627, kl=0.0000, acc=0.7672\n",
      "epoch 300, loss=0.9046, kl=0.0000, acc=0.8251\n",
      "epoch 400, loss=0.7205, kl=0.0000, acc=0.8558\n",
      "epoch 500, loss=0.6028, kl=0.0000, acc=0.8764\n",
      "epoch 600, loss=0.5217, kl=0.0000, acc=0.8910\n",
      "epoch 700, loss=0.4622, kl=0.0000, acc=0.9018\n",
      "epoch 800, loss=0.4155, kl=0.0000, acc=0.9105\n",
      "epoch 900, loss=0.3793, kl=0.0000, acc=0.9172\n",
      "epoch 1000, loss=0.3497, kl=0.0000, acc=0.9229\n",
      "epoch 1100, loss=0.3255, kl=0.0000, acc=0.9276\n",
      "epoch 1200, loss=0.3052, kl=0.0000, acc=0.9315\n",
      "epoch 1300, loss=0.2876, kl=0.0000, acc=0.9349\n",
      "epoch 1400, loss=0.2721, kl=0.0000, acc=0.9379\n",
      "epoch 1500, loss=0.2589, kl=0.0000, acc=0.9404\n",
      "epoch 1600, loss=0.2465, kl=0.0000, acc=0.9430\n",
      "epoch 1700, loss=0.2360, kl=0.0000, acc=0.9450\n",
      "epoch 1800, loss=0.2267, kl=0.0000, acc=0.9470\n",
      "epoch 1900, loss=0.2181, kl=0.0000, acc=0.9487\n",
      "epoch 2000, loss=0.2099, kl=0.0000, acc=0.9505\n",
      "epoch 2100, loss=0.2028, kl=0.0000, acc=0.9519\n",
      "epoch 2200, loss=0.1961, kl=0.0000, acc=0.9533\n",
      "epoch 2300, loss=0.1899, kl=0.0000, acc=0.9547\n",
      "epoch 2400, loss=0.1845, kl=0.0000, acc=0.9558\n",
      "epoch 2500, loss=0.1793, kl=0.0000, acc=0.9568\n",
      "epoch 2600, loss=0.1744, kl=0.0000, acc=0.9578\n",
      "epoch 2700, loss=0.1700, kl=0.0000, acc=0.9588\n",
      "epoch 2800, loss=0.1660, kl=0.0000, acc=0.9596\n",
      "epoch 2900, loss=0.1619, kl=0.0000, acc=0.9605\n",
      "epoch 3000, loss=0.1582, kl=0.0000, acc=0.9613\n",
      "epoch 3100, loss=0.1547, kl=0.0000, acc=0.9620\n",
      "epoch 3200, loss=0.1513, kl=0.0000, acc=0.9627\n",
      "epoch 3300, loss=0.1483, kl=0.0000, acc=0.9633\n",
      "epoch 3400, loss=0.1453, kl=0.0000, acc=0.9639\n",
      "epoch 3500, loss=0.1424, kl=0.0000, acc=0.9646\n",
      "epoch 3600, loss=0.1396, kl=0.0000, acc=0.9652\n",
      "epoch 3700, loss=0.1372, kl=0.0000, acc=0.9657\n",
      "epoch 3800, loss=0.1346, kl=0.0000, acc=0.9663\n",
      "epoch 3900, loss=0.1324, kl=0.0000, acc=0.9667\n",
      "epoch 4000, loss=0.1302, kl=0.0000, acc=0.9672\n",
      "epoch 4100, loss=0.1279, kl=0.0000, acc=0.9677\n",
      "epoch 4200, loss=0.1260, kl=0.0000, acc=0.9681\n",
      "epoch 4300, loss=0.1241, kl=0.0000, acc=0.9685\n",
      "epoch 4400, loss=0.1222, kl=0.0000, acc=0.9689\n",
      "epoch 4500, loss=0.1204, kl=0.0000, acc=0.9693\n",
      "epoch 4600, loss=0.1188, kl=0.0000, acc=0.9697\n",
      "epoch 4700, loss=0.1171, kl=0.0000, acc=0.9700\n",
      "epoch 4800, loss=0.1154, kl=0.0000, acc=0.9704\n",
      "epoch 4900, loss=0.1138, kl=0.0000, acc=0.9708\n",
      "epoch 5000, loss=0.1124, kl=0.0000, acc=0.9711\n",
      "epoch 5100, loss=0.1110, kl=0.0000, acc=0.9714\n",
      "epoch 5200, loss=0.1097, kl=0.0000, acc=0.9716\n",
      "epoch 5300, loss=0.1083, kl=0.0000, acc=0.9719\n",
      "epoch 5400, loss=0.1070, kl=0.0000, acc=0.9722\n",
      "epoch 5500, loss=0.1057, kl=0.0000, acc=0.9725\n",
      "epoch 5600, loss=0.1044, kl=0.0000, acc=0.9728\n",
      "epoch 5700, loss=0.1032, kl=0.0000, acc=0.9731\n",
      "epoch 5800, loss=0.1021, kl=0.0000, acc=0.9734\n",
      "epoch 5900, loss=0.1010, kl=0.0000, acc=0.9736\n",
      "epoch 6000, loss=0.1000, kl=0.0000, acc=0.9738\n",
      "epoch 6100, loss=0.0990, kl=0.0000, acc=0.9741\n",
      "epoch 6200, loss=0.0980, kl=0.0000, acc=0.9743\n",
      "epoch 6300, loss=0.0970, kl=0.0000, acc=0.9745\n",
      "epoch 6400, loss=0.0960, kl=0.0000, acc=0.9747\n",
      "epoch 6500, loss=0.0951, kl=0.0000, acc=0.9749\n",
      "epoch 6600, loss=0.0942, kl=0.0000, acc=0.9751\n",
      "epoch 6700, loss=0.0933, kl=0.0000, acc=0.9753\n",
      "epoch 6800, loss=0.0925, kl=0.0000, acc=0.9755\n",
      "epoch 6900, loss=0.0916, kl=0.0000, acc=0.9757\n",
      "epoch 7000, loss=0.0909, kl=0.0000, acc=0.9759\n",
      "epoch 7100, loss=0.0901, kl=0.0000, acc=0.9760\n",
      "epoch 7200, loss=0.0893, kl=0.0000, acc=0.9762\n",
      "epoch 7300, loss=0.0885, kl=0.0000, acc=0.9764\n",
      "epoch 7400, loss=0.0878, kl=0.0000, acc=0.9765\n",
      "epoch 7500, loss=0.0871, kl=0.0000, acc=0.9767\n",
      "epoch 7600, loss=0.0863, kl=0.0000, acc=0.9769\n",
      "epoch 7700, loss=0.0856, kl=0.0000, acc=0.9770\n",
      "epoch 7800, loss=0.0850, kl=0.0000, acc=0.9772\n",
      "epoch 7900, loss=0.0843, kl=0.0000, acc=0.9773\n",
      "epoch 8000, loss=0.0837, kl=0.0000, acc=0.9775\n",
      "epoch 8100, loss=0.0830, kl=0.0000, acc=0.9776\n",
      "epoch 8200, loss=0.0824, kl=0.0000, acc=0.9778\n",
      "epoch 8300, loss=0.0818, kl=0.0000, acc=0.9779\n",
      "epoch 8400, loss=0.0813, kl=0.0000, acc=0.9780\n",
      "epoch 8500, loss=0.0807, kl=0.0000, acc=0.9782\n",
      "epoch 8600, loss=0.0801, kl=0.0000, acc=0.9783\n",
      "epoch 8700, loss=0.0796, kl=0.0000, acc=0.9784\n",
      "epoch 8800, loss=0.0791, kl=0.0000, acc=0.9785\n",
      "epoch 8900, loss=0.0785, kl=0.0000, acc=0.9786\n",
      "epoch 9000, loss=0.0780, kl=0.0000, acc=0.9787\n",
      "epoch 9100, loss=0.0775, kl=0.0000, acc=0.9789\n",
      "epoch 9200, loss=0.0770, kl=0.0000, acc=0.9790\n",
      "epoch 9300, loss=0.0765, kl=0.0000, acc=0.9791\n",
      "epoch 9400, loss=0.0760, kl=0.0000, acc=0.9792\n",
      "epoch 9500, loss=0.0755, kl=0.0000, acc=0.9793\n",
      "epoch 9600, loss=0.0750, kl=0.0000, acc=0.9794\n",
      "epoch 9700, loss=0.0746, kl=0.0000, acc=0.9795\n",
      "epoch 9800, loss=0.0741, kl=0.0000, acc=0.9797\n",
      "epoch 9900, loss=0.0736, kl=0.0000, acc=0.9797\n",
      "epoch 10000, loss=0.0732, kl=0.0000, acc=0.9798\n"
     ]
    }
   ],
   "source": [
    "config.verbose = True\n",
    "model =  Convnet(x_dim=1)\n",
    "model, g = train(task_generator=task_generator,\n",
    "                 model=model,\n",
    "                 forward_fn=classify_proto,\n",
    "                 config=config,\n",
    "                 loss_fn=nn.CrossEntropyLoss()  # F.cross_entropy\n",
    "                )\n",
    "torch.save(model.state_dict(), 'proto_omniglot_10k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "accs_proto, losses_proto = validate(task_generator=task_generator,\n",
    "                        forward_fn=classify_proto,\n",
    "                        config=config,\n",
    "                        model=model,\n",
    "                        loss_fn=nn.CrossEntropyLoss())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9855400081276894, 0.014472103031131597)"
      ]
     },
     "execution_count": 108,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(accs_proto), np.std(accs_proto)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "config.iterations = 0\n",
    "accs_attn_test, losses_attn_test = validate(task_generator=task_generator,\n",
    "                        forward_fn=forward_euclid,\n",
    "                        config=config,\n",
    "                        model=model,\n",
    "                        loss_fn=nn.CrossEntropyLoss())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.9816266762018204, 0.014782069532884237)"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(accs_attn_test), np.std(accs_attn_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Feature selection + euclid attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "epoch 100, loss=2.7803, kl=0.0000, acc=0.3025\n",
      "epoch 200, loss=2.6343, kl=0.0000, acc=0.4568\n",
      "epoch 300, loss=2.5462, kl=0.0000, acc=0.5473\n",
      "epoch 400, loss=2.4917, kl=0.0000, acc=0.6032\n",
      "epoch 500, loss=2.4519, kl=0.0000, acc=0.6435\n",
      "epoch 600, loss=2.4229, kl=0.0000, acc=0.6729\n",
      "epoch 700, loss=2.4004, kl=0.0000, acc=0.6958\n",
      "epoch 800, loss=2.3816, kl=0.0000, acc=0.7149\n",
      "epoch 900, loss=2.3666, kl=0.0000, acc=0.7301\n",
      "epoch 1000, loss=2.3534, kl=0.0000, acc=0.7434\n",
      "epoch 1100, loss=2.3406, kl=0.0000, acc=0.7563\n",
      "epoch 1200, loss=2.3294, kl=0.0000, acc=0.7676\n",
      "epoch 1300, loss=2.3199, kl=0.0000, acc=0.7773\n",
      "epoch 1400, loss=2.3113, kl=0.0000, acc=0.7859\n",
      "epoch 1500, loss=2.3042, kl=0.0000, acc=0.7930\n",
      "epoch 1600, loss=2.2975, kl=0.0000, acc=0.7997\n",
      "epoch 1700, loss=2.2931, kl=0.0000, acc=0.8041\n",
      "epoch 1800, loss=2.2883, kl=0.0000, acc=0.8090\n",
      "epoch 1900, loss=2.2838, kl=0.0000, acc=0.8134\n",
      "epoch 2000, loss=2.2794, kl=0.0000, acc=0.8179\n",
      "epoch 2100, loss=2.2748, kl=0.0000, acc=0.8225\n",
      "epoch 2200, loss=2.2708, kl=0.0000, acc=0.8266\n",
      "epoch 2300, loss=2.2670, kl=0.0000, acc=0.8304\n",
      "epoch 2400, loss=2.2631, kl=0.0000, acc=0.8343\n",
      "epoch 2500, loss=2.2597, kl=0.0000, acc=0.8377\n",
      "epoch 2600, loss=2.2564, kl=0.0000, acc=0.8410\n",
      "epoch 2700, loss=2.2538, kl=0.0000, acc=0.8436\n",
      "epoch 2800, loss=2.2507, kl=0.0000, acc=0.8467\n",
      "epoch 2900, loss=2.2478, kl=0.0000, acc=0.8496\n",
      "epoch 3000, loss=2.2451, kl=0.0000, acc=0.8523\n",
      "epoch 3100, loss=2.2425, kl=0.0000, acc=0.8550\n",
      "epoch 3200, loss=2.2402, kl=0.0000, acc=0.8573\n",
      "epoch 3300, loss=2.2378, kl=0.0000, acc=0.8597\n",
      "epoch 3400, loss=2.2356, kl=0.0000, acc=0.8619\n",
      "epoch 3500, loss=2.2334, kl=0.0000, acc=0.8641\n",
      "epoch 3600, loss=2.2313, kl=0.0000, acc=0.8662\n",
      "epoch 3700, loss=2.2294, kl=0.0000, acc=0.8680\n",
      "epoch 3800, loss=2.2275, kl=0.0000, acc=0.8700\n",
      "epoch 3900, loss=2.2258, kl=0.0000, acc=0.8717\n",
      "epoch 4000, loss=2.2241, kl=0.0000, acc=0.8734\n",
      "epoch 4100, loss=2.2225, kl=0.0000, acc=0.8750\n",
      "epoch 4200, loss=2.2211, kl=0.0000, acc=0.8764\n",
      "epoch 4300, loss=2.2197, kl=0.0000, acc=0.8778\n",
      "epoch 4400, loss=2.2182, kl=0.0000, acc=0.8793\n",
      "epoch 4500, loss=2.2168, kl=0.0000, acc=0.8808\n",
      "epoch 4600, loss=2.2152, kl=0.0000, acc=0.8824\n",
      "epoch 4700, loss=2.2139, kl=0.0000, acc=0.8836\n",
      "epoch 4800, loss=2.2127, kl=0.0000, acc=0.8849\n",
      "epoch 4900, loss=2.2113, kl=0.0000, acc=0.8863\n",
      "epoch 5000, loss=2.2102, kl=0.0000, acc=0.8873\n",
      "epoch 5100, loss=2.2094, kl=0.0000, acc=0.8881\n",
      "epoch 5200, loss=2.2083, kl=0.0000, acc=0.8892\n",
      "epoch 5300, loss=2.2071, kl=0.0000, acc=0.8904\n",
      "epoch 5400, loss=2.2060, kl=0.0000, acc=0.8915\n",
      "epoch 5500, loss=2.2048, kl=0.0000, acc=0.8927\n",
      "epoch 5600, loss=2.2037, kl=0.0000, acc=0.8938\n",
      "epoch 5700, loss=2.2025, kl=0.0000, acc=0.8950\n",
      "epoch 6400, loss=2.1964, kl=0.0000, acc=0.9011\n",
      "epoch 6500, loss=2.1958, kl=0.0000, acc=0.9017\n",
      "epoch 6600, loss=2.1950, kl=0.0000, acc=0.9024\n",
      "epoch 6700, loss=2.1944, kl=0.0000, acc=0.9030\n",
      "epoch 6800, loss=2.1937, kl=0.0000, acc=0.9037\n",
      "epoch 6900, loss=2.1930, kl=0.0000, acc=0.9044\n",
      "epoch 7000, loss=2.1923, kl=0.0000, acc=0.9051\n",
      "epoch 7100, loss=2.1917, kl=0.0000, acc=0.9057\n",
      "epoch 7200, loss=2.1910, kl=0.0000, acc=0.9064\n",
      "epoch 7300, loss=2.1903, kl=0.0000, acc=0.9071\n",
      "epoch 7400, loss=2.1897, kl=0.0000, acc=0.9077\n",
      "epoch 7500, loss=2.1891, kl=0.0000, acc=0.9083\n",
      "epoch 7600, loss=2.1884, kl=0.0000, acc=0.9090\n",
      "epoch 7700, loss=2.1879, kl=0.0000, acc=0.9095\n",
      "epoch 7800, loss=2.1873, kl=0.0000, acc=0.9101\n",
      "epoch 7900, loss=2.1869, kl=0.0000, acc=0.9105\n",
      "epoch 8000, loss=2.1863, kl=0.0000, acc=0.9110\n",
      "epoch 8100, loss=2.1859, kl=0.0000, acc=0.9114\n",
      "epoch 8200, loss=2.1854, kl=0.0000, acc=0.9120\n",
      "epoch 8300, loss=2.1848, kl=0.0000, acc=0.9125\n",
      "epoch 8400, loss=2.1843, kl=0.0000, acc=0.9130\n",
      "epoch 8500, loss=2.1839, kl=0.0000, acc=0.9135\n",
      "epoch 8600, loss=2.1833, kl=0.0000, acc=0.9141\n",
      "epoch 8700, loss=2.1827, kl=0.0000, acc=0.9146\n",
      "epoch 8800, loss=2.1822, kl=0.0000, acc=0.9151\n",
      "epoch 8900, loss=2.1817, kl=0.0000, acc=0.9156\n",
      "epoch 9000, loss=2.1812, kl=0.0000, acc=0.9161\n",
      "epoch 9100, loss=2.1807, kl=0.0000, acc=0.9166\n",
      "epoch 9200, loss=2.1804, kl=0.0000, acc=0.9169\n",
      "epoch 9300, loss=2.1800, kl=0.0000, acc=0.9173\n",
      "epoch 9400, loss=2.1795, kl=0.0000, acc=0.9178\n",
      "epoch 9500, loss=2.1791, kl=0.0000, acc=0.9182\n",
      "epoch 9600, loss=2.1786, kl=0.0000, acc=0.9187\n",
      "epoch 9700, loss=2.1781, kl=0.0000, acc=0.9192\n",
      "epoch 9800, loss=2.1776, kl=0.0000, acc=0.9197\n",
      "epoch 9900, loss=2.1771, kl=0.0000, acc=0.9201\n",
      "epoch 10000, loss=2.1767, kl=0.0000, acc=0.9206\n"
     ]
    }
   ],
   "source": [
    "config.verbose = True\n",
    "model =  Convnet(x_dim=1)\n",
    "model, g = train(task_generator=task_generator,\n",
    "                 model=model,\n",
    "                 forward_fn=forward_euclid,\n",
    "                 config=config,\n",
    "                 loss_fn=F.cross_entropy)\n",
    "torch.save(model.state_dict(), 'attn_omniglot_10k')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9594400005936623"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs, losses = validate(task_generator=task_generator,\n",
    "                    forward_fn=forward_euclid,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=F.cross_entropy)\n",
    "np.mean(accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model =  Convnet(x_dim=1)\n",
    "model.load_state_dict(torch.load('attn_omniglot_10k'))\n",
    "model = model.to(device)\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.9624566689729691, 0.06424073820264695)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs, losses = validate(task_generator=task_generator,\n",
    "                    forward_fn=forward_euclid,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=F.cross_entropy)\n",
    "np.mean(accs), 100 * np.std(accs)/np.sqrt(1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Test performance on alphabets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class alphabetOmniglot(torch.utils.data.Dataset):\n",
    "    \"\"\"\n",
    "    A (possibly dumb) way to wrap the base Omniglot dataset to get it to\n",
    "    behave like the rest of the pytorch datasets, i.e. having the data\n",
    "    and targets being attributes such that\n",
    "        dataset.data[i] is the i-th data\n",
    "        dataset.targets[i] is the i-th target\n",
    "    \"\"\"\n",
    "    def __init__(\n",
    "        self,\n",
    "        root,\n",
    "        download,\n",
    "        background,\n",
    "    ):\n",
    "        base = torchvision.datasets.Omniglot(root=root,\n",
    "                                             download=download,\n",
    "                                             background=background)\n",
    "        \n",
    "        transform = torchvision.transforms.Compose([\n",
    "            torchvision.transforms.Resize((28,28)),\n",
    "            torchvision.transforms.ToTensor()\n",
    "        ])\n",
    "        data, targets = [], []\n",
    "        for ex in base:\n",
    "            data.append(transform(ex[0]))\n",
    "        self.data = torch.cat(data, 0)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.targets)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        return self.data[idx]\n",
    "    \n",
    "class alphabet_generator:\n",
    "    \n",
    "    def __init__(self, config):\n",
    "        dataset = alphabetOmniglot(root='./data/',download=True,background=False)\n",
    "        test = torchvision.datasets.Omniglot(root='DATA_FOLDER',\n",
    "                                             download=True,\n",
    "                                             background=False)\n",
    "        self.dataset = dataset\n",
    "        self.test = test\n",
    "        \n",
    "        # get character indices\n",
    "        indices = []\n",
    "        for alphabet in test._alphabets:\n",
    "            alphabet_index = []\n",
    "            for idx,character in enumerate(test._characters):\n",
    "                if character.startswith(alphabet):\n",
    "                    alphabet_index.append(idx)\n",
    "            indices.append(alphabet_index)\n",
    "        self.indices = indices\n",
    "    \n",
    "    def get_shot_query(self, config, device, **task_kwargs):\n",
    "        # select n alphabets, m characters, s support from each character and the rest are queries (20-s)\n",
    "        n, m, s = 3, 13, 7\n",
    "        \n",
    "        # choose alphabets\n",
    "        dataset = self.dataset\n",
    "        indices = self.indices\n",
    "        alphabets = np.random.choice(len(indices),n)\n",
    "        support_ids = []\n",
    "        query_ids = []\n",
    "        support_labels = []\n",
    "        query_labels = []\n",
    "        for label,a in enumerate(alphabets):\n",
    "            # permute to get random sample of characters\n",
    "            characters = np.random.permutation(indices[a])[:m]\n",
    "            for c in characters:\n",
    "                # there are 20 of each character, with dataset ids starting at 20c\n",
    "                shuffled_ids = np.random.permutation(list(range(c*20,(c+1)*20)))\n",
    "                support_ids += list(shuffled_ids[:s])\n",
    "                query_ids += list(shuffled_ids[s:])\n",
    "            support_labels += [label]*m*s\n",
    "            query_labels += [label]*m*(20-s)\n",
    "        support_labels, query_labels = torch.tensor(support_labels), torch.tensor(query_labels)\n",
    "        support, queries = dataset.data[support_ids][:, None, ...], dataset.data[query_ids][:, None, ...]\n",
    "        \n",
    "        return support.to(device), support_labels.long().to(device), queries.to(device), query_labels.long().to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_state_dict(torch.load('proto_omniglot_10k'))\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.8341558179855346, 0.08456723601631734)"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs, losses = validate(task_generator=alphabet_generator,\n",
    "                    forward_fn=classify_proto,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=nn.CrossEntropyLoss())\n",
    "np.mean(accs), np.std(accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.948380671530962, 0.0822310599780685)"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs, losses = validate(task_generator=alphabet_generator,\n",
    "                    forward_fn=forward_euclid,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=nn.CrossEntropyLoss())\n",
    "np.mean(accs), np.std(accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.9605266293287277, 0.0738353866045988)"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config.iterations = 0\n",
    "accs, losses = validate(task_generator=alphabet_generator,\n",
    "                    forward_fn=forward_euclid,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=nn.CrossEntropyLoss())\n",
    "np.mean(accs), np.std(accs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "Files already downloaded and verified\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(0.9423057215809822, 0.26429950746671643)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(torch.load('attn_omniglot_10k'))\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "accs, losses = validate(task_generator=alphabet_generator,\n",
    "                    forward_fn=forward_euclid,\n",
    "                    config=config,\n",
    "                    model=model,\n",
    "                    loss_fn=F.cross_entropy)\n",
    "np.mean(accs), 100 * np.std(accs)/np.sqrt(1000)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch38",
   "language": "python",
   "name": "pytorch38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
