{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "95db6b99-da75-48a5-bca7-fe31da604f7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from torch import optim, nn, Tensor\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "import wandb\n",
    "from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer\n",
    "import transformers\n",
    "import lightning as L\n",
    "from inspect import signature, _ParameterKind\n",
    "import copy\n",
    "import gc\n",
    "import datasets\n",
    "from torch.utils.data import DataLoader\n",
    "from matplotlib import pyplot as plt\n",
    "import pandas as pd\n",
    "from country_list import countries_for_language"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "9e24a301-b965-4d2f-a2d0-0544b0e518f6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1393843b157f45ec818332a963308b99",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e05e2f1f2914203bc096af62ed06738",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a120dd460d37415e927966d7faa118e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6aad144a54bc48eea0a0ae43ee7cff8a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc35b5e79ee940cc9379df7a7d785315",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9b39202ba26429c99c5b975761a1d6e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "64908c94284e4ab8898f85cb0e611340",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ed8ec0a32b914128baa5ff97da766dd6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3744e62a517244988ecd02805cdcf919",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0302c8f11324d55881f2d4575f20404",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc3eb5029f2549039b7ff279e9a86f73",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model_name = 'mistralai/Mistral-7B-v0.1'\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "Token = {v: k for k, v in tokenizer.get_vocab().items()}\n",
    "model = AutoModelForCausalLM.from_pretrained(model_name).to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d7684779-3e82-4a11-a82b-2981fb65534a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk(v, k=10):\n",
    "    # Takes in logits\n",
    "    #v = softmax(v.flatten())\n",
    "    if isinstance(v, torch.Tensor):\n",
    "        v = v.detach().cpu().numpy()\n",
    "    v = v.flatten()\n",
    "    idxs = v.argsort()[-k:][::-1]\n",
    "    ret = [(Token[i], v[i]) for i in idxs]\n",
    "    return pd.DataFrame(ret, columns=['token', 'logit'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "5bbabc45-7d2b-47b8-82ed-08b4cc3052d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_tokens(s):\n",
    "    tokens = tokenizer(s).input_ids\n",
    "    print('|'.join(Token[x] for x in tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "805ea47e-9432-4794-8212-b5fab8ceac22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch.autograd.grad_mode.set_grad_enabled at 0x7fe3adf48220>"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.set_grad_enabled(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "13ea3d1d-efb7-49ae-86cb-b1a3784102cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "countries = countries_for_language('en')\n",
    "countries = [c for _, c in countries if len(tokenizer(c)['input_ids']) == 2]\n",
    "countries = [c for c in countries if c != 'Singapore'] # kind of an edge case..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "id": "6384a2b5-4eeb-46f6-8432-c25b2d688031",
   "metadata": {},
   "outputs": [],
   "source": [
    "template = 'The nation of {}, whose capital city is'\n",
    "input = tokenizer([template.format(c) for c in countries], return_tensors='pt').to('cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "c4685e9d-c8dd-4f72-9b62-9e80928b9cce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<s>|▁The|▁nation|▁of|▁France|,|▁whose|▁capital|▁city|▁is\n"
     ]
    }
   ],
   "source": [
    "print_tokens('The nation of France, whose capital city is')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "6dda508d-6a08-41e3-a57d-916d4955ff10",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(**input, output_hidden_states=True, output_attentions=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "8540df96-8593-41ab-ba74-05117a5f0937",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([51, 32, 10, 10])"
      ]
     },
     "execution_count": 102,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.attentions[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "2d678007-350d-4d19-8ba1-07763013ca32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.7842e-01, 2.1582e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.7448e-01, 2.2072e-02, 3.4500e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.4036e-01, 1.5855e-02, 2.3344e-03, 4.1448e-02, 0.0000e+00, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.6279e-01, 1.2634e-02, 1.9308e-03, 1.7299e-02, 5.3518e-03, 0.0000e+00,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.0677e-01, 8.3474e-03, 8.1267e-04, 1.2414e-02, 3.7718e-03, 6.7887e-02,\n",
       "         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.2191e-01, 5.7350e-03, 4.3037e-04, 5.1827e-03, 2.5756e-03, 4.7140e-02,\n",
       "         1.7025e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n",
       "        [9.5568e-01, 9.2763e-03, 9.4642e-04, 7.0913e-03, 1.4423e-03, 2.0458e-02,\n",
       "         3.3106e-03, 1.7978e-03, 0.0000e+00, 0.0000e+00],\n",
       "        [9.5250e-01, 9.6856e-03, 8.6801e-04, 5.4756e-03, 1.2053e-03, 1.9831e-02,\n",
       "         3.8943e-03, 2.4613e-03, 4.0811e-03, 0.0000e+00],\n",
       "        [9.4406e-01, 9.6041e-03, 1.0153e-03, 6.3139e-03, 8.6331e-04, 1.3003e-02,\n",
       "         1.9344e-03, 1.5433e-03, 5.1147e-03, 1.6545e-02]], device='cuda:0')"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.attentions[0].mean(dim=0)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb8fe19b-20ff-4397-9b6d-99aaf5daebcb",
   "metadata": {},
   "outputs": [],
   "source": [
    "ticks = ['The', 'nation', 'of', 'C', ',', 'whose', 'capital', 'city', 'is']\n",
    "for i in range(32):\n",
    "    for j in range(32):\n",
    "        print(i, j)\n",
    "        plt.imshow(out.attentions[i].mean(dim=0)[j][1:,1:].detach().cpu().numpy())\n",
    "        plt.xticks(range(9), ticks)\n",
    "        plt.yticks(range(9), ticks)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "7382c37c-c772-4b6b-a090-eace013bf3fe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([51, 10, 4096])"
      ]
     },
     "execution_count": 113,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.hidden_states[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 125,
   "id": "1d216f34-ed7b-4ead-889f-f77b4dd861aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.000 "
     ]
    }
   ],
   "source": [
    "ret = 0\n",
    "print(f'{0:.3f}', end=' ')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "id": "6c753005-1b30-4957-96aa-025850520ad8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
      "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
      "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
      "0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n",
      "1.6 1.8 1.8 2.0 2.1 2.0 2.0 2.0 2.1 2.1 2.2 2.1 2.1 2.1 2.1 2.2 2.2 2.1 2.0 1.9 2.0 2.0 2.0 2.0 2.0 2.0 2.1 2.1 2.1 2.1 2.1 2.2 \n",
      "0.0 2.5 2.4 2.7 3.5 3.2 2.7 2.8 2.9 2.9 3.0 3.1 2.9 3.0 3.1 3.1 2.9 2.9 2.9 2.6 2.4 2.3 2.3 2.3 2.3 2.3 2.3 2.3 2.4 2.4 2.5 2.6 \n",
      "0.0 3.2 2.4 2.7 2.8 2.9 2.8 2.9 3.0 3.0 3.0 3.0 2.9 2.9 3.0 2.9 2.8 2.7 2.7 2.5 2.4 2.4 2.3 2.3 2.2 2.2 2.2 2.2 2.2 2.3 2.3 2.4 \n",
      "0.0 3.9 3.7 3.0 3.2 3.1 2.3 2.2 2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9 2.9 3.0 3.3 3.2 2.7 2.7 2.6 2.5 2.4 2.1 2.1 2.1 2.1 2.1 2.1 2.2 \n",
      "0.0 4.3 4.0 3.0 3.0 3.2 2.7 2.7 2.8 2.9 2.9 2.9 2.9 2.9 2.9 3.1 2.9 3.1 3.3 3.2 2.7 2.6 2.5 2.5 2.5 2.2 2.2 2.1 2.1 2.1 2.1 2.2 \n",
      "0.0 4.2 3.6 3.3 3.1 3.2 3.1 3.2 3.2 3.3 3.3 3.3 3.2 3.2 3.1 3.2 2.8 3.0 3.3 2.8 2.4 2.3 2.2 2.2 2.1 1.9 2.0 1.9 1.9 1.8 1.8 1.8 \n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    for j in range(32):\n",
    "        h = out.hidden_states[j][:,i,:]\n",
    "        h -= h.mean(dim=0)\n",
    "        norm = torch.linalg.norm(h, ord=2).item()\n",
    "        avg_norm = torch.linalg.norm(h, dim=1, ord=2).mean().item()\n",
    "        ret = 0 if norm == 0 else norm / avg_norm\n",
    "        #print(ret, end=' ')\n",
    "        print(f'{ret:.1f}', end=' ')\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "43f29d44-5a89-4d44-a0ee-80c4bbb45768",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "33"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(out.hidden_states)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "34d50394-72df-47ff-9421-195ac244079b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([51, 4096])"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.hidden_states[0][:,0,:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb84a8dd-e873-4de0-8476-69bea2ceb807",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
