{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe1e2ed7-725d-4c6b-a139-2cd513b7a968",
   "metadata": {},
   "source": [
    "# Softmax vs Linear attention\n",
    "\n",
    "This is the code for training one-layer SoftMax/Linear Attention to learn in-context linear regression tasks. The transformer model consists of only a single Multi-Head Attention layer. Our code is based on https://github.com/Y-Agent/ICL_linear, originally developed for the paper **In-Context Linear Regression Demystified: Training Dynamics and Mechanistic Interpretability of Multi-Head Softmax Attention** by He et al.\n",
    "\n",
    "Multiple training runs can be done using the file run_experiments.py, and the figures are plotted via the plots.ipynb notebook."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30f223bd-364e-4f9e-a3d0-60e7bf8a3bb7",
   "metadata": {},
   "source": [
    "**Install necessary packages and check if we use GPU**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7811d5bb-39da-473a-913e-f82369d36e3e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "GPU is not available. Using CPU.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from typing import Optional\n",
    "import numpy as np\n",
    "import re\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    device_name = torch.cuda.get_device_name(0)\n",
    "    device_message = f\"GPU is available.\\nUsing device: {device_name}\"\n",
    "else:\n",
    "    device_message = \"GPU is not available. Using CPU.\"\n",
    "    \n",
    "print(device_message)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da6fcfd5-9077-4379-b17c-40af5829b03c",
   "metadata": {},
   "source": [
    "## Data Sampling Setting\n",
    "Here we assume the input of data is structured as in-context example sequence $Z$ and query token $z_q$ where\n",
    "\n",
    "$$\n",
    "    Z = \\begin{bmatrix}x_1 & x_2 & \\dots & x_L \\\\ y_1 & y_2 & \\dots &y_L\\end{bmatrix}, z_q = \\begin{bmatrix}x_q \\\\ 0\\end{bmatrix}\n",
    "$$\n",
    "\n",
    "Where $x \\in \\mathbb{R}^d, y \\in \\mathbb{R}$.\n",
    "\n",
    "Data is sampled by following step:\n",
    "1. Sample $x_1,x_2,\\dots,x_L\\overset{\\mathrm{i.i.d}}{\\sim} \\mathcal{N}(0, I_d), x_q \\sim \\mathcal{N}(0, I_d)$\n",
    "2. Sample $\\beta \\sim \\mathcal{N}(0, I_d)$\n",
    "3. Sample $\\epsilon_1,\\epsilon_2,\\dots,\\epsilon_L\\overset{\\mathrm{i.i.d}}{\\sim} \\mathcal{N}(0, \\sigma^2)$\n",
    "4. $y_i = \\beta^\\top x_i + \\epsilon_i, i \\in [L]$\n",
    "\n",
    "The following code imports the data generation process from utils/data_utils. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f8ab55d0-66f6-4114-b7da-3ba57bdfda0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.data_utils import DataMethod, LinearReg, find_latest_checkpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e6ab93e-6f61-4e08-a551-acdd67b78a4e",
   "metadata": {},
   "source": [
    "#### `DataMethod` Class:\n",
    "- **Purpose**: Provides a base framework for data generation and transformation.\n",
    "- **Methods**:\n",
    "  - `__init__`: Initializes the class with an optional dictionary of parameters.\n",
    "  - `__generatedata__`: Generates synthetic data with a specified sequence length and dimension. Returns a tensor filled with random values.\n",
    "  - `__transform__`: Transforms input data by creating input-target pairs where the target is a shifted version of the input.\n",
    "\n",
    "#### `LinearReg` Class:\n",
    "- **Purpose**: Generates linear regression data based on specific parameters such as sequence length, noise level, and data size.\n",
    "- **Attributes**:\n",
    "  - `L`, `dx`, `dy`, `noise_std`, and `number_of_samples`: Parameters controlling data generation. (Note: Denote `noise_std` as `\\delta`, then the variance for $\\epsilon$ is $d \\delta^2$ where $d$ denotes `dx`)\n",
    "  - `G`: A matrix used for task-specific indexing.\n",
    "- **Methods**:\n",
    "  - `__init__`: Initializes parameters and creates the G matrix for task indexing.\n",
    "  - `__generatedata__`: Generates input data, regression coefficients, and target outputs with noise. Concatenates data for use in regression tasks.\n",
    "  - `__transform__`: Transforms input data, optionally zeroing out a specified index."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beb80188-f7d0-457b-836b-b2d57e6aeed1",
   "metadata": {},
   "source": [
    "In our training setting, we have $L=8*d, \\sigma^2=0.1$. We use a default batch size of $64$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44234d99-c80e-4b58-8563-59e053ea5391",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sample data from the linear model:\n",
      "\n",
      "z (input and output concatenated):\n",
      "Dimensions of z: torch.Size([128, 40, 6])\n",
      "\n",
      "z_q (query data):\n",
      "tensor([[[-0.2249,  0.2860, -1.6568, -2.7158, -0.3051,  0.0000]],\n",
      "\n",
      "        [[ 0.1802,  0.9375, -1.3590, -0.8258, -3.1066,  0.0000]],\n",
      "\n",
      "        [[-0.0927, -0.7899,  0.7072,  0.0131, -0.1963,  0.0000]],\n",
      "\n",
      "        [[ 0.0475,  0.2617,  0.5637,  0.9903,  0.4291,  0.0000]],\n",
      "\n",
      "        [[-1.7596, -1.5120,  0.5712, -1.1243,  0.2512,  0.0000]]])\n",
      "Dimensions of z_q: torch.Size([128, 1, 6])\n",
      "\n",
      "y_q (query targets):\n",
      "tensor([[[ 4.8991e-01]],\n",
      "\n",
      "        [[-3.2598e+00]],\n",
      "\n",
      "        [[-2.7748e-01]],\n",
      "\n",
      "        [[ 6.5674e-01]],\n",
      "\n",
      "        [[ 1.3246e-03]]])\n",
      "Dimensions of y_q: torch.Size([128, 1, 1])\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Define parameters\n",
    "d = 5\n",
    "bsize = 128\n",
    "L = 8 * d\n",
    "\n",
    "# Initialize the LinearReg object\n",
    "train_method = LinearReg({\n",
    "    \"L\": L,\n",
    "    \"dx\": d,\n",
    "    \"dy\": 1,\n",
    "    \"number_of_samples\": bsize,\n",
    "    \"noise_std\": np.sqrt(0.1),\n",
    "    \"seed\": None\n",
    "})\n",
    "\n",
    "# Generate data samples\n",
    "z_q, z, y_q = train_method.__generatedata__()\n",
    "\n",
    "# Print some samples from the generated data along with their dimensions\n",
    "print(\"Sample data from the linear model:\\n\")\n",
    "\n",
    "print(\"z (input and output concatenated):\")\n",
    "# print(z[:5])  # Print the first 5 samples\n",
    "print(f\"Dimensions of z: {z.shape}\\n\")\n",
    "\n",
    "print(\"z_q (query data):\")\n",
    "print(z_q[:5])  # Print the first 5 query samples\n",
    "print(f\"Dimensions of z_q: {z_q.shape}\\n\")\n",
    "\n",
    "print(\"y_q (query targets):\")\n",
    "print(y_q[:5])  # Print the first 5 query targets\n",
    "print(f\"Dimensions of y_q: {y_q.shape}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd682041-8d99-428a-9cce-d5f61807a154",
   "metadata": {},
   "source": [
    "## Model Class -- Single Layer Multi-Head Attention.\n",
    "\n",
    "The code below defines the number of heads of the model. Suppose the number of head is $H$. Then the output of the model is\n",
    "$$\n",
    "    \\mathtt{TF}_\\theta(Z,z_q) = \\sum_{h = 1}^HO^{(h)}V^{(h)}\\mathtt{SoftMax}(\\frac{Z^\\top K^{(h)^\\top}Q^{(h)}z_q}{\\sqrt{d+1}})\n",
    "$$\n",
    "\n",
    "Here we have\n",
    "$$\n",
    "    Q^{(h)},K^{(h)},V^{(h)} \\in \\mathbb{R}^{(d+1)\\times(d+1)}, O^{(h)} \\in \\mathbb{R}^{1\\times(d+1)}\n",
    "$$\n",
    "The training loss is the least-squared loss: \n",
    "$$\n",
    "   \\ell(\\theta) = \\mathbb{E} \\bigl [\\bigl( \\mathtt{TF}_\\theta(Z,z_q)  -y_q \\bigr)^2\\bigr].\n",
    "$$\n",
    "\n",
    "**Further Explanation of Our Model**\n",
    "\n",
    "The model is different from a default Multi-Head SoftMax Attention model because it only computes the output of the query token.\n",
    "\n",
    "However, it is **equivalent** to a default Multi-Head SoftMax Attention model that calculates\n",
    "$$\n",
    "\\mathtt{TF}^{'}_\\theta(Z^{'}) = \\sum_{h = 1}^H[O^{(h)}V^{(h)}\\mathtt{SoftMax}(\\frac{Z^{'^\\top} K^{(h)^\\top}Q^{(h)}Z^{'}}{\\sqrt{d+1}} + M)]_q\n",
    "$$\n",
    "\n",
    "where \n",
    "\n",
    "$$\n",
    "Z^{'} = [Z, z_q]\n",
    "$$\n",
    "\n",
    "and $M$ is mask matrix in $\\mathbb{R}^{(d+1）\\times(d+1)}$ where\n",
    "$$\n",
    "M_{ij} = \n",
    "\\begin{cases}\n",
    "    0,& \\text{if } i \\geqslant j\\\\\n",
    "    -\\infty,              & \\text{otherwise}\n",
    "\\end{cases}\n",
    "$$\n",
    "\n",
    "so that token will attend to tokens strictly preceding it.\n",
    "\n",
    "The subscript $q$ in $\\mathtt{TF}_\\theta(Z^{'})$ means we use query token $z_q$ position's output as the output of the model. And apply training loss to it, which the same as above.\n",
    "\n",
    "The two variants are equivalent, we use the **first** to train our problem since we focus on **single layer** cases, but we also provide the **second** type of model to make it easier to stack as **multi layer** model."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91a7d804-c69e-48fc-b90b-18ee6daaae78",
   "metadata": {},
   "source": [
    "**Note:** \n",
    "- We do not need to use positional embedding because it is okay to have permutation invariance (the first $L$ $(x,y)$ pairs are i.i.d.).\n",
    "- We do not allow $z_q$ attend to itself.\n",
    "\n",
    "We define the transformer architecture as follows. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "143e260f-7dc1-4d84-b7e1-9cd414670866",
   "metadata": {},
   "source": [
    "#### `MultiHeadAttention` Class\n",
    "\n",
    "This class implements a custom multi-head attention mechanism with flexible initialization methods and optional bias handling for linear projections.\n",
    "\n",
    "- **`__init__` Method**:\n",
    "  - Initializes the class with configurable parameters:\n",
    "    - `n_embd` (embedding dimension), `n_head` (number of heads), `init_method` (weight initialization strategy), `bias` (enable/disable bias), and `n_out` (optional output dimension).\n",
    "  - Defines linear layers (`q_proj`, `k_proj`, `v_proj`, `o_proj`) for projecting queries, keys, values, and output.\n",
    "  - Handles weight initialization based on the specified `init_method`:\n",
    "    - **\"random\"**: Standard random initialization (default behavior).\n",
    "    - **\"small_id\"**: Initializes the projection weights as small identity matrices for a more controlled initialization.\n",
    "    - **\"small_id_qk\"**: Only initializes the query and key projections with small identity matrices.\n",
    "    - **\"oppo_small_id\"**: Initializes half of the heads with a positive small identity matrix and the other half with a negative small identity matrix. This introduces diversity in the heads' initial states.\n",
    "\n",
    "- **`_attn` Method**:\n",
    "  - Implements the core attention mechanism using scaled dot-product attention:\n",
    "    - **Inputs**:\n",
    "      - `q` (queries), `k` (keys), and `v` (values) with shape `(batch_size, seq_len, n_head * n_embd)`.\n",
    "    - **Reshaping**:\n",
    "      - Splits and reshapes the inputs into separate heads.\n",
    "    - **Attention Computation**:\n",
    "      - Computes the attention scores using matrix multiplication, applies a scaling factor, and uses a softmax function to normalize the scores.\n",
    "    - **Output**:\n",
    "      - Computes the weighted sum of the values based on the attention weights.\n",
    "\n",
    "- **`forward` Method**:\n",
    "  - Defines the forward pass through the multi-head attention layer:\n",
    "    - **Inputs**:\n",
    "      - `z_q` (query input) with shape `(batch_size, 1, n_embd)` and `z` (key/value input) with shape `(batch_size, seq_len, n_embd)`.\n",
    "    - **Projection**:\n",
    "      - Projects the inputs to queries, keys, and values using the linear layers.\n",
    "    - **Attention Computation**:\n",
    "      - Calls the `_attn` method to compute attention outputs.\n",
    "    - **Output Projection**:\n",
    "      - Reshapes and projects the attention output to the desired output dimension.\n",
    "    - **Optional Output**:\n",
    "      - If `attention_out` is `True`, returns the attention weights as well.\n",
    "\n",
    "- **`extract_qk` Method**:\n",
    "  - Extracts the query and key weight matrices for each head, useful for analyzing or visualizing the learned attention patterns.\n",
    "\n",
    "- **`extract_ov` Method**:\n",
    "  - Extracts the output and value weight matrices for each head, providing insight into the learned transformations in the attention mechanism.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9acbf564-24c0-4f0a-b80b-4a6e5beab5a4",
   "metadata": {},
   "source": [
    "### Forward Function \n",
    "\n",
    "The forward function of the `MultiHeadAttention` module implements a series of linear transformations and scaled dot-product attention. Here's a breakdown of the mathematical equations representing its behavior:\n",
    "\n",
    "Given:\n",
    "- $ z_q \\in \\mathbb{R}^{B \\times 1 \\times d} $: Query input, where $ B $ is the batch size and $ d $ is the embedding dimension.\n",
    "- $ z \\in \\mathbb{R}^{B \\times S \\times d} $: Key/Value input, where $ S $ is the sequence length.\n",
    "\n",
    "### 1. Linear Projections\n",
    "The input queries ($z_q$), keys ($ z $), and values ($ z $) are projected using linear layers:\n",
    "\\begin{align}\n",
    "Q = z_q W_Q \\in \\mathbb{R}^{B \\times 1 \\times (h \\cdot d_h)}, \\quad\n",
    "K = z W_K \\in \\mathbb{R}^{B \\times S \\times (h \\cdot d_h)}, \\quad \n",
    "V = z W_V \\in \\mathbb{R}^{B \\times S \\times (h \\cdot d_h)}\n",
    "\\end{align}\n",
    "where:\n",
    "- $ W_Q, W_K, W_V \\in \\mathbb{R}^{d \\times (h \\cdot d_h)} $ are the projection matrices.\n",
    "- $ h $ is the number of heads.\n",
    "- $d_h = \\frac{d}{h} $ is the dimension of each head.\n",
    "\n",
    "### 2. Reshape and Permute\n",
    "The projected queries, keys, and values are reshaped and permuted to split the heads:\n",
    "\\begin{align}\n",
    "Q = \\text{reshape}(Q) \\in \\mathbb{R}^{B \\times h \\times 1 \\times d_h}, \\quad \n",
    "K = \\text{reshape}(K) \\in \\mathbb{R}^{B \\times h \\times S \\times d_h}, \\quad \n",
    "V = \\text{reshape}(V) \\in \\mathbb{R}^{B \\times h \\times S \\times d_h}\n",
    "\\end{align}\n",
    "\n",
    "### 3. Scaled Dot-Product Attention\n",
    "The attention scores are computed using scaled dot-product attention:\n",
    "$$\n",
    "\\text{Attention}(Q, K, V) = \\text{softmax} \\left( \\frac{QK^T}{\\sqrt{d_h}} \\right) V\n",
    "$$\n",
    "where:\n",
    "- $ QK^T \\in \\mathbb{R}^{B \\times h \\times 1 \\times S} $ represents the dot-product between queries and keys.\n",
    "- The division by $ \\sqrt{d_h} $ scales the dot-product values.\n",
    "- The softmax operation normalizes the attention scores across the sequence dimension.\n",
    "\n",
    "### 4. Concatenation of Heads\n",
    "The outputs of the attention mechanism are concatenated:\n",
    "$$\n",
    "\\text{Concat}(head_1, \\ldots, head_h) \\in \\mathbb{R}^{B \\times 1 \\times (h \\cdot d_h)}\n",
    "$$\n",
    "\n",
    "### 5. Final Linear Projection\n",
    "The concatenated output is projected back to the output dimension:\n",
    "$$\n",
    "\\text{Output} = \\text{Concat}(head_1, \\ldots, head_h) W_O \\in \\mathbb{R}^{B \\times 1 \\times d_{\\text{out}}}\n",
    "$$\n",
    "where:\n",
    "- $ W_O \\in \\mathbb{R}^{(h \\cdot d_h) \\times d_{\\text{out}}}$ is the output projection matrix, and $ d_{\\text{out}} $ is either $ d $ or a user-defined output dimension.\n",
    "\n",
    " \n",
    "The overall equation for the forward function can be represented as:\n",
    "$$\n",
    "\\text{Output} = \\text{Concat} \\left( \\text{softmax} \\left( \\frac{QK^T}{\\sqrt{d_h}} \\right) V \\right) W_O\n",
    "$$\n",
    "where:\n",
    "- $ Q, K, V $ are the projected queries, keys, and values.\n",
    "- The attention weights are computed as $ \\text{softmax} \\left( \\frac{QK^T}{\\sqrt{d_h}} \\right)$.\n",
    "- The output is obtained by concatenating the attention outputs from all heads and applying a final linear transformation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e07fcd6e-b7ee-40d3-b650-1383be92fc1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.models import MultiHeadAttention, Config"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4bec26e-d9ac-4d9c-b2d0-d5043067287a",
   "metadata": {},
   "source": [
    "## Load the whole pipeline "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "50baaee6-e2d2-4f00-930f-4f466431e3f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from run_experiments import *\n",
    "import random\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7fe61aad-6207-43c4-a843-28b40c477932",
   "metadata": {},
   "source": [
    "#### Run and Visualize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "42b616e0-8bcc-4208-8326-3fc1dbb71824",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 10\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.0694:   0%|                | 70/500001 [00:00<11:54, 699.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 1.0694, Validation Loss: 1.0707\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.7335:  10%|█▏          | 50125/500001 [00:46<06:56, 1080.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 50000 steps, Training Loss: 0.7335, Validation Loss: 0.7299\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6810:  20%|██▏        | 100122/500001 [01:32<06:09, 1082.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 100000 steps, Training Loss: 0.6810, Validation Loss: 0.7345\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6023:  30%|███▎       | 150154/500001 [02:18<05:23, 1081.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 150000 steps, Training Loss: 0.6023, Validation Loss: 0.7706\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.7451:  40%|████▍      | 200125/500001 [03:04<04:37, 1080.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 200000 steps, Training Loss: 0.7451, Validation Loss: 0.6952\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.7995:  50%|█████▌     | 250142/500001 [03:50<03:50, 1083.10it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 250000 steps, Training Loss: 0.7995, Validation Loss: 0.7233\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6861:  60%|██████▌    | 300196/500001 [04:36<03:05, 1080.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 300000 steps, Training Loss: 0.6861, Validation Loss: 0.7784\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.7221:  70%|███████▋   | 350192/500001 [05:23<02:19, 1075.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 350000 steps, Training Loss: 0.7221, Validation Loss: 0.6586\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6839:  80%|████████▊  | 400203/500001 [06:09<01:33, 1072.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 400000 steps, Training Loss: 0.6839, Validation Loss: 0.7138\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6846:  90%|█████████▉ | 450147/500001 [06:56<00:46, 1074.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 450000 steps, Training Loss: 0.6846, Validation Loss: 0.7233\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.7776: 100%|███████████| 500001/500001 [07:42<00:00, 1081.46it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 500000 steps, Training Loss: 0.7776, Validation Loss: 0.6370\n",
      "Run ended\n",
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 15\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.8916:   0%|                | 89/500001 [00:00<09:21, 889.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 0.8916, Validation Loss: 1.0995\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5442:  10%|█▎           | 50137/500001 [00:54<08:07, 922.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 50000 steps, Training Loss: 0.5442, Validation Loss: 0.6671\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6241:  20%|██▍         | 100153/500001 [01:48<07:18, 912.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 100000 steps, Training Loss: 0.6241, Validation Loss: 0.5383\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6150:  30%|███▌        | 150162/500001 [02:43<06:24, 909.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 150000 steps, Training Loss: 0.6150, Validation Loss: 0.6984\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6039:  40%|████▊       | 200145/500001 [03:38<05:28, 913.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 200000 steps, Training Loss: 0.6039, Validation Loss: 0.6635\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6089:  50%|██████      | 250147/500001 [04:32<04:31, 921.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 250000 steps, Training Loss: 0.6089, Validation Loss: 0.5600\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5781:  60%|███████▏    | 300129/500001 [05:26<03:36, 921.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 300000 steps, Training Loss: 0.5781, Validation Loss: 0.6429\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.8006:  70%|████████▍   | 350186/500001 [06:20<02:42, 921.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 350000 steps, Training Loss: 0.8006, Validation Loss: 0.5665\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5032:  80%|█████████▌  | 400146/500001 [07:15<01:48, 916.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 400000 steps, Training Loss: 0.5032, Validation Loss: 0.5548\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6657:  90%|██████████▊ | 450175/500001 [08:09<00:54, 914.39it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 450000 steps, Training Loss: 0.6657, Validation Loss: 0.6456\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6384: 100%|████████████| 500001/500001 [09:04<00:00, 918.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 500000 steps, Training Loss: 0.6384, Validation Loss: 0.4896\n",
      "Run ended\n",
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 20\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.2105:   0%|                | 79/500001 [00:00<10:33, 789.02it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 1.2105, Validation Loss: 1.0345\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5851:  10%|█▎           | 50084/500001 [01:03<09:29, 790.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 50000 steps, Training Loss: 0.5851, Validation Loss: 0.4824\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4765:  20%|██▍         | 100120/500001 [02:06<08:25, 791.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 100000 steps, Training Loss: 0.4765, Validation Loss: 0.5881\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6111:  30%|███▌        | 150152/500001 [03:08<07:20, 793.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 150000 steps, Training Loss: 0.6111, Validation Loss: 0.5537\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6311:  40%|████▊       | 200128/500001 [04:11<06:18, 791.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 200000 steps, Training Loss: 0.6311, Validation Loss: 0.5166\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4546:  50%|██████      | 250150/500001 [05:14<05:13, 796.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 250000 steps, Training Loss: 0.4546, Validation Loss: 0.4558\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5579:  60%|███████▏    | 300095/500001 [06:17<04:12, 792.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 300000 steps, Training Loss: 0.5579, Validation Loss: 0.6735\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4599:  70%|████████▍   | 350108/500001 [07:20<03:08, 794.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 350000 steps, Training Loss: 0.4599, Validation Loss: 0.5165\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5846:  80%|█████████▌  | 400138/500001 [08:23<02:06, 791.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 400000 steps, Training Loss: 0.5846, Validation Loss: 0.4053\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6449:  90%|██████████▊ | 450144/500001 [09:26<01:02, 793.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 450000 steps, Training Loss: 0.6449, Validation Loss: 0.6151\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4902: 100%|████████████| 500001/500001 [10:28<00:00, 795.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 500000 steps, Training Loss: 0.4902, Validation Loss: 0.5766\n",
      "Run ended\n",
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 25\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.1024:   0%|                | 65/500001 [00:00<12:59, 641.76it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 1.1024, Validation Loss: 1.0858\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.0484:  10%|█▎           | 50133/500001 [01:11<10:41, 701.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 50000 steps, Training Loss: 1.0484, Validation Loss: 1.1391\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.9104:  20%|██▍         | 100107/500001 [02:22<09:32, 698.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 100000 steps, Training Loss: 0.9104, Validation Loss: 0.8790\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5045:  30%|███▌        | 150123/500001 [03:34<08:23, 695.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 150000 steps, Training Loss: 0.5045, Validation Loss: 0.5312\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4587:  40%|████▊       | 200113/500001 [04:45<07:10, 696.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 200000 steps, Training Loss: 0.4587, Validation Loss: 0.5272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4437:  50%|██████      | 250111/500001 [05:57<06:02, 689.22it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 250000 steps, Training Loss: 0.4437, Validation Loss: 0.4747\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4816:  60%|███████▏    | 300137/500001 [07:08<04:46, 697.18it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 300000 steps, Training Loss: 0.4816, Validation Loss: 0.4783\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.6092:  70%|████████▍   | 350088/500001 [08:19<03:35, 696.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 350000 steps, Training Loss: 0.6092, Validation Loss: 0.4808\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4344:  80%|█████████▌  | 400083/500001 [09:31<02:23, 696.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 400000 steps, Training Loss: 0.4344, Validation Loss: 0.6301\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4839:  90%|██████████▊ | 450090/500001 [10:42<01:12, 692.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 450000 steps, Training Loss: 0.4839, Validation Loss: 0.5336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4750: 100%|████████████| 500001/500001 [11:53<00:00, 700.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 500000 steps, Training Loss: 0.4750, Validation Loss: 0.4697\n",
      "Run ended\n",
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 30\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.9734:   0%|                | 63/500001 [00:00<13:13, 629.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 0.9734, Validation Loss: 0.9844\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4411:  10%|█▎           | 50075/500001 [01:18<11:48, 635.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 50000 steps, Training Loss: 0.4411, Validation Loss: 0.5043\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4475:  20%|██▍         | 100088/500001 [02:36<10:29, 635.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 100000 steps, Training Loss: 0.4475, Validation Loss: 0.4534\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4825:  30%|███▌        | 150110/500001 [03:55<09:11, 634.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 150000 steps, Training Loss: 0.4825, Validation Loss: 0.4034\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4408:  40%|████▊       | 200097/500001 [05:13<07:52, 634.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 200000 steps, Training Loss: 0.4408, Validation Loss: 0.5181\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4140:  50%|██████      | 250105/500001 [06:32<06:33, 635.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 250000 steps, Training Loss: 0.4140, Validation Loss: 0.4847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.5490:  60%|███████▏    | 300097/500001 [07:50<05:15, 634.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 300000 steps, Training Loss: 0.5490, Validation Loss: 0.4867\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4465:  70%|████████▍   | 350094/500001 [09:09<03:57, 630.44it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 350000 steps, Training Loss: 0.4465, Validation Loss: 0.4883\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.4629:  80%|█████████▌  | 400090/500001 [10:27<02:36, 638.51it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 400000 steps, Training Loss: 0.4629, Validation Loss: 0.4419\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.3773:  90%|██████████▊ | 450092/500001 [11:46<01:18, 638.17it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 450000 steps, Training Loss: 0.3773, Validation Loss: 0.4447\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 0.3890: 100%|████████████| 500001/500001 [13:04<00:00, 637.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 500000 steps, Training Loss: 0.3890, Validation Loss: 0.4277\n",
      "Run ended\n",
      "Config(\n",
      "  n_head: 1\n",
      "  d: 5\n",
      "  L: 35\n",
      "  activation: softmax\n",
      "  method: SGD\n",
      "  learning_rate: 0.001\n",
      "  batch_size: 256\n",
      "  noise_std: 0.31622776601683794\n",
      "  optimizer_params: {}\n",
      "  training_steps: 500001\n",
      "  save_log_every_step: 1000\n",
      "  loss_log_every_step: 1000\n",
      "  validation_every: 1000\n",
      "  print_loss_every: 50000\n",
      "  n_embd (d + 1): 6\n",
      "  warm: True\n",
      "  normalize: True\n",
      ")\n",
      "Using device: cpu\n",
      "Failed to load checkpoint.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.0707:   0%|                | 58/500001 [00:00<14:32, 573.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "After 0 steps, Training Loss: 1.0707, Validation Loss: 1.0107\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training loss 1.0149:   6%|▊            | 32049/500001 [00:55<13:25, 580.94it/s]\n",
      "\n",
      "KeyboardInterrupt\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Set seed\n",
    "seed = 376\n",
    "set_seed(seed)\n",
    "\n",
    "n_head, d = 1, 5\n",
    "activation = 'softmax'\n",
    "warm = True\n",
    "\n",
    "L = 10\n",
    "config = Config(n_head=n_head, n_out=1, d=d, method='SGD',\n",
    "                  L=L,\n",
    "                  learning_rate=0.001, training_steps= 500001, batch_size=256, \n",
    "                  save_log_every_step=1000, optimizer_params={}, noise_std=np.sqrt(0.1), \n",
    "                  activation=activation, warm=warm, seed=seed)\n",
    "print(config)\n",
    "\n",
    "# Create the \"saved_models\" directory if it doesn't exist\n",
    "os.makedirs(\"saved_models\", exist_ok=True)\n",
    "# Save the model in the \"saved_models\" directory\n",
    "model_path = f\"saved_models/Activation_{activation}_L_{L}_d_{d}\"\n",
    "\n",
    "run_experiment(config, model_path, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5e1fb18-2063-4acb-92b0-247ad9762ef7",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
