{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import itertools\n",
    "import numpy as np\n",
    "import os\n",
    "import torch_scatter\n",
    "\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def oddness(vectors):\n",
    "    # inputs: vectors [n,d]\n",
    "    # assumes they come in +/- form, so convert to 0/1 first\n",
    "    # returns: labels [n], 0 for even, 1 for odd (i.e. XOR)\n",
    "    return torch.relu(vectors).sum(1) % 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def xor_noise(n,a,i):\n",
    "    # returns two tensors, features and labels\n",
    "    #   features: [n*a^2, a+i]\n",
    "    #   labels: [n*a^2]\n",
    "    # the first a components of the features are 'active' and determine the label\n",
    "    # through XOR (oddness), the next i are drawn at random and are 'inactive' i.e.\n",
    "    # they do not determine the label\n",
    "    \n",
    "    lst = list(itertools.product([0, 1], repeat=a))\n",
    "    noiseless = 2*torch.Tensor(lst)-1\n",
    "    labels = oddness(noiseless)\n",
    "    # repeat n times\n",
    "    noiseless = noiseless.repeat(n,1)\n",
    "    labels = labels.repeat(n)\n",
    "\n",
    "    noise = 2*torch.randint(0,2,(n*2**a,i))-1\n",
    "    \n",
    "    return torch.cat((noiseless,noise),1), labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def partitioned_xor(n,a,i,s,train=True):\n",
    "    \"\"\"\n",
    "    XOR over a elements with i inactive, with the active elements in those\n",
    "    up to s if train, else after s.\n",
    "    \n",
    "    Returns:\n",
    "     - features: [n*a^2, a+i]\n",
    "     - labels: [n*a^2]\n",
    "    \n",
    "    Generates using xor_noise, where the active elements are the first a,\n",
    "    and then permutes as appropriate.\n",
    "    \"\"\"\n",
    "    \n",
    "    if s < a:\n",
    "        raise ValueError('s needs to be larger than a')\n",
    "    if s > a+i:\n",
    "        raise ValueError('s needs to be smaller than a+i')\n",
    "    if a > i:\n",
    "        raise ValueError('a needs to be smaller than i')\n",
    "        \n",
    "    inputs, labels = xor_noise(n,a,i)\n",
    "    \n",
    "    if train:\n",
    "        # permute the elements up to s\n",
    "        idx = np.arange(s)\n",
    "        np.random.shuffle(idx)\n",
    "        perm = torch.tensor(idx)\n",
    "        inputs[:,torch.tensor(np.arange(s))]= inputs[:,perm]\n",
    "    else:\n",
    "        # flip so that the active elements are last\n",
    "        reverse = torch.tensor(np.flip(np.arange(a+i)).copy())\n",
    "        inputs = inputs[:,reverse]\n",
    "        # permute the elements after s\n",
    "        idx = np.arange(s,a+i)\n",
    "        np.random.shuffle(idx)\n",
    "        perm = torch.tensor(idx)\n",
    "        inputs[:,torch.tensor(np.arange(s,a+i))] = inputs[:,perm]\n",
    "        \n",
    "    return inputs, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hot_attn(Q,K,V,temp):\n",
    "    return torch.softmax(Q@K.T/temp,-1)@V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mad(x):\n",
    "    return (x - x.mean(0)).abs().mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rescaled_attn_test(support, support_labels, query, query_labels, temp, iterations, scale):\n",
    "    \n",
    "    # standardise the combined set\n",
    "    standard = F.batch_norm(torch.cat((support,query),0),None,None,training=True)\n",
    "    # split back up\n",
    "    support, query = standard[:support.size(0)], standard[support.size(0):]\n",
    "    \n",
    "    s0, s1 = support[support_labels==0], support[support_labels==1]\n",
    "    \n",
    "    for _ in range(iterations):\n",
    "        s0 = hot_attn(s0,s0,s0,temp)\n",
    "        s1 = hot_attn(s1,s1,s1,temp)\n",
    "        \n",
    "    combined = torch.cat((s0,s1),0)\n",
    "    rescale = mad(combined)\n",
    "    rescale = scale * (rescale - rescale.min()) / (rescale.max() - rescale.min())\n",
    "    \n",
    "    predictions = hot_attn(rescale*support,rescale*query,support_labels,1.)\n",
    "    \n",
    "    accuracy = ((predictions > 0.5) == query_labels)# .sum()/query_labels.size(0)\n",
    "    \n",
    "    return accuracy.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def no_scale_attn_test(support, support_labels, query, query_labels, standardise=True):\n",
    "    \n",
    "    if standardise:\n",
    "        # standardise the combined set\n",
    "        standard = F.batch_norm(torch.cat((support,query),0),None,None,training=True)\n",
    "        # split back up\n",
    "        support, query = standard[:support.size(0)], standard[support.size(0):]\n",
    "    \n",
    "    predictions = hot_attn(query,support,support_labels,1.)\n",
    "    \n",
    "    accuracy = ((predictions > 0.5) == query_labels).sum()/query_labels.size(0)\n",
    "    \n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_permute(support, query):\n",
    "    \"\"\"\n",
    "    Permute the features of the support and query sets the same way.\n",
    "    \"\"\"\n",
    "    \n",
    "    combined = torch.cat((support,query),0)\n",
    "    \n",
    "    idx = np.arange(combined.size(1))\n",
    "    np.random.shuffle(idx)\n",
    "    perm = torch.tensor(idx)\n",
    "    combined[:,torch.tensor(np.arange(combined.size(1)))]= combined[:,perm]\n",
    "    \n",
    "    return combined[:support.size(0)], combined[support.size(0):]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "support, support_labels = xor_noise(5,3,3)\n",
    "query, query_labels = xor_noise(5,3,3)\n",
    "support, query = feature_permute(support, query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True,  True,  True,  True,  True,  True,\n",
       "        True,  True,  True,  True])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rescaled_attn_test(support, support_labels, query, query_labels, 0.5, 5, 2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Protonets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Feedforward(torch.nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, out_dim):\n",
    "        super(Feedforward, self).__init__()\n",
    "        self.input_size = input_size\n",
    "        self.hidden_size  = hidden_size\n",
    "        self.fc1 = torch.nn.Linear(self.input_size, self.hidden_size)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "        self.fc2 = torch.nn.Linear(self.hidden_size, out_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        hidden = self.fc1(x)\n",
    "        relu = self.relu(hidden)\n",
    "        output = self.fc2(relu)\n",
    "        return output\n",
    "    \n",
    "class Averager():\n",
    "    def __init__(self):\n",
    "        self.n = 0\n",
    "        self.v = 0\n",
    "\n",
    "    def add(self, x):\n",
    "        self.v = (self.v * self.n + x) / (self.n + 1)\n",
    "        self.n += 1\n",
    "\n",
    "    def item(self):\n",
    "        return self.v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def euclidean_metric(a, b):\n",
    "    n = a.shape[0]\n",
    "    m = b.shape[0]\n",
    "    a = a.unsqueeze(1).expand(n, m, -1)\n",
    "    b = b.unsqueeze(0).expand(n, m, -1)\n",
    "    logits = -((a - b)**2).sum(dim=2)\n",
    "    return logits\n",
    "\n",
    "def classify_proto(train, test, train_labels, **kwargs):\n",
    "    # proto = torch_scatter.scatter_mean(train, train_labels.type(torch.int64), dim=0)\n",
    "    tr0, tr1 = train[train_labels==0], train[train_labels==1]\n",
    "    \n",
    "    proto_tr0 = tr0.mean(0)\n",
    "    proto_tr1 = tr1.mean(0)\n",
    "    proto = torch.stack((proto_tr0, proto_tr1))\n",
    "\n",
    "    # Compute predictions and accuracy\n",
    "    logits = euclidean_metric(test, proto)\n",
    "    # predictions = torch.softmax(logits, axis=-1)\n",
    "    # return predictions\n",
    "    return logits\n",
    "\n",
    "def train(n, a, i, out_dim, max_epoch=1000, verbose = False, val_tasks = 1000):\n",
    "    seq_len = a + i\n",
    "    \n",
    "    # Set up model\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = Feedforward(input_size=seq_len, hidden_size=100, out_dim=out_dim)\n",
    "    model = model.to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "    for epoch in range(1, max_epoch + 1):\n",
    "        optimizer.zero_grad()\n",
    "        model.train()\n",
    "\n",
    "        # Get and reshape data\n",
    "        support, support_labels = xor_noise(n, a, i)\n",
    "        query, query_labels = xor_noise(n, a, i)\n",
    "        support, query = feature_permute(support, query)\n",
    "        support = support.to(device)\n",
    "        query = query.to(device)\n",
    "        support_labels = support_labels.to(device)\n",
    "        query_labels = query_labels.to(device)\n",
    "\n",
    "        # Compute prototypes\n",
    "        support = model(support)\n",
    "        query = model(query)\n",
    "        logits = classify_proto(support, query, support_labels)\n",
    "\n",
    "        # Compute distances and loss\n",
    "        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()\n",
    "        loss = F.cross_entropy(logits, query_labels.long())\n",
    "        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if verbose:\n",
    "            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))\n",
    "            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))\n",
    "        proto = None; logits = None; loss = None\n",
    "    \n",
    "    # Validate\n",
    "    model.eval()\n",
    "    val_accs = []\n",
    "    \n",
    "    for epoch in range(1, val_tasks):\n",
    "        # Get and reshape data\n",
    "        support, support_labels = xor_noise(n, a, i)\n",
    "        query, query_labels = xor_noise(n, a, i)\n",
    "        support, query = feature_permute(support, query)\n",
    "        support = support.to(device)\n",
    "        query = query.to(device)\n",
    "        support_labels = support_labels.to(device)\n",
    "        query_labels = query_labels.to(device)\n",
    "\n",
    "        # Compute prototypes\n",
    "        support = model(support)\n",
    "        query = model(query)\n",
    "        logits = classify_proto(support, query, support_labels)\n",
    "\n",
    "        # Compute distances and loss\n",
    "        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()\n",
    "        loss = F.cross_entropy(logits, query_labels.long())\n",
    "        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)\n",
    "        val_accs.append(acc.cpu().numpy())\n",
    "        if verbose:\n",
    "            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))\n",
    "            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))\n",
    "        proto = None; logits = None; loss = None\n",
    "    \n",
    "    return val_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 5, a: 2, out dim: 1, acc: 0.564965009689331\n",
      "L: 5, a: 3, out dim: 1, acc: 0.5538288354873657\n",
      "L: 5, a: 4, out dim: 1, acc: 0.5982983112335205\n",
      "$56.5 \\pm 0.5$ & $55.4 \\pm 0.6$ & $59.8 \\pm 0.6$\n",
      "\n",
      "L: 5, a: 2, out dim: 5, acc: 0.7391892671585083\n",
      "L: 5, a: 3, out dim: 5, acc: 0.6600350141525269\n",
      "L: 5, a: 4, out dim: 5, acc: 0.9122872948646545\n",
      "$73.9 \\pm 0.5$ & $66.0 \\pm 0.7$ & $91.2 \\pm 0.6$\n",
      "\n",
      "L: 5, a: 2, out dim: 10, acc: 0.857357382774353\n",
      "L: 5, a: 3, out dim: 10, acc: 0.7481982111930847\n",
      "L: 5, a: 4, out dim: 10, acc: 0.9999499917030334\n",
      "$85.7 \\pm 0.4$ & $74.8 \\pm 0.6$ & $100.0 \\pm 0.0$\n",
      "\n",
      "L: 5, a: 2, out dim: 25, acc: 0.9043042659759521\n",
      "L: 5, a: 3, out dim: 25, acc: 0.804729700088501\n",
      "L: 5, a: 4, out dim: 25, acc: 1.0\n",
      "$90.4 \\pm 0.3$ & $80.5 \\pm 0.6$ & $100.0 \\pm 0.0$\n",
      "\n",
      "L: 10, a: 2, out dim: 1, acc: 0.5189689993858337\n",
      "L: 10, a: 3, out dim: 1, acc: 0.5014263987541199\n",
      "L: 10, a: 4, out dim: 1, acc: 0.5024024248123169\n",
      "$51.9 \\pm 0.3$ & $50.1 \\pm 0.2$ & $50.2 \\pm 0.2$\n",
      "\n",
      "L: 10, a: 2, out dim: 10, acc: 0.5709208846092224\n",
      "L: 10, a: 3, out dim: 10, acc: 0.49997493624687195\n",
      "L: 10, a: 4, out dim: 10, acc: 0.5027903318405151\n",
      "$57.1 \\pm 0.4$ & $50.0 \\pm 0.2$ & $50.3 \\pm 0.2$\n",
      "\n",
      "L: 10, a: 2, out dim: 20, acc: 0.5794795155525208\n",
      "L: 10, a: 3, out dim: 20, acc: 0.5048548579216003\n",
      "L: 10, a: 4, out dim: 20, acc: 0.5031657218933105\n",
      "$57.9 \\pm 0.4$ & $50.5 \\pm 0.2$ & $50.3 \\pm 0.2$\n",
      "\n",
      "L: 10, a: 2, out dim: 100, acc: 0.6314814686775208\n",
      "L: 10, a: 3, out dim: 100, acc: 0.5018518567085266\n",
      "L: 10, a: 4, out dim: 100, acc: 0.5052427053451538\n",
      "$63.1 \\pm 0.4$ & $50.2 \\pm 0.2$ & $50.5 \\pm 0.2$\n",
      "\n"
     ]
    }
   ],
   "source": [
    "n = 5\n",
    "for L in [5, 10]:\n",
    "    for out_dim in [1, L, 2*L, L**2]:\n",
    "        accs = []\n",
    "        for a in [2, 3, 4]:\n",
    "            i = L - a\n",
    "\n",
    "            # Protonet\n",
    "            val_acc = train(n, a, i, out_dim)\n",
    "            print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, out_dim, np.mean(val_acc)))\n",
    "            accs.append(val_acc)\n",
    "        print('${:.1f} \\pm {:.1f}$ & ${:.1f} \\pm {:.1f}$ & ${:.1f} \\pm {:.1f}$'.format(100*np.mean(accs[0]),\n",
    "                                                                                       100*np.std(accs[0])/np.sqrt(1000),\n",
    "                                                                                       100*np.mean(accs[1]),\n",
    "                                                                                       100*np.std(accs[1])/np.sqrt(1000),\n",
    "                                                                                       100*np.mean(accs[2]),\n",
    "                                                                                       100*np.std(accs[2])/np.sqrt(1000)))\n",
    "        print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Attention model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 5, a: 2, out dim: 5, acc: 0.9954954954954955\n",
      "L: 5, a: 3, out dim: 5, acc: 1.0\n",
      "L: 5, a: 4, out dim: 5, acc: 1.0\n",
      "$99.550 \\pm 0.212$ & $100.000 \\pm 0.000$ & $100.000 \\pm 0.000$\n",
      "\n",
      "L: 10, a: 2, out dim: 10, acc: 0.763913913913914\n",
      "L: 10, a: 3, out dim: 10, acc: 0.8317317317317318\n",
      "L: 10, a: 4, out dim: 10, acc: 0.961498998998999\n",
      "$76.391 \\pm 1.343$ & $83.173 \\pm 1.183$ & $96.150 \\pm 0.608$\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for L in [5, 10]:\n",
    "    accs = []\n",
    "    for a in [2, 3, 4]:\n",
    "        i = L - a\n",
    "        \n",
    "        vas = []\n",
    "        for epoch in range(1, 1000):\n",
    "            #\n",
    "            support, support_labels = xor_noise(n, a, i)\n",
    "            query, query_labels = xor_noise(n, a, i)\n",
    "            support, query = feature_permute(support, query)\n",
    "\n",
    "            # Attn\n",
    "            acc = rescaled_attn_test(support, support_labels, query, query_labels, 0.5, 5, 2)\n",
    "            vas.append(acc)\n",
    "        \n",
    "        val_acc = vas\n",
    "        print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, L, np.mean(val_acc)))\n",
    "        accs.append(val_acc)\n",
    "    print('${:.3f} \\pm {:.3f}$ & ${:.3f} \\pm {:.3f}$ & ${:.3f} \\pm {:.3f}$'.format(100*np.mean(accs[0]),\n",
    "                                                                                       100*np.std(accs[0])/np.sqrt(1000),\n",
    "                                                                                       100*np.mean(accs[1]),\n",
    "                                                                                       100*np.std(accs[1])/np.sqrt(1000),\n",
    "                                                                                       100*np.mean(accs[2]),\n",
    "                                                                                       100*np.std(accs[2])/np.sqrt(1000)))\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Protonet on partitioned XOR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_partitioned(n, a, i, s, out_dim, max_epoch=1000, verbose = False, val_tasks = 1000):\n",
    "    seq_len = a + i\n",
    "    \n",
    "    # Set up model\n",
    "    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "    model = Feedforward(input_size=seq_len, hidden_size=100, out_dim=out_dim)\n",
    "    model = model.to(device)\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "    for epoch in range(1, max_epoch + 1):\n",
    "        optimizer.zero_grad()\n",
    "        model.train()\n",
    "\n",
    "        # Get and reshape data\n",
    "        support, support_labels = partitioned_xor(n, a, i, s, train=True)\n",
    "        query, query_labels = partitioned_xor(n, a, i, s, train=True)\n",
    "        # support, query = feature_permute(support, query)\n",
    "        support = support.to(device)\n",
    "        query = query.to(device)\n",
    "        support_labels = support_labels.to(device)\n",
    "        query_labels = query_labels.to(device)\n",
    "\n",
    "        # Compute prototypes\n",
    "        support = model(support)\n",
    "        query = model(query)\n",
    "        logits = classify_proto(support, query, support_labels)\n",
    "\n",
    "        # Compute distances and loss\n",
    "        loss = F.cross_entropy(logits, query_labels.long())\n",
    "        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if verbose:\n",
    "            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))\n",
    "            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))\n",
    "        proto = None; logits = None; loss = None\n",
    "    \n",
    "    # Validate\n",
    "    model.eval()\n",
    "    val_accs = []\n",
    "    \n",
    "    for epoch in range(1, val_tasks):\n",
    "        # Get and reshape data\n",
    "        support, support_labels = partitioned_xor(n, a, i, s, train=False)\n",
    "        query, query_labels = partitioned_xor(n, a, i, s, train=False)\n",
    "        # support, query = feature_permute(support, query)\n",
    "        support = support.to(device)\n",
    "        query = query.to(device)\n",
    "        support_labels = support_labels.to(device)\n",
    "        query_labels = query_labels.to(device)\n",
    "\n",
    "        # Compute prototypes\n",
    "        support = model(support)\n",
    "        query = model(query)\n",
    "        logits = classify_proto(support, query, support_labels)\n",
    "\n",
    "        # Compute distances and loss\n",
    "        # one_hot_labels = F.one_hot(query_labels.long(), num_classes=2).float()\n",
    "        loss = F.cross_entropy(logits, query_labels.long())\n",
    "        acc = (logits.argmax(-1) == query_labels).sum()/query_labels.size(0)\n",
    "        val_accs.append(acc.cpu().numpy())\n",
    "        if verbose:\n",
    "            print('epoch {}, loss={:.4f} acc={:.4f}'.format(epoch, loss.item(), acc))\n",
    "            # print('ETA:{}/{}'.format(timer.measure(), timer.measure(epoch / config.max_epoch)))\n",
    "        proto = None; logits = None; loss = None\n",
    "    \n",
    "    return val_accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "n = 5\n",
    "a = 2\n",
    "i = 3\n",
    "s = 3\n",
    "out_dim = (a + i) ** 2\n",
    "val_acc = train(n, a, i, out_dim)\n",
    "val_acc_partitioned = train_partitioned(n, a, i, s, out_dim)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "L: 10, a: 2, out dim: 1, acc: 0.5089098215103149\n",
      "L: 10, a: 3, out dim: 1, acc: 0.5016274452209473\n",
      "L: 10, a: 4, out dim: 1, acc: 0.5018521547317505\n",
      "$0.509$ & $0.502$ & $0.502$\n",
      "\n",
      "L: 10, a: 2, out dim: 10, acc: 0.5063568949699402\n",
      "L: 10, a: 3, out dim: 10, acc: 0.49984949827194214\n",
      "L: 10, a: 4, out dim: 10, acc: 0.5019532442092896\n",
      "$0.506$ & $0.500$ & $0.502$\n",
      "\n",
      "L: 10, a: 2, out dim: 20, acc: 0.5112118124961853\n",
      "L: 10, a: 3, out dim: 20, acc: 0.4973980486392975\n",
      "L: 10, a: 4, out dim: 20, acc: 0.5030782222747803\n",
      "$0.511$ & $0.497$ & $0.503$\n",
      "\n",
      "L: 10, a: 2, out dim: 100, acc: 0.5072084069252014\n",
      "L: 10, a: 3, out dim: 100, acc: 0.4991496801376343\n",
      "L: 10, a: 4, out dim: 100, acc: 0.4998376965522766\n",
      "$0.507$ & $0.499$ & $0.500$\n",
      "\n"
     ]
    }
   ],
   "source": [
    "s = 4\n",
    "\n",
    "for L in [10]:\n",
    "    for out_dim in [1, L, 2*L, L**2]:\n",
    "        accs = []\n",
    "        for a in [2, 3, 4]:\n",
    "            i = L - a\n",
    "\n",
    "            # Protonet\n",
    "            val_acc = train_partitioned(n, a, i, s, out_dim)\n",
    "            print('L: {}, a: {}, out dim: {}, acc: {}'.format(L, a, out_dim, val_acc))\n",
    "            accs.append(val_acc)\n",
    "        print('${:.3f}$ & ${:.3f}$ & ${:.3f}$'.format(accs[0], accs[1], accs[2]))\n",
    "        print()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pytorch38",
   "language": "python",
   "name": "pytorch38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
