{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:\n",
    "    train = [json.loads(line) for line in f]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "input = [d['inputs'] for d in train]\n",
    "output = [d['targets'][0] for d in train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "import numpy as np\n",
    "from skopt import gp_minimize\n",
    "from skopt.space import Real\n",
    "import torch\n",
    "import transformer_lens.utils as utils\n",
    "from functools import partial\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadTransformerLensModel(modelPath):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(modelPath)\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained(modelPath, low_cpu_mem_usage=True)\n",
    "    model = HookedTransformer.from_pretrained(modelPath, hf_model=hf_model, device='cpu', fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
    "\n",
    "    return model, tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "74a4dd99e84e41b188472e1629459e27",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "MODEL_PATH = 'meta-llama/Meta-Llama-3-8B'\n",
    "model, tokenizer = loadTransformerLensModel(MODEL_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Moving model to device:  cuda\n"
     ]
    }
   ],
   "source": [
    "device = 'cuda' #cuda\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-4.3711e-08, -1.0000e+00],\n",
      "        [ 1.0000e+00, -4.3711e-08]])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "def create_rope_rotation_matrix(angles):\n",
    "    \"\"\"\n",
    "    Create a RoPE (Rotary Position Embedding) rotation matrix as a tensor for a given vector of angles.\n",
    "\n",
    "    Parameters:\n",
    "    angles (torch.Tensor): A 1D PyTorch tensor of rotation angles for each dimension pair.\n",
    "\n",
    "    Returns:\n",
    "    torch.Tensor: A 2D PyTorch tensor representing the RoPE rotation matrix.\n",
    "    \"\"\"\n",
    "    d = angles.shape[0]  # Dimension of the square matrix (should be even)\n",
    "    \n",
    "    # Ensure the embedding dimension is even\n",
    "    if d % 2 != 0:\n",
    "        raise ValueError(\"The length of the angles tensor should be even.\")\n",
    "    \n",
    "    rope_matrix = torch.zeros((d, d))\n",
    "\n",
    "    # Fill the rotation matrix with 2x2 rotation matrices for each dimension pair\n",
    "    for i in range(0, d, 2):\n",
    "        angle = angles[i // 2]\n",
    "        cos_angle = torch.cos(angle)\n",
    "        sin_angle = torch.sin(angle)\n",
    "        # Set the 2x2 rotation submatrix\n",
    "        rope_matrix[i, i] = cos_angle\n",
    "        rope_matrix[i, i + 1] = -sin_angle\n",
    "        rope_matrix[i + 1, i] = sin_angle\n",
    "        rope_matrix[i + 1, i + 1] = cos_angle\n",
    "\n",
    "    return rope_matrix\n",
    "# Example usage:\n",
    "angles = torch.tensor([torch.pi/2, torch.pi/2])  # Example angles for embedding of size 6 (3 pairs)\n",
    "rope_matrix = create_rope_rotation_matrix(angles)\n",
    "print(rope_matrix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = 32\n",
    "H = 32\n",
    "N = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 65536/65536 [00:28<00:00, 2295.75it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "search_space = []\n",
    "for i in tqdm(range(L * H * N)):\n",
    "    search_space.append(Real(-np.pi, np.pi))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "nShots = 10\n",
    "nShots_example = [(i, o) for i, o in zip(input[:nShots], output[:nShots])]\n",
    "import random\n",
    "random.shuffle(nShots_example) \n",
    "Template = \"\"\"\"\"\"\n",
    "for i in range(4):\n",
    "    Template += f\"Input: {nShots_example[i][0]}\\noutput: {nShots_example[i][1]}\\n\\n\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "583 + 721 -> 1304\n",
      "731 + 591 -> 1322\n",
      "499 + 594 -> 1093\n",
      "553 + 272 -> 825\n",
      "702 + 172 -> 874\n",
      "972 + 6 ->\n",
      "output: 978\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "297 + 434 -> 731\n",
      "695 + 690 -> 1385\n",
      "197 + 127 -> 324\n",
      "531 + 983 -> 1514\n",
      "410 + 933 -> 1343\n",
      "839 + 64 ->\n",
      "output: 903\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "253 + 226 -> 479\n",
      "111 + 509 -> 620\n",
      "472 + 98 -> 570\n",
      "152 + 860 -> 1012\n",
      "913 + 895 -> 1808\n",
      "877 + 337 ->\n",
      "output: 1214\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "605 + 661 -> 1266\n",
      "757 + 288 -> 1045\n",
      "413 + 720 -> 1133\n",
      "637 + 234 -> 871\n",
      "415 + 937 -> 1352\n",
      "498 + 973 ->\n",
      "output: 1471\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(Template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_of_examples = 100\n",
    "prompts = []\n",
    "labels = []\n",
    "for i in range(nShots, nShots + number_of_examples):\n",
    "    prompts.append(Template + f\"Input: {input[i]}\\noutput: \")\n",
    "    labels.append(output[i])\n",
    "    # print(prompt)\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = prompts[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runCache(model, input_id):\n",
    "    cache = {}\n",
    "    model.reset_hooks()\n",
    "    fwd_hooks_list = []\n",
    "    cache = {}\n",
    "    def storeHookCache(value, hook):\n",
    "        cache[hook.name] = torch.from_numpy(value.detach().cpu().numpy())\n",
    "    for layer in range(L):\n",
    "        fwd_hooks_list.append((utils.get_act_name(\"k\", layer, \"attn\"), storeHookCache))\n",
    "    \n",
    "    model.run_with_hooks(input_id, return_type=None, fwd_hooks=fwd_hooks_list)\n",
    "    \n",
    "    return cache\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runDefaultModel(model, tokenizer, prompt):\n",
    "    input_id = tokenizer.encode(prompt, return_tensors='pt').to(device)\n",
    "    model.reset_hooks()\n",
    "    output = model(input_id)\n",
    "    predicted_token = torch.argmax(output[:, -1, :], dim=1)[0]\n",
    "        \n",
    "    return str(tokenizer.decode(predicted_token, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jaxtyping import Float, Int\n",
    "\n",
    "def rotateMatrix(\n",
    "        clean_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
    "        hook,\n",
    "        head_index,\n",
    "        dimension_array):\n",
    "    clean_head_vector[:, :, head_index, :] = create_rope_rotation_matrix(dimension_array) @ clean_head_vector[:, :, head_index, :]\n",
    "    return clean_head_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runRotatedModel(model, tokenizer, prompt, D):\n",
    "    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')\n",
    "    encoded_prompt = encoded_prompt.to(device)\n",
    "    \n",
    "    # cache = runCache(model, encoded_prompt)\n",
    "    \n",
    "    list_fwd_hooks = []\n",
    "    for layer in range(L):\n",
    "        for head in range(H):\n",
    "            list_fwd_hooks.append((utils.get_act_name(\"z\", layer, \"attn\"), partial(rotateMatrix, head_index=head, dimension_array=D[layer, head])))\n",
    "    \n",
    "    rotated_logits = model.run_with_hooks(encoded_prompt, return_type=\"logits\", fwd_hooks=list_fwd_hooks)\n",
    "    \n",
    "    predicted_token = torch.argmax(rotated_logits[:, -1, :], dim=0)\n",
    "    token = tokenizer.decode(predicted_token, skip_special_tokens=True)\n",
    "    return token\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getOriginalAccuracy(model, tokenizer, prompts, labels):\n",
    "    correct = 0\n",
    "    from tqdm import tqdm\n",
    "    for prompt, label in tqdm(zip(prompts, labels), desc=\"Original Accuracy\", total=len(prompts)):\n",
    "        output = runDefaultModel(model, tokenizer, prompt)\n",
    "        if output == label:\n",
    "            correct += 1\n",
    "    return correct / len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(params):\n",
    "    # params is a flattened version of the 2D matrix with each entry being an n-length array\n",
    "    D = np.array(params).reshape((L, H, N))  # Reshape into 2D matrix of n-length arrays\n",
    "    \n",
    "    accuracy = []\n",
    "    from tqdm import tqdm\n",
    "    for prompt, label in tqdm(zip(prompts, labels), desc=\"Rotated Accuracy\", total=len(prompts)):\n",
    "        predicted_output = runRotatedModel(model, tokenizer, prompt, D)\n",
    "        if str(predicted_output).strip() == str(label).strip():\n",
    "            accuracy.append(1)\n",
    "        else:\n",
    "            accuracy.append(0)\n",
    "    \n",
    "    # Example of a mock objective function\n",
    "    # Replace with actual model evaluation code\n",
    "    # score = -np.sum(D) + np.random.normal(0, 1)  # Example score to minimize\n",
    "    return -1 * (sum(accuracy) / len(accuracy))  # Return the score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Original Accuracy: 100%|██████████| 100/100 [00:37<00:00,  2.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Normal Accuracy: 0.54\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "normalAccuracy = getOriginalAccuracy(model, tokenizer, prompts, labels)\n",
    "print(f\"Normal Accuracy: {normalAccuracy}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7896cf1f8f40>>\n",
      "Traceback (most recent call last):\n",
      "  File \"miniconda3/envs/ALTI/lib/python3.10/site-packages/ipykernel/ipkernel.py\", line 775, in _clean_thread_parent_frames\n",
      "    def _clean_thread_parent_frames(\n",
      "KeyboardInterrupt: \n"
     ]
    }
   ],
   "source": [
    "result = gp_minimize(\n",
    "    func=objective,          # The objective function\n",
    "    dimensions=search_space, # The search space for the hyperparameters\n",
    "    n_calls=50,              # Number of function evaluations\n",
    "    n_initial_points=10,     # Number of random initialization points\n",
    "    acq_func='EI',           # Acquisition function 'Expected Improvement'\n",
    "    verbose=True             # Verbose mode for detailed output\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_params = np.array(result.x).reshape((L, H, N))\n",
    "\n",
    "import pickle\n",
    "with open('optimized_params.pkl', 'wb') as f:\n",
    "    pickle.dump(optimized_params, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ALTI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
