{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict, OrderedDict\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import re\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from transformers import LlamaTokenizer, LlamaForCausalLM\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "from demonstrations import demonstrations\n",
    "\n",
    "os.environ[\"REQUESTS_CA_BUNDLE\"] = \"...\"\n",
    "os.environ[\"SSL_CERT_FILE\"] = \"...\"\n",
    "\n",
    "import random\n",
    "import json\n",
    "\n",
    "\n",
    "device = \"cuda:0\"\n",
    "STORAGE_DIR = \"cache/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#MY_TOKEN = "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "55f92fe8fdf44068ae4c014c5bead6fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LlamaForCausalLM.from_pretrained(\"meta-llama/Llama-2-7b-hf\",\n",
    "                                         cache_dir=\"....\",\n",
    "                                         torch_dtype=torch.float16#,\n",
    "                                       #  token=MY_TOKEN\n",
    "                                        )\n",
    "tokenizer = LlamaTokenizer.from_pretrained(\"meta-llama/Llama-2-7b-hf\",\n",
    "                                            cache_dir=\"....\"\n",
    "                                     #    token=MY_TOKEN\n",
    "                                        )\n",
    "device = \"cuda:0\"\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LlamaConfig {\n",
       "  \"_name_or_path\": \"meta-llama/Llama-2-7b-hf\",\n",
       "  \"architectures\": [\n",
       "    \"LlamaForCausalLM\"\n",
       "  ],\n",
       "  \"attention_bias\": false,\n",
       "  \"attention_dropout\": 0.0,\n",
       "  \"bos_token_id\": 1,\n",
       "  \"eos_token_id\": 2,\n",
       "  \"hidden_act\": \"silu\",\n",
       "  \"hidden_size\": 4096,\n",
       "  \"initializer_range\": 0.02,\n",
       "  \"intermediate_size\": 11008,\n",
       "  \"max_position_embeddings\": 4096,\n",
       "  \"model_type\": \"llama\",\n",
       "  \"num_attention_heads\": 32,\n",
       "  \"num_hidden_layers\": 32,\n",
       "  \"num_key_value_heads\": 32,\n",
       "  \"pretraining_tp\": 1,\n",
       "  \"rms_norm_eps\": 1e-05,\n",
       "  \"rope_scaling\": null,\n",
       "  \"rope_theta\": 10000.0,\n",
       "  \"tie_word_embeddings\": false,\n",
       "  \"torch_dtype\": \"float16\",\n",
       "  \"transformers_version\": \"4.39.1\",\n",
       "  \"use_cache\": true,\n",
       "  \"vocab_size\": 32000\n",
       "}"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "LAYER_NUM = model.config.num_hidden_layers\n",
    "HEAD_NUM = model.config.num_attention_heads\n",
    "\n",
    "HEAD_SPAN = model.config.hidden_size // model.config.num_attention_heads\n",
    "\n",
    "# !! CHANGE the following line if you use this code with models implementing Grouped query attention\n",
    "QUERIES_PER_KEY = model.config.num_attention_heads // model.config.num_key_value_heads\n",
    "\n",
    "# LLaMA2 tokens for letters A-F. If you use different tokenizer, change next line approprietely\n",
    "AF_TOKENS = np.array([319, 350, 315, 360, 382, 383]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "random.seed(42)\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def permute_options(options, label, seed):\n",
    "    \"\"\"\n",
    "    A pseudo-random shuffle of the answer <<options>> that also returns new correct answer\n",
    "        parameters:\n",
    "                options --- list of options\n",
    "                label   --- correct answer\n",
    "                seed    --- random seed\n",
    "                \n",
    "    Options 'E' and 'F' are special and must remain at their places\n",
    "    \"\"\"\n",
    "    aaa = np.arange(len(options) - 2)\n",
    "    np.random.seed(seed)\n",
    "    np.random.shuffle(aaa)\n",
    "    \n",
    "    new_label = chr(ord('A') + aaa[ord(label) - ord('A')])\n",
    "    new_options = {}\n",
    "    for option in ['A', 'B', 'C', 'D']: \n",
    "        new_options[chr(aaa[ord(option) - ord('A')] + ord('A'))] = options[option]\n",
    "   \n",
    "    new_options['E'] = options['E']  \n",
    "    new_options['F'] = options['F']\n",
    "\n",
    "    return dict(sorted(new_options.items())), new_label\n",
    "\n",
    "def angular_dist(vec_a, vec_b):\n",
    "    # Without normalization it works much faster and yields better results for some unknown reason. We used this variant\n",
    "    return torch.sum(vec_a * vec_b) #/ (vec_a.norm(p=2) * vec_b.norm(p=2))\n",
    "\n",
    "def angular_dist_matrix(mat_a, mat_b):\n",
    "    # Was never properely used\n",
    "    return (mat_a @ mat_b.T) #/ (mat_a.norm(p=2, dim=-1).reshape((mat_a.shape[0], 1)) @ mat_b.norm(p=2, dim=-1).reshape((1, mat_b.shape[0])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Wrapper(torch.nn.Module):\n",
    "    '''\n",
    "    This is a wrapper over Huggingface LLaMA Transformer that allows access to Query and Key vectors\n",
    "    '''\n",
    "    def __init__(self, model, *args):\n",
    "        super().__init__(*args)\n",
    "        self.selected_out = OrderedDict()\n",
    "\n",
    "        self.pretrained = model\n",
    "        self.fhooks = []\n",
    "\n",
    "        for i in range(LAYER_NUM):\n",
    "            self.fhooks.append(self.pretrained.model.layers[i].self_attn.q_proj\n",
    "                .register_forward_hook(self.forward_hook(\"query_vec_\" + str(i))))\n",
    "            self.fhooks.append(self.pretrained.model.layers[i].self_attn.k_proj\n",
    "                .register_forward_hook(self.forward_hook(\"key_vec_\" + str(i))))\n",
    "        \n",
    "        #    Removed to lower memory consumption and computational time\n",
    "        #    self.fhooks.append(self.pretrained.model.layers[i].self_attn.v_proj\n",
    "        #        .register_forward_hook(self.forward_hook(\"value_vec_\" + str(i))))\n",
    "    \n",
    "    def forward_hook(self, layer_name):\n",
    "        def hook(module, input, output):\n",
    "            self.selected_out[layer_name] = output.cpu()\n",
    "        return hook\n",
    "\n",
    "    def forward(self, x):        \n",
    "        out = self.pretrained(**x)\n",
    "        return out, self.selected_out\n",
    "    \n",
    "newmodel = Wrapper(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### A simplified example of usage"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Main function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def do_calc(data, prompt='', samples_range=range(10000), permute=False, return_logits=False, single_head=None):\n",
    "    \"\"\"\n",
    "    Parameters:\n",
    "        data      ------- self-explanatory\n",
    "        prompt      ----- Here go examples in case of the Few-Shot prompting. For Zero-shot leave it empty.\n",
    "        samples_range --- Container with numbers of samples to be considered \n",
    "        permute     ----- Specifies if a permutation of answer options is required\n",
    "        return_logits  -- Set True if you need not only the predicted labels, but also row similarities and baseline logits\n",
    "                                    for every possible option A-F\n",
    "        single_head  ---- None or Tuple(#Layer, #Head) if computations are to be performed on a single head only (at evaluation)\n",
    "    \n",
    "    \"\"\"\n",
    "    true_labels = []\n",
    "    predicted_labels, predicted_raw = [], []\n",
    "    next_token_labels, logits_options = [], []\n",
    "    \n",
    "    if single_head is None:\n",
    "        layers_range = range(LAYER_NUM)\n",
    "        heads_range = range(HEAD_NUM)\n",
    "    else:\n",
    "        if len(single_head) != 2 :\n",
    "            raise Exception('\"Single_head\" if not None must contain single head number --- pair of two integer numbers (#Layer, #Head)') \n",
    "        else:\n",
    "            layers_range = [single_head[0]]\n",
    "            heads_range = [single_head[1]]\n",
    "\n",
    "    for LAYER in range(LAYER_NUM):\n",
    "        predicted_labels.append([])\n",
    "        predicted_raw.append([])\n",
    "\n",
    "        for HEAD in range(HEAD_NUM):\n",
    "            predicted_labels[LAYER].append([])\n",
    "            predicted_raw[LAYER].append([])\n",
    "\n",
    "    for EXMPL in tqdm(samples_range):\n",
    "        \"\"\"\n",
    "        Assembling the prompt from different parts: Examples (if any) + Context + Question + Options + Finisher\n",
    "        \"\"\"    \n",
    "        if 'context' in data[EXMPL].keys():    # Some quesions are given without context\n",
    "            encodinds_context_q = tokenizer(prompt + \"Context: \" + data[EXMPL]['context'] + \"\\nQuestion: \" + \\\n",
    "                                            data[EXMPL]['question'] + \"\\nOptions:\\n\", return_tensors=\"pt\")\n",
    "        else:\n",
    "            encodinds_context_q = tokenizer(prompt + \"Question: \" + data[EXMPL]['question'] + \"\\nOptions:\\n\", \n",
    "                                            return_tensors=\"pt\")\n",
    "            \n",
    "        num_q = encodinds_context_q[\"input_ids\"].shape[-1] - 1\n",
    "\n",
    "        encodings_answ, options_answ = [], []\n",
    "        \n",
    "        \"\"\" \n",
    "        For some experiments we need to permute answer options\n",
    "        \"\"\"\n",
    "        options_raw, answer_raw = data[EXMPL]['choices'], data[EXMPL]['answer']\n",
    "        if permute:\n",
    "            options_raw, answer_raw = permute_options(options_raw, answer_raw, EXMPL)\n",
    "        \n",
    "        for option in ['A', 'B', 'C', 'D', 'E', 'F']:\n",
    "            options_raw[option] = str(options_raw[option])            \n",
    "            encodings_answ.append(tokenizer(option + \". \" + options_raw[option] + \"\\n\", return_tensors=\"pt\"))\n",
    "            if len(options_answ) == 0:\n",
    "                options_answ.append(int(num_q + encodings_answ[-1][\"input_ids\"].shape[-1] - 1))\n",
    "            else:\n",
    "                options_answ.append(int(options_answ[-1] + encodings_answ[-1][\"input_ids\"].shape[-1] - 1))\n",
    "\n",
    "        encodings_answ.append(tokenizer(\"Answer:\", return_tensors=\"pt\"))\n",
    "        inputs = {\n",
    "            \"input_ids\" : torch.cat([encodinds_context_q[\"input_ids\"]] + [x[\"input_ids\"][..., 1:] for x in encodings_answ], 1).to(device),\n",
    "            \"attention_mask\" : torch.cat([encodinds_context_q[\"attention_mask\"]] + [x[\"attention_mask\"][..., 1:] for x in encodings_answ], 1).to(device)\n",
    "        }\n",
    "        \n",
    "        \"\"\"\n",
    "        #  If you change the promt format use this to debug:\n",
    "        \n",
    "        print(inputs)\n",
    "        for i in range(len(options_answ)):\n",
    "            print(inputs[\"input_ids\"][..., options_answ[i]]) # <<<<<< This all must be aligned\n",
    "        print(\"\\n\\n\", inputs[\"input_ids\"][..., -1])          \n",
    "        \"\"\"\n",
    "        with torch.no_grad():\n",
    "            outputs = newmodel(inputs)\n",
    "\n",
    "        true_labels.append(answer_raw)\n",
    "            \n",
    "        for LAYER in layers_range:\n",
    "            for HEAD in heads_range:\n",
    "                predicts = np.zeros(len(options_answ))\n",
    "                for i in range(len(options_answ)):\n",
    "                    predicts[i] = angular_dist(outputs[1][\"query_vec_\" + str(LAYER)][0][-1][HEAD * HEAD_SPAN:(HEAD + 1) * HEAD_SPAN], \n",
    "                                                                                    outputs[1][\"key_vec_\" + str(LAYER)][0][options_answ[i]][(HEAD // QUERIES_PER_KEY)  * HEAD_SPAN:((HEAD // QUERIES_PER_KEY) + 1) * HEAD_SPAN])\n",
    "\n",
    "                predicted_labels[LAYER][HEAD].append(chr(np.argmax(predicts) + ord('A')))\n",
    "                if return_logits:\n",
    "                    predicted_raw[LAYER][HEAD].append(predicts)\n",
    "\n",
    "        \"\"\"\n",
    "        This part is for the baseline prediction --- that is an A-F letter which is the most probable next token after prompt.\n",
    "        \"\"\"\n",
    "        logits = outputs[0].logits.detach().cpu()\n",
    "        logits = logits[:, -1, :]\n",
    "        logits_full = logits.squeeze(0)\n",
    "        ## LLAMA2 tokens for letters 'A'-'F'; Change the next line for your model approprietely\n",
    "        logits_reduced = logits_full[AF_TOKENS].numpy() \n",
    "        \n",
    "        if return_logits:\n",
    "            logits_options.append(logits_reduced)\n",
    "        \n",
    "        next_token_labels.append(chr(ord('A') + np.argmax(logits_reduced)))\n",
    "        \"\"\"\n",
    "        end of baseline part\n",
    "        \"\"\"\n",
    "\n",
    "    if single_head is None:\n",
    "        for LAYER in range(HEAD_NUM):\n",
    "            predicted_labels[LAYER] = np.array(predicted_labels[LAYER])\n",
    "            if return_logits:\n",
    "                predicted_raw[LAYER] = np.array(predicted_raw[LAYER])\n",
    "\n",
    "        predicted_labels = np.stack(predicted_labels)\n",
    "    \n",
    "        if return_logits:\n",
    "            predicted_raw = np.stack(predicted_raw)\n",
    "            logits_options = np.stack(logits_options)\n",
    "\n",
    "            return true_labels, (predicted_labels, predicted_raw), (next_token_labels, logits_options)\n",
    "        else:\n",
    "            return true_labels, predicted_labels, next_token_labels\n",
    "    else:\n",
    "        return np.array(true_labels), np.array(predicted_labels[single_head[0]][single_head[1]]), np.array(next_token_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##### Auxillary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_OUR_metric(target, predictions, target_permuted, predictions_permuted):\n",
    "    our_metric_value = np.mean((target == predictions) * (target_permuted == predictions_permuted))\n",
    "    return our_metric_value\n",
    "\n",
    "def print_results_train(true_labels, qk_predicts, baseline):\n",
    "    mn = np.zeros((32,32))\n",
    "    for i in range(32):\n",
    "        for j in range(32):\n",
    "            mn[i][j] = np.mean(qk_predicts[i][j] == true_labels)\n",
    "\n",
    "    print(\"QK results\", np.max(mn), \"Best head: \", np.argmax(mn) // 32, np.argmax(mn) % 32)\n",
    "    print(\"Baseline score\", np.mean(np.array(true_labels) == baseline))\n",
    "    \n",
    "    return (np.argmax(mn) // 32, np.argmax(mn) % 32)\n",
    "\n",
    "def print_results_test(true_labels, qk_predicts, baseline):\n",
    "    print(\"QK results: \", np.mean(qk_predicts == true_labels))\n",
    "    print(\"Baseline score: \", np.mean(np.array(true_labels) == baseline))\n",
    "    \n",
    "def log_results_train(dataset, shots, baseline_label, baseline_raw, qk_label, qk_raw, true_label):\n",
    "    dict1 = {\n",
    "        'true_labels' : true_label,\n",
    "        'baseline_predictions' : baseline_label,\n",
    "        'baseline_logits' : baseline_raw,\n",
    "        'qk_predictions' : qk_label,\n",
    "        'qk_raw_scores' : qk_raw\n",
    "    }\n",
    "    with open(STORAGE_DIR + '{}_10k_llama2_{}-shot_train.pckl'.format(dataset, shots), 'wb') as handle:\n",
    "        pickle.dump(dict1, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
    "    \n",
    "def log_results_test(dataset, shots, baseline_label, qk_label, true_label):\n",
    "    dict1 = {\n",
    "        'true_labels' : true_label,\n",
    "        'baseline_predictions' : baseline_label,\n",
    "        'qk_predictions' : qk_label,\n",
    "    }\n",
    "    with open(STORAGE_DIR + '{}_10k_llama2_{}-shot_test.pckl'.format(dataset, shots), 'wb') as handle:\n",
    "        pickle.dump(dict1, handle, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiments (example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10000\n"
     ]
    }
   ],
   "source": [
    "json_file = 'data/cosmosqa_10k_new.json'\n",
    "\n",
    "with open(json_file) as json_data:\n",
    "    data = json.load(json_data)\n",
    "\n",
    "print(len(data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### n-shot prompting"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['Cosmos', 'MMLU', 'Hellaswag', 'Halu Dialogue'])\n",
      "dict_keys(['0-shot', '1-shot', '2-shot', '3-shot', '4-shot', '5-shot'])\n"
     ]
    }
   ],
   "source": [
    "print(demonstrations.keys())\n",
    "print(demonstrations[\"Cosmos\"].keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0-shot example\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"0-shot example\")\n",
    "print(demonstrations[\"Cosmos\"][\"0-shot\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2-shot example\n",
      "Context: It 's got character . You know what they say : ' If it ai n't too broke , do n't fix it ' . \" Howard stood up in his pyjamas that smelled of Vince , and looked out of the window . He repeated Vince 's aphorism over and over until his heartbeat settled a little .\n",
      "Question: What was Howard doing when he walked to the window ?\n",
      "Options:\n",
      "A. He was sneezing .\n",
      "B. He was chewing gum .\n",
      "C. He was scratching his lap .\n",
      "D. He was thinking .\n",
      "E. I don't know .\n",
      "F. None of the above .\n",
      "Answer: D\n",
      "Context: Maybe it 's because right now other than feeling like shit , I do n't look pregnant . If some random person looked at me , they would see a slightly chubby me . Not to mention I do n't know if I like tons of attention .\n",
      "Question: Why might someone think the speaker is chubby ?\n",
      "Options:\n",
      "A. They do n't like the way they look .\n",
      "B. They are pregnant .\n",
      "C. They are embarrassed about their appearance .\n",
      "D. They are overweight .\n",
      "E. I don't know .\n",
      "F. None of the above .\n",
      "Answer: B\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"2-shot example\")\n",
    "print(demonstrations[\"Cosmos\"][\"2-shot\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1-shot example\n",
      "Context: It 's got character . You know what they say : ' If it ai n't too broke , do n't fix it ' . \" Howard stood up in his pyjamas that smelled of Vince , and looked out of the window . He repeated Vince 's aphorism over and over until his heartbeat settled a little .\n",
      "Question: What was Howard doing when he walked to the window ?\n",
      "Options:\n",
      "A. He was sneezing .\n",
      "B. He was chewing gum .\n",
      "C. He was scratching his lap .\n",
      "D. He was thinking .\n",
      "E. I don't know .\n",
      "F. None of the above .\n",
      "Answer: D\n",
      "\n"
     ]
    }
   ],
   "source": [
    "promt_dialogue_1 = demonstrations[\"Cosmos\"][\"1-shot\"]\n",
    "print(\"1-shot example\")\n",
    "print(promt_dialogue_1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### A simplified example of our experiments\n",
    "\n",
    "(numbers below will be different because these computations were performed on smaller subsets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Calibration --- selecting the best heads"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:34<00:00,  2.92it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train performace\n",
      "Cosmos 1-shot, no permutes\n",
      "QK results 0.38 Best head:  27 17\n",
      "Baseline score 0.26\n",
      "Cosmos 1-shot, A-D permutation\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:33<00:00,  2.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "QK results 0.38 Best head:  13 0\n",
      "Baseline score 0.23\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "true_labels, (predicted_labels, predicted_raw), (next_token_labels, logits_options) = do_calc(data, promt_dialogue_1, range(500), return_logits=True)\n",
    "\n",
    "print(\"Train performace\")\n",
    "print(\"Cosmos 1-shot, no permutes\")\n",
    "\n",
    "best_head = print_results_train(true_labels, predicted_labels, next_token_labels)\n",
    "#log_results_train('halu-dialogue', '1', next_token_labels, logits_options, predicted_labels, predicted_raw, true_labels)\n",
    "\n",
    "print(\"Cosmos 1-shot, A-D permutation\")\n",
    "true_labels2, (predicted_labels2, predicted_raw2), (next_token_labels2, logits_options2) = do_calc(data, promt_dialogue_1, range(500), return_logits=True, permute=True)\n",
    "best_head2 = print_results_train(true_labels2, predicted_labels2, next_token_labels2)\n",
    "#log_results_train('halu-dialogue_shuffled', '1', next_token_labels2, logits_options2, predicted_labels2, predicted_raw2, true_labels2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"EVal performace\")\n",
    "print(\"Cosmos 1-shot, no permutes\")\n",
    "\n",
    "true_labels, predicted_labels, next_token_labels = do_calc(data, promt_dialogue_1, range(500, 10000), single_head=best_head)\n",
    "print_results_test(true_labels, predicted_labels, next_token_labels)\n",
    "#log_results_test('halu-dialogue', '1', next_token_labels, predicted_labels, true_labels)\n",
    "\n",
    "true_labels2, predicted_labels2, next_token_labels2 = do_calc(data, promt_dialogue_1, range(500, 10000), permute=True, single_head=best_head2)\n",
    "print(\"Cosmos 1-shot, A-D permutation\")\n",
    "print_results_test(true_labels2, predicted_labels2, next_token_labels2)\n",
    "#log_results_test('halu-dialogue_shuffled', '1', next_token_labels2, predicted_labels2, true_labels2)\n",
    "\n",
    "ours = compute_OUR_metric(true_labels, predicted_labels, true_labels2, predicted_labels2)\n",
    "ours_baseline = compute_OUR_metric(true_labels, next_token_labels, true_labels2, next_token_labels2)\n",
    "\n",
    "print(\"Permutation Accuracy:\")\n",
    "print(\"baseline: {:5.3f}%,   QK:  {:5.3f}%\".format(100 * ours_baseline, ours*100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
