{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1833b961d5fb4b1db3bcd5c1b455f9cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model meta-llama/Meta-Llama-3-8B into HookedTransformer\n",
      "Moving model to device:  cuda\n"
     ]
    }
   ],
   "source": [
    "from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache\n",
    "from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer\n",
    "import numpy as np\n",
    "from skopt import gp_minimize\n",
    "from skopt.space import Real\n",
    "import torch\n",
    "import transformer_lens.utils as utils\n",
    "from functools import partial\n",
    "import torch\n",
    "from botorch.models import SingleTaskGP\n",
    "# from botorch.fit import fit_gpytorch_model\n",
    "from botorch.optim import optimize_acqf\n",
    "from botorch.acquisition import ExpectedImprovement, LogExpectedImprovement, qLogExpectedImprovement\n",
    "from botorch.utils.sampling import draw_sobol_samples\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "import torch\n",
    "from botorch.models import SingleTaskGP\n",
    "from botorch.models.transforms import Normalize, Standardize\n",
    "from botorch.fit import fit_gpytorch_mll\n",
    "from gpytorch.mlls import ExactMarginalLogLikelihood\n",
    "# %%\n",
    "def loadTransformerLensModel(modelPath):\n",
    "    tokenizer = AutoTokenizer.from_pretrained(modelPath)\n",
    "    hf_model = AutoModelForCausalLM.from_pretrained(modelPath, low_cpu_mem_usage=True)\n",
    "    model = HookedTransformer.from_pretrained(modelPath, hf_model=hf_model, device='cpu', fold_ln=False, center_writing_weights=False, center_unembed=False, tokenizer=tokenizer)\n",
    "\n",
    "    return model, tokenizer\n",
    "\n",
    "# %%\n",
    "MODEL_PATH = 'meta-llama/Meta-Llama-3-8B'\n",
    "# meta-llama/Llama-2-7b-hf\n",
    "model, tokenizer = loadTransformerLensModel(MODEL_PATH)\n",
    "dtype = torch.float64\n",
    "\n",
    "# %%\n",
    "device = 'cuda' #cuda\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "fileName = '_scale_rotate_layer_16_32_angle_-0.2617993877991494_0.2617993877991494_scale_0.5_2'\n",
    "few_shot_prompts = f'prompts/prompts{fileName}.pkl'\n",
    "few_shot_labels = f'labels/labels{fileName}.pkl'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = f'checkpoint/botorch_checkpoint_{fileName}.pkl'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = [16,32]\n",
    "H = 32\n",
    "N = 64\n",
    "num_L = L[-1] - L[0]\n",
    "# angles = [0, torch.pi / 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "\n",
    "def load_checkpoint(filepath):\n",
    "    # Load the checkpoint dictionary\n",
    "    with open(filepath, 'rb') as f:\n",
    "        checkpoint = pickle.load(f)\n",
    "    \n",
    "    train_X = checkpoint['train_X']\n",
    "    train_Y = checkpoint['train_Y']\n",
    "    gp_state_dict = checkpoint['gp_state_dict']\n",
    "    iteration = checkpoint['iteration']\n",
    "    \n",
    "    return train_X, train_Y, gp_state_dict, iteration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:\n",
    "    train = [json.loads(line) for line in f]\n",
    "\n",
    "# %%\n",
    "input = [d['inputs'] for d in train]\n",
    "output = [d['targets'][0] for d in train]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_rope_rotation_matrix(angles):\n",
    "    \"\"\"\n",
    "    Create a RoPE (Rotary Position Embedding) rotation matrix as a tensor for a given vector of angles.\n",
    "\n",
    "    Parameters:\n",
    "    angles (torch.Tensor): A 1D PyTorch tensor of rotation angles for each dimension pair.\n",
    "\n",
    "    Returns:\n",
    "    torch.Tensor: A 2D PyTorch tensor representing the RoPE rotation matrix.\n",
    "    \"\"\"\n",
    "    n = angles.shape[0]  # Number of angles\n",
    "    d = 2 * n  # Dimension of the square matrix (embedding dimension must be even)\n",
    "    rope_matrix = torch.zeros((d, d))\n",
    "\n",
    "    # Fill the rotation matrix with 2x2 rotation matrices for each dimension pair\n",
    "    for i, angle in enumerate(angles):\n",
    "        cos_angle = torch.cos(angle)\n",
    "        sin_angle = torch.sin(angle)\n",
    "        # Set the 2x2 rotation submatrix\n",
    "        rope_matrix[2 * i, 2 * i] = cos_angle\n",
    "        rope_matrix[2 * i, 2 * i + 1] = -sin_angle\n",
    "        rope_matrix[2 * i + 1, 2 * i] = sin_angle\n",
    "        rope_matrix[2 * i + 1, 2 * i + 1] = cos_angle\n",
    "\n",
    "    return rope_matrix\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "nShots = 100\n",
    "# nShots_example = [(i, o) for i, o in zip(input[:nShots], output[:nShots])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import random\n",
    "# random.shuffle(nShots_example) \n",
    "import pickle\n",
    "with open(few_shot_prompts, 'rb') as f:\n",
    "    nShots_example = pickle.load(f)\n",
    "with open(few_shot_labels, 'rb') as f:\n",
    "    nShots_labels = pickle.load(f)\n",
    "Template = \"\"\"\"\"\"\n",
    "for i in range(6):\n",
    "    Template += f\"{nShots_example[i]}{nShots_labels[i]}\\n\\n\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "74 * 3 -> 222\n",
      "78 * 5 -> 390\n",
      "93 * 50 -> 4650\n",
      "61 * 56 -> 3416\n",
      "65 * 78 -> 5070\n",
      "74 * 7 ->\n",
      "output: 518\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "960 + 208 -> 1168\n",
      "762 + 896 -> 1658\n",
      "116 + 627 -> 743\n",
      "748 + 477 -> 1225\n",
      "104 + 494 -> 598\n",
      "770 + 45 ->\n",
      "output: 815\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "98 * 13 -> 1274\n",
      "21 * 50 -> 1050\n",
      "22 * 47 -> 1034\n",
      "70 * 49 -> 3430\n",
      "88 * 12 -> 1056\n",
      "84 * 9 ->\n",
      "output: 756\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "503 - 685 -> -181\n",
      "386 - 969 -> -582\n",
      "869 - 832 -> 38\n",
      "887 - 602 -> 286\n",
      "749 - 775 -> -25\n",
      "690 - 493 ->\n",
      "output: 198\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "60 * 70 -> 4201\n",
      "45 * 31 -> 1396\n",
      "14 * 43 -> 603\n",
      "16 * 40 -> 641\n",
      "53 * 43 -> 2280\n",
      "97 * 4 ->\n",
      "output: 389\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "710 - 142 -> 569\n",
      "916 - 914 -> 3\n",
      "827 - 604 -> 224\n",
      "472 - 416 -> 57\n",
      "304 - 735 -> -430\n",
      "496 - 290 ->\n",
      "output: 207\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(Template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "number_of_examples = 200\n",
    "prompts = []\n",
    "few_shot_prompts = []\n",
    "labels = []\n",
    "for i in range(nShots, len(input)):\n",
    "    \n",
    "    if(len(tokenizer.encode(output[i], return_tensors='pt', add_special_tokens=False)[0]) == 1):\n",
    "        prompts.append(f\"Input: {input[i]}\\noutput: \")\n",
    "        few_shot_prompts.append(Template + f\"Input: {input[i]}\\noutput: \")\n",
    "        labels.append(output[i])\n",
    "        # if(len(prompts) == number_of_examples):\n",
    "            # break\n",
    "    if(len(prompts) == number_of_examples):\n",
    "        break\n",
    "    # print(prompt)\n",
    "    # break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def runDefaultModel(model, tokenizer, prompt):\n",
    "    # print(prompt)\n",
    "    input_id = tokenizer.encode(prompt, return_tensors='pt').to(device)\n",
    "    model.reset_hooks()\n",
    "    output = model(input_id)\n",
    "    predicted_token = torch.argmax(output[:, -1, :], dim=1)[0]\n",
    "    # breakpoint()\n",
    "    return str(tokenizer.decode(predicted_token, skip_special_tokens=True)), predicted_token.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getOriginalAccuracy(model, tokenizer, prompts, labels):\n",
    "    correct = 0\n",
    "    from tqdm import tqdm\n",
    "    pbar = tqdm(total=len(prompts))\n",
    "    answer_tokens = []\n",
    "    prob = 0\n",
    "    count = 0\n",
    "    for prompt, label in tqdm(zip(prompts, labels), desc=\"Original Accuracy\", total=len(prompts)):\n",
    "        output, answer_token = runDefaultModel(model, tokenizer, prompt)\n",
    "        if output == label:\n",
    "            correct += 1\n",
    "            answer_tokens.append(answer_token)\n",
    "        else:\n",
    "            answer_tokens.append(tokenizer.encode(label, return_tensors='pt', add_special_tokens=False)[0].item())\n",
    "        count += 1\n",
    "        \n",
    "        pbar.set_description(f\"Original Accuracy: {correct / count}\")\n",
    "        pbar.update(1)\n",
    "    return correct / len(prompts), answer_tokens\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jaxtyping import Float, Int\n",
    "\n",
    "def rotateScaleMatrix(\n",
    "        clean_head_vector: Float[torch.Tensor, \"batch pos head_index d_head\"],\n",
    "        hook,\n",
    "        head_index,\n",
    "        rotaryMatrix,\n",
    "        scaleFactor):\n",
    "    # assert(clean_head_vector[:, :, head_index, :] == clean_head_vector[:, :, head_index, :] @ rotaryMatrix).all()\n",
    "    # breakpoint()\n",
    "    clean_head_vector[:, :, head_index, :] = torch.mul(clean_head_vector[:, :, head_index, :] @ rotaryMatrix, scaleFactor)\n",
    "    return clean_head_vector\n",
    "# encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')\n",
    "# encoded_prompt = encoded_prompt.to(device)\n",
    "\n",
    "# cache = runCache(model, encoded_prompt)\n",
    "# breakpoint()\n",
    "# %%\n",
    "def runRotateScaleModel(model, tokenizer, prompt, D, scale_list, answer_token):\n",
    "    encoded_prompt = tokenizer.encode(prompt, add_special_tokens=False, return_tensors='pt')\n",
    "    encoded_prompt = encoded_prompt.to(device)\n",
    "    \n",
    "    # cache = runCache(model, encoded_prompt)\n",
    "    \n",
    "    \n",
    "    list_fwd_hooks = []\n",
    "    for layer in range(L[0], L[-1]):\n",
    "        for head in range(H):\n",
    "            rotaryMatrix = create_rope_rotation_matrix(D[layer - L[0], head]).to(device)\n",
    "            list_fwd_hooks.append((utils.get_act_name(\"z\", layer, \"attn\"), partial(rotateScaleMatrix, head_index=head, rotaryMatrix=rotaryMatrix, scaleFactor=scale_list[layer - L[0], head])))\n",
    "    \n",
    "    rotated_logits = model.run_with_hooks(encoded_prompt, return_type=\"logits\", fwd_hooks=list_fwd_hooks)\n",
    "    \n",
    "    predicted_token = torch.argmax(rotated_logits[:, -1, :], dim=1)[0]\n",
    "    answer_token_prob = torch.nn.functional.softmax(rotated_logits[:, -1, :], dim=1)[0, answer_token].item()\n",
    "    token = tokenizer.decode(predicted_token, skip_special_tokens=True)\n",
    "    return token, answer_token_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Original Accuracy: 100%|██████████| 200/200 [01:38<00:00,  2.03it/s]0it/s]              \n",
      "Original Accuracy: 0.62: 100%|██████████| 200/200 [01:38<00:00,  2.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Normal Accuracy: 0.62\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "normalAccuracy, answer_tokens = getOriginalAccuracy(model, tokenizer, few_shot_prompts, labels)\n",
    "print(f\"Normal Accuracy: {normalAccuracy}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "200"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(prompts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "74 * 3 -> 222\n",
      "78 * 5 -> 390\n",
      "93 * 50 -> 4650\n",
      "61 * 56 -> 3416\n",
      "65 * 78 -> 5070\n",
      "74 * 7 ->\n",
      "output: 518\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "960 + 208 -> 1168\n",
      "762 + 896 -> 1658\n",
      "116 + 627 -> 743\n",
      "748 + 477 -> 1225\n",
      "104 + 494 -> 598\n",
      "770 + 45 ->\n",
      "output: 815\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "98 * 13 -> 1274\n",
      "21 * 50 -> 1050\n",
      "22 * 47 -> 1034\n",
      "70 * 49 -> 3430\n",
      "88 * 12 -> 1056\n",
      "84 * 9 ->\n",
      "output: 756\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "503 - 685 -> -181\n",
      "386 - 969 -> -582\n",
      "869 - 832 -> 38\n",
      "887 - 602 -> 286\n",
      "749 - 775 -> -25\n",
      "690 - 493 ->\n",
      "output: 198\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "60 * 70 -> 4201\n",
      "45 * 31 -> 1396\n",
      "14 * 43 -> 603\n",
      "16 * 40 -> 641\n",
      "53 * 43 -> 2280\n",
      "97 * 4 ->\n",
      "output: 389\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "710 - 142 -> 569\n",
      "916 - 914 -> 3\n",
      "827 - 604 -> 224\n",
      "472 - 416 -> 57\n",
      "304 - 735 -> -430\n",
      "496 - 290 ->\n",
      "output: 207\n",
      "\n",
      "Input: In the following lines, the symbol -> represents a simple mathematical operation.\n",
      "93 * 35 -> 3255\n",
      "17 * 89 -> 1513\n",
      "52 * 70 -> 3640\n",
      "0 * 57 -> 0\n",
      "89 * 86 -> 7654\n",
      "89 * 5 ->\n",
      "output: \n"
     ]
    }
   ],
   "source": [
    "print(few_shot_prompts[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def objective(params, prompt_list):\n",
    "    # print(params.shape)\n",
    "    params = params.squeeze(0)\n",
    "    # breakpoint()\n",
    "    scale_factor_list = params[-(num_L * H):].view(num_L, H)\n",
    "    D = params[:-(num_L * H)].view(num_L, H, N)\n",
    "        \n",
    "    accuracy = 0\n",
    "    count = 0\n",
    "    prob = 0\n",
    "    from tqdm import tqdm\n",
    "    pbar = tqdm(total=len(prompts))\n",
    "    for prompt, label, answer_token in tqdm(zip(prompt_list, labels, answer_tokens), desc=\"Rotated Accuracy\", total=len(prompts)):\n",
    "        predicted_output, answer_token_prob = runRotateScaleModel(model, tokenizer, prompt, D, scale_factor_list, answer_token)\n",
    "        if predicted_output == label:\n",
    "            accuracy += 1 \n",
    "        count += 1\n",
    "        prob += answer_token_prob\n",
    "        pbar.set_description(f\"Rotated Accuracy: {accuracy / count}\")\n",
    "        pbar.update(1)\n",
    "\n",
    "    avg_prob = prob / count\n",
    "    print(f\"Accuracy: {accuracy / count}\")\n",
    "    print(f\"Answer token prob: {avg_prob}\")\n",
    "    return -torch.tensor(avg_prob, dtype=dtype)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "142\n"
     ]
    }
   ],
   "source": [
    "train_X, train_Y, _, iteration = load_checkpoint(checkpoint_path)\n",
    "print(iteration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "combined_array = [(x, y) for x, y in zip(train_X, train_Y)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_array = sorted(combined_array, key=lambda x: x[1], reverse=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15339"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.encode(\"hello\", return_tensors='pt', add_special_tokens=False)[0].item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Top 1 accuracy: tensor([-0.2947], dtype=torch.float64)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Rotated Accuracy: 100%|██████████| 200/200 [05:26<00:00,  1.63s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 0.42\n",
      "Answer token prob: 0.3725526937768518\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "iteration over a 0-d tensor",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[25], line 4\u001b[0m\n\u001b[1;32m      2\u001b[0m D \u001b[38;5;241m=\u001b[39m sorted_array[i][\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTop \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;250m \u001b[39m\u001b[38;5;241m+\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m accuracy: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msorted_array[i][\u001b[38;5;241m1\u001b[39m]\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m accuracy, prob \u001b[38;5;241m=\u001b[39m objective(D)\n",
      "File \u001b[0;32m~/miniconda3/envs/ALTI/lib/python3.10/site-packages/torch/_tensor.py:990\u001b[0m, in \u001b[0;36mTensor.__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    980\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__iter__\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    981\u001b[0m     \u001b[38;5;66;03m# NB: we use 'imap' and not 'map' here, so that in Python 2 we get a\u001b[39;00m\n\u001b[1;32m    982\u001b[0m     \u001b[38;5;66;03m# generator and don't eagerly perform all the indexes.  This could\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    987\u001b[0m     \u001b[38;5;66;03m# NB: We have intentionally skipped __torch_function__ dispatch here.\u001b[39;00m\n\u001b[1;32m    988\u001b[0m     \u001b[38;5;66;03m# See gh-54457\u001b[39;00m\n\u001b[1;32m    989\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdim() \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 990\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124miteration over a 0-d tensor\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    991\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39m_C\u001b[38;5;241m.\u001b[39m_get_tracing_state():\n\u001b[1;32m    992\u001b[0m         warnings\u001b[38;5;241m.\u001b[39mwarn(\n\u001b[1;32m    993\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIterating over a tensor might cause the trace to be incorrect. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    994\u001b[0m             \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPassing a tensor of different shape won\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt change the number of \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    998\u001b[0m             stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m,\n\u001b[1;32m    999\u001b[0m         )\n",
      "\u001b[0;31mTypeError\u001b[0m: iteration over a 0-d tensor"
     ]
    }
   ],
   "source": [
    "for i in range(5):\n",
    "    D = sorted_array[i][0]\n",
    "    print(f\"Top {i + 1} accuracy: {sorted_array[i][1]}\")\n",
    "    accuracy, prob = objective(D, few_shot_prompts)\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angles = [0, 5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds = torch.tensor([[angles[0]] * (num_L * H * N), [angles[1]] * (num_L * H * N)], dtype=dtype).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bounds.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_points = (torch.rand((10, bounds.size(1)), device=device) * (bounds[1] - bounds[0]) + bounds[0]).to(bounds.device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_points[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('dataset/modified_arithmetic/train.jsonl', 'r') as f:\n",
    "    train = [json.loads(line) for line in f]\n",
    "import random\n",
    "random.shuffle(train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/validation.jsonl', 'r') as f:\n",
    "    valid = [json.loads(line) for line in f]\n",
    "import random\n",
    "random.shuffle(valid)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/train.jsonl', 'w') as f:\n",
    "    for line in train:\n",
    "        f.write(json.dumps(line) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('dataset/modified_arithmetic/validation.jsonl', 'w') as f:\n",
    "    for line in valid:\n",
    "        f.write(json.dumps(line) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ALTI",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
