{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HHqA2AiZPsUS"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import Dataset, DataLoader\n",
        "import torch.optim as optim\n",
        "import matplotlib.pyplot as plt\n",
        "from scipy.stats import linregress\n",
        "import torch.optim.lr_scheduler as lr_scheduler\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5IkReBAVP61t"
      },
      "outputs": [],
      "source": [
        "def create_causal_mask(seq_len):\n",
        "    \"\"\"Creates a causal mask (upper triangular matrix of zeros).\"\"\"\n",
        "    return torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)  # (1, 1, seq_len, seq_len)\n",
        "\n",
        "class TwoLayerAttention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(TwoLayerAttention, self).__init__()\n",
        "        self.Wv1  = nn.Linear(2, 2, bias=False)\n",
        "        self.Wqk2 = nn.Linear(2, 2, bias=False)\n",
        "        self.Wv2  = nn.Linear(2, 1, bias=False)\n",
        "        self.sep_embedding = torch.nn.Parameter(torch.randn(1,2))\n",
        "\n",
        "        with torch.no_grad():\n",
        "          self.sep_embedding.data = torch.tensor([[10, -10]], dtype=torch.float32)\n",
        "          self.Wv1.weight.data = torch.abs(self.Wv1.weight.data)\n",
        "          self.Wv2.weight.data = torch.abs(self.Wv2.weight.data)\n",
        "          self.Wqk2.weight.data = torch.abs(self.Wqk2.weight.data)\n",
        "\n",
        "    def forward(self, x, indices):\n",
        "        B = x.shape[0]\n",
        "        embeddings = []\n",
        "        for b in range(B):\n",
        "            idx = indices[b].item()\n",
        "            before, after = x[b, :idx], x[b, idx:]  # Split at index\n",
        "            new_seq = torch.cat([before, self.sep_embedding, after], dim=0)  # Insert\n",
        "            embeddings.append(new_seq)\n",
        "        # Stack back into a single tensor\n",
        "        x = torch.stack(embeddings)  # Shape (B, S+1, 2)\n",
        "\n",
        "        V_1 = self.Wv1(x)\n",
        "\n",
        "        seq_len = x.shape[1]\n",
        "        causal_mask = create_causal_mask(seq_len).to(x.device)\n",
        "        attn_scores_1 = torch.matmul(x, x.transpose(-2, -1))\n",
        "        attn_scores_1 = attn_scores_1.masked_fill(causal_mask == 0, float('-inf'))\n",
        "\n",
        "        attn_weights_1 = torch.softmax(attn_scores_1, dim=-1) # batch x seq x seq\n",
        "        Z_1 = torch.matmul(attn_weights_1, V_1)\n",
        "        H_1 = Z_1 + x\n",
        "\n",
        "        QK_2 = self.Wqk2(H_1)\n",
        "        V_2 = self.Wv2(H_1)\n",
        "        attn_scores_2 = torch.matmul(H_1, QK_2.transpose(-2, -1))\n",
        "        attn_scores_2 = attn_scores_2.masked_fill(causal_mask == 0, float('-inf'))\n",
        "        attn_weights_2 = torch.softmax(attn_scores_2, dim=-1) # batch x seq x seq\n",
        "\n",
        "        H_2 = torch.matmul(attn_weights_2, V_2).squeeze(-1)\n",
        "\n",
        "        return H_2[:, -1], Z_1, H_1, attn_weights_1, attn_weights_2"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kOgMjjLMbmks"
      },
      "outputs": [],
      "source": [
        "def generate_dataset(num_samples, seq_len):\n",
        "    target_avgs = torch.arange(2, 8).float()  # Desired average bins (1–9)\n",
        "\n",
        "    # Containers\n",
        "    sequences = []\n",
        "    random_indices = []\n",
        "    averages = []\n",
        "    bins = {avg.item(): [] for avg in target_avgs}\n",
        "\n",
        "    # Helper: average of subsequence after random index\n",
        "    def get_avg_after_index(seq, idx):\n",
        "        return seq[idx:].mean()\n",
        "\n",
        "    # Generate sequences with roughly uniform tail average distribution\n",
        "    while len(sequences) !=  num_samples:\n",
        "        seq = torch.randint(0, 10, (seq_len - 1,)).float()\n",
        "        idx = torch.randint(0, seq_len - 1, (1,)).item()\n",
        "        avg = get_avg_after_index(seq, idx)\n",
        "\n",
        "        # Find nearest bin\n",
        "        closest_bin = target_avgs[torch.argmin((target_avgs - avg).abs())].item()\n",
        "\n",
        "        # Only keep if bin has space\n",
        "        if len(bins[closest_bin]) < (num_samples // len(target_avgs)) + 1:\n",
        "            sequences.append(seq)\n",
        "            random_indices.append(idx)\n",
        "            averages.append(avg)\n",
        "            bins[closest_bin].append((seq, avg, idx))\n",
        "\n",
        "        # Print bin counts as a vector\n",
        "        bin_counts = [len(bins[avg.item()]) for avg in target_avgs]\n",
        "        print(\"Bin counts:\", torch.tensor(bin_counts))\n",
        "\n",
        "    sequences = sequences[:num_samples]\n",
        "    random_indices = random_indices[:num_samples]\n",
        "    averages = averages[:num_samples]\n",
        "\n",
        "    # Convert to tensors\n",
        "    sequences = torch.stack(sequences)\n",
        "    random_indices = torch.tensor(random_indices)\n",
        "    averages = torch.tensor(averages)\n",
        "\n",
        "    return sequences, random_indices, averages\n",
        "\n",
        "\n",
        "class AveragingDataset(Dataset):\n",
        "    def __init__(self, embeddings, sep_idx, avg):\n",
        "        self.embeddings = embeddings  # Shape: (N, S, 2)\n",
        "        self.sep_idx = sep_idx  # Shape: (N, S)\n",
        "        self.avg = avg  # Shape: (N,)\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.embeddings)  # N\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.embeddings[idx], self.sep_idx[idx], self.avg[idx]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "DbItE4gwUXAf",
        "outputId": "e44c773d-b73c-46de-a0fd-c34acf20d1cc"
      },
      "outputs": [],
      "source": [
        "for phase in ['Train', 'Test']:\n",
        "    num_samples = 8192\n",
        "    seq_len = 16\n",
        "    sequences, rand_idx, avg = generate_dataset(num_samples, seq_len) # B x S x 1\n",
        "    sequences = torch.cat([sequences.unsqueeze(-1), (-1) * torch.ones(seq_len -1, 1).unsqueeze(0).expand(num_samples, seq_len-1, 1)], dim=-1) # 1 x S x 1\n",
        "\n",
        "    sequences = sequences.to(device)\n",
        "    rand_idx = rand_idx.to(device)\n",
        "    avg = avg.to(device)\n",
        "\n",
        "    dataset = AveragingDataset(sequences, rand_idx, avg)\n",
        "    batch_size = 1024\n",
        "    if phase == 'Train':\n",
        "        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n",
        "    elif phase == 'Test':\n",
        "        test_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "brBYYXdqdzwL",
        "outputId": "0dfa763e-260f-452b-fe9e-8a46a629b5f8"
      },
      "outputs": [],
      "source": [
        "model = TwoLayerAttention().to(device)\n",
        "criterion = nn.MSELoss()\n",
        "optimizer = optim.AdamW(model.parameters(), lr=5e-2, weight_decay=0.001)\n",
        "\n",
        "num_epochs = 50\n",
        "scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)\n",
        "\n",
        "losses = []\n",
        "sep = []\n",
        "for epoch in range(num_epochs):\n",
        "    for batch in dataloader:\n",
        "        sequences, sep_idx, target_avg = batch\n",
        "        optimizer.zero_grad()\n",
        "        outputs, _, _, _, _ = model(sequences, sep_idx)\n",
        "        loss = criterion(outputs, target_avg)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        losses.append(loss.item())\n",
        "        sep.append(model.sep_embedding.clone().detach().cpu())\n",
        "\n",
        "    scheduler.step()\n",
        "    print(f\"Epoch {epoch+1}, LR: {scheduler.get_last_lr()[0]:.5f}, Loss: {loss.item():.4f}\")\n",
        "\n",
        "sep = torch.cat(sep)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "90nS7g3TgjNT",
        "outputId": "9edcc139-52f9-4bb2-ae86-4bf710b6140e"
      },
      "outputs": [],
      "source": [
        "# prediction vs output\n",
        "for phase, loader in zip(['Train', 'Test'], [dataloader, test_dataloader]):\n",
        "    for batch in loader:\n",
        "        sequences, sep_idx, target_avg = batch\n",
        "        outputs, _, _, _, _ = model(sequences, sep_idx)\n",
        "        outputs = outputs.cpu().detach().numpy()\n",
        "        target_avg = target_avg.cpu().detach().numpy()\n",
        "        break\n",
        "\n",
        "    slope, intercept, r_value, p_value, std_err = linregress(outputs, target_avg)\n",
        "    line = slope * outputs + intercept\n",
        "    r_squared = r_value ** 2\n",
        "\n",
        "    plt.figure()\n",
        "    plt.scatter(outputs, target_avg, alpha=0.7, edgecolors='k')\n",
        "    plt.plot(outputs, line, 'r--', label=f'Linear fit ($R^2$ = {r_squared:.2f})')\n",
        "    plt.xlabel('Model Outputs', fontsize=12)\n",
        "    plt.ylabel('Target Averages', fontsize=12)\n",
        "    plt.title('Model Prediction vs Actual Target on '+phase, fontsize=14)\n",
        "    plt.grid(True)\n",
        "    plt.legend()\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "# loss\n",
        "plt.figure()\n",
        "plt.semilogy(losses, label=\"Loss\", linewidth=2, marker='o', markersize=4)\n",
        "\n",
        "plt.title(\"Training Loss Over Iterations\", fontsize=14)\n",
        "plt.xlabel(\"Iteration\", fontsize=12)\n",
        "plt.ylabel(\"Loss\", fontsize=12)\n",
        "\n",
        "plt.grid(True, which=\"both\", linestyle='--', linewidth=0.5)\n",
        "plt.legend()\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "# probabilities and embeddings of a single example\n",
        "for batch in test_dataloader:\n",
        "    break\n",
        "\n",
        "sequences, sep_idx, target_avg = batch\n",
        "idx = 0\n",
        "print('sep_idx:', sep_idx[idx].item())\n",
        "outputs, Z_1, H_1, attn_weights_1, attn_weights_2 = model(sequences, sep_idx)\n",
        "\n",
        "f = lambda x: x[idx].detach().cpu().numpy()\n",
        "target_avg = f(target_avg)\n",
        "outputs = f(outputs)\n",
        "Z_1 = f(Z_1)\n",
        "H_1 = f(H_1)\n",
        "attn_weights_1 = f(attn_weights_1)\n",
        "attn_weights_2 = f(attn_weights_2)\n",
        "\n",
        "print('output:', outputs.item())\n",
        "print('target:', target_avg.item())\n",
        "\n",
        "# the [SEP] forms an attention sink\n",
        "fig, axs = plt.subplots(1, 2, figsize=(10, 4.5))\n",
        "\n",
        "im1 = axs[0].imshow(attn_weights_1, aspect='auto')\n",
        "axs[0].tick_params(axis='both',       \n",
        "               which='major',         \n",
        "               labelsize=12)\n",
        "axs[1].tick_params(axis='both',          \n",
        "               which='major',         \n",
        "               labelsize=12)\n",
        "cbar = plt.colorbar(im1, ax=axs[0])\n",
        "cbar.ax.tick_params(labelsize=13)\n",
        "\n",
        "im2 = axs[1].imshow(attn_weights_2, aspect='auto')\n",
        "cbar=plt.colorbar(im2, ax=axs[1])\n",
        "cbar.ax.tick_params(labelsize=13)\n",
        "\n",
        "for ax in axs:\n",
        "    ax.set_xlabel(\"Key index\", fontsize=15)\n",
        "    ax.set_ylabel(\"Query index\", fontsize=15)\n",
        "\n",
        "plt.tight_layout()\n",
        "plt.show()\n",
        "\n",
        "plt.figure()\n",
        "plt.plot(attn_weights_2[-1,:].flatten(), '-*')\n",
        "plt.title('Last Row of Attention Weights 2')\n",
        "\n",
        "# the tokens after the [SEP] have activation outliers in their second coordinate\n",
        "# the second coordinate is an outlier feature dimension, and\n",
        "plt.matshow(Z_1.T)\n",
        "\n",
        "ax = plt.gca()                                 # current axes\n",
        "ax.tick_params(axis='both',                    # 'x', 'y', or 'both'\n",
        "               which='major',                  # major ticks\n",
        "               labelsize=13,                   # font size of numbers\n",
        "               length=6, width=1.5)            # ↑ size/thickness of the tick marks themselves\n",
        "\n",
        "#plt.title('Output of Attention Layer')\n",
        "cbar = plt.colorbar()\n",
        "cbar.ax.tick_params(labelsize=13)\n",
        "plt.show()\n",
        "\n",
        "# the attention matrices W1 V , W2 Q, W2 K, W2 V are low rank, projecting either to the tag or number subspace\n",
        "func = lambda x: x.detach().cpu()\n",
        "print('SEP:', func(model.sep_embedding))\n",
        "print('Wv1:', func(model.Wv1.weight))\n",
        "print('Wqk2:', func(model.Wqk2.weight))\n",
        "print('Wv2:', func(model.Wv2.weight))\n",
        "\n",
        "# sep embedding\n",
        "plt.figure()\n",
        "plt.plot(sep)\n",
        "plt.title('Embedding of [SEP]')\n",
        "plt.xlabel('Iteration')\n",
        "plt.ylabel('Embedding')\n",
        "plt.legend(['First Coordinate (Number)', 'Second Coordinate (Tag)'])\n",
        "plt.grid(True)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
