{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b953a58",
   "metadata": {},
   "source": [
    "Set Seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2042d3f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "def set_seed(seed=72):\n",
    "    random.seed(seed)             \n",
    "    np.random.seed(seed)          \n",
    "    torch.manual_seed(seed)        \n",
    "    torch.cuda.manual_seed(seed)   \n",
    "    torch.cuda.manual_seed_all(seed)  \n",
    "\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "set_seed(108)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da7ae922",
   "metadata": {},
   "source": [
    "Initialize Model and Generate Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10531d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "import math\n",
    "\n",
    "SEQ_LEN = 20\n",
    "VOCAB_SIZE = 10 \n",
    "NUM_SAMPLES = 10000\n",
    "BATCH_SIZE = 256\n",
    "\n",
    "PATTERN_LEN = 4\n",
    "FINETUNE_NUM_SAMPLES = 1500000\n",
    "\n",
    "\n",
    "\n",
    "class TransformerBlock(nn.Module):\n",
    "    def __init__(self, dim, heads=4, mlp_ratio=4.0, dropout=0.0):\n",
    "        super().__init__()\n",
    "        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout, batch_first=True)\n",
    "        self.norm1 = nn.LayerNorm(dim)\n",
    "        self.norm2 = nn.LayerNorm(dim)\n",
    "\n",
    "        hidden_dim = int(dim * mlp_ratio)\n",
    "        self.mlp = nn.Sequential(\n",
    "            nn.Linear(dim, hidden_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(hidden_dim, dim),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        # x: (batch, seq_len, dim)\n",
    "        attn_out, _ = self.attn(x, x, x)\n",
    "        x = self.norm1(x + attn_out)\n",
    "        mlp_out = self.mlp(x)\n",
    "        x = x + mlp_out\n",
    "        return x\n",
    "\n",
    "class model(nn.Module):\n",
    "    def __init__(self, vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN, hidden_dim=400):\n",
    "        super().__init__()\n",
    "        assert hidden_dim % 4 == 0, \"hidden_dim must be divisible by 4\"\n",
    "        self.embedding = nn.Sequential(\n",
    "            nn.Linear(vocab_size+SEQ_LEN, hidden_dim)\n",
    "        )\n",
    "        self.layers = nn.ModuleList([\n",
    "            TransformerBlock(dim=hidden_dim)\n",
    "            for _ in range(6)\n",
    "        ])\n",
    "        self.fcout = nn.Sequential(\n",
    "            nn.Linear(hidden_dim, hidden_dim),\n",
    "            nn.GELU(),\n",
    "            nn.Linear(hidden_dim, vocab_size)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.nn.functional.one_hot(x, num_classes=VOCAB_SIZE).float()\n",
    "        eye = torch.eye(SEQ_LEN, device=x.device)\n",
    "        x = torch.cat([x,eye.unsqueeze(0).expand(x.size(0),-1,-1)],dim=-1)\n",
    "        x = self.embedding(x)\n",
    "        for layer in self.layers:\n",
    "            x = layer(x)+x\n",
    "        x = self.fcout(x)\n",
    "        x = torch.sum(x,dim=1)\n",
    "        return x\n",
    "\n",
    "mymodel = model().to('cuda')\n",
    "\n",
    "def init_weights(m):\n",
    "    if isinstance(m, nn.Linear):\n",
    "    \n",
    "        nn.init.normal_(m.weight, mean=0.01, std=0.02)\n",
    "        if m.bias is not None:\n",
    "            nn.init.zeros_(m.bias)\n",
    "\n",
    "\n",
    "mymodel.apply(init_weights)\n",
    "\n",
    "\n",
    "\n",
    "finetune_samples = []\n",
    "\n",
    "for i in range(FINETUNE_NUM_SAMPLES):\n",
    "    pattern = torch.randint(0, VOCAB_SIZE, size=(PATTERN_LEN,))\n",
    "    starting_place = torch.randint(0, SEQ_LEN+1 - PATTERN_LEN, size=(1,))\n",
    "    sequence = torch.randint(0, VOCAB_SIZE, size=(SEQ_LEN+1,))\n",
    "    sequence[starting_place:starting_place+PATTERN_LEN] = pattern\n",
    "    sequence[-PATTERN_LEN:] = pattern\n",
    "    \n",
    "    finetune_samples.append((sequence[:-1], sequence[-1]))\n",
    "\n",
    "class InductionDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, num_samples=FINETUNE_NUM_SAMPLES, seq_len=SEQ_LEN, vocab_size=VOCAB_SIZE):\n",
    "        self.data = finetune_samples\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        x, y = self.data[index]\n",
    "        return x, y\n",
    "\n",
    "    def save(self, filepath):\n",
    "        \"\"\"Save the dataset to a file\"\"\"\n",
    "        data_tensor = torch.stack([item for item,_ in self.data])\n",
    "        targets_tensor = torch.stack([item for _,item in self.data])\n",
    "        torch.save({\n",
    "            'data': data_tensor,\n",
    "            'targets': targets_tensor\n",
    "        }, filepath)\n",
    "    \n",
    "    @staticmethod\n",
    "    def load(filepath):\n",
    "        \"\"\"Load the dataset from a file\"\"\"\n",
    "        loaded = torch.load(filepath)\n",
    "        # Convert back to list of tuples format expected by the dataset\n",
    "        dataset = InductionDataset()\n",
    "        dataset.data = [(loaded['data'][i], loaded['targets'][i]) for i in range(len(loaded['data']))]\n",
    "        return dataset\n",
    "\n",
    "\n",
    "inductionDataset = InductionDataset(num_samples=FINETUNE_NUM_SAMPLES)\n",
    "\n",
    "inductionDataset.save('inductionDataset_1.pt')\n",
    "\n",
    "inductionDataset = InductionDataset.load('inductionDataset_1.pt')\n",
    "\n",
    "# Define split sizes\n",
    "total_samples = len(inductionDataset)\n",
    "train_size = int(0.8 * total_samples)\n",
    "val_size = int(0.2 * total_samples)\n",
    "test_size = total_samples - train_size - val_size\n",
    "\n",
    "# Split the dataset\n",
    "train_dataset, val_dataset, test_dataset = torch.utils.data.dataset.random_split(\n",
    "    inductionDataset, [train_size, val_size, test_size]\n",
    ")\n",
    "\n",
    "inductiondataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
    "val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e516d3a9",
   "metadata": {},
   "source": [
    "Setting Trainable Layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93af2a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_trainable_layers(model, trainable_layer_ids):\n",
    "    \"\"\"\n",
    "    Freeze all layers except the ones listed in trainable_layer_ids.\n",
    "    Args:\n",
    "        model: The full model\n",
    "        trainable_layer_ids: list of ints, e.g. [0, 2] to train only layers[0] and layers[2]\n",
    "    \"\"\"\n",
    "    # Freeze all parameters\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad = False\n",
    "\n",
    "    # Unfreeze designated Transformer layers\n",
    "    for i, layer in enumerate(model.layers):\n",
    "        if i in trainable_layer_ids:\n",
    "            for param in layer.parameters():\n",
    "                param.requires_grad = True\n",
    "\n",
    "    for param in model.fcout.parameters():\n",
    "        param.requires_grad = True\n",
    "\n",
    "    for param in model.embedding.parameters():\n",
    "        param.requires_grad = True\n",
    "\n",
    "# Example: fine-tune only layers 1 and 4\n",
    "set_trainable_layers(mymodel, trainable_layer_ids=[0,3])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc23c581",
   "metadata": {},
   "source": [
    "Training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71b33bf3",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "FINETUNE_EPOCHS = 15\n",
    "\n",
    "optimizer = torch.optim.AdamW(\n",
    "    filter(lambda p: p.requires_grad, mymodel.parameters()), \n",
    "    lr=1e-4\n",
    ")\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_acc = 0.0  # track best validation accuracy\n",
    "best_loss = float(\"inf\") \n",
    "best_model_path = \"best_model.pt\"\n",
    "\n",
    "losses = []\n",
    "accs = []\n",
    "\n",
    "for epoch in range(FINETUNE_EPOCHS):\n",
    "    # --- Training ---\n",
    "    mymodel.train()\n",
    "    for i, batch in enumerate(inductiondataloader):\n",
    "        x, y = batch\n",
    "        x, y = x.to('cuda'), y.to('cuda')\n",
    "        logits = mymodel(x)  # shape: (batch, vocab_size)\n",
    "        loss = criterion(logits, y)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if i % 100 == 0:\n",
    "            tqdm.tqdm.write(f\"[Train] Epoch {epoch}, Loss: {loss.item():.4f}\")\n",
    "            losses.append([epoch,loss.item()])\n",
    "\n",
    "    # --- Evaluation ---\n",
    "    mymodel.eval()\n",
    "    eval_loss = 0.0\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for x, y in val_dataloader:  # <--- validation dataloader\n",
    "            x, y = x.to('cuda'), y.to('cuda')\n",
    "            logits = mymodel(x)  # (batch, vocab_size)\n",
    "            loss = criterion(logits, y)\n",
    "            eval_loss += loss.item() * x.size(0)\n",
    "\n",
    "            # compute accuracy\n",
    "            preds = torch.argmax(logits, dim=-1)\n",
    "            correct += (preds == y).sum().item()\n",
    "            total += y.size(0)\n",
    "\n",
    "    avg_loss = eval_loss / total\n",
    "    acc = correct / total\n",
    "    tqdm.tqdm.write(f\"[Eval] Epoch {epoch}, Loss: {avg_loss:.4f}, Acc: {acc:.4f}\")\n",
    "    accs.append([epoch,acc])\n",
    "\n",
    "    # --- Save best model ---\n",
    "    if acc > best_acc:  \n",
    "        best_acc = acc\n",
    "        best_loss = avg_loss\n",
    "        torch.save({\n",
    "            \"epoch\": epoch,\n",
    "            \"model_state_dict\": mymodel.state_dict(),\n",
    "            \"optimizer_state_dict\": optimizer.state_dict(),\n",
    "            \"loss\": avg_loss,\n",
    "            \"accuracy\": acc,\n",
    "        }, best_model_path)\n",
    "        tqdm.tqdm.write(f\"--> Saved new best model at epoch {epoch} with Acc: {acc:.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7623d894",
   "metadata": {},
   "source": [
    "Saving loss and acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "fc28ab98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Save logs after training ---\n",
    "torch.save(losses, \"train_losses_4_6.pt\")\n",
    "torch.save(accs, \"val_accs_4_6.pt\")\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# --- Plot Training Loss ---\n",
    "plt.figure(figsize=(8, 5))\n",
    "loss_epochs = [l[0] for l in losses]\n",
    "loss_vals = [l[1] for l in losses]\n",
    "plt.plot(loss_epochs, loss_vals, label=\"Train Loss\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Loss\")\n",
    "plt.title(\"Training Loss\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"train_loss.png\")\n",
    "plt.close()\n",
    "\n",
    "# --- Plot Validation Accuracy ---\n",
    "plt.figure(figsize=(8, 5))\n",
    "acc_epochs = [a[0] for a in accs]\n",
    "acc_vals = [a[1] for a in accs]\n",
    "plt.plot(acc_epochs, acc_vals, label=\"Val Accuracy\")\n",
    "plt.xlabel(\"Epoch\")\n",
    "plt.ylabel(\"Accuracy\")\n",
    "plt.title(\"Validation Accuracy\")\n",
    "plt.legend()\n",
    "plt.grid(True)\n",
    "plt.savefig(\"val_acc.png\")\n",
    "plt.close()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
