{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qa4jZeHiQDNj"
   },
   "outputs": [],
   "source": [
    "import tiktoken\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import functional as F\n",
    "from einops import rearrange, reduce, repeat\n",
    "\n",
    "from datasets import load_dataset\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from dataclasses import dataclass\n",
    "\n",
    "import math\n",
    "import os\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "seed = 24\n",
    "np.random.seed(seed)\n",
    "\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU\n",
    "\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "keZrq8EAOfvd",
    "outputId": "237e727f-66cf-45b4-e80a-ffb986d36117"
   },
   "outputs": [],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "@dataclass\n",
    "class HyperparametersConfig:\n",
    "    batch_size: int = 64          # micro-batch size B\n",
    "    max_iters: int = 19073        # equals max_steps in your loop\n",
    "    eval_interval: int = 250      # you eval every 250 steps\n",
    "    learning_rate: float = 6e-3   # max_lr in your schedule\n",
    "    eval_iters: int = 20          # matches val_loss_steps\n",
    "    dropout: float = 0.0          # no dropout in the model\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ModelConfig:\n",
    "    context_length: int = 1024    # T\n",
    "    batch_size: int = 64          # keep consistent with HyperparametersConfig\n",
    "    vocab_size: int = 50304       # you construct GPT with 50304\n",
    "    n_head: int = 4\n",
    "    n_layer: int = 1\n",
    "    n_embd: int = 172"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ij2T8rw8PfJH"
   },
   "outputs": [],
   "source": [
    "class SelfAttention(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        assert config.n_embd % config.n_head == 0\n",
    "        self.config = config\n",
    "\n",
    "        self.n_head = config.n_head\n",
    "        self.n_embd = config.n_embd\n",
    "        self.head_size = self.n_embd // self.n_head\n",
    "\n",
    "        self.qkv_proj = nn.Linear(self.n_embd, 3 * self.n_embd, bias=False)\n",
    "        self.residual_proj = nn.Linear(self.n_embd, self.n_embd)\n",
    "\n",
    "\n",
    "        self.register_buffer('mask', torch.tril(torch.ones(config.context_length, config.context_length)).view(1, 1, config.context_length, config.context_length))\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, T, C = x.shape\n",
    "\n",
    "        qkv = self.qkv_proj(x) # (B, T, 3 * n_embd) = (B, T, 3 * n_head * n_head_size)\n",
    "        qkv = rearrange(qkv, 'b t (n hn hs) -> b hn t (n hs)', n = 3, hn=self.config.n_head, hs = self.head_size) # (B, T, n_head, 3 * head_size)\n",
    "\n",
    "        q, k, v = qkv.split(self.head_size, dim=-1) # 3 view of shape (B, n_head, T, head_size)\n",
    "\n",
    "        attn = q @ k.transpose(-2, -1) * self.head_size ** (-0.5) # (B, n_head, T, T)\n",
    "        attn = attn.masked_fill(self.mask[:,:,:T, :T] == 0, float('-inf')) # (B, n_head, T, T)\n",
    "        attn = F.softmax(attn, dim=-1) # (B, n_head, T, T)\n",
    "\n",
    "        out = attn @ v # (B, n_head, T, head_size)\n",
    "        out = out.transpose(1, 2).contiguous().view(B, T, C) # (B, T, n_embd)\n",
    "        out = self.residual_proj(out)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "        return out\n",
    "\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "\n",
    "        self.first_layer = nn.Linear(config.n_embd, 4 * config.n_embd)\n",
    "        self.gelu = nn.GELU(approximate = 'tanh')\n",
    "        self.residual_proj = nn.Linear(4 * config.n_embd, config.n_embd)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.first_layer(x)\n",
    "        x = self.gelu(x)\n",
    "        x = self.residual_proj(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "class DecoderBlock(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.self_attention = SelfAttention(config)\n",
    "        self.layer_norm1 = nn.LayerNorm(config.n_embd)\n",
    "        self.mlp = MLP(config)\n",
    "        self.layer_norm2 = nn.LayerNorm(config.n_embd)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.self_attention(self.layer_norm1(x))\n",
    "        x = x + self.mlp(self.layer_norm2(x))\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "\n",
    "class GPT(nn.Module):\n",
    "    def __init__(self, config):\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "\n",
    "        self.transformer = nn.ModuleDict(\n",
    "            dict(\n",
    "                emb_layer = nn.Embedding(config.vocab_size, config.n_embd),\n",
    "                pos_layer = nn.Embedding(config.context_length, config.n_embd),\n",
    "                decoder_blocks = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)]),\n",
    "                layer_norm = nn.LayerNorm(config.n_embd)\n",
    "            )\n",
    "        )\n",
    "\n",
    "        self.output_layer = nn.Linear(config.n_embd, config.vocab_size, bias = False)\n",
    "\n",
    "        self.transformer.emb_layer.weight = self.output_layer.weight\n",
    "\n",
    "    def forward(self, x):\n",
    "        B, T = x.shape\n",
    "        assert T <= self.config.context_length\n",
    "\n",
    "        pos = torch.arange(T, dtype=torch.long, device=device) # (T)\n",
    "        pos_emb = self.transformer[\"pos_layer\"](pos) # (T) --> (T, vocab_size) --> (T, n_embd)\n",
    "        token_emb = self.transformer[\"emb_layer\"](x) # (B, T, n_embd)\n",
    "\n",
    "        x = token_emb + pos_emb\n",
    "\n",
    "        for block in self.transformer[\"decoder_blocks\"]:\n",
    "          x = block(x)\n",
    "\n",
    "        x = self.transformer[\"layer_norm\"](x)\n",
    "        logits = self.output_layer(x)\n",
    "\n",
    "        return logits\n",
    "\n",
    "    def generate(self, idx, max_new_tokens):\n",
    "        for _ in range(max_new_tokens):\n",
    "            idx_cond = idx[:, -self.config.context_length:]\n",
    "            logits = self(idx_cond)\n",
    "            logits = logits[:, -1, :] # becomes (B, C)\n",
    "            probs = F.softmax(logits, dim=-1) # (B, C)\n",
    "            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n",
    "            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
    "        return idx\n",
    "\n",
    "    def generate_top(self, idx, max_new_tokens, top_k=10):\n",
    "      \"\"\"\n",
    "      Generates tokens one at a time, up to `max_new_tokens`,\n",
    "      using top-k sampling.\n",
    "\n",
    "      Arguments:\n",
    "      - idx: (B, T) tensor with the current context indices.\n",
    "      - max_new_tokens: number of new tokens to generate.\n",
    "      - top_k: how many top logits to keep for sampling.\n",
    "      \"\"\"\n",
    "      for _ in range(max_new_tokens):\n",
    "          idx_cond = idx[:, -self.config.context_length:]\n",
    "\n",
    "          logits = self(idx_cond)          # (B, T, vocab_size)\n",
    "\n",
    "          logits = logits[:, -1, :]        # (B, vocab_size)\n",
    "\n",
    "          if top_k is not None:\n",
    "              values, indices = torch.topk(logits, k=top_k, dim=-1)\n",
    "\n",
    "              min_values = values[:, -1].unsqueeze(-1)  # shape (B, 1)\n",
    "              logits[logits < min_values] = -1e10\n",
    "\n",
    "          probs = F.softmax(logits, dim=-1)  # (B, vocab_size)\n",
    "\n",
    "          idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)\n",
    "\n",
    "          idx = torch.cat((idx, idx_next), dim=1)             # (B, T+1)\n",
    "\n",
    "      return idx\n",
    "    @torch.no_grad()\n",
    "    def generate_nucleus(self, start_tokens, max_new_tokens, top_p = 0.95, temperature = 1.0):\n",
    "      for _ in range(max_new_tokens):\n",
    "        start_tokens_cut = start_tokens[:, -self.config.context_length:]\n",
    "        logits = self(start_tokens_cut)\n",
    "        logits = logits[:, -1, :] / temperature # (B, vocab_size)\n",
    "        sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)\n",
    "        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)\n",
    "        mask = cumulative_probs > top_p\n",
    "        idx = (mask == True).nonzero(as_tuple=True)[0][0]\n",
    "        mask[idx] = False\n",
    "        logits[:, sorted_indices[mask]] = float('-inf')\n",
    "        probs = F.softmax(logits, dim=-1)\n",
    "        next_tokens = torch.multinomial(probs, num_samples=1)\n",
    "        start_tokens = torch.cat((start_tokens, next_tokens), dim=1) # (B, 1)\n",
    "\n",
    "      return start_tokens\n",
    "\n",
    "    def generate_nucleus1(self, idx, max_new_tokens, top_p=0.95, temperature=1.0):\n",
    "        \"\"\"\n",
    "        Generates tokens up to `max_new_tokens` using nucleus (top-p) sampling\n",
    "        with an adjustable temperature.\n",
    "\n",
    "        Args:\n",
    "            idx (torch.LongTensor): (B, T) tensor of current context token indices.\n",
    "            max_new_tokens (int): Number of tokens to generate.\n",
    "            top_p (float): Probability threshold for nucleus sampling.\n",
    "            temperature (float): Temperature for scaling logits (>= 0.0).\n",
    "                                - 1.0 is the 'base' temperature (no scaling).\n",
    "                                - < 1.0 makes the distribution sharper.\n",
    "                                - > 1.0 makes the distribution more flat/uniform.\n",
    "        \"\"\"\n",
    "        for _ in range(max_new_tokens):\n",
    "            idx_cond = idx[:, -self.config.context_length:]\n",
    "\n",
    "            logits = self(idx_cond)  # (B, T, vocab_size)\n",
    "            logits = logits[:, -1, :]  # Focus on the last time step: (B, vocab_size)\n",
    "\n",
    "            if temperature != 1.0:\n",
    "                logits = logits / temperature\n",
    "\n",
    "            sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)\n",
    "\n",
    "            sorted_probs = F.softmax(sorted_logits, dim=-1)\n",
    "            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)\n",
    "\n",
    "            sorted_indices_to_remove = cumulative_probs > top_p\n",
    "\n",
    "            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()\n",
    "            sorted_indices_to_remove[..., 0] = False\n",
    "\n",
    "            for batch_idx in range(logits.size(0)):\n",
    "                remove_indices = sorted_indices[batch_idx][sorted_indices_to_remove[batch_idx]]\n",
    "                logits[batch_idx][remove_indices] = -1e10\n",
    "\n",
    "            probs = F.softmax(logits, dim=-1)\n",
    "            idx_next = torch.multinomial(probs, num_samples=1)  # (B, 1)\n",
    "\n",
    "            idx = torch.cat([idx, idx_next], dim=1)  # (B, T+1)\n",
    "\n",
    "        return idx\n",
    "\n",
    "\n",
    "def load_tokens(filename):\n",
    "    npt = np.load(filename)\n",
    "    npt = npt.astype(np.int32) # added after video\n",
    "    ptt = torch.tensor(npt, dtype=torch.long)\n",
    "    return ptt\n",
    "\n",
    "class DataLoaderLite:\n",
    "    def __init__(self, B, T, process_rank, num_processes, split,\n",
    "                 data_root=\"edu_fineweb10B\", verbose=None):\n",
    "        self.B = B\n",
    "        self.T = T\n",
    "        self.process_rank = process_rank\n",
    "        self.num_processes = num_processes\n",
    "        assert split in {'train', 'val'}\n",
    "\n",
    "        self.verbose = (process_rank == 0) if (verbose is None) else bool(verbose)\n",
    "\n",
    "        shards = [s for s in sorted(os.listdir(data_root)) if split in s]\n",
    "        shards = [os.path.join(data_root, s) for s in shards]\n",
    "        assert len(shards) > 0, f\"no shards found for split {split} in {data_root}\"\n",
    "        self.shards = shards\n",
    "        if self.verbose:\n",
    "            print(f\"found {len(shards)} shards for split {split}\")\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.current_shard = 0\n",
    "        self.tokens = load_tokens(self.shards[self.current_shard])\n",
    "        self.current_position = self.B * self.T * self.process_rank\n",
    "\n",
    "    def next_batch(self):\n",
    "        B, T = self.B, self.T\n",
    "        buf = self.tokens[self.current_position : self.current_position + B*T + 1]\n",
    "        x = (buf[:-1]).view(B, T)\n",
    "        y = (buf[1:]).view(B, T)\n",
    "        self.current_position += B * T * self.num_processes\n",
    "        if self.current_position + (B * T * self.num_processes + 1) > len(self.tokens):\n",
    "            self.current_shard = (self.current_shard + 1) % len(self.shards)\n",
    "            self.tokens = load_tokens(self.shards[self.current_shard])\n",
    "            self.current_position = B * T * self.process_rank\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ExqIH5NpfUCv"
   },
   "outputs": [],
   "source": [
    "@torch.no_grad()\n",
    "def estimate_loss(data_loader, run_on_whole = False):\n",
    "    out = 0\n",
    "    model.eval()\n",
    "    iters = int(len(data_loader.tokens) // (data_loader.B * data_loader.T)) if run_on_whole else config_hyperparams.eval_iters\n",
    "    losses = torch.zeros(iters)\n",
    "    for k in range(iters):\n",
    "        X, Y = data_loader.next_batch()\n",
    "        X, Y = X.to(device), Y.to(device)\n",
    "        logits = model(X)\n",
    "        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), Y.view(-1))\n",
    "        losses[k] = loss.item()\n",
    "    out = losses.mean()\n",
    "    model.train()\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_model = ModelConfig()\n",
    "config_hyperparams = HyperparametersConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config_hyperparams = HyperparametersConfig()\n",
    "config_model = ModelConfig()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MccMWdPmfEwL",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "process_rank = 0\n",
    "num_processes = 1\n",
    "B, T = config_model.batch_size, config_model.context_length\n",
    "train_loader = DataLoaderLite(B=B, T=T, process_rank=process_rank,\n",
    "                              num_processes=num_processes, split=\"train\")\n",
    "val_loader   = DataLoaderLite(B=B, T=T, process_rank=process_rank,\n",
    "                              num_processes=num_processes, split=\"val\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANWgE9GPdzNJ"
   },
   "outputs": [],
   "source": [
    "model = GPT(config_model)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sF2UnmTZd2X4",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=config_hyperparams.learning_rate)\n",
    "\n",
    "for iter in tqdm(range(config_hyperparams.max_iters)):\n",
    "\n",
    "    if iter % config_hyperparams.eval_interval == 0 or iter == config_hyperparams.max_iters - 1:\n",
    "        losses = estimate_loss(val_loader, run_on_whole=True)\n",
    "        print(f\"step {iter}: train loss {losses:.4f}\")\n",
    "\n",
    "    xb, yb = train_loader.next_batch()\n",
    "    xb, yb = xb.to(device), yb.to(device)\n",
    "\n",
    "    logits = model(xb)\n",
    "    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), yb.view(-1))\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    print(loss)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.AdamW(model.parameters(), lr=config_hyperparams.learning_rate)\n",
    "\n",
    "beta = 0.99           # EMA smoothing (≈100-step effective window)\n",
    "ema = None\n",
    "last_val = None\n",
    "\n",
    "pbar = tqdm(range(config_hyperparams.max_iters), dynamic_ncols=True, desc=\"train\")\n",
    "\n",
    "for it in pbar:\n",
    "    xb, yb = train_loader.next_batch()\n",
    "    xb, yb = xb.to(device), yb.to(device)\n",
    "\n",
    "    logits = model(xb)\n",
    "    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), yb.view(-1))\n",
    "    optimizer.zero_grad(set_to_none=True)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    l = float(loss.item())\n",
    "\n",
    "    if (it % config_hyperparams.eval_interval == 0 and it > 0) or it == (config_hyperparams.max_iters - 1):\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            v = estimate_loss(val_loader, run_on_whole=True)   # your function\n",
    "            last_val = float(v if not isinstance(v, dict) else v.get(\"val\", next(iter(v.values()))))\n",
    "        model.train()\n",
    "\n",
    "    pbar.set_postfix(loss=f\"{l:.4f}\", val=(\"--\" if last_val is None else f\"{last_val:.4f}\"))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = \"fineweb10B-onelayer_transformer.pt\"\n",
    "torch.save(model.state_dict(), save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = \"data/fineweb10B-onelayer_transformer.pt\"\n",
    "model.load_state_dict(torch.load(save_path, map_location=torch.device('cpu')))\n",
    "model.eval()  # Set to eval mode if you are doing inference\n",
    "\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "print(f\"Total params: {total_params:,}\")\n",
    "print(f\"Trainable params: {trainable_params:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "enc = tiktoken.get_encoding(\"gpt2\")\n",
    "text = \"life is \"\n",
    "context = torch.tensor(enc.encode(text)).unsqueeze(0).to(device)                 \n",
    "print(enc.decode(model.generate_nucleus1(context, max_new_tokens=100)[0].tolist()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}