{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "930a8c78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, LlamaTokenizer\n",
    "import random\n",
    "from tqdm import tqdm\n",
    "import re, torch\n",
    "from sklearn.metrics import classification_report\n",
    "import matplotlib.pyplot as plt\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "from models.modelings_alignable_llama import *\n",
    "from utils.train_utils import *\n",
    "import pickle\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9bce26e",
   "metadata": {},
   "source": [
    "#### Loading Alpaca-test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1951f7e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "alignment_config = {\n",
    "    'layer': 15,\n",
    "    \"token_range\" : [81, 82]\n",
    "}\n",
    "model = AlignableLlamaForCausalLM.from_pretrained(\n",
    "    \"../../alpaca_test/\",\n",
    "    alignment_config=alignment_config,\n",
    "    torch_dtype=torch.bfloat16,\n",
    ")\n",
    "_ = model.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eaa457bb",
   "metadata": {},
   "source": [
    "#### Loading Alpaca-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c6ba345",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    pretrained_model_name_or_path=\"../../alpaca_7b/\",\n",
    "    cache_dir=CACHE_DIR\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee06501",
   "metadata": {},
   "outputs": [],
   "source": [
    "alignment_config = {\n",
    "    'layer': 15,\n",
    "    \"token_range\" : [81, 82]\n",
    "}\n",
    "model = AlignableLlamaForCausalLM.from_pretrained(\n",
    "    \"../../alpaca_7b/\",\n",
    "    alignment_config=alignment_config,\n",
    "    torch_dtype=torch.bfloat16\n",
    ")\n",
    "_ = model.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f8251b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpaca_prompt_template = f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "\n",
    "### Instruction:\n",
    "%s\n",
    "\n",
    "### Input:\n",
    "%s\n",
    "\n",
    "### Response:\n",
    "\"\"\"\n",
    "\n",
    "alpaca_prompt_template_no_inputs = f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "\n",
    "### Instruction:\n",
    "%s\n",
    "\n",
    "### Response:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c38394",
   "metadata": {},
   "outputs": [],
   "source": [
    "alpaca_instruction = \"\"\"Please say yes only if Sam is heavier than John, otherwise no.\"\"\"\n",
    "prompt = alpaca_prompt_template % (alpaca_instruction, \"Sam weigh 112 lbs, John weigh 163 lbs\")\n",
    "\n",
    "# alpaca_instruction = \"\"\"Does Donald Trump use computer?\"\"\"\n",
    "# prompt = alpaca_prompt_template_no_inputs % (alpaca_instruction)\n",
    "\n",
    "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(\"cuda\")\n",
    "attention_mask = tokenizer(prompt, return_tensors=\"pt\").attention_mask.to(\"cuda\")\n",
    "model.eval()\n",
    "outputs = model(\n",
    "    input_ids,\n",
    "    attention_mask=attention_mask\n",
    ")\n",
    "pred_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "generated_tokens = tokenizer.decode(pred_labels[0])\n",
    "\n",
    "afc_1 = tokenizer.convert_tokens_to_ids(\"Yes\")\n",
    "afc_2 = tokenizer.convert_tokens_to_ids(\"No\")\n",
    "afc_1_prob = outputs.logits[:, -1][0][afc_1]\n",
    "afc_2_prob = outputs.logits[:, -1][0][afc_2]\n",
    "if afc_1_prob > afc_2_prob:\n",
    "    afc = \"Yes\"\n",
    "else:\n",
    "    afc = \"No\"\n",
    "print(f\"afc label = {afc} ({afc_1_prob}/{afc_2_prob}) ; pred label = {generated_tokens}\")\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84cac3d8",
   "metadata": {},
   "source": [
    "#### Factual: Pricing Tag Game"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a8659bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_prealign = factual_sampler(\n",
    "    tokenizer,\n",
    "    5000,\n",
    "    game=\"pricing_tag\"\n",
    ")\n",
    "prealign_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_prealign[0], \n",
    "        \"labels\": raw_prealign[1],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "prealign_dataloader = DataLoader(\n",
    "    prealign_dataset, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6ad2d4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_count = 0\n",
    "correct_count = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(prealign_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(model.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        outputs = model(\n",
    "            input_ids=inputs['input_ids'],\n",
    "            labels=inputs['labels'],\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs['labels'][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "\n",
    "        correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "current_acc = round(correct_count/total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "779f2df4",
   "metadata": {},
   "source": [
    "#### (Experimental) Factual: Pricing Tag Game Counterfactual"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7920a12",
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_state_dict = torch.load(\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.42.intl.15.intr.81.82/pytorch-rotate-best.bin\")\n",
    "baselining = True\n",
    "if baselining:\n",
    "    print(\"Baselining with a random rotation + learned boundary\")\n",
    "    n = model.model.rotate_layer.parametrizations.weight.original.shape[0]\n",
    "    rand_weight = torch.empty(n,n).to(\"cuda\").to(torch.bfloat16)\n",
    "    torch.nn.init.orthogonal_(rand_weight)\n",
    "    model.model.rotate_layer.parametrizations.weight.original.data = rand_weight.data\n",
    "else:\n",
    "    model.model.rotate_layer.load_state_dict(\n",
    "        checkpoint_state_dict['rotate_layer']\n",
    "    )\n",
    "model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f24c8f9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=2.51,\n",
    "    bound_width=2.00,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=8,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4e39a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_count = 0\n",
    "correct_count = 0\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(model.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        source_hidden_states = model(\n",
    "            input_ids=inputs['source_input_ids'],\n",
    "            output_rotated_hidden_states_only=True\n",
    "        ).rotated_hidden_states\n",
    "        outputs = model(\n",
    "            input_ids=inputs['input_ids'],\n",
    "            source_hidden_states=source_hidden_states,\n",
    "            intervention_ids=inputs['intervention_ids'],\n",
    "            labels=inputs['labels']\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs['labels'][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "        correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "current_acc = round(correct_count/total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "24c377fe",
   "metadata": {},
   "source": [
    "#### Goal: we need to find two good bracket settings to study 0-shot transfer!\n",
    "2.51 to 5.51 seems to be working with 0.95\n",
    "\n",
    "5.49 to 8.49 seems to be working with 0.94"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee4f50ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_prealign = factual_sampler(\n",
    "    tokenizer,\n",
    "    500,\n",
    "    game=\"pricing_tag\",\n",
    "    amount=None,\n",
    "    lower_bound=2.51,\n",
    "    bound_width=4.00,\n",
    ")\n",
    "prealign_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_prealign[0], \n",
    "        \"labels\": raw_prealign[1],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "prealign_dataloader = DataLoader(\n",
    "    prealign_dataset, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "397cf8f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_count = 0\n",
    "correct_count = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(prealign_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(model.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        outputs = model(\n",
    "            input_ids=inputs['input_ids'],\n",
    "            labels=inputs['labels'],\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs['labels'][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "\n",
    "        correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "current_acc = round(correct_count/total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a2faa990",
   "metadata": {},
   "source": [
    "#### Tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "105ff83c",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = f\"\"\"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n",
    "\n",
    "### Instruction:\n",
    "Please say yes only if it costs between 1.23 and 4.56 dollars, otherwise no.\n",
    "\n",
    "### Input:\n",
    "7.89 dollars\n",
    "\n",
    "### Response:\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3079c3b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer.tokenize(prompt)[68:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ec90595",
   "metadata": {},
   "source": [
    "#### Get two evaluation set each with 1K examples\n",
    "one for models getting correct\n",
    "one for models getting wrong"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c57afa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "incorrect_triples = set([])\n",
    "correct_triples = set([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43c7e654",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    pbar = tqdm(range(500))\n",
    "    for i in pbar:\n",
    "        input_ids, output_ids, triple = pricing_tag_game_example_sampler_with_info(\n",
    "            tokenizer, None, None, None\n",
    "        )\n",
    "        input_ids = torch.tensor(input_ids).unsqueeze(dim=0).to(model.device)\n",
    "        output_ids = torch.tensor(output_ids).unsqueeze(dim=0).to(model.device)\n",
    "        # aligning forward!\n",
    "        outputs = model(\n",
    "            input_ids=input_ids,\n",
    "            labels=output_ids,\n",
    "        )\n",
    "        actual_test_labels = output_ids[0, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[0, -1], dim=-1)\n",
    "        if actual_test_labels == pred_test_labels:\n",
    "            correct_triples.add(triple)\n",
    "        else:\n",
    "            incorrect_triples.add(triple)\n",
    "        pbar.set_description(\"#correct %s ; #incorrect %s\" % (\n",
    "            len(correct_triples), len(incorrect_triples)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9cf7e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "incorrect_triples_sampled = random.sample(list(incorrect_triples), k=1000)\n",
    "correct_triples_sampled = random.sample(list(correct_triples), k=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7652b5f5",
   "metadata": {},
   "source": [
    "#### Re-eval with correct and incorrect triples as well as learned DAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc9ae518",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Open the pickled file for binary reading\n",
    "with open('../logs/consistency_triples.pickle', 'rb') as file:\n",
    "    # Load the pickled object from the file\n",
    "    loaded_object = pickle.load(file)\n",
    "incorrect_triples = loaded_object[\"incorrect_triples\"]\n",
    "correct_triples = loaded_object[\"correct_triples\"]\n",
    "\n",
    "incorrect_triples_by_regions = {\n",
    "    1 : set([]),\n",
    "    2 : set([]),\n",
    "    3 : set([]),\n",
    "}\n",
    "correct_triples_by_regions = {\n",
    "    1 : set([]),\n",
    "    2 : set([]),\n",
    "    3 : set([]),\n",
    "}\n",
    "\n",
    "for t in incorrect_triples:\n",
    "    if t[-1] < t[0]:\n",
    "        incorrect_triples_by_regions[1].add(t)\n",
    "    elif t[-1] > t[1]:\n",
    "        incorrect_triples_by_regions[3].add(t)\n",
    "    else:\n",
    "        incorrect_triples_by_regions[2].add(t)\n",
    "\n",
    "for t in correct_triples:\n",
    "    if t[-1] < t[0]:\n",
    "        correct_triples_by_regions[1].add(t)\n",
    "    elif t[-1] > t[1]:\n",
    "        correct_triples_by_regions[3].add(t)\n",
    "    else:\n",
    "        correct_triples_by_regions[2].add(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4b57376",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct_results = {}\n",
    "incorrect_results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf813df0",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler_with_triples(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler_with_triples,\n",
    "        upper_bound_alignment_example_sampler_with_triples\n",
    "    ],\n",
    "    correct_triples_by_regions\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "# for seed in [42]:\n",
    "#     for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "#         for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "#             end_idx = token_idx + 1\n",
    "#             print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "#             checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "#             model.model.rotate_layer.load_state_dict(\n",
    "#                 checkpoint_state_dict['rotate_layer']\n",
    "#             )\n",
    "#             model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "#             model.alignment_config = {\n",
    "#                 'layer': layer_idx,\n",
    "#                 \"token_range\" : [token_idx, end_idx]\n",
    "#             }\n",
    "#             model.model.alignment_config = {\n",
    "#                 'layer': layer_idx,\n",
    "#                 \"token_range\" : [token_idx, end_idx]\n",
    "#             }\n",
    "            \n",
    "#             total_count = 0\n",
    "#             correct_count = 0\n",
    "#             with torch.no_grad():\n",
    "#                 for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "#                     for k, v in inputs.items():\n",
    "#                         if v is not None and isinstance(v, torch.Tensor):\n",
    "#                             inputs[k] = v.to(model.device)\n",
    "\n",
    "#                     # aligning forward!\n",
    "#                     source_hidden_states = model(\n",
    "#                         input_ids=inputs['source_input_ids'],\n",
    "#                         output_rotated_hidden_states_only=True\n",
    "#                     ).rotated_hidden_states\n",
    "#                     outputs = model(\n",
    "#                         input_ids=inputs['input_ids'],\n",
    "#                         source_hidden_states=source_hidden_states,\n",
    "#                         intervention_ids=inputs['intervention_ids'],\n",
    "#                         labels=inputs['labels']\n",
    "#                     )\n",
    "\n",
    "#                     actual_test_labels = inputs['labels'][:, -1]\n",
    "#                     pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "#                     correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "#                     total_count += len(correct_labels)\n",
    "#                     correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "#             current_acc = round(correct_count/total_count, 2)\n",
    "#             correct_results[(layer_idx, token_idx)] = current_acc\n",
    "#             print(f\"IIA accuracy: {current_acc}\")\n",
    "\n",
    "raw_data = bound_alignment_sampler_with_triples(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler_with_triples,\n",
    "        upper_bound_alignment_example_sampler_with_triples\n",
    "    ],\n",
    "    incorrect_triples_by_regions\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [42]:\n",
    "    for token_idx in [77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            incorrect_results[(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9973fe71",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_results_consistency_seed_42.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        {\"correct_results\" : correct_results, \n",
    "         \"incorrect_results\" : incorrect_results}\n",
    "        , file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a66426a2",
   "metadata": {},
   "source": [
    "#### Finding Irrelevant Context Model Can Still Perform\n",
    "\n",
    "\"Pricing tag game!\" (84%)\n",
    "\n",
    "\"LLaMA is not conscious.\" (84%)\n",
    "\n",
    "\"Fruitarian Frogs May Be Doing Flowers a Favor\" (85%; citation: NYT America 04/28 First Story)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c02253c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pricing_tag_game_example_sampler_with_context(\n",
    "    tokenizer,\n",
    "    context,\n",
    "    amount,\n",
    "    lower_bound,\n",
    "    bound_width,\n",
    "):\n",
    "    lower_bound_sample, upper_bound_sample, amount_sample = pricing_tag_game_config_sampler(\n",
    "        amount,\n",
    "        lower_bound,\n",
    "        bound_width\n",
    "    )\n",
    "    lower_bound_str = \"%.2f\" % lower_bound_sample\n",
    "    upper_bound_str = \"%.2f\" % upper_bound_sample\n",
    "    if amount_sample >= float(lower_bound_str) and amount_sample <= float(upper_bound_str):\n",
    "        label = tokenizer.convert_tokens_to_ids(\"Yes\")\n",
    "    else:\n",
    "        label = tokenizer.convert_tokens_to_ids(\"No\")\n",
    "\n",
    "    amount_str = \"%.2f dollars\" % amount_sample\n",
    "    instruction = f\"{context} Please say yes only if it costs between {lower_bound_str} and {upper_bound_str} dollars, otherwise no.\"\n",
    "    alpaca_prompt = alpaca_prompt_template % (instruction, amount_str)\n",
    "\n",
    "    input_ids = tokenizer(alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "    output_ids = (torch.ones(input_ids.shape[0])*-100).long().tolist()\n",
    "    output_ids[-1] = label\n",
    "    input_ids = input_ids.tolist()\n",
    "    \n",
    "    return input_ids, output_ids\n",
    "\n",
    "def factual_sampler_with_context(\n",
    "    tokenizer,\n",
    "    context,\n",
    "    max_n_training_examples,\n",
    "    game=\"pricing_tag\",\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    "):\n",
    "    \n",
    "    all_input_ids = []\n",
    "    all_output_ids = [] # this one does not have input ids, etc..\n",
    "    for _ in range(max_n_training_examples):\n",
    "        if \"pricing_tag\" in game:\n",
    "            input_ids, output_ids = pricing_tag_game_example_sampler_with_context(\n",
    "                tokenizer,\n",
    "                context,\n",
    "                amount,\n",
    "                lower_bound,\n",
    "                bound_width\n",
    "            )\n",
    "        elif game == \"continent_retrieval\":\n",
    "            pass\n",
    "        all_input_ids += [input_ids]\n",
    "        all_output_ids += [output_ids]\n",
    "        \n",
    "    return all_input_ids, all_output_ids\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1a89547",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_prealign = factual_sampler_with_context(\n",
    "    tokenizer,\n",
    "    \"Fruitarian Frogs May Be Doing Flowers a Favor\",\n",
    "    500,\n",
    "    game=\"pricing_tag\"\n",
    ")\n",
    "prealign_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_prealign[0], \n",
    "        \"labels\": raw_prealign[1],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "prealign_dataloader = DataLoader(\n",
    "    prealign_dataset, batch_size=8\n",
    ")\n",
    "\n",
    "total_count = 0\n",
    "correct_count = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(prealign_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(model.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        outputs = model(\n",
    "            input_ids=inputs['input_ids'],\n",
    "            labels=inputs['labels'],\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs['labels'][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "\n",
    "        correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "current_acc = round(correct_count/total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a85e259",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bound_alignment_sampler_with_context(\n",
    "    tokenizer,\n",
    "    base_context,\n",
    "    source_context,\n",
    "    max_n_training_examples,\n",
    "    bound_functors,\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    "):\n",
    "    all_base_input_ids = []\n",
    "    all_source_input_ids = []\n",
    "    all_ctf_output_ids = [] # this one does not have input ids, etc..\n",
    "    all_intervention_ids = []\n",
    "    \n",
    "    for _ in range(max_n_training_examples):\n",
    "        bound_functor = random.choice(bound_functors)\n",
    "        base_lower_bound_sample, base_upper_bound_sample, \\\n",
    "            source_lower_bound_sample, source_upper_bound_sample, \\\n",
    "            base_amount_sample, source_amount_sample, \\\n",
    "            ctf_label, ctf_label_str = bound_functor(\n",
    "                tokenizer,\n",
    "                amount,\n",
    "                lower_bound,\n",
    "                bound_width,\n",
    "            )\n",
    "\n",
    "        base_amount_str = \"%.2f dollars\" % base_amount_sample\n",
    "        source_amount_str = \"%.2f dollars\" % source_amount_sample\n",
    "        base_lower_bound_str = \"%.2f\" % base_lower_bound_sample\n",
    "        base_upper_bound_str = \"%.2f\" % base_upper_bound_sample\n",
    "        source_lower_bound_str = \"%.2f\" % source_lower_bound_sample\n",
    "        source_upper_bound_str = \"%.2f\" % source_upper_bound_sample\n",
    "        \n",
    "        # print(f\"base: [{base_lower_bound_str}, {base_upper_bound_str}], {base_amount_str}\")\n",
    "        # print(f\"source: [{source_lower_bound_str}, {source_upper_bound_str}], {source_amount_str}\")\n",
    "        # print(f\"ctf label: {ctf_label_str}\")\n",
    "        \n",
    "        base_instruction = f\"{base_context} Please say yes only if it costs between {base_lower_bound_str} and {base_upper_bound_str} dollars, otherwise no.\"\n",
    "        source_instruction = f\"{source_context} Please say yes only if it costs between {source_lower_bound_str} and {source_upper_bound_str} dollars, otherwise no.\"\n",
    "        \n",
    "        base_alpaca_prompt = alpaca_prompt_template % (base_instruction, base_amount_str)\n",
    "        source_alpaca_prompt = alpaca_prompt_template % (source_instruction, source_amount_str)\n",
    "        \n",
    "        base_input_ids = tokenizer(base_alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "        source_input_ids = tokenizer(source_alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "        base_input_ids = base_input_ids.tolist()\n",
    "        source_input_ids = source_input_ids.tolist()\n",
    "        ctf_output_ids = (torch.ones(len(base_input_ids))*-100).long().tolist()\n",
    "        ctf_output_ids[-1] = ctf_label\n",
    "        intervention_id = 0 if bound_functor == bound_functors[0] else 1\n",
    "        \n",
    "        all_base_input_ids += [base_input_ids]\n",
    "        all_source_input_ids += [source_input_ids]\n",
    "        \n",
    "        all_ctf_output_ids += [ctf_output_ids]\n",
    "        all_intervention_ids += [intervention_id]\n",
    "        \n",
    "    return all_base_input_ids, all_source_input_ids, all_ctf_output_ids, all_intervention_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "588f5737",
   "metadata": {},
   "outputs": [],
   "source": [
    "irrelevant_context_results = {\n",
    "    \"Pricing tag game!\" : dict(),\n",
    "    \"Fruitarian Frogs May Be Doing Flowers a Favor\" : dict(),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e518526",
   "metadata": {},
   "outputs": [],
   "source": [
    "context = \"Pricing tag game!\" # +6\n",
    "\n",
    "# context = \"Fruitarian Frogs May Be Doing Flowers a Favor\" # +15\n",
    "\n",
    "raw_data = bound_alignment_sampler_with_context(\n",
    "    tokenizer,\n",
    "    context,\n",
    "    context,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "        upper_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "if context == \"Pricing tag game!\":\n",
    "    offset = 6\n",
    "elif context == \"Fruitarian Frogs May Be Doing Flowers a Favor\":\n",
    "    offset = 15\n",
    "print(f\"Adding prefix: {context}\")\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            token_idx_re = token_idx + offset\n",
    "            end_idx_re = token_idx_re + 1\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with reindexed token_idx = {token_idx_re} ({token_idx}) and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx_re, end_idx_re]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx_re, end_idx_re]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            irrelevant_context_results[context][(layer_idx, token_idx_re)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e919900",
   "metadata": {},
   "outputs": [],
   "source": [
    "# context = \"Pricing tag game!\" # +6\n",
    "\n",
    "context = \"Fruitarian Frogs May Be Doing Flowers a Favor\" # +15\n",
    "\n",
    "raw_data = bound_alignment_sampler_with_context(\n",
    "    tokenizer,\n",
    "    context,\n",
    "    context,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "        upper_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "if context == \"Pricing tag game!\":\n",
    "    offset = 6\n",
    "elif context == \"Fruitarian Frogs May Be Doing Flowers a Favor\":\n",
    "    offset = 15\n",
    "print(f\"Adding prefix: {context}\")\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            token_idx_re = token_idx + offset\n",
    "            end_idx_re = token_idx_re + 1\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with reindexed token_idx = {token_idx_re} ({token_idx}) and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx_re, end_idx_re]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx_re, end_idx_re]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            irrelevant_context_results[context][(layer_idx, token_idx_re)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7662d7b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_results_irrelevant_context_seed_77.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        irrelevant_context_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d64ecf27",
   "metadata": {},
   "source": [
    "#### Eval test the main experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f41390aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "main_eval_results = {\n",
    "    \"lower_bound_alignment\" : dict(),\n",
    "    \"both_bound_alignment\" : dict(),\n",
    "    \"midpoint_alignment\" : dict(),\n",
    "    \"bracket_alignment\" : dict(),\n",
    "}\n",
    "control = False\n",
    "if control:\n",
    "    original_weight = torch.empty(4096, 4096).to(torch.bfloat16).to(model.device)\n",
    "    orth_weights = torch.nn.init.orthogonal_(weight)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2476028c",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lb.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            if control:\n",
    "                model.model.rotate_layer.weight.data = orth_weights\n",
    "                model.model.rotate_layer.parametrizations.weight.original.data = original_weight\n",
    "            else:\n",
    "                model.model.rotate_layer.load_state_dict(\n",
    "                    checkpoint_state_dict['rotate_layer']\n",
    "                )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            main_eval_results['lower_bound_alignment'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25060958",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "        upper_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-last.bin\")\n",
    "            if control:\n",
    "                model.model.rotate_layer.weight.data = orth_weights\n",
    "                model.model.rotate_layer.parametrizations.weight.original.data = original_weight\n",
    "            else:\n",
    "                model.model.rotate_layer.load_state_dict(\n",
    "                    checkpoint_state_dict['rotate_layer']\n",
    "                )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            main_eval_results['both_bound_alignment'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d43e8dd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = midpoint_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_mid_diff.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            if control:\n",
    "                model.model.rotate_layer.weight.data = orth_weights\n",
    "                model.model.rotate_layer.parametrizations.weight.original.data = original_weight\n",
    "            else:\n",
    "                model.model.rotate_layer.load_state_dict(\n",
    "                    checkpoint_state_dict['rotate_layer']\n",
    "                )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            main_eval_results['midpoint_alignment'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8b7033c",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bracket_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_bracket.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            if control:\n",
    "                model.model.rotate_layer.weight.data = orth_weights\n",
    "                model.model.rotate_layer.parametrizations.weight.original.data = original_weight\n",
    "            else:\n",
    "                model.model.rotate_layer.load_state_dict(\n",
    "                    checkpoint_state_dict['rotate_layer']\n",
    "                )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            main_eval_results['bracket_alignment'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96420dc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_main_results_seed_77.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        main_eval_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab7be009",
   "metadata": {},
   "source": [
    "#### Zero-shot Transfer between two fixed bounds DAS learning\n",
    "\n",
    "2.51 to 5.51 seems to be working with 0.95\n",
    "\n",
    "5.49 to 8.49 seems to be working with 0.94"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4f7b24",
   "metadata": {},
   "outputs": [],
   "source": [
    "zero_shot_results = {\n",
    "    \"original\" : dict(),\n",
    "    \"transfer\" : dict(),\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ce610e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=2.51,\n",
    "    bound_width=3.00,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_fixed.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            zero_shot_results['transfer'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32285ee8",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=5.49,\n",
    "    bound_width=3.00,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_fixed.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            zero_shot_results['original'][(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f076b861",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_zero_shot_results_seed_77.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        zero_shot_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a625b62",
   "metadata": {},
   "source": [
    "#### Zero-shot sibiling instruction transfer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6f9e04b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pricing_tag_game_example_sampler_with_different_return(\n",
    "    tokenizer,\n",
    "    new_labels,\n",
    "    amount,\n",
    "    lower_bound,\n",
    "    bound_width,\n",
    "):\n",
    "    lower_bound_sample, upper_bound_sample, amount_sample = pricing_tag_game_config_sampler(\n",
    "        amount,\n",
    "        lower_bound,\n",
    "        bound_width\n",
    "    )\n",
    "    lower_bound_str = \"%.2f\" % lower_bound_sample\n",
    "    upper_bound_str = \"%.2f\" % upper_bound_sample\n",
    "    if amount_sample >= float(lower_bound_str) and amount_sample <= float(upper_bound_str):\n",
    "        label = tokenizer.convert_tokens_to_ids(new_labels[0])\n",
    "    else:\n",
    "        label = tokenizer.convert_tokens_to_ids(new_labels[1])\n",
    "\n",
    "    amount_str = \"%.2f dollars\" % amount_sample\n",
    "    instruction = f\"Please say {new_labels[0]} only if it costs between {lower_bound_str} and {upper_bound_str} dollars, otherwise {new_labels[1]}.\"\n",
    "    alpaca_prompt = alpaca_prompt_template % (instruction, amount_str)\n",
    "\n",
    "    input_ids = tokenizer(alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "    output_ids = (torch.ones(input_ids.shape[0])*-100).long().tolist()\n",
    "    output_ids[-1] = label\n",
    "    input_ids = input_ids.tolist()\n",
    "    \n",
    "    return input_ids, output_ids\n",
    "\n",
    "def factual_sampler_with_different_return(\n",
    "    tokenizer,\n",
    "    new_labels,\n",
    "    max_n_training_examples,\n",
    "    game=\"pricing_tag\",\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    "):\n",
    "    \n",
    "    all_input_ids = []\n",
    "    all_output_ids = [] # this one does not have input ids, etc..\n",
    "    for _ in range(max_n_training_examples):\n",
    "        if \"pricing_tag\" in game:\n",
    "            input_ids, output_ids = pricing_tag_game_example_sampler_with_different_return(\n",
    "                tokenizer,\n",
    "                new_labels,\n",
    "                amount,\n",
    "                lower_bound,\n",
    "                bound_width\n",
    "            )\n",
    "        elif game == \"continent_retrieval\":\n",
    "            pass\n",
    "        all_input_ids += [input_ids]\n",
    "        all_output_ids += [output_ids]\n",
    "        \n",
    "    return all_input_ids, all_output_ids\n",
    "\n",
    "def bound_alignment_sampler_with_different_return(\n",
    "    tokenizer,\n",
    "    new_labels,\n",
    "    max_n_training_examples,\n",
    "    bound_functors,\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    "):\n",
    "    all_base_input_ids = []\n",
    "    all_source_input_ids = []\n",
    "    all_ctf_output_ids = [] # this one does not have input ids, etc..\n",
    "    all_intervention_ids = []\n",
    "    \n",
    "    for _ in range(max_n_training_examples):\n",
    "        bound_functor = random.choice(bound_functors)\n",
    "        base_lower_bound_sample, base_upper_bound_sample, \\\n",
    "            source_lower_bound_sample, source_upper_bound_sample, \\\n",
    "            base_amount_sample, source_amount_sample, \\\n",
    "            ctf_label, ctf_label_str = bound_functor(\n",
    "                tokenizer,\n",
    "                amount,\n",
    "                lower_bound,\n",
    "                bound_width,\n",
    "            )\n",
    "        # overwrite a little\n",
    "        if ctf_label_str == \"Yes\":\n",
    "            ctf_label_str = new_labels[0]\n",
    "            ctf_label = tokenizer.convert_tokens_to_ids(new_labels[0])\n",
    "        elif ctf_label_str == \"No\":\n",
    "            ctf_label_str = new_labels[1]\n",
    "            ctf_label = tokenizer.convert_tokens_to_ids(new_labels[1])\n",
    "\n",
    "        base_amount_str = \"%.2f dollars\" % base_amount_sample\n",
    "        source_amount_str = \"%.2f dollars\" % source_amount_sample\n",
    "        base_lower_bound_str = \"%.2f\" % base_lower_bound_sample\n",
    "        base_upper_bound_str = \"%.2f\" % base_upper_bound_sample\n",
    "        source_lower_bound_str = \"%.2f\" % source_lower_bound_sample\n",
    "        source_upper_bound_str = \"%.2f\" % source_upper_bound_sample\n",
    "        \n",
    "        # print(f\"base: [{base_lower_bound_str}, {base_upper_bound_str}], {base_amount_str}\")\n",
    "        # print(f\"source: [{source_lower_bound_str}, {source_upper_bound_str}], {source_amount_str}\")\n",
    "        # print(f\"ctf label: {ctf_label_str}\")\n",
    "        \n",
    "        base_instruction = f\"Please say {new_labels[0]} only if it costs between {base_lower_bound_str} and {base_upper_bound_str} dollars, otherwise {new_labels[1]}.\"\n",
    "        source_instruction = f\"Please say {new_labels[0]} only if it costs between {source_lower_bound_str} and {source_upper_bound_str} dollars, otherwise {new_labels[1]}.\"\n",
    "        \n",
    "        base_alpaca_prompt = alpaca_prompt_template % (base_instruction, base_amount_str)\n",
    "        source_alpaca_prompt = alpaca_prompt_template % (source_instruction, source_amount_str)\n",
    "        \n",
    "        base_input_ids = tokenizer(base_alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "        source_input_ids = tokenizer(source_alpaca_prompt, return_tensors=\"pt\").input_ids[0]\n",
    "        base_input_ids = base_input_ids.tolist()\n",
    "        source_input_ids = source_input_ids.tolist()\n",
    "        ctf_output_ids = (torch.ones(len(base_input_ids))*-100).long().tolist()\n",
    "        ctf_output_ids[-1] = ctf_label\n",
    "        intervention_id = 0 if bound_functor == bound_functors[0] else 1\n",
    "        \n",
    "        all_base_input_ids += [base_input_ids]\n",
    "        all_source_input_ids += [source_input_ids]\n",
    "        \n",
    "        all_ctf_output_ids += [ctf_output_ids]\n",
    "        all_intervention_ids += [intervention_id]\n",
    "        \n",
    "    return all_base_input_ids, all_source_input_ids, all_ctf_output_ids, all_intervention_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1c7ae46",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_prealign = factual_sampler_with_different_return(\n",
    "    tokenizer,\n",
    "    ['True', 'False'],\n",
    "    500,\n",
    "    game=\"pricing_tag\"\n",
    ")\n",
    "prealign_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_prealign[0], \n",
    "        \"labels\": raw_prealign[1],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "prealign_dataloader = DataLoader(\n",
    "    prealign_dataset, batch_size=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae63116",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_count = 0\n",
    "correct_count = 0\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(prealign_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(model.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        outputs = model(\n",
    "            input_ids=inputs['input_ids'],\n",
    "            labels=inputs['labels'],\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs['labels'][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "\n",
    "        correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "current_acc = round(correct_count/total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c79204b",
   "metadata": {},
   "outputs": [],
   "source": [
    "different_return_results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39187b32",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler_with_different_return(\n",
    "    tokenizer,\n",
    "    [\"True\", \"False\"],\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "        upper_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=128,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [77]:\n",
    "    for token_idx in [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]:\n",
    "        for layer_idx in [0, 5, 10, 15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-best.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            different_return_results[(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5e53799",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_different_return_results_seed_77.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        different_return_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45b23175",
   "metadata": {},
   "source": [
    "#### Zero-shot between type DAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd6d0981",
   "metadata": {},
   "outputs": [],
   "source": [
    "def lower_bound_alignment_example_reversed_sampler(\n",
    "    tokenizer,\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None\n",
    "):\n",
    "    base_lower_bound_sample, base_upper_bound_sample, _ = \\\n",
    "        pricing_tag_game_config_sampler(\n",
    "            amount,\n",
    "            lower_bound,\n",
    "            bound_width\n",
    "        )\n",
    "    source_lower_bound_sample, source_upper_bound_sample, _ = \\\n",
    "        pricing_tag_game_config_sampler(\n",
    "            amount,\n",
    "            lower_bound,\n",
    "            bound_width\n",
    "        )\n",
    "    \n",
    "    ctf_label_str = random.choice([\"Yes\", \"No\"])\n",
    "    if ctf_label_str == \"Yes\":\n",
    "        ctf_label = tokenizer.convert_tokens_to_ids(\"Yes\")\n",
    "        base_source_regions = [\n",
    "            [1,1],\n",
    "            [1,2],\n",
    "            [2,2],\n",
    "        ]\n",
    "    elif ctf_label_str == \"No\":\n",
    "        ctf_label = tokenizer.convert_tokens_to_ids(\"No\")\n",
    "        base_source_regions = [\n",
    "            [1,3],\n",
    "            [2,1],\n",
    "            [2,3],\n",
    "            [3,1],\n",
    "            [3,2],\n",
    "            [3,3]\n",
    "        ]\n",
    "    base_source_region = random.choice(base_source_regions)\n",
    "    base_region = base_source_region[0]\n",
    "    source_region = base_source_region[1]\n",
    "\n",
    "    base_amount_sample = sample_with_region(\n",
    "        base_region, base_lower_bound_sample, base_upper_bound_sample)\n",
    "    source_amount_sample = sample_with_region(\n",
    "        source_region, source_lower_bound_sample, source_upper_bound_sample)\n",
    "        \n",
    "    return base_lower_bound_sample, base_upper_bound_sample, \\\n",
    "        source_lower_bound_sample, source_upper_bound_sample, \\\n",
    "        base_amount_sample, source_amount_sample, ctf_label, ctf_label_str\n",
    "    \n",
    "def upper_bound_alignment_example_reversed_sampler(\n",
    "    tokenizer,\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None\n",
    "):\n",
    "    base_lower_bound_sample, base_upper_bound_sample, base_amount_sample = \\\n",
    "        pricing_tag_game_config_sampler(\n",
    "            amount,\n",
    "            lower_bound,\n",
    "            bound_width\n",
    "        )\n",
    "    source_lower_bound_sample, source_upper_bound_sample, source_amount_sample = \\\n",
    "        pricing_tag_game_config_sampler(\n",
    "            amount,\n",
    "            lower_bound,\n",
    "            bound_width\n",
    "        )\n",
    "    \n",
    "    ctf_label_str = random.choice([\"Yes\", \"No\"])\n",
    "    if ctf_label_str == \"Yes\":\n",
    "        ctf_label = tokenizer.convert_tokens_to_ids(\"Yes\")\n",
    "        base_source_regions = [\n",
    "            [3,3],\n",
    "            [3,2],\n",
    "            [2,2],\n",
    "        ]\n",
    "    elif ctf_label_str == \"No\":\n",
    "        ctf_label = tokenizer.convert_tokens_to_ids(\"No\")\n",
    "        base_source_regions = [\n",
    "            [1,1],\n",
    "            [1,2],\n",
    "            [1,3],\n",
    "            [2,1],\n",
    "            [2,3],\n",
    "            [3,1]\n",
    "        ]\n",
    "    base_source_region = random.choice(base_source_regions)\n",
    "    base_region = base_source_region[0]\n",
    "    source_region = base_source_region[1]\n",
    "    \n",
    "    base_amount_sample = sample_with_region(\n",
    "        base_region, base_lower_bound_sample, base_upper_bound_sample)\n",
    "    source_amount_sample = sample_with_region(\n",
    "        source_region, source_lower_bound_sample, source_upper_bound_sample)\n",
    "    \n",
    "    return base_lower_bound_sample, base_upper_bound_sample, \\\n",
    "        source_lower_bound_sample, source_upper_bound_sample, \\\n",
    "        base_amount_sample, source_amount_sample, ctf_label, ctf_label_str"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19dff34a",
   "metadata": {},
   "outputs": [],
   "source": [
    "type_transfer_results = {}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f2b963b",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    1000,\n",
    "    [\n",
    "        lower_bound_alignment_example_reversed_sampler,\n",
    "        upper_bound_alignment_example_reversed_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0], \n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset, batch_size=8,\n",
    ")\n",
    "\n",
    "# grid search\n",
    "# [69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81]\n",
    "# [0, 5, 10, 15, 20, 25, 30]\n",
    "for seed in [42]:\n",
    "    for token_idx in [81]:\n",
    "        for layer_idx in [15, 20, 25, 30]:\n",
    "            end_idx = token_idx + 1\n",
    "            print(f\"eval with token_idx = {token_idx} and layer_idx = {layer_idx}\")\n",
    "            checkpoint_state_dict = torch.load(f\"../results_alpaca-7b/alpaca-7B.task.pricing_tag_lub.seed.{seed}.intl.{layer_idx}.intr.{token_idx}.{end_idx}/pytorch-rotate-last.bin\")\n",
    "            model.model.rotate_layer.load_state_dict(\n",
    "                checkpoint_state_dict['rotate_layer']\n",
    "            )\n",
    "            model.model.intervention_boundaries.data = checkpoint_state_dict['intervention_boundaries'].data\n",
    "            intervention_boundaries = model.model.intervention_boundaries.data\n",
    "            intervention_boundaries = torch.clamp(intervention_boundaries, 1e-3, 1)\n",
    "            start_idx_1 = 0\n",
    "            end_idx_1 = int((model.model.searchable_n_embd//2) * intervention_boundaries[0])\n",
    "            start_idx_2 = end_idx_1\n",
    "            end_idx_2 = end_idx_1 * 2\n",
    "            model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            model.model.alignment_config = {\n",
    "                'layer': layer_idx,\n",
    "                \"token_range\" : [token_idx, end_idx]\n",
    "            }\n",
    "            \n",
    "            total_count = 0\n",
    "            correct_count = 0\n",
    "            with torch.no_grad():\n",
    "                for step, inputs in enumerate(tqdm(test_dataloader)):\n",
    "                    for k, v in inputs.items():\n",
    "                        if v is not None and isinstance(v, torch.Tensor):\n",
    "                            inputs[k] = v.to(model.device)\n",
    "\n",
    "                    # aligning forward!\n",
    "                    source_hidden_states = model(\n",
    "                        input_ids=inputs['source_input_ids'],\n",
    "                        output_rotated_hidden_states_only=True\n",
    "                    ).rotated_hidden_states\n",
    "                    \n",
    "                    # do a hard snapped swap of aligned reprs.\n",
    "                    chunk_1 = source_hidden_states[:,start_idx_1:end_idx_1].clone()\n",
    "                    chunk_2 = source_hidden_states[:,start_idx_2:end_idx_2].clone()\n",
    "                    source_hidden_states[:,start_idx_2:end_idx_2] = chunk_1\n",
    "                    source_hidden_states[:,start_idx_1:end_idx_1] = chunk_2\n",
    "                    \n",
    "                    outputs = model(\n",
    "                        input_ids=inputs['input_ids'],\n",
    "                        source_hidden_states=source_hidden_states,\n",
    "                        intervention_ids=inputs['intervention_ids'],\n",
    "                        labels=inputs['labels']\n",
    "                    )\n",
    "\n",
    "                    actual_test_labels = inputs['labels'][:, -1]\n",
    "                    pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "                    correct_labels = (actual_test_labels==pred_test_labels)\n",
    "\n",
    "                    total_count += len(correct_labels)\n",
    "                    correct_count += correct_labels.sum().tolist()\n",
    "\n",
    "            current_acc = round(correct_count/total_count, 2)\n",
    "            type_transfer_results[(layer_idx, token_idx)] = current_acc\n",
    "            print(f\"IIA accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8438869",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../logs/eval_type_transfer_results.pkl', 'wb') as file:\n",
    "    pickle.dump(\n",
    "        type_transfer_results, file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c9dbe50",
   "metadata": {},
   "source": [
    "#### Dummy CLS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c3b71a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "from collections import Counter\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e198129",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "        upper_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=None,\n",
    "    bound_width=None,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca876158",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = midpoint_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e8d6676",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bracket_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0], \n",
    "    raw_data[1], \n",
    "    raw_data[2],\n",
    "    raw_data[3]\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9181455",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=2.51,\n",
    "    bound_width=3.00,\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd83825",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer,\n",
    "    10000,\n",
    "    [\n",
    "        lower_bound_alignment_example_sampler,\n",
    "    ],\n",
    "    amount=None,\n",
    "    lower_bound=5.49,\n",
    "    bound_width=3.00,\n",
    ")\n",
    "labels = [row[-1] for row in raw_data[2]]\n",
    "counter = Counter(labels)\n",
    "print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "846ad200",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
