{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "36487546",
   "metadata": {},
   "source": [
    "# Corner Visualization Experiment\n",
    "\n",
    "OK, we begin by loading a model.  GPT-J here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b67160ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch, baukit\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "#MODEL_NAME = \"gpt2-xl\"  # gpt2-xl or EleutherAI/gpt-j-6B\n",
    "MODEL_NAME = \"EleutherAI/gpt-j-6B\"\n",
    "model, tok = (\n",
    "    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False).to(\"cuda\"),\n",
    "    AutoTokenizer.from_pretrained(MODEL_NAME)\n",
    ")\n",
    "baukit.set_requires_grad(False, model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "840dafd6",
   "metadata": {},
   "source": [
    "I am commenting this out - but I have included a world-cities csv for doing a corner at world cities."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04240bfa",
   "metadata": {},
   "source": [
    "Here are a few hacks for experimenting with corner calculations in different ways.\n",
    "\n",
    "`make_corner_vector` tries to calculate a corner with linear algebra - it sorta comes close, but it doesn't actually do what it aims to do, because the layernorm nonlinearity interferes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f747be74",
   "metadata": {},
   "outputs": [],
   "source": [
    "generic_text = 'The quick brown fox jumped over the lazy dogs, while the rain in Spain fell mainly in the plain.'\n",
    "\n",
    "# Normalizing logits to the empirically measured scale seems to work pretty well\n",
    "def get_logit_scale(model, tok):\n",
    "    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(generic_text).items()}\n",
    "    embed_layer = [n for n, _ in model.named_modules() if 'lm_head' in n][0]\n",
    "    with baukit.Trace(model, embed_layer) as t:\n",
    "        model(**inp)\n",
    "        return t.output.max(2)[0].mean() # scale is: average maximum logit\n",
    "\n",
    "# Normalizing pre-layernorm vectors to the empirically measured scale seems to overestimate\n",
    "def get_prenorm_scale(model, tok):\n",
    "    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(generic_text).items()}\n",
    "    with baukit.Trace(model, 'transformer.ln_f', retain_input=True) as t:\n",
    "        model(**inp)\n",
    "        return (t.input ** 2).mean().sqrt() # This is about 2.86 for GPT-J, which is too big.\n",
    "\n",
    "def make_corner_vector(words, model, tok, logit_scale=None, prenorm_scale=None, add_space=True):\n",
    "    if logit_scale is None:\n",
    "        logit_scale = get_logit_scale(model, tok)\n",
    "    if prenorm_scale is None:\n",
    "        prenorm_scale = get_prenorm_scale(model, tok)\n",
    "    decoding_vectors = model.lm_head.weight\n",
    "    decoding_bias = model.lm_head.bias\n",
    "    # Form a list of target tokens\n",
    "    token_numbers = [tok((' '  if add_space else '') + w)['input_ids'][0] for w in words]\n",
    "    token_numbers = list(set(token_numbers))\n",
    "    # Solve the linear algebra\n",
    "    A = decoding_vectors[token_numbers]\n",
    "    b = logit_scale - decoding_bias[token_numbers]\n",
    "    x = torch.linalg.lstsq(A, b).solution\n",
    "    x_scale = (x ** 2).mean().sqrt()\n",
    "    prenorm_x = x / x_scale * prenorm_scale\n",
    "    return prenorm_x"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "91afddac",
   "metadata": {},
   "source": [
    "Here is an optimization-based corner-finder.\n",
    "\n",
    "`optimize_corner_vector` initializes with linear algebra but then finishes off with an optimizer.  I found that RMSprop actually works pretty well without tuning.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "00442bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize_corner_vector(words, model, tok, logit_scale=None, prenorm_scale=None,\n",
    "                           lr=1e-4, iters=500, add_space=True):\n",
    "    if logit_scale is None:\n",
    "        logit_scale = get_logit_scale(model, tok)\n",
    "    if prenorm_scale is None:\n",
    "        prenorm_scale = get_prenorm_scale(model, tok)\n",
    "    decoding_vectors = model.lm_head.weight\n",
    "    decoding_bias = model.lm_head.bias\n",
    "    # Form a list of target tokens\n",
    "    token_numbers = [tok((' '  if add_space else '') + w)['input_ids'][0] for w in words]\n",
    "    token_numbers = list(set(token_numbers))\n",
    "    # Solve the linear algebra\n",
    "    A = decoding_vectors[token_numbers]\n",
    "    b = logit_scale - decoding_bias[token_numbers]\n",
    "    x = torch.linalg.lstsq(A, b).solution\n",
    "    x_scale = (x ** 2).mean().sqrt()\n",
    "    prenorm_x = x / x_scale * prenorm_scale\n",
    "    # Now optimize to make it better.\n",
    "    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))\n",
    "    x = prenorm_x\n",
    "    x.requires_grad = True\n",
    "    best_loss = None\n",
    "    optimizer = torch.optim.RMSprop([x], lr=lr)\n",
    "    for _ in range(iters):\n",
    "        p = decoder(x)[token_numbers]\n",
    "        m = p.mean()\n",
    "        loss = (p - m).abs().mean() - m\n",
    "        if best_loss is None or loss < best_loss:\n",
    "            best_loss = loss.clone().detach()\n",
    "            result = x.clone().detach()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    return result\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31b632c3",
   "metadata": {},
   "source": [
    "Here we find the corner between eight words (eight color words, for fun)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "945c019c",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = ['red', 'green', 'blue', 'orange', 'yellow', 'purple', 'gray', 'brown']\n",
    "\n",
    "token_numbers = [tok(' ' + c)['input_ids'][0] for c in colors]\n",
    "\n",
    "v = optimize_corner_vector(colors, model, tok)\n",
    "decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))\n",
    "decoder(v)[token_numbers]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49df724",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "cmap = matplotlib.colors.ListedColormap(colors)\n",
    "\n",
    "x, y = torch.meshgrid(torch.linspace(-1, 1, 100), torch.linspace(-1, 1, 100))\n",
    "how_many = 3\n",
    "\n",
    "for j in range(4096):\n",
    "    for i in range(j+1, 4096):\n",
    "        vv = torch.zeros_like(x)[:,:,None] + v.cpu()[None,None,:]\n",
    "        vv[:,:,j] += x * 10\n",
    "        vv[:,:,i] += y * 10\n",
    "        cindex = decoder(vv.cuda())[:,:,token_numbers].argmax(dim=2).cpu()\n",
    "        if len(cindex.unique()) == len(colors):\n",
    "            print(j, i)\n",
    "            plt.scatter(x, y, c=cindex, cmap=cmap)\n",
    "            plt.axis('square')\n",
    "            plt.show()\n",
    "            how_many -= 1\n",
    "            if how_many <= 0:\n",
    "                break\n",
    "    if how_many <= 0:\n",
    "            break\n",
    "\n",
    "print('done')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b15ee86",
   "metadata": {},
   "source": [
    "## Attribute lens based-on-corner test\n",
    "\n",
    "Here is an attribute-lens style test."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5abba6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_corner_readout(model, tok, words, logit_scale=None, prenorm_scale=None):\n",
    "    # Get the corner vector\n",
    "    x = optimize_corner_vector(words, model, tok, logit_scale=logit_scale, prenorm_scale=prenorm_scale)\n",
    "\n",
    "    decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))\n",
    "    def corner_readout(h):\n",
    "        import numpy\n",
    "        cuda_h = h.cuda()\n",
    "        return decoder(cuda_h + x)\n",
    "    return corner_readout"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169545eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "with open('worldcities.csv') as w:\n",
    "    records = list(csv.DictReader(w))\n",
    "big_city_list = [r['city'] for r in records if float(r['population'] or 0) >= 1000000]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1fc957e",
   "metadata": {},
   "outputs": [],
   "source": [
    "short_city_list = [c for c in big_city_list if len(tok(' ' + c)['input_ids']) <= 1]\n",
    "len(short_city_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24361a08",
   "metadata": {},
   "outputs": [],
   "source": [
    "f = make_corner_readout(\n",
    "    model, tok,\n",
    "    short_city_list)\n",
    "\n",
    "# Try running f on some zero vectors\n",
    "probs = f(torch.zeros(1, 5, 1, 3, 2, 4096))\n",
    "print(probs.shape)\n",
    "probs.sum(dim=-1).flatten()  # Veriy that Probabilities add up to 1.0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ab007e8",
   "metadata": {},
   "source": [
    "This function gathers the hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2989295d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_hidden_states(model, tok, prefix):\n",
    "    import re\n",
    "    from baukit import TraceDict\n",
    "    inp = {k: torch.tensor(v)[None].cuda() for k, v in tok(prefix).items()}\n",
    "    layer_names = [n for n, _ in model.named_modules()\n",
    "                   if re.match(r'^transformer.h.\\d+$', n)]\n",
    "    with TraceDict(model, layer_names) as tr:\n",
    "        logits = model(**inp)['logits']\n",
    "    return torch.stack([tr[layername].output[0] for layername in layer_names])\n",
    "\n",
    "prompt = 'Hello, my name is also'\n",
    "hs = get_hidden_states(model, tok, prompt)\n",
    "hs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "53c4aa29",
   "metadata": {},
   "source": [
    "Here is the basic logit lens visualization.  Comments inline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10aff42e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_logit_lens(model, tok, prefix, topk=5, color=None, hs=None, decoder=None):\n",
    "    from baukit import show\n",
    "\n",
    "    # You can pass in a function to compute the hidden states, or just the tensor of hidden states.\n",
    "    if hs is None:\n",
    "        hs = get_hidden_states\n",
    "    if callable(hs):\n",
    "        hs = hs(model, tok, prefix)\n",
    "\n",
    "    # The full decoder head normalizes hidden state and applies softmax at the end.\n",
    "    if decoder is None:\n",
    "        decoder = torch.nn.Sequential(model.transformer.ln_f, model.lm_head, torch.nn.Softmax(dim=-1))\n",
    "\n",
    "    probs = decoder(hs) # Apply the decoder head to every hidden state\n",
    "    favorite_probs, favorite_tokens = probs.topk(k=topk, dim=-1)\n",
    "    # Let's also plot hidden state magnitudes\n",
    "    magnitudes = hs.norm(dim=-1)\n",
    "    # For some reason the 0th token always has huge magnitudes, so normalize based on subsequent token max.\n",
    "    magnitudes = magnitudes / magnitudes[:,:,1:].max()\n",
    "    \n",
    "    # All the input tokens.\n",
    "    prompt_tokens = [tok.decode(t) for t in tok.encode(prefix)]\n",
    "\n",
    "    # Foreground color shows token probability, and background color shows hs magnitude\n",
    "    if color is None:\n",
    "        color = [0, 0, 255]\n",
    "    def color_fn(m, p):\n",
    "        #a = [int(255 * (1-m) + c * m) for c in color]\n",
    "        a = [int(255 * (1-p) + c * p) for c in color]\n",
    "        b = [int(196 * (1-p) + 0 * p)] * 2 + [0]\n",
    "        return show.style(background=f'rgb({a[0]}, {a[1]}, {a[2]})',\n",
    "                          #color=f'rgb({b[0]}, {b[1]}, {b[2]})' )\n",
    "                          color='black' if p < 0.75 else 'white' )\n",
    "\n",
    "    # In the hover popup, show topk probabilities beyond the 0th.\n",
    "    def hover(tok, prob, toks, m):\n",
    "        lines = [f'mag: {m:.2f}']\n",
    "        for p, t in zip(prob, toks):\n",
    "            lines.append(f'{tok.decode(t)}: prob {p:.2f}')\n",
    "        return show.attr(title='\\n'.join(lines))\n",
    "    \n",
    "    # Construct the HTML output using show.\n",
    "    header_line = [ # header line\n",
    "             [[show.style(fontWeight='bold'), 'Layer']] +\n",
    "             [\n",
    "                 [show.style(background='yellow'), show.attr(title=f'Token {i}'), t]\n",
    "                 for i, t in enumerate(prompt_tokens)\n",
    "             ]\n",
    "         ]\n",
    "    layer_logits = [\n",
    "             # first column\n",
    "             [[show.style(fontWeight='bold'), layer]] +\n",
    "             [\n",
    "                 # subsequent columns\n",
    "                 [color_fn(m, p[0]), hover(tok, p, t, m), show.style(overflowX='hide'), tok.decode(t[0])]\n",
    "                 for m, p, t in zip(wordmags, wordprobs, words)\n",
    "             ]\n",
    "        for layer, wordmags, wordprobs, words in\n",
    "                zip(range(len(magnitudes)), magnitudes[:, 0], favorite_probs[:, 0], favorite_tokens[:,0])]\n",
    "    \n",
    "    # If you want to get the html without showing it, use show.html(...)\n",
    "    show(header_line + layer_logits + header_line)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fae7d715",
   "metadata": {},
   "source": [
    "An example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42bbda64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# year_list = [str(y) for y in range(1800, 2023)]\n",
    "\n",
    "f = make_corner_readout(\n",
    "    model, tok,\n",
    "    short_city_list, prenorm_scale=0.75)\n",
    "\n",
    "show_logit_lens(model, tok, 'Brent Cross is a neighborhood in the city of', decoder=f)\n",
    "#show_logit_lens(model, tok, 'Prudential Center is a mall in the city of', decoder=f)\n",
    "show_logit_lens(model, tok, 'South of Houston Street is a neighborhood in the city of', decoder=f)\n",
    "# show_logit_lens(model, tok, 'South of Houston Street is a neighborhood in the city of', decoder=f)\n",
    "#show_logit_lens(model, tok, 'Eisenhower was born in')\n",
    "#show_logit_lens(model, tok, 'Harrison Ford was born in', decoder=f)\n",
    "#show_logit_lens(model, tok, 'Harrison Ford was born in')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
