{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b785f8537c434359b2bd654f2b67d2f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer\n",
      "Moving model to device:  cuda\n"
     ]
    }
   ],
   "source": [
    "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "import numpy as np\n",
    "from skopt import gp_minimize\n",
    "from skopt.space import Real\n",
    "import torch\n",
    "import transformer_lens.utils as utils\n",
    "from functools import partial\n",
    "import torch\n",
    "from botorch.models import SingleTaskGP\n",
    "# from botorch.fit import fit_gpytorch_model\n",
    "from botorch.optim import optimize_acqf\n",
    "from botorch.acquisition import ExpectedImprovement, LogExpectedImprovement, qLogExpectedImprovement\n",
    "from botorch.utils.sampling import draw_sobol_samples\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "import torch\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.models.transforms import Normalize, Standardize\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "# %%\n",
    "def loadTransformerLensModel(modelPath):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(modelPath)\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained(modelPath, low_cpu_mem_usage=True)\n",
    "    model = HookedTransformer.from_pretrained(modelPath, hf_model=hf_model, device='cpu', fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
    "\n",
    "    return model, tokenizer\n",
    "\n",
    "# %%\n",
    "MODEL_PATH = 'meta-llama/Meta-Llama-3-8B'\n",
    "# meta-llama/Llama-2-7b-hf\n",
    "model, tokenizer = loadTransformerLensModel(MODEL_PATH)\n",
    "dtype = torch.float64\n",
    "\n",
    "# %%\n",
    "device = 'cuda' #cuda\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "fileName = '_layer_0_32_angle_-0.39269908169872414_0.39269908169872414'\n",
    "few_shot_prompts = f'prompts/prompts{fileName}.pkl'\n",
    "few_shot_labels = f'labels/labels{fileName}.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = f'checkpoint/botorch_checkpoint_{fileName}.pkl'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = [0,32]\n",
    "H = 32\n",
    "N = 64\n",
    "num_L = L[-1] - L[0]\n",
    "# angles = [0, torch.pi / 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "\n",
    "def load_checkpoint(filepath):\n",
    "    # Load the checkpoint dictionary\n",
    "    with open(filepath, 'rb') as f:\n",
    "        checkpoint = pickle.load(f)\n",
    "    \n",
    "    train_X = checkpoint['train_X']\n",
    "    train_Y = checkpoint['train_Y']\n",
    "    gp_state_dict = checkpoint['gp_state_dict']\n",
    "    iteration = checkpoint['iteration']\n",
    "    \n",
    "    return train_X, train_Y, gp_state_dict, iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:\n",
    "    train = [json.loads(line) for line in f]\n",
    "\n",
    "# %%\n",
    "input = [d['inputs'] for d in train]\n",
    "output = [d['targets'][0] for d in train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_rope_rotation_matrix(angles):\n",
    "    \"\"\"\n",
    "    Create a RoPE (Rotary Position Embedding) rotation matrix as a tensor for a given vector of angles.\n",
    "\n",
    "    Parameters:\n",
    "    angles (torch.Tensor): A 1D PyTorch tensor of rotation angles for each dimension pair.\n",
    "\n",
    "    Returns:\n",
    "    torch.Tensor: A 2D PyTorch tensor representing the RoPE rotation matrix.\n",
    "    \"\"\"\n",
    "    n = angles.shape[0]  # Number of angles\n",
    "    d = 2 * n  # Dimension of the square matrix (embedding dimension must be even)\n",
    "    rope_matrix = torch.zeros((d, d))\n",
    "\n",
    "    # Fill the rotation matrix with 2x2 rotation matrices for each dimension pair\n",
    "    for i, angle in enumerate(angles):\n",
    "        cos_angle = torch.cos(angle)\n",
    "        sin_angle = torch.sin(angle)\n",
    "        # Set the 2x2 rotation submatrix\n",
    "        rope_matrix[2 * i, 2 * i] = cos_angle\n",
    "        rope_matrix[2 * i, 2 * i + 1] = -sin_angle\n",
    "        rope_matrix[2 * i + 1, 2 * i] = sin_angle\n",
    "        rope_matrix[2 * i + 1, 2 * i + 1] = cos_angle\n",
    "\n",
    "    return rope_matrix\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "nShots = 100\n",
    "# nShots_example = [(i, o) for i, o in zip(input[:nShots], output[:nShots])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import random\n",
    "# random.shuffle(nShots_example) \n",
    "import pickle\n",
    "with open(few_shot_prompts, 'rb') as f:\n",
    "    nShots_example = pickle.load(f)\n",
    "with open(few_shot_labels, 'rb') as f:\n",
    "    nShots_labels = pickle.load(f)\n",
    "Template = \"\"\"\"\"\"\n",
    "for i in range(6):\n",
    "    Template += f\"{nShots_example[i]}{nShots_labels[i]}\\n\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "74 * 3 -> 222\n",
      "78 * 5 -> 390\n",
      "93 * 50 -> 4650\n",
      "61 * 56 -> 3416\n",
      "65 * 78 -> 5070\n",
      "74 * 7 ->\n",
      "output: 518\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "960 + 208 -> 1168\n",
      "762 + 896 -> 1658\n",
      "116 + 627 -> 743\n",
      "748 + 477 -> 1225\n",
      "104 + 494 -> 598\n",
      "770 + 45 ->\n",
      "output: 815\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "98 * 13 -> 1274\n",
      "21 * 50 -> 1050\n",
      "22 * 47 -> 1034\n",
      "70 * 49 -> 3430\n",
      "88 * 12 -> 1056\n",
      "84 * 9 ->\n",
      "output: 756\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "503 - 685 -> -181\n",
      "386 - 969 -> -582\n",
      "869 - 832 -> 38\n",
      "887 - 602 -> 286\n",
      "749 - 775 -> -25\n",
      "690 - 493 ->\n",
      "output: 198\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "60 * 70 -> 4201\n",
      "45 * 31 -> 1396\n",
      "14 * 43 -> 603\n",
      "16 * 40 -> 641\n",
      "53 * 43 -> 2280\n",
      "97 * 4 ->\n",
      "output: 389\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "710 - 142 -> 569\n",
      "916 - 914 -> 3\n",
      "827 - 604 -> 224\n",
      "472 - 416 -> 57\n",
      "304 - 735 -> -430\n",
      "496 - 290 ->\n",
      "output: 207\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(Template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_of_examples = 200\n",
    "prompts = []\n",
    "few_shot_prompts = []\n",
    "labels = []\n",
    "for i in range(nShots, len(input)):\n",
    "    \n",
    "    if(len(tokenizer.encode(output[i], return_tensors='pt', add_special_tokens=False)[0]) == 1):\n",
    "        prompts.append(f\"Input: {input[i]}\\noutput: \")\n",
    "        few_shot_prompts.append(Template + f\"Input: {input[i]}\\noutput: \")\n",
    "        labels.append(output[i])\n",
    "        # if(len(prompts) == number_of_examples):\n",
    "            # break\n",
    "    if(len(prompts) == number_of_examples):\n",
    "        break\n",
    "    # print(prompt)\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runDefaultModel(model, tokenizer, prompt):\n",
    "    # print(prompt)\n",
    "    input_id = tokenizer.encode(prompt, return_tensors='pt').to(device)\n",
    "    model.reset_hooks()\n",
    "    output = model(input_id)\n",
    "    predicted_token = torch.argmax(output[:, -1, :], dim=1)[0]\n",
    "    # breakpoint()\n",
    "    return str(tokenizer.decode(predicted_token, skip_special_tokens=True)), predicted_token.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getOriginalAccuracy(model, tokenizer, prompts, labels):\n",
    "    correct = 0\n",
    "    from tqdm import tqdm\n",
    "    pbar = tqdm(total=len(prompts))\n",
    "    answer_tokens = []\n",
    "    prob = 0\n",
    "    count = 0\n",
    "    for prompt, label in tqdm(zip(prompts, labels), desc=\"Original Accuracy\", total=len(prompts)):\n",
    "        output, answer_token = runDefaultModel(model, tokenizer, prompt)\n",
    "        if output == label:\n",
    "            correct += 1\n",
    "            answer_tokens.append(answer_token)\n",
    "        else:\n",
    "            answer_tokens.append(tokenizer.encode(label, return_tensors='pt', add_special_tokens=False)[0].item())\n",
    "        count += 1\n",
    "        \n",
    "        pbar.set_description(f\"Original Accuracy: {correct / count}\")\n",
    "        pbar.update(1)\n",
    "    return correct / len(prompts), answer_tokens\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jaxtyping import Float, Int\n",
    "\n",
    "def rotateMatrix(\n",
    "        clean_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
    "        hook,\n",
    "        head_index,\n",
    "        rotaryMatrix):\n",
    "    # assert(clean_head_vector[:, :, head_index, :] == clean_head_vector[:, :, head_index, :] @ rotaryMatrix).all()\n",
    "    # breakpoint()\n",
    "    clean_head_vector[:, :, head_index, :] = clean_head_vector[:, :, head_index, :] @ rotaryMatrix\n",
    "    return clean_head_vector\n",
    "# encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')\n",
    "# encoded_prompt = encoded_prompt.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runRotatedModel(model, tokenizer, prompt, D, answer_token):\n",
    "    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')\n",
    "    encoded_prompt = encoded_prompt.to(device)\n",
    "    \n",
    "    # cache = runCache(model, encoded_prompt)\n",
    "    \n",
    "    \n",
    "    list_fwd_hooks = []\n",
    "    for layer in range(L[0], L[-1]):\n",
    "        for head in range(H):\n",
    "            rotaryMatrix = create_rope_rotation_matrix(D[layer - L[0], head]).to(device)\n",
    "            list_fwd_hooks.append((utils.get_act_name(\"z\", layer, \"attn\"), partial(rotateMatrix, head_index=head, rotaryMatrix=rotaryMatrix)))\n",
    "    \n",
    "    rotated_logits = model.run_with_hooks(encoded_prompt, return_type=\"logits\", fwd_hooks=list_fwd_hooks)\n",
    "    \n",
    "    predicted_token = torch.argmax(rotated_logits[:, -1, :], dim=1)[0]\n",
    "    answer_token_prob = torch.nn.functional.softmax(rotated_logits[:, -1, :], dim=1)[0, answer_token].item()\n",
    "    token = tokenizer.decode(predicted_token, skip_special_tokens=True)\n",
    "    return token, answer_token_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "\u001b[A\n",
      "Original Accuracy: 100%|██████████| 200/200 [00:26<00:00,  7.48it/s]\n",
      "Original Accuracy: 0.215: 100%|██████████| 200/200 [00:26<00:00,  7.48it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Normal Accuracy: 0.215\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "normalAccuracy, answer_tokens = getOriginalAccuracy(model, tokenizer, prompts, labels)\n",
    "print(f\"Normal Accuracy: {normalAccuracy}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "74 * 3 -> 222\n",
      "78 * 5 -> 390\n",
      "93 * 50 -> 4650\n",
      "61 * 56 -> 3416\n",
      "65 * 78 -> 5070\n",
      "74 * 7 ->\n",
      "output: 518\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "960 + 208 -> 1168\n",
      "762 + 896 -> 1658\n",
      "116 + 627 -> 743\n",
      "748 + 477 -> 1225\n",
      "104 + 494 -> 598\n",
      "770 + 45 ->\n",
      "output: 815\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "98 * 13 -> 1274\n",
      "21 * 50 -> 1050\n",
      "22 * 47 -> 1034\n",
      "70 * 49 -> 3430\n",
      "88 * 12 -> 1056\n",
      "84 * 9 ->\n",
      "output: 756\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "503 - 685 -> -181\n",
      "386 - 969 -> -582\n",
      "869 - 832 -> 38\n",
      "887 - 602 -> 286\n",
      "749 - 775 -> -25\n",
      "690 - 493 ->\n",
      "output: 198\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "60 * 70 -> 4201\n",
      "45 * 31 -> 1396\n",
      "14 * 43 -> 603\n",
      "16 * 40 -> 641\n",
      "53 * 43 -> 2280\n",
      "97 * 4 ->\n",
      "output: 389\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "710 - 142 -> 569\n",
      "916 - 914 -> 3\n",
      "827 - 604 -> 224\n",
      "472 - 416 -> 57\n",
      "304 - 735 -> -430\n",
      "496 - 290 ->\n",
      "output: 207\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "93 * 35 -> 3255\n",
      "17 * 89 -> 1513\n",
      "52 * 70 -> 3640\n",
      "0 * 57 -> 0\n",
      "89 * 86 -> 7654\n",
      "89 * 5 ->\n",
      "output: \n"
     ]
    }
   ],
   "source": [
    "print(few_shot_prompts[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(params, prompt_list):\n",
    "    D = params.view(num_L, H, N)  # Ensure the tensor is reshaped correctly\n",
    "    accuracy = 0\n",
    "    count = 0\n",
    "    prob = 0\n",
    "    from tqdm import tqdm\n",
    "    pbar = tqdm(total=len(prompts))\n",
    "    for prompt, label in tqdm(zip(prompt_list, labels), desc=\"Rotated Accuracy\", total=len(prompts)):\n",
    "        predicted_output, answer_token_prob = runRotatedModel(model, tokenizer, prompt, D, 12)\n",
    "        # print(predicted_output, label)\n",
    "        if predicted_output == label:\n",
    "            accuracy += 1 \n",
    "        count += 1\n",
    "        prob += answer_token_prob\n",
    "        pbar.set_description(f\"Rotated Accuracy: {accuracy / count}\")\n",
    "        pbar.update(1)\n",
    "\n",
    "    avg_prob = prob / count\n",
    "    print(f\"Accuracy: {accuracy / count}\")\n",
    "    # print(f\"Answer token prob: {avg_prob}\")\n",
    "    return accuracy / count, avg_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "110\n"
     ]
    }
   ],
   "source": [
    "train_X, train_Y, _, iteration = load_checkpoint(checkpoint_path)\n",
    "print(iteration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_array = [(x, y) for x, y in zip(train_X, train_Y)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_array = sorted(combined_array, key=lambda x: x[1], reverse=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15339"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode(\"hello\", return_tensors='pt', add_special_tokens=False)[0].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = sorted_array[0][0].view(num_L, H, N)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([32, 32, 64])"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "D.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.0076)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "create_rope_rotation_matrix(D[0,0]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 1 accuracy: tensor([-0.4074], dtype=torch.float64)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rotated Accuracy: 100%|██████████| 200/200 [08:52<00:00,  2.66s/it]65s/it]             \n",
      "Rotated Accuracy: 0.755: 100%|██████████| 200/200 [08:52<00:00,  2.66s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.755\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rotated Accuracy: 100%|██████████| 200/200 [07:47<00:00,  2.34s/it]2s/it]               \n",
      "Rotated Accuracy: 0.45: 100%|██████████| 200/200 [07:47<00:00,  2.34s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.45\n",
      "Top 2 accuracy: tensor([-0.3536], dtype=torch.float64)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rotated Accuracy: 100%|██████████| 200/200 [08:56<00:00,  2.68s/it]6s/it]              \n",
      "Rotated Accuracy: 0.63: 100%|██████████| 200/200 [08:56<00:00,  2.68s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.63\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rotated Accuracy:  14%|█▍        | 29/200 [01:07<06:38,  2.33s/it]07<06:35,  2.31s/it] \n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[48], line 5\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTop \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msorted_array[i][\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m      4\u001b[0m accuracy, prob \u001b[38;5;241m=\u001b[39m objective(D, few_shot_prompts)\n\u001b[0;32m----> 5\u001b[0m accuracy, prob \u001b[38;5;241m=\u001b[39m \u001b[43mobjective\u001b[49m\u001b[43m(\u001b[49m\u001b[43mD\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompts\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[40], line 9\u001b[0m, in \u001b[0;36mobjective\u001b[0;34m(params, prompt_list)\u001b[0m\n\u001b[1;32m      7\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(prompts))\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m prompt, label \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mzip\u001b[39m(prompt_list, labels), desc\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRotated Accuracy\u001b[39m\u001b[38;5;124m\"\u001b[39m, total\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(prompts)):\n\u001b[0;32m----> 9\u001b[0m     predicted_output, answer_token_prob \u001b[38;5;241m=\u001b[39m \u001b[43mrunRotatedModel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprompt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mD\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m12\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;66;03m# print(predicted_output, label)\u001b[39;00m\n\u001b[1;32m     11\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m predicted_output \u001b[38;5;241m==\u001b[39m label:\n",
      "Cell \u001b[0;32mIn[36], line 11\u001b[0m, in \u001b[0;36mrunRotatedModel\u001b[0;34m(model, tokenizer, prompt, D, answer_token)\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m layer \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(L[\u001b[38;5;241m0\u001b[39m], L[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]):\n\u001b[1;32m     10\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m head \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(H):\n\u001b[0;32m---> 11\u001b[0m         rotaryMatrix \u001b[38;5;241m=\u001b[39m \u001b[43mcreate_rope_rotation_matrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43mD\u001b[49m\u001b[43m[\u001b[49m\u001b[43mlayer\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mL\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mhead\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mto(device)\n\u001b[1;32m     12\u001b[0m         list_fwd_hooks\u001b[38;5;241m.\u001b[39mappend((utils\u001b[38;5;241m.\u001b[39mget_act_name(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mz\u001b[39m\u001b[38;5;124m\"\u001b[39m, layer, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mattn\u001b[39m\u001b[38;5;124m\"\u001b[39m), partial(rotateMatrix, head_index\u001b[38;5;241m=\u001b[39mhead, rotaryMatrix\u001b[38;5;241m=\u001b[39mrotaryMatrix)))\n\u001b[1;32m     14\u001b[0m rotated_logits \u001b[38;5;241m=\u001b[39m model\u001b[38;5;241m.\u001b[39mrun_with_hooks(encoded_prompt, return_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlogits\u001b[39m\u001b[38;5;124m\"\u001b[39m, fwd_hooks\u001b[38;5;241m=\u001b[39mlist_fwd_hooks)\n",
      "Cell \u001b[0;32mIn[30], line 23\u001b[0m, in \u001b[0;36mcreate_rope_rotation_matrix\u001b[0;34m(angles)\u001b[0m\n\u001b[1;32m     21\u001b[0m     rope_matrix[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i, \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m-\u001b[39msin_angle\n\u001b[1;32m     22\u001b[0m     rope_matrix[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i] \u001b[38;5;241m=\u001b[39m sin_angle\n\u001b[0;32m---> 23\u001b[0m     rope_matrix[\u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m i \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m] \u001b[38;5;241m=\u001b[39m cos_angle\n\u001b[1;32m     25\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m rope_matrix\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    D = sorted_array[i][0]\n",
    "    print(f\"Top {i + 1} accuracy: {sorted_array[i][1]}\")\n",
    "    accuracy, prob = objective(D, few_shot_prompts)\n",
    "    accuracy, prob = objective(D, prompts)\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angles = [0, 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds = torch.tensor([[angles[0]] * (num_L * H * N), [angles[1]] * (num_L * H * N)], dtype=dtype).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_points = (torch.rand((10, bounds.size(1)), device=device) * (bounds[1] - bounds[0]) + bounds[0]).to(bounds.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_points[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:\n",
    "    train = [json.loads(line) for line in f]\n",
    "import random\n",
    "random.shuffle(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/validation.jsonl', 'r') as f:\n",
    "    valid = [json.loads(line) for line in f]\n",
    "import random\n",
    "random.shuffle(valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/train.jsonl', 'w') as f:\n",
    "    for line in train:\n",
    "        f.write(json.dumps(line) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/validation.jsonl', 'w') as f:\n",
    "    for line in valid:\n",
    "        f.write(json.dumps(line) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ALTI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
