{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e5ae1a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer\n",
    "import torch\n",
    "\n",
    "model_path1 = \"phase_I_model_path\"\n",
    "model_path2 = \"phase_II_model_path\"\n",
    "model_path3 = \"phase_III_model_path\"\n",
    "\n",
    "device = \"cuda:0\"\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(model_path1)\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "model1 = GPT2LMHeadModel.from_pretrained(model_path1).to(device)\n",
    "model2 = GPT2LMHeadModel.from_pretrained(model_path2).to(device)\n",
    "model3 = GPT2LMHeadModel.from_pretrained(model_path3).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ab9f3d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "with open('../data/base_configuration.2000.200.7.2/test.json', 'r') as f:\n",
    "    datas = json.load(f)\n",
    "\n",
    "# queries and answers(e2)\n",
    "id_queries = []  # [(query, e2), ...]\n",
    "ood_queries = []\n",
    "\n",
    "for data in datas:    \n",
    "    if data[\"type\"] == \"ood_atomic\":\n",
    "        e1, r, e2 = data[\"target_text\"].strip('<>').split('><')[:-1]\n",
    "        ood_queries.append((data[\"input_text\"], f\"<{e2}>\"))\n",
    "    if data[\"type\"] == \"id_atomic\":\n",
    "        e1, r, e2 = data[\"target_text\"].strip('<>').split('><')[:-1]\n",
    "        id_queries.append((data[\"input_text\"], f\"<{e2}>\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d57b3d5b",
   "metadata": {},
   "source": [
    "### Immediate probing "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b5c3d760",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "\n",
    "# TODO: change the model and query type\n",
    "model = model3\n",
    "queries = ood_queries\n",
    "\n",
    "word_embedding = model.lm_head.weight.data\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "correct_cnt = 0\n",
    "\n",
    "# layer 5 at r1 positon\n",
    "target_token_index = 1\n",
    "target_layer = 5\n",
    "\n",
    "for query, target in tqdm(queries):\n",
    "    decoder_temp = tokenizer([query], return_tensors=\"pt\", padding=True)\n",
    "    decoder_input_ids, decoder_attention_mask = decoder_temp[\"input_ids\"], decoder_temp[\"attention_mask\"]\n",
    "    decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=decoder_input_ids,\n",
    "            attention_mask=decoder_attention_mask,\n",
    "            output_hidden_states=True\n",
    "        )\n",
    "\n",
    "    all_hidden_states = outputs.hidden_states\n",
    "    word_embedding = model.lm_head.weight\n",
    "\n",
    "    # decode\n",
    "    target_hidden_state = all_hidden_states[target_layer][:, target_token_index, :]  # (batch_size, hidden_dim)\n",
    "    logits = torch.matmul(target_hidden_state, word_embedding.T)  # (batch_size, vocab_size)\n",
    "    next_token = torch.argmax(logits, dim=-1)  # (batch_size,)\n",
    "\n",
    "    # check\n",
    "    if tokenizer.decode(next_token.item()) == target:\n",
    "        correct_cnt += 1\n",
    "\n",
    "print(correct_cnt / len(id_queries))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8a252539",
   "metadata": {},
   "source": [
    "### Full-run probing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f7fbf0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "\n",
    "# TODO: change the model and query type\n",
    "model = model3\n",
    "queries = id_queries\n",
    "\n",
    "word_embedding = model.lm_head.weight.data\n",
    "model.config.pad_token_id = model.config.eos_token_id\n",
    "\n",
    "target_query = queries[-1]  # ramdom or fixed\n",
    "\n",
    "correct_cnt = 0\n",
    "\n",
    "# layer 5 at r1 positon\n",
    "target_token_index = 1\n",
    "target_layer = 5\n",
    "\n",
    "for query, target in tqdm(queries[:-1]):\n",
    "\n",
    "    decoder_temp = tokenizer([query], return_tensors=\"pt\", padding=True)\n",
    "    decoder_input_ids, decoder_attention_mask = decoder_temp[\"input_ids\"], decoder_temp[\"attention_mask\"]\n",
    "    decoder_input_ids, decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs1 = model(\n",
    "            input_ids=decoder_input_ids,\n",
    "            attention_mask=decoder_attention_mask,\n",
    "            output_hidden_states=True\n",
    "        )\n",
    "\n",
    "    hidden_states_batch = outputs1.hidden_states  # [1+num_layers, batch_size, seq_len, hidden_size]\n",
    "\n",
    "    def hook_fn(module, input, output):\n",
    "        main_output = output[0].clone()\n",
    "        \n",
    "        main_output[0, target_token_index, :] = hidden_states_batch[target_layer][0, target_token_index, :]\n",
    "\n",
    "        return (main_output,) + output[1:]\n",
    "\n",
    "    handle = model.transformer.h[target_layer].register_forward_hook(hook_fn)\n",
    "\n",
    "    decoder_temp = tokenizer([target_query], return_tensors=\"pt\", padding=True)\n",
    "    decoder_input_ids, decoder_attention_mask = decoder_temp[\"input_ids\"], decoder_temp[\"attention_mask\"]\n",
    "    target_decoder_input_ids, target_decoder_attention_mask = decoder_input_ids.to(device), decoder_attention_mask.to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        outputs = model(\n",
    "            input_ids=target_decoder_input_ids,\n",
    "            attention_mask=target_decoder_attention_mask,\n",
    "        )\n",
    "\n",
    "    handle.remove()\n",
    "\n",
    "    # decode\n",
    "    logits = outputs.logits  # [batch_size, seq_len, vocab_size]\n",
    "    predicted_token_ids = torch.argmax(logits, dim=-1)  # [batch_size, seq_len]\n",
    "    decoded_text = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)\n",
    "    decoded_token = decoded_text[0].split()[1]\n",
    "\n",
    "    # check\n",
    "    if decoded_token == target:\n",
    "        correct_cnt += 1\n",
    "\n",
    "print(correct_cnt / len(id_queries[:-1]))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
