{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8adc6d63",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "id_e2_to_queries = {}  # {\"<e_581>\": [\"<e_1525><r_123>\", \"<e_98><r_158>\", ...],...}\n",
    "ood_e2_to_queries = {}  # {\"<e_581>\": [\"<e_1915><r_25>\", ...], ...}\n",
    "\n",
    "with open('../data/base_configuration.2000.200.7.2/train.json', 'r') as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "# Split the atomic triples into ID and OOD by ID/OOD ratio\n",
    "id_data = data[:38000]         # 2000 * 20 * 0.95\n",
    "ood_data = data[38000:40000]   # 2000 * 20 * 0.05\n",
    "\n",
    "for data in id_data:\n",
    "    e1, r, e2 = data[\"target_text\"].strip('<>').split('><')[:-1]\n",
    "    if f\"<{e2}>\" not in id_e2_to_queries:\n",
    "        id_e2_to_queries[f\"<{e2}>\"] = []\n",
    "    id_e2_to_queries[f\"<{e2}>\"].append(f\"<{e1}><{r}>\")\n",
    "\n",
    "for data in ood_data:\n",
    "    e1, r, e2 = data[\"target_text\"].strip('<>').split('><')[:-1]\n",
    "    if f\"<{e2}>\" not in ood_e2_to_queries:\n",
    "            ood_e2_to_queries[f\"<{e2}>\"] = []\n",
    "    ood_e2_to_queries[f\"<{e2}>\"].append(f\"<{e1}><{r}>\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22dec9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2LMHeadModel\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "def compute_similarity_metrics(model, tokenizer, target_layer, id_e2_to_queries, ood_e2_to_queries, device):\n",
    "    id_similarities = []  # ID Cohesion\n",
    "    ood_similarities = [] # OOD Alignment\n",
    "\n",
    "    target_index = 1  # r1 position\n",
    "\n",
    "    model.to(device)\n",
    "    model.eval()\n",
    "\n",
    "    # ID-derived e2 in id_e2_to_queries.keys() probably not in ood_e2_to_queries.keys()\n",
    "    for e2 in ood_e2_to_queries.keys():\n",
    "\n",
    "        id_queries = id_e2_to_queries.get(e2, [])\n",
    "        ood_queries = ood_e2_to_queries.get(e2, [])\n",
    "\n",
    "        if len(id_queries) == 0 or len(ood_queries) == 0:\n",
    "            continue\n",
    "\n",
    "        id_hiddens = []\n",
    "        for query in id_queries:\n",
    "            inputs = tokenizer([query], return_tensors=\"pt\", padding=True).to(device)\n",
    "            \n",
    "            with torch.no_grad():\n",
    "                outputs = model(**inputs, output_hidden_states=True)\n",
    "            hidden = outputs.hidden_states[target_layer][0, target_index, :]\n",
    "\n",
    "            id_hiddens.append(hidden)\n",
    "\n",
    "        id_hiddens = torch.stack(id_hiddens)\n",
    "        id_centroid = id_hiddens.mean(dim=0)\n",
    "\n",
    "        id_cos_sims = F.cosine_similarity(id_hiddens, id_centroid.unsqueeze(0), dim=1)\n",
    "        id_avg_sim = id_cos_sims.mean().item()\n",
    "        id_similarities.append(id_avg_sim)\n",
    "\n",
    "        ood_hiddens = []\n",
    "        for query in ood_queries:\n",
    "            inputs = tokenizer([query], return_tensors=\"pt\", padding=True).to(device)\n",
    "            with torch.no_grad():\n",
    "                outputs = model(**inputs, output_hidden_states=True)\n",
    "            hidden = outputs.hidden_states[target_layer][0, target_index, :]\n",
    "            ood_hiddens.append(hidden)\n",
    "\n",
    "        ood_hiddens = torch.stack(ood_hiddens)\n",
    "        ood_cos_sims = F.cosine_similarity(ood_hiddens, id_centroid.unsqueeze(0), dim=1)\n",
    "        ood_avg_sim = ood_cos_sims.mean().item()\n",
    "        ood_similarities.append(ood_avg_sim)\n",
    "\n",
    "    mean_id = sum(id_similarities) / len(id_similarities)\n",
    "    mean_ood = sum(ood_similarities) / len(ood_similarities)\n",
    "    return mean_id, mean_ood\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b5b03e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "\n",
    "base_dir = \"/your/checkpoints/directory\"  # replace with your directory\n",
    "checkpoint_prefix = \"checkpoint-\"\n",
    "\n",
    "all_ckpts = [\n",
    "    os.path.join(base_dir, d) for d in os.listdir(base_dir)\n",
    "    if d.startswith(checkpoint_prefix) and os.path.isdir(os.path.join(base_dir, d))\n",
    "]\n",
    "\n",
    "min_step_interval = 2000\n",
    "start_step = 0         \n",
    "\n",
    "# extract step numbers from checkpoint paths\n",
    "ckpt_tuples = []\n",
    "for path in all_ckpts:\n",
    "    match = re.search(r\"checkpoint-(\\d+)\", path)\n",
    "    if match:\n",
    "        step = int(match.group(1))\n",
    "        if step >= start_step:\n",
    "            ckpt_tuples.append((step, path))\n",
    "\n",
    "ckpt_tuples.sort()\n",
    "\n",
    "# sample checkpoints based on the minimum step interval\n",
    "selected_ckpts = []\n",
    "last_step = -min_step_interval\n",
    "for step, path in ckpt_tuples:\n",
    "    if step - last_step >= min_step_interval:\n",
    "        selected_ckpts.append((step, path))\n",
    "        last_step = step\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44bc43e8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import GPT2Tokenizer\n",
    "from tqdm import tqdm\n",
    "\n",
    "device = \"cuda:0\"\n",
    "\n",
    "target_layer = 5  # determined by cross-query semantic patching\n",
    "\n",
    "# Load the tokenizer\n",
    "tokenizer = GPT2Tokenizer.from_pretrained(selected_ckpts[0][1])\n",
    "tokenizer.padding_side = \"left\"\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "results = []\n",
    "\n",
    "for label, path in tqdm(selected_ckpts):\n",
    "    print(f\"Processing {label}...\")\n",
    "    model = GPT2LMHeadModel.from_pretrained(path).to(device)\n",
    "    id_sim, ood_sim = compute_similarity_metrics(\n",
    "        model, tokenizer, target_layer,\n",
    "        id_e2_to_queries, ood_e2_to_queries, device\n",
    "    )\n",
    "    results.append((label, (id_sim, ood_sim)))\n",
    "    del model\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "print(results)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "latent",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
