{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "762ba8e1-d912-480d-90ef-4e35889cbaa8",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dcc79b29-2b8e-4ecd-bed6-d639dfff2b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "from config import config\n",
    "from llmtelora import Lllmtelora\n",
    "from model_gpt_lr2 import ModelGPTLR2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7b0f5e1-79c8-4d91-8445-72f18a2aa7c6",
   "metadata": {},
   "source": [
    "# Helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "0c8c7cf6-c5b1-4c8c-b1c6-0107a68c1d1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(manager, output_length=200):\n",
    "    text_inp = manager.data.encode(manager.args.prompt_demo)\n",
    "    output = manager.model.generate(text_inp, output_length)\n",
    "    text_out = manager.data.decode(output)\n",
    "\n",
    "    text = f'\\n\\n----\\nDemo      for generate method\\n'\n",
    "    text += f'Input  : {manager.args.prompt_demo}\\n'\n",
    "    text += f'Output : {text_out}\\n'\n",
    "    print(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d23e37c0-a768-426a-8753-6111f488dd68",
   "metadata": {},
   "source": [
    "# Base model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "92273201-4712-4035-b236-a4f1474b85c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "----\n",
      "Demo      for generate method\n",
      "Input  : There was a\n",
      "Output : There was a young boy,helf. He had a curious little girl. She decided to send a package to mail it to him inside so she could see anything.\n",
      "\n",
      "\"Your tooth,\" she said. \"A crack only one, little bee Slow\".\n",
      "\n",
      "The owl replied, \"Well, I will give you something very, it will work!\"\n",
      "\n",
      "\"I will move down, go,\" said. \"Do you like math!\"\n",
      "But it said, \"I forget these olives so we can take a break and easily.\"\n",
      "\n",
      "Jack spread his balance on the roof with a smile on his face. He held hands above and he felt the cell and moved until he was Patel.\n",
      "\n",
      "Toby knew he had a good idea. He started going back to get more andbank. It was just for help.Once there was a little girl called Mary. Susie was three years old and she was feeling a bit scared.\n",
      "\n",
      "One day, she decided not go outside\n",
      "\n"
     ]
    }
   ],
   "source": [
    "args = config()\n",
    "args.name = 'base'\n",
    "\n",
    "manager = Lllmtelora(args, dev=True)\n",
    "manager.set_data()\n",
    "manager.set_model()\n",
    "manager.set_trainer()\n",
    "manager.load()\n",
    "\n",
    "generate(manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7b14fea-09e3-4f57-9650-657d307c8bd4",
   "metadata": {},
   "source": [
    "# Trained low-rank model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7979ab44-9cb0-4e4a-a388-27cfe979dcc4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "----\n",
      "Demo      for generate method\n",
      "Input  : There was a\n",
      "Output : There was a little girl called staying on. She was three and she to explore dress. She had in her car. and Ben like to rowballs doorway.. She\n",
      "\n",
      "Anna opens sad and She asked not like clean.. She acted like all the fun she had noises the porch and and she felt curious.\n",
      "\n",
      "At first closet, she made at the her doll toys and\n",
      "\n",
      "Lily picked up her toys. But she was ignorant so She asked that her were so like a doll to One day, Lily got a attention to said.No, Lily. I think't belong I'm boring.\" Go away. Lily held her dolluably left and and let to stop with anyone friends. so did and did to give them a She saw her bowber. asked Lily she was happy so have her car with. The made her new doll and it smiled. She knew and swung her doll, and she could accepted her she also. And went outside to play house, mailman and didn books\n",
      "\n"
     ]
    }
   ],
   "source": [
    "args = config()\n",
    "args.name = 'demo_lr2_rank1'\n",
    "args.mode = 'lr2'\n",
    "args.d = 2\n",
    "args.r = 1\n",
    "\n",
    "manager = Lllmtelora(args, dev=True)\n",
    "manager.set_data()\n",
    "manager.load()\n",
    "manager.set_trainer()\n",
    "\n",
    "generate(manager)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef70bfdf-b7c1-4abb-9648-a0256c466247",
   "metadata": {},
   "source": [
    "# Trained low-rank model (customized generation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "4fdb740e-8853-492a-a93b-d945b1c3f9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ModelGPTLR2Dev(ModelGPTLR2):\n",
    "    @torch.no_grad()\n",
    "    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
    "        for _ in range(max_new_tokens // self.d):\n",
    "            # If the sequence context is growing too long we must crop it:\n",
    "            if idx.size(1) <= self.block_size-self.d:\n",
    "                idx_cond = idx\n",
    "            else:\n",
    "                idx_cond = idx[:, -self.block_size-self.d:]\n",
    "            \n",
    "            idxs_next, _ = self(idx_cond)\n",
    "            idx = torch.cat((idx, idxs_next), dim=1)\n",
    "\n",
    "        return idx\n",
    "\n",
    "    def _head_forward_pred(self, x):\n",
    "        w = self.lm_head_weight(x).reshape(-1, self.r)\n",
    "        w = nn.functional.softmax(w, dim=1)[-1, :]\n",
    "        # item_next = torch.multinomial(w, num_samples=1)[0]\n",
    "        item_next = torch.argmax(w)\n",
    "        # print(item_next, w)\n",
    "        log_U, log_V = self._build_lr_factors(x)\n",
    "        \n",
    "        U_curr = nn.functional.softmax(log_U[:, item_next])\n",
    "        V_curr = nn.functional.softmax(log_V[:, item_next])\n",
    "        #idx_next1 = torch.multinomial(U_curr, num_samples=1)\n",
    "        #idx_next2 = torch.multinomial(V_curr, num_samples=1)     \n",
    "\n",
    "        idx_next1 = torch.argmax(U_curr).unsqueeze(0)\n",
    "        idx_next2 = torch.argmax(V_curr).unsqueeze(0)     \n",
    "        \n",
    "        pred = torch.stack([idx_next1, idx_next2], dim=1)\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "494671dc-e2bb-4557-b812-6a825a8c278c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "----\n",
      "Demo      for generate method\n",
      "Input  : There was a\n",
      "Output : There was a little girl named She was to play with her toys. One day, she's mommy and her a were in the park. She\n",
      "\n",
      "\"Mom, can I play a game?\" she asked.\n",
      "\n",
      "\"No, it don't want to play with me.\n",
      "\n",
      "Her mommy said, \"No, it's not nice. It's not. It's not a toy. It's just a toy. It can't you. it is it's. You can have it.. It is not yours.\"\n",
      "\n",
      "Lily was sad. She wanted to play with her.. She\n",
      "\n",
      "\"No, it, it, it. I is not. It is a. It is a. It is a toy. It is a toy. It is a. It it is It is a toy. It is a. It\n",
      "\n",
      "Lily was scared. She wanted to touch it. She\n",
      "\n",
      "\"No, it. I it. It is mine\n",
      "\n"
     ]
    }
   ],
   "source": [
    "args = config()\n",
    "args.name = 'demo_lr2_rank2'\n",
    "args.mode = 'lr2'\n",
    "args.d = 2\n",
    "args.r = 2\n",
    "\n",
    "manager = Lllmtelora(args, dev=True)\n",
    "manager.set_data()\n",
    "manager.load(model_class=ModelGPTLR2Dev)\n",
    "manager.set_trainer()\n",
    "\n",
    "generate(manager)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb699f28-7c1b-4cb4-a141-e7b7c1780360",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6abd605-a5b7-42d9-a484-28b2abd5c4b2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
