{
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# README\n",
        "\n",
        "This notebook demonstrates abrupt learning for the MWS task discussed in the main paper. We log loss (`train_loss`), accuracy (`train_acc`), partial solution (`idx0_acc`), repetition frequency (`model_repeat_frac`) and cosine similarity between hidden states (`mean_cosine_sim`). We also log attention progress measure (`att_prog_measure`) and attention maps.\n",
        "\n",
        "For convenience, we log these metrics / plots using Wandb. Please add your Wandb login key and project name in the last code cell of this notebook."
      ],
      "metadata": {
        "id": "p4oc5PtIhdbm"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YiF5Vq1LGhEw"
      },
      "outputs": [],
      "source": [
        "!pip install dotmap\n",
        "\n",
        "import math\n",
        "import random\n",
        "import yaml\n",
        "import argparse\n",
        "from dotmap import DotMap\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.nn import functional as F\n",
        "from torch.optim import Adam\n",
        "from torch.nn.functional import cosine_similarity\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import wandb"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "id": "pNvJqb60GnE3"
      },
      "outputs": [],
      "source": [
        "# Model definition\n",
        "\"\"\"\n",
        "(Modified version of) Andrej Karpathy's minGPT implementation (https://github.com/karpathy/minGPT/blob/master/mingpt/model.py)\n",
        "\n",
        "Full definition of a GPT Language Model, all of it in this single file.\n",
        "\n",
        "References:\n",
        "1) the official GPT-2 TensorFlow implementation released by OpenAI:\n",
        "https://github.com/openai/gpt-2/blob/master/src/model.py\n",
        "2) huggingface/transformers PyTorch implementation:\n",
        "https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py\n",
        "\"\"\"\n",
        "\n",
        "# from mingpt.utils import CfgNode as CN\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "\n",
        "class NewGELU(nn.Module):\n",
        "    \"\"\"\n",
        "    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).\n",
        "    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415\n",
        "    \"\"\"\n",
        "\n",
        "    def forward(self, x):\n",
        "        return (\n",
        "            0.5\n",
        "            * x\n",
        "            * (\n",
        "                1.0\n",
        "                + torch.tanh(\n",
        "                    math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))\n",
        "                )\n",
        "            )\n",
        "        )\n",
        "\n",
        "\n",
        "class CausalSelfAttention(nn.Module):\n",
        "    \"\"\"\n",
        "    A vanilla multi-head masked self-attention layer with a projection at the end.\n",
        "    It is possible to use torch.nn.MultiheadAttention here but I am including an\n",
        "    explicit implementation here to show that there is nothing too scary here.\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, config, return_att=False):\n",
        "        super().__init__()\n",
        "        assert config.n_embd % config.n_head == 0\n",
        "\n",
        "        # key, query, value projections for all heads\n",
        "        self.k = nn.Linear(config.n_embd, config.n_embd)\n",
        "        self.q = nn.Linear(config.n_embd, config.n_embd)\n",
        "        self.v = nn.Linear(config.n_embd, config.n_embd)\n",
        "\n",
        "        # output projection\n",
        "        self.c_proj = nn.Linear(config.n_embd, config.n_embd)\n",
        "\n",
        "        # causal mask to ensure that attention is only applied to the left in the input sequence\n",
        "        self.register_buffer(\n",
        "            \"bias\",\n",
        "            torch.tril(torch.ones(config.block_size, config.block_size)).view(\n",
        "                1, 1, config.block_size, config.block_size\n",
        "            ),\n",
        "        )\n",
        "\n",
        "        self.n_head = config.n_head\n",
        "        self.n_embd = config.n_embd\n",
        "        self.return_att = return_att\n",
        "\n",
        "    def forward(self, x):\n",
        "        B, T, C = (\n",
        "            x.size()\n",
        "        )  # batch size, sequence length, embedding dimensionality (n_embd)\n",
        "\n",
        "        # calculate query, key, values for all heads\n",
        "        q = self.q(x)\n",
        "        k = self.k(x)\n",
        "        v = self.v(x)\n",
        "\n",
        "        k = k.view(B, T, self.n_head, C // self.n_head).transpose(\n",
        "            1, 2\n",
        "        )  # (B, nh, T, hs)\n",
        "        q = q.view(B, T, self.n_head, C // self.n_head).transpose(\n",
        "            1, 2\n",
        "        )  # (B, nh, T, hs)\n",
        "        v = v.view(B, T, self.n_head, C // self.n_head).transpose(\n",
        "            1, 2\n",
        "        )  # (B, nh, T, hs)\n",
        "\n",
        "        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)\n",
        "        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))\n",
        "        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, 0.0)\n",
        "\n",
        "        att_copy = att.clone().detach()\n",
        "\n",
        "        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n",
        "        y = (\n",
        "            y.transpose(1, 2).contiguous().view(B, T, C)\n",
        "        )  # re-assemble all head outputs side by side\n",
        "\n",
        "        # output projection\n",
        "        y = self.c_proj(y)\n",
        "\n",
        "        if self.return_att:\n",
        "            return y, att_copy\n",
        "\n",
        "        return y\n",
        "\n",
        "\n",
        "class Block(nn.Module):\n",
        "    \"\"\"an unassuming Transformer block\"\"\"\n",
        "\n",
        "    def __init__(self, config, return_att=False):\n",
        "        super().__init__()\n",
        "        self.ln_1 = nn.LayerNorm(config.n_embd)\n",
        "        self.attn = CausalSelfAttention(config, return_att=return_att)\n",
        "        self.ln_2 = nn.LayerNorm(config.n_embd)\n",
        "        self.mlp = nn.ModuleDict(\n",
        "            dict(\n",
        "                c_fc=nn.Linear(config.n_embd, 4 * config.n_embd),\n",
        "                c_proj=nn.Linear(4 * config.n_embd, config.n_embd),\n",
        "                act=NewGELU(),\n",
        "                # dropout=nn.Dropout(config.resid_pdrop),\n",
        "            )\n",
        "        )\n",
        "\n",
        "        m = self.mlp\n",
        "        self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x)))  # MLP forward\n",
        "        self.return_att = return_att\n",
        "\n",
        "    def forward(self, x):\n",
        "        if self.return_att:\n",
        "            x_prev, att = self.attn(self.ln_1(x))\n",
        "            x = x + x_prev\n",
        "\n",
        "            x = x + self.mlpf(self.ln_2(x))\n",
        "\n",
        "            return x, att\n",
        "\n",
        "        else:\n",
        "            x = x + self.attn(self.ln_1(x))\n",
        "            x = x + self.mlpf(self.ln_2(x))\n",
        "            return x\n",
        "\n",
        "\n",
        "class GPTLinear(nn.Module):\n",
        "    \"\"\"GPT Language Model\"\"\"\n",
        "\n",
        "    def __init__(self, config, return_att=False):\n",
        "        super().__init__()\n",
        "        assert config.vocab_size is not None\n",
        "        assert config.block_size is not None\n",
        "        self.block_size = config.block_size\n",
        "        self.return_att = return_att\n",
        "\n",
        "        self.transformer = nn.ModuleDict(\n",
        "            dict(\n",
        "                wte=nn.Embedding(config.vocab_size, config.n_embd),\n",
        "                wpe=nn.Embedding(config.block_size, config.n_embd),\n",
        "                h=nn.ModuleList(\n",
        "                    [\n",
        "                        Block(config, return_att=self.return_att)\n",
        "                        for _ in range(config.n_layer)\n",
        "                    ]\n",
        "                ),\n",
        "                ln_f=nn.LayerNorm(config.n_embd),\n",
        "            )\n",
        "        )\n",
        "        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
        "\n",
        "        # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper\n",
        "        self.apply(self._init_weights)\n",
        "        for pn, p in self.named_parameters():\n",
        "            if pn.endswith(\"c_proj.weight\"):\n",
        "                torch.nn.init.normal_(\n",
        "                    p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)\n",
        "                )\n",
        "\n",
        "        # report number of parameters (note we don't count the decoder parameters in lm_head)\n",
        "        n_params = sum(p.numel() for p in self.transformer.parameters())\n",
        "        # print(\"number of parameters: %.2fM\" % (n_params / 1e6,))\n",
        "\n",
        "    def _init_weights(self, module):\n",
        "        if isinstance(module, nn.Linear):\n",
        "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
        "            if module.bias is not None:\n",
        "                torch.nn.init.zeros_(module.bias)\n",
        "        elif isinstance(module, nn.Embedding):\n",
        "            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
        "        elif isinstance(module, nn.LayerNorm):\n",
        "            torch.nn.init.zeros_(module.bias)\n",
        "            torch.nn.init.ones_(module.weight)\n",
        "\n",
        "\n",
        "    # Only used for weight decay experiments -------------------------------------------\n",
        "    def configure_optimizers(self, train_config):\n",
        "        \"\"\"\n",
        "        This long function is unfortunately doing something very simple and is being very defensive:\n",
        "        We are separating out all parameters of the model into two buckets: those that will experience\n",
        "        weight decay for regularization and those that won't (biases, and layernorm/embedding weights).\n",
        "        We are then returning the PyTorch optimizer object.\n",
        "        \"\"\"\n",
        "\n",
        "        # separate out all parameters to those that will and won't experience regularizing weight decay\n",
        "        decay = set()\n",
        "        no_decay = set()\n",
        "        whitelist_weight_modules = (torch.nn.Linear, )\n",
        "        blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)\n",
        "        for mn, m in self.named_modules():\n",
        "            for pn, p in m.named_parameters():\n",
        "                fpn = '%s.%s' % (mn, pn) if mn else pn # full param name\n",
        "                # random note: because named_modules and named_parameters are recursive\n",
        "                # we will see the same tensors p many many times. but doing it this way\n",
        "                # allows us to know which parent module any tensor p belongs to...\n",
        "                if pn.endswith('bias'):\n",
        "                    # all biases will not be decayed\n",
        "                    no_decay.add(fpn)\n",
        "                elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):\n",
        "                    # weights of whitelist modules will be weight decayed\n",
        "                    decay.add(fpn)\n",
        "                elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):\n",
        "                    # weights of blacklist modules will NOT be weight decayed\n",
        "                    no_decay.add(fpn)\n",
        "\n",
        "        # validate that we considered every parameter\n",
        "        param_dict = {pn: p for pn, p in self.named_parameters()}\n",
        "        inter_params = decay & no_decay\n",
        "        union_params = decay | no_decay\n",
        "        assert len(inter_params) == 0, \"parameters %s made it into both decay/no_decay sets!\" % (str(inter_params), )\n",
        "        assert len(param_dict.keys() - union_params) == 0, \"parameters %s were not separated into either decay/no_decay set!\" \\\n",
        "                                                    % (str(param_dict.keys() - union_params), )\n",
        "\n",
        "        # create the pytorch optimizer object\n",
        "        optim_groups = [\n",
        "            {\"params\": [param_dict[pn] for pn in sorted(list(decay))], \"weight_decay\": train_config.wd},\n",
        "            {\"params\": [param_dict[pn] for pn in sorted(list(no_decay))], \"weight_decay\": 0.0},\n",
        "        ]\n",
        "        optimizer = torch.optim.AdamW(optim_groups, lr=train_config.lr)\n",
        "        return optimizer\n",
        "    # -------------------------------------------\n",
        "\n",
        "\n",
        "    def forward(self, idx, targets=None):\n",
        "        device = idx.device\n",
        "        b, t = idx.size()\n",
        "        assert (\n",
        "            t <= self.block_size\n",
        "        ), f\"Cannot forward sequence of length {t}, block size is only {self.block_size}\"\n",
        "        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(\n",
        "            0\n",
        "        )  # shape (1, t)\n",
        "\n",
        "        # forward the GPT model itself\n",
        "        tok_emb = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)\n",
        "        pos_emb = self.transformer.wpe(pos)  # position embeddings of shape (1, t, n_embd)\n",
        "        x = tok_emb + pos_emb\n",
        "\n",
        "        for block in self.transformer.h:\n",
        "            if self.return_att:\n",
        "                x, attn_map = block(x)\n",
        "            else:\n",
        "                x = block(x)\n",
        "\n",
        "\n",
        "        x = self.transformer.ln_f(x)\n",
        "\n",
        "        # Track residual state before LM head for representation collapse\n",
        "        pre_lm_h = x.clone().detach()\n",
        "\n",
        "        # Final logits\n",
        "        logits = self.lm_head(x)\n",
        "\n",
        "        # if we are given some desired targets also calculate the loss\n",
        "        loss = None\n",
        "        if targets is not None:\n",
        "            loss = F.cross_entropy(\n",
        "                logits.reshape(-1, logits.size(-1)),\n",
        "                targets.reshape(-1),\n",
        "                ignore_index=-1,\n",
        "            )\n",
        "\n",
        "        if self.return_att:\n",
        "            return attn_map, pre_lm_h, logits, loss\n",
        "\n",
        "        return pre_lm_h, logits, loss\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def generate(\n",
        "        self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None\n",
        "    ):\n",
        "        \"\"\"\n",
        "        Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n",
        "        the sequence max_new_tokens times, feeding the predictions back into the model each time.\n",
        "        Most likely you'll want to make sure to be in model.eval() mode of operation for this.\n",
        "        \"\"\"\n",
        "\n",
        "        for _ in range(max_new_tokens):\n",
        "            # if the sequence context is growing too long we must crop it at block_size\n",
        "            idx_cond = (\n",
        "                idx if idx.size(1) <= self.block_size else idx[:, -self.block_size :]\n",
        "            )\n",
        "            # forward the model to get the logits for the index in the sequence\n",
        "            _, _, logits, _ = self(idx_cond)\n",
        "            # pluck the logits at the final step and scale by desired temperature\n",
        "            logits = logits[:, -1, :] / temperature\n",
        "            # optionally crop the logits to only the top k options\n",
        "            if top_k is not None:\n",
        "                v, _ = torch.topk(logits, top_k)\n",
        "                logits[logits < v[:, [-1]]] = -float(\"Inf\")\n",
        "            # apply softmax to convert logits to (normalized) probabilities\n",
        "            probs = F.softmax(logits, dim=-1)\n",
        "            # either sample from the distribution or take the most likely element\n",
        "            if do_sample:\n",
        "                idx_next = torch.multinomial(probs, num_samples=1)\n",
        "            else:\n",
        "                _, idx_next = torch.topk(probs, k=1, dim=-1)\n",
        "            # append sampled index to the running sequence and continue\n",
        "            idx = torch.cat((idx, idx_next), dim=1)\n",
        "\n",
        "        return idx\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "Q4VrcEUrGndW"
      },
      "outputs": [],
      "source": [
        "# Data\n",
        "\n",
        "class MovingWindowSum:\n",
        "    def __init__(self, min_num=1, max_num=16, k=2, p=17, sep=17, device=\"cuda\"):\n",
        "        self.min_num = min_num\n",
        "        self.max_num = max_num\n",
        "        self.k = k\n",
        "        self.p = p\n",
        "        self.sep = sep\n",
        "        self.device = device\n",
        "        assert self.p > self.max_num\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def sample(\n",
        "        self,\n",
        "        num_samples,\n",
        "        num_tokens,\n",
        "    ):\n",
        "        random_ints = torch.randint(\n",
        "            low=self.min_num, high=self.max_num + 1, size=(num_samples, num_tokens)\n",
        "        ).to(self.device)\n",
        "\n",
        "        random_ints_np = random_ints.detach().cpu().numpy()\n",
        "        convolution = torch.stack(\n",
        "            [\n",
        "                torch.from_numpy(\n",
        "                    np.convolve(\n",
        "                        random_ints_np[i],\n",
        "                        np.ones(self.k),\n",
        "                        mode=\"valid\",\n",
        "                    )\n",
        "                )\n",
        "                for i in range(random_ints.shape[0])\n",
        "            ]\n",
        "        )\n",
        "\n",
        "        moving_sum = random_ints.clone().detach()\n",
        "        moving_sum[:, self.k - 1 :] = convolution\n",
        "\n",
        "        # for i in range(num_samples):\n",
        "        #     for j in range(0, self.k - 1):\n",
        "        #         if moving_sum[i, j] != random_ints[i, j]:\n",
        "        #             print(f\"ERROR! {i} {j}\")\n",
        "        #     for j in range(self.k - 1, num_tokens):\n",
        "        #         if moving_sum[i, j] != torch.sum(random_ints[i, j-self.k+1:j+1]):\n",
        "        #             print(f\"ERROR! {i} {j}\")\n",
        "\n",
        "        # exit()\n",
        "        samples = (\n",
        "            torch.cat(\n",
        "                [\n",
        "                    random_ints,\n",
        "                    self.sep * torch.ones(size=(num_samples, 1)).to(self.device),\n",
        "                    torch.remainder(input=moving_sum, other=self.p),\n",
        "                ],\n",
        "                axis=-1,\n",
        "            )\n",
        "            .to(int)\n",
        "            .detach()\n",
        "        )\n",
        "\n",
        "        return samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "GHF1-ylNGouE"
      },
      "outputs": [],
      "source": [
        "# Config\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "config = {\n",
        "'model':\n",
        "  {\n",
        "    'n_layer': 1,\n",
        "    'n_head': 1,\n",
        "    'n_embd': 256,\n",
        "    'linear': True,\n",
        "  },\n",
        "\n",
        "'data':\n",
        "  {\n",
        "    'name': 'window',\n",
        "    'min_num': 1,\n",
        "    'max_num': 16,\n",
        "    'k': 2,\n",
        "    'p': 17,\n",
        "    'sep': 17,\n",
        "    'cot': False,\n",
        "    'num_tokens': 16,\n",
        "    'n_train': 256,\n",
        "    'n_test': 64,\n",
        "    'fixed_len': True,\n",
        "  },\n",
        "\n",
        "'train':\n",
        "  {\n",
        "    'lr': 0.0001,\n",
        "    'grad_clip': -1,\n",
        "    'num_steps': 400,\n",
        "    'norm_type': \"none_rank\",\n",
        "    'wandb': True,\n",
        "    'save_ckpt': False,\n",
        "    'ckpt_freq': 20,\n",
        "  }\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5HCPTLdZGqC2"
      },
      "outputs": [],
      "source": [
        "# Train\n",
        "\n",
        "def train_step(\n",
        "    model,\n",
        "    optim,\n",
        "    data_sampler,\n",
        "    step,\n",
        "    config,\n",
        "):\n",
        "    n_train, n_test, num_tokens = (\n",
        "        config.data.n_train,\n",
        "        config.data.n_test,\n",
        "        config.data.num_tokens,\n",
        "    )\n",
        "\n",
        "    data = data_sampler.sample(\n",
        "        num_samples=n_train + n_test,\n",
        "        num_tokens=num_tokens,\n",
        "    )\n",
        "\n",
        "    train_data = data[:n_train, :]\n",
        "    test_data = data[n_train:, :]\n",
        "\n",
        "    prompt_len = num_tokens + 1\n",
        "    gen_len = num_tokens\n",
        "    acc_start = num_tokens + 1\n",
        "\n",
        "    model.train()\n",
        "    optim.zero_grad(set_to_none=True)\n",
        "\n",
        "    _, _, _, loss = model(\n",
        "        train_data[:, :-1], targets=train_data[:, 1:]\n",
        "    )\n",
        "    loss.backward()\n",
        "\n",
        "    if config.train.grad_clip > 0:\n",
        "        torch.nn.utils.clip_grad_norm_(model.parameters(), config.train.grad_clip)\n",
        "\n",
        "    optim.step()\n",
        "\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        # Log train loss, train / test acc, repetition frequency\n",
        "        attn_map, pre_lm_h, _, train_loss = model(train_data[:, :-1], targets=train_data[:, 1:])\n",
        "\n",
        "        train_pred = model.generate(\n",
        "            idx=train_data[:, :prompt_len],\n",
        "            max_new_tokens=gen_len,\n",
        "        )\n",
        "        test_pred = model.generate(\n",
        "            idx=test_data[:, :prompt_len],\n",
        "            max_new_tokens=gen_len,\n",
        "        )\n",
        "\n",
        "        train_acc = torch.mean(\n",
        "            (train_pred[:, acc_start:] == train_data[:, acc_start:]).to(float)\n",
        "        ).item()\n",
        "        test_acc = torch.mean(\n",
        "            (test_pred[:, acc_start:] == test_data[:, acc_start:]).to(float)\n",
        "        ).item()\n",
        "\n",
        "        data_repeat_frac = torch.mean((test_data[:, acc_start:-1] == test_data[:, acc_start+1:]).to(float))\n",
        "        model_repeat_frac = torch.mean((test_pred[:, acc_start:-1] == test_pred[:, acc_start+1:]).to(float))\n",
        "\n",
        "        # Log attention progress measure\n",
        "        attn_map_output_seq = attn_map[:, :, acc_start-1:]\n",
        "        att_mask = torch.zeros_like(attn_map_output_seq).to(device)\n",
        "\n",
        "        att_mask[:, :, 0, 0] = 1\n",
        "        for i in range(num_tokens - 1):\n",
        "            att_mask[:, :, i + 1, i : i + 2] = 1\n",
        "\n",
        "        att_prog_measure = torch.mean(\n",
        "            torch.sum(torch.abs(attn_map_output_seq) * att_mask, dim=(-3, -2, -1)) /\n",
        "            torch.sum(torch.abs(attn_map_output_seq), dim=(-3, -2, -1)),\n",
        "            dim=0\n",
        "        )\n",
        "\n",
        "        # Log pair-wise cosine similarity between hidden states\n",
        "        embed_start = acc_start - 1\n",
        "        embed_len = gen_len\n",
        "\n",
        "        logit_cs = torch.zeros((embed_len, embed_len))\n",
        "\n",
        "        for i_1 in range(embed_start, embed_start + embed_len):\n",
        "            for i_2 in range(embed_start, i_1):\n",
        "                logit_cs[i_1 - embed_start, i_2 - embed_start] = torch.mean(\n",
        "                    (\n",
        "                        cosine_similarity(\n",
        "                            pre_lm_h[:, i_1, :], pre_lm_h[:, i_2, :], dim=-1\n",
        "                        )\n",
        "                    ), dim=0\n",
        "                )\n",
        "\n",
        "        # Log plots for cosine similarity, attention map\n",
        "        logit_fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(30, 15))\n",
        "\n",
        "        im1 = ax[0].imshow(logit_cs)\n",
        "        ax[0].set_title(\"avg pre_lm_h cosine sim\")\n",
        "        cb1 = logit_fig.colorbar(im1, location=\"right\", shrink=0.99, pad=0.02, ax=ax[0])\n",
        "\n",
        "        avg_attn_map = torch.mean(attn_map, dim=0).squeeze().detach().cpu().numpy()\n",
        "\n",
        "        im2 = ax[1].imshow(avg_attn_map)\n",
        "        ax[1].set_title(\"att map\")\n",
        "        cb4 = logit_fig.colorbar(im2, location=\"right\", shrink=0.99, pad=0.02, ax=ax[1])\n",
        "        ax[1].set_xticks(range(avg_attn_map.shape[-1]))\n",
        "        ax[1].set_yticks(range(avg_attn_map.shape[-2]))\n",
        "\n",
        "        for i1 in range(embed_len):\n",
        "            for i2 in range(embed_len):\n",
        "                text1 = ax[0].text(\n",
        "                    i2,\n",
        "                    i1,\n",
        "                    round(logit_cs[i1, i2].item(), 2),\n",
        "                    ha=\"center\",\n",
        "                    va=\"center\",\n",
        "                    color=\"w\",\n",
        "                )\n",
        "\n",
        "\n",
        "        print(\n",
        "            f\"Step {step} -- Train loss: {train_loss}, Train Acc: {train_acc} Test Acc: {test_acc}\"\n",
        "        )\n",
        "        # print(f\"input: {test_data[0]} \\n predicted:{test_pred[0]}\")\n",
        "\n",
        "        if config.train.wandb:\n",
        "\n",
        "            log_data = {\n",
        "                \"train_loss\": train_loss,\n",
        "                \"train_acc\": train_acc,\n",
        "                \"test_acc\": test_acc,\n",
        "                \"data_repeat_frac\": data_repeat_frac,\n",
        "                \"model_repeat_frac\": model_repeat_frac,\n",
        "                \"att_prog_measure\": att_prog_measure,\n",
        "                \"cosine_sim_attn_map_fig\": logit_fig,\n",
        "                \"mean_cosine_sim\": torch.sum(logit_cs[:, 1:]) / (0.5 * (gen_len-1) * (gen_len-2))\n",
        "            }\n",
        "\n",
        "            for output_pos in range(gen_len):\n",
        "                log_data.update(\n",
        "                    {\n",
        "                        f\"idx{output_pos}_acc\": torch.mean(\n",
        "                            (train_pred[:, acc_start + output_pos] == train_data[:, acc_start + output_pos]).to(float)\n",
        "                        ).item()\n",
        "                    }\n",
        "                )\n",
        "\n",
        "                if output_pos < gen_len-1:\n",
        "                    log_data.update(\n",
        "                        {\n",
        "                            f\"mean_cosine_sim_{output_pos}\": torch.sum(logit_cs[:, output_pos]) / (gen_len-1-output_pos)\n",
        "                        }\n",
        "                    )\n",
        "\n",
        "            wandb.log(log_data)\n",
        "\n",
        "        plt.close()\n",
        "        del (\n",
        "            logit_fig,\n",
        "            ax,\n",
        "            logit_cs,\n",
        "        )\n",
        "\n",
        "        if config.train.save_ckpt:\n",
        "            if (step == 0) or ((step + 1) % config.train.ckpt_freq == 0):\n",
        "                model.train()\n",
        "                torch.save(\n",
        "                    {\n",
        "                        \"epoch\": step,\n",
        "                        \"model\": model.state_dict(),\n",
        "                        \"optim\": optim.state_dict(),\n",
        "                        \"train_loss\": train_loss,\n",
        "                        \"test_acc\": test_acc,\n",
        "                    },\n",
        "                    \"./mws_k2_l1_h1_a16_n16.tar\",\n",
        "                )\n",
        "                print(f\"saved state at epoch {step} to {f'./mws_k2_l1_h1_a16_n16.tar'}\")\n",
        "\n",
        "                if config.train.wandb:\n",
        "                    model_wandb = wandb.Artifact(\n",
        "                        f\"model_step{step}\", type=\"model\"\n",
        "                    )\n",
        "                    model_wandb.add_file(f\"./mws_k2_l1_h1_a16_n16.tar\")\n",
        "                    wandb.log_artifact(model_wandb)\n",
        "                    print(\"model uploaded to wandb\")\n",
        "\n",
        "config = DotMap(config)\n",
        "\n",
        "config.model.vocab_size = max(config.data.p, config.data.max_num) + 1\n",
        "config.model.block_size = 2 * config.data.num_tokens + 1\n",
        "\n",
        "data_sampler = MovingWindowSum(\n",
        "    min_num=config.data.min_num,\n",
        "    max_num=config.data.max_num,\n",
        "    k=config.data.k,\n",
        "    p=config.data.p,\n",
        ")\n",
        "\n",
        "model = GPTLinear(config.model, return_att=True).to(device)\n",
        "optim = Adam(model.parameters(), lr=config.train.lr)\n",
        "\n",
        "if config.train.wandb:\n",
        "    wandb_run_name = 'mws'\n",
        "    # TODO: add your wandb login key here\n",
        "    wandb.login(key=\"\")\n",
        "    # TODO: specify which wandb project to log metrics / plots to\n",
        "    wandb.init(project=\"\", name=wandb_run_name, config=config)\n",
        "    wandb.watch(model)\n",
        "\n",
        "for step in range(config.train.num_steps):\n",
        "    train_step(\n",
        "        model=model,\n",
        "        optim=optim,\n",
        "        data_sampler=data_sampler,\n",
        "        step=step,\n",
        "        config=config,\n",
        "    )\n",
        "\n",
        "if config.train.wandb:\n",
        "    wandb.finish()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FglJOzSAKJaa"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}