{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e63702b3",
   "metadata": {
    "vscode": {
     "languageId": "markdown"
    }
   },
   "outputs": [],
   "source": [
    "# This notebook provides a toy example to demonstrate how LifelongLearningEmbedding (LLE) works.\n",
    "# The example is intentionally simple and lightweight, so it can be run on machines with limited hardware resources. The goal is to make it easy for anyone to test and understand the core concepts of LLE without requiring specialized hardware.\n",
    "\n",
    "# This notebook is meant to be supplemental material to the ICLR paper submission titled:\n",
    "# LIFELONG-LEARNING EMBEDDINGS: INCREMENTAL AND CONTINUAL REPRESENTATION LEARNING FOR DYNAMIC E-COMMERCE TRENDS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64dc7f83",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "import experiments.lle as lle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9bb0c934",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Toy Example - id matcher \n",
    "# In a first toy example we want to train an id matcher. The goal is to have a matcher that matches 0 -> 1 and 1 -> 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3426a086",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This is a simple model that uses the LifelongLearningEmbedding\n",
    "# to demonstrate its functionality\n",
    "# The model consists of an embedding layer, a linear layer and an output layer\n",
    "# Just think of the LifelongLearningEmbedding as a drop-in replacement for nn.Embedding\n",
    "class SimpleMatcher(nn.Module):\n",
    "    def __init__(self, vocab_map, embedding_dim):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embedding = lle.LifelongLearningEmbedding(vocab_map, embedding_dim)\n",
    "        # instead of nn.Embedding use the LifelongLearningEmbedding class\n",
    "        # you can use all the argumentes of nn.Embedding as kwargs\n",
    "\n",
    "        self.activation = nn.LeakyReLU()\n",
    "        self.linear = nn.Linear(self.embedding.embedding_dim, 5) # Intermediate layer\n",
    "        self.out = nn.Linear(5, 2) # Output layer to match to 2 classes (\"a\", \"b\")\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = self.embedding(x).squeeze(1)\n",
    "        x = self.activation(x)\n",
    "        x = self.linear(x)\n",
    "        x = self.activation(x)\n",
    "        x = self.out(x)\n",
    "        return x\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def extend_embedding(self, new_dim: int):\n",
    "        # Here we need to adjust the linear layer after the embedding layer\n",
    "        # to match the new embedding dimension\n",
    "        new_linear = nn.Linear(new_dim, 5)\n",
    "        new_linear.weight[:, :self.embedding.embedding_dim] = self.linear.weight \n",
    "        new_linear.bias = self.linear.bias\n",
    "        self.embedding.extend_embedding_dim(new_dim) # call the extend method of the embedding\n",
    "        self.linear = new_linear\n",
    "\n",
    "# initialize the model with a vocab of size 2 (0 and 1) and embedding dimension of 5\n",
    "model = SimpleMatcher(vocab_map={0, 1}, embedding_dim=5)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c68aec29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SimpleMatcher(\n",
       "  (embedding): LifelongLearningEmbedding(\n",
       "    (embedding): Embedding(2, 5)\n",
       "  )\n",
       "  (activation): LeakyReLU(negative_slope=0.01)\n",
       "  (linear): Linear(in_features=5, out_features=5, bias=True)\n",
       "  (out): Linear(in_features=5, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# lets take a look at the model\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "67b81768",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output for 0: 1\n",
      "Output for 1: 0\n"
     ]
    }
   ],
   "source": [
    "# let's train the model to learn the mapping 0 -> 1 and 1 -> 0\n",
    "# everthing is the same here as with any other PyTorch model\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "for _ in range(100):\n",
    "    model.zero_grad\n",
    "\n",
    "    inputs = torch.tensor([[0], [1], [0], [1]], dtype=torch.long)\n",
    "    labels = torch.tensor([1, 0, 1, 0], dtype=torch.long)  # 0 -> 1 and 1 -> 0\n",
    "\n",
    "    outputs = model(inputs).squeeze()\n",
    "    loss = criterion(outputs, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "print(f\"Output for 0: {model(torch.tensor([[0]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 1: {model(torch.tensor([[1]], dtype=torch.long)).squeeze().detach().argmax()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "1d107a0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('embedding.embedding.weight',\n",
       "              tensor([[-0.8109, -1.0397, -1.1221,  2.9629,  0.8672],\n",
       "                      [-0.1864, -1.3836, -1.3590,  1.6475, -0.6617]])),\n",
       "             ('linear.weight',\n",
       "              tensor([[-0.0400,  0.0478,  0.0571, -0.0684,  0.1449],\n",
       "                      [ 0.4311,  0.4138,  0.3981,  0.1471, -0.1634],\n",
       "                      [-0.2830,  0.3371,  0.2720,  0.4930,  0.4958],\n",
       "                      [-0.3038, -0.1703,  0.2491,  0.3964,  0.3716],\n",
       "                      [ 0.4148,  0.0207,  0.3414,  0.2908, -0.4975]])),\n",
       "             ('linear.bias',\n",
       "              tensor([ 0.1973,  0.0560, -0.1734, -0.4439, -0.2062])),\n",
       "             ('out.weight',\n",
       "              tensor([[-0.2380,  0.2319, -0.7353, -0.3855,  0.2107],\n",
       "                      [-0.2020,  0.0881,  0.1348,  0.3854, -0.0684]])),\n",
       "             ('out.bias', tensor([ 0.6491, -0.0806]))])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# that worked as expected, now lets add a new vocab and increase the embedding dimension\n",
    "# but first lets take a look at the weigths\n",
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f3d3c069",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9212,  0.5475]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# and lets have a look at the output layer.\n",
    "model(torch.tensor([0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2bf0194b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# now we extend the embedding dimension to 7\n",
    "model.extend_embedding(new_dim=7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "3bc318e5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('embedding.embedding.weight',\n",
       "              tensor([[-0.8109, -1.0397, -1.1221,  2.9629,  0.8672,  0.0000,  0.0000],\n",
       "                      [-0.1864, -1.3836, -1.3590,  1.6475, -0.6617,  0.0000,  0.0000]])),\n",
       "             ('linear.weight',\n",
       "              tensor([[-0.0400,  0.0478,  0.0571, -0.0684,  0.1449,  0.1418,  0.1775],\n",
       "                      [ 0.4311,  0.4138,  0.3981,  0.1471, -0.1634, -0.0549, -0.2763],\n",
       "                      [-0.2830,  0.3371,  0.2720,  0.4930,  0.4958,  0.0189, -0.1231],\n",
       "                      [-0.3038, -0.1703,  0.2491,  0.3964,  0.3716,  0.0023,  0.0165],\n",
       "                      [ 0.4148,  0.0207,  0.3414,  0.2908, -0.4975, -0.3589,  0.2156]])),\n",
       "             ('linear.bias',\n",
       "              tensor([ 0.1973,  0.0560, -0.1734, -0.4439, -0.2062])),\n",
       "             ('out.weight',\n",
       "              tensor([[-0.2380,  0.2319, -0.7353, -0.3855,  0.2107],\n",
       "                      [-0.2020,  0.0881,  0.1348,  0.3854, -0.0684]])),\n",
       "             ('out.bias', tensor([ 0.6491, -0.0806]))])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the state dict looks the same except for the embedding weights which have been extended with 0s for each new dimension.\n",
    "# and the additional linear layer weights which will not have any effect yet because they are connected to the new embedding dimensions\n",
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3a6dd944",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output for 0: 1\n",
      "Output for 1: 0\n"
     ]
    }
   ],
   "source": [
    "# as we can see the output layer is the same as before\n",
    "print(f\"Output for 0: {model(torch.tensor([[0]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 1: {model(torch.tensor([[1]], dtype=torch.long)).squeeze().detach().argmax()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d1d3651e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.9212,  0.5475]], grad_fn=<AddmmBackward0>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#even the output layer outputs the same\n",
    "model(torch.tensor([0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "85376b38",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('embedding.embedding.weight',\n",
       "              tensor([[-0.8109, -1.0397, -1.1221,  2.9629,  0.8672,  0.0000,  0.0000],\n",
       "                      [-0.1864, -1.3836, -1.3590,  1.6475, -0.6617,  0.0000,  0.0000],\n",
       "                      [-0.4987, -1.2117, -1.2406,  2.3052,  0.1028,  0.0000,  0.0000],\n",
       "                      [-0.4987, -1.2117, -1.2406,  2.3052,  0.1028,  0.0000,  0.0000]])),\n",
       "             ('linear.weight',\n",
       "              tensor([[-0.0400,  0.0478,  0.0571, -0.0684,  0.1449,  0.1418,  0.1775],\n",
       "                      [ 0.4311,  0.4138,  0.3981,  0.1471, -0.1634, -0.0549, -0.2763],\n",
       "                      [-0.2830,  0.3371,  0.2720,  0.4930,  0.4958,  0.0189, -0.1231],\n",
       "                      [-0.3038, -0.1703,  0.2491,  0.3964,  0.3716,  0.0023,  0.0165],\n",
       "                      [ 0.4148,  0.0207,  0.3414,  0.2908, -0.4975, -0.3589,  0.2156]])),\n",
       "             ('linear.bias',\n",
       "              tensor([ 0.1973,  0.0560, -0.1734, -0.4439, -0.2062])),\n",
       "             ('out.weight',\n",
       "              tensor([[-0.2380,  0.2319, -0.7353, -0.3855,  0.2107],\n",
       "                      [-0.2020,  0.0881,  0.1348,  0.3854, -0.0684]])),\n",
       "             ('out.bias', tensor([ 0.6491, -0.0806]))])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# now lets extend the vocab\n",
    "# now we want to learn a new mapping 2-> 0, 3 ->2, 1 -> 3, and 0 -> 1\n",
    "# first we need to update the vocab map of the embedding\n",
    "# and lets use an average strategy for the new embeddings\n",
    "# this means the new weigths are initialised through the average of the existing embeddings\n",
    "# just call the update_embedding method of the embedding\n",
    "model.embedding.update_embedding(new_vocab_map={0, 1, 2, 3}, strategy=lle.strat_avg_all)\n",
    "# now the embedding has been updated\n",
    "# lets take a look at the state dict\n",
    "model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "072b7585",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output for 0: 1\n",
      "Output for 1: 0\n"
     ]
    }
   ],
   "source": [
    "# we can see that the first two rows are the same as before\n",
    "# the new rows are the average of the first two rows\n",
    "# the new dimensions are still 0s\n",
    "# the linear layer is the same as before\n",
    "# and if we make a prediction with the old inputs we get the same result as before\n",
    "print(f\"Output for 0: {model(torch.tensor([[0]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 1: {model(torch.tensor([[1]], dtype=torch.long)).squeeze().detach().argmax()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "83cbaecd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# so we have successfully extended the embedding and the model still works as before\n",
    "# now we need to change the heads to learn the new mappings\n",
    "with torch.no_grad():\n",
    "    new_out = nn.Linear(5, 4) # Output layer to match to 4 classes (\"0\", \"1\", \"2\", \"3\")\n",
    "    new_out.weight[:2] = model.out.weight # keep the weights for the old classes\n",
    "    new_out.bias[:2] = model.out.bias # keep the bias for the old classes\n",
    "    model.out = new_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "67af21ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output for 0: 1\n",
      "Output for 1: 3\n",
      "Output for 2: 0\n",
      "Output for 3: 2\n"
     ]
    }
   ],
   "source": [
    "# now we can train the model on the new vocab\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=0.001) # Important! Reinitialize the optimizer to include the new parameters\n",
    "\n",
    "for _ in range(500):\n",
    "    model.zero_grad\n",
    "\n",
    "    inputs = torch.tensor([[0], [1], [2], [3]], dtype=torch.long)\n",
    "    labels = torch.tensor([1, 3, 0, 2], dtype=torch.long)  # 0 -> 1, 1 -> 0, 2 -> 3, 3 -> 2\n",
    "\n",
    "    outputs = model(inputs).squeeze()\n",
    "    loss = criterion(outputs, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "print(f\"Output for 0: {model(torch.tensor([[0]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 1: {model(torch.tensor([[1]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 2: {model(torch.tensor([[2]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 3: {model(torch.tensor([[3]], dtype=torch.long)).squeeze().detach().argmax()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "e6c84589",
   "metadata": {},
   "outputs": [],
   "source": [
    "# good that it worked :)\n",
    "# as a last step lets change the vocab again this time only remove the 0 from the vocabulary\n",
    "# its the same fucntion as before\n",
    "model.embedding.update_embedding(new_vocab_map={1, 2, 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2ae621c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{1: 0, 2: 1, 3: 2}"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# until now we had a direct mapping from input to embedding index, with the removal of 0 the mapping is now:\n",
    "model.embedding.vocab_map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ed2c18e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output for 1: 3\n",
      "Output for 2: 0\n",
      "Output for 3: 2\n"
     ]
    }
   ],
   "source": [
    "# however the model still works for the remaining inputs\n",
    "# just make sure the mapping is correct\n",
    "vocab_map = model.embedding.vocab_map\n",
    "print(f\"Output for 1: {model(torch.tensor([[vocab_map.get(1)]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 2: {model(torch.tensor([[vocab_map.get(2)]], dtype=torch.long)).squeeze().detach().argmax()}\")\n",
    "print(f\"Output for 3: {model(torch.tensor([[vocab_map.get(3)]], dtype=torch.long)).squeeze().detach().argmax()}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
