{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4db8159c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import ast\n",
    "import time\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random, inspect\n",
    "from typing import List, Sequence, Union\n",
    "from transformers import AutoTokenizer, AutoConfig, models, AutoModelForSequenceClassification\n",
    "\n",
    "Array = Union[np.ndarray]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c17f375f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def attn_and_indices(model, tokenizer, scenario_text, option_text, max_length=512):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      A: (L, H, T, T) attention (np.float32, on CPU)\n",
    "      S_idx: list[int] of scenario token indices\n",
    "      O_idx: list[int] of option token indices\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    enc = tokenizer(\n",
    "        scenario_text, option_text,\n",
    "        return_tensors=\"pt\", truncation=True, max_length=max_length, add_special_tokens=True\n",
    "    )\n",
    "    seq_ids = enc.sequence_ids(0)  # 0=scenario, 1=option, None=special\n",
    "    S_idx = [i for i, sid in enumerate(seq_ids) if sid == 0]\n",
    "    O_idx = [i for i, sid in enumerate(seq_ids) if sid == 1]\n",
    "\n",
    "    enc = {k: v.to(device) for k, v in enc.items()}\n",
    "\n",
    "    model.eval()\n",
    "    with torch.inference_mode():\n",
    "        if isinstance(type(model), models.bart.modeling_bart.BartModel):\n",
    "            out = model.model.encoder(**enc, output_attentions=True)\n",
    "        else:\n",
    "            out = model(**enc, output_attentions=True)\n",
    "\n",
    "    A = torch.stack([a[0] for a in out.attentions], dim=0).detach().cpu().float().numpy()  # (L,H,T,T)\n",
    "    return A, S_idx, O_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8365d94",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _symmetrize(U: Array) -> Array:\n",
    "    return 0.5 * (U + U.T)\n",
    "\n",
    "def _score_option_metric(U_so: Array, eps: float = 1e-6, normalize: bool = True) -> float:\n",
    "    # Score0 = mean_o  -log( max_s U[s,o] + eps )\n",
    "    if U_so.size == 0:\n",
    "        return 0.0\n",
    "    u_max_per_o = U_so.max(axis=0)              # (o_len,)\n",
    "    score = -np.log(np.clip(u_max_per_o, eps, 1.0)).sum()\n",
    "    if normalize and U_so.shape[1] > 0:\n",
    "        score /= float(U_so.shape[1])\n",
    "    return float(score)\n",
    "\n",
    "def _score_option_rank_invariant(U_so: Array, normalize: bool = True) -> float:\n",
    "    # Rank-only: smaller is better\n",
    "    s_len, o_len = U_so.shape\n",
    "    if o_len == 0:\n",
    "        return 0.0\n",
    "    all_edges = U_so.reshape(-1)       # (s_len*o_len,)\n",
    "    u_star = U_so.max(axis=0)          # (o_len,)\n",
    "    greater_counts = (all_edges[:, None] > u_star[None, :]).sum(axis=0)\n",
    "    ranks = 1 + greater_counts\n",
    "    score = ranks.sum()\n",
    "    if normalize:\n",
    "        score /= float(o_len)\n",
    "    return float(score)\n",
    "\n",
    "def compute_scores_for_case_indices(\n",
    "    attn_case: List[Array],\n",
    "    S_indices_list: List[Sequence[int]],\n",
    "    O_indices_list: List[Sequence[int]],\n",
    "    use_rank_invariant: bool = False,\n",
    "    eps: float = 1e-6,\n",
    "    normalize: bool = True,\n",
    ") -> Array:\n",
    "    \"\"\"\n",
    "    attn_case: [ (L,H,T,T) ] * num_options\n",
    "    S_indices_list / O_indices_list: list per option of token indices for scenario/option\n",
    "    returns: scores (num_options, L, H)  (lower is better)\n",
    "    \"\"\"\n",
    "    num_options = len(attn_case)\n",
    "    assert num_options == len(S_indices_list) == len(O_indices_list)\n",
    "    L, H = attn_case[0].shape[:2]\n",
    "    scores = np.zeros((num_options, L, H), dtype=np.float32)\n",
    "\n",
    "    for i, A_lh in enumerate(attn_case):\n",
    "        S_idx = np.array(S_indices_list[i], dtype=np.int64)\n",
    "        O_idx = np.array(O_indices_list[i], dtype=np.int64)\n",
    "        for l in range(L):\n",
    "            for h in range(H):\n",
    "                U = _symmetrize(attn_case[i][l, h])            # (T,T)\n",
    "               \n",
    "                U_so = U[np.ix_(S_idx, O_idx)]                 # (|S|, |O|)\n",
    "                if use_rank_invariant:\n",
    "                    s = _score_option_rank_invariant(U_so, normalize)\n",
    "                else:\n",
    "                    s = _score_option_metric(U_so, eps, normalize)\n",
    "                scores[i, l, h] = s\n",
    "    return scores\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fb67047",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_and_best(acc_map: np.ndarray):\n",
    "    assert acc_map.ndim == 2, \"acc_map should be (L, H)\"\n",
    "    L, H = acc_map.shape\n",
    "    acc_flat = acc_map.reshape(-1)             # (L*H,)\n",
    "    best_idx = int(acc_flat.argmax())          # 0..L*H-1\n",
    "    best_layer = best_idx // H                 \n",
    "    best_head  = best_idx %  H                 \n",
    "    best_acc   = float(acc_flat[best_idx])\n",
    "    return acc_flat, best_layer, best_head, best_acc\n",
    "\n",
    "def topk_layers_heads(acc_map: np.ndarray, k: int = 5):\n",
    "    L, H = acc_map.shape\n",
    "    acc_flat = acc_map.reshape(-1)\n",
    "    order = np.argsort(-acc_flat)[:k]\n",
    "    return [(idx // H, idx % H, float(acc_flat[idx])) for idx in order]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6aedb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def topk_pool(U_so, k=3):\n",
    "    s_len, o_len = U_so.shape\n",
    "    k = min(k, s_len)\n",
    "    idx = np.argpartition(-U_so, kth=k-1, axis=0)[:k, :]\n",
    "    vals = U_so[idx, np.arange(o_len)]\n",
    "    return vals.mean(axis=0)\n",
    "\n",
    "def softmax_pool(U_so, tau=15.0):\n",
    "    W = np.exp(tau * (U_so - U_so.max(axis=0, keepdims=True)))\n",
    "    W /= (W.sum(axis=0, keepdims=True) + 1e-12)\n",
    "    return (W * U_so).sum(axis=0)  # (o_len,)\n",
    "\n",
    "def katz_flow(U, beta=0.6, K=3):\n",
    "    T = U.shape[0]\n",
    "    out = np.zeros_like(U)\n",
    "    P = U.copy()\n",
    "    for k in range(1, K+1):\n",
    "        out += (beta**(k-1)) * P\n",
    "        P = P @ U\n",
    "        colsum = P.sum(axis=0, keepdims=True) + 1e-12\n",
    "        P = P / colsum\n",
    "    colsum = out.sum(axis=0, keepdims=True) + 1e-12\n",
    "    return out / colsum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8b962b01",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_scores_for_case_indices(\n",
    "    attn_case: List[Array],\n",
    "    S_indices_list: List[Sequence[int]],\n",
    "    O_indices_list: List[Sequence[int]],\n",
    "    use_rank_invariant: bool = False,\n",
    "    eps: float = 1e-6,\n",
    "    normalize: bool = True,\n",
    "    agg_mode: str = \"max\",          # {\"max\",\"topk\",\"softmax\"}\n",
    "    topk_k: int = 3,\n",
    "    softmax_tau: float = 15.0,\n",
    "    use_multihop: bool = False,\n",
    "    multihop_beta: float = 0.6,\n",
    "    multihop_K: int = 3,\n",
    "    blend_multihop: float = 0.7     # lam in [0,1]; U_blend = lam*U + (1-lam)*U_multihop\n",
    ") -> Array:\n",
    "    \"\"\"\n",
    "    attn_case: [ (L,H,T,T) ] * num_options\n",
    "    S_indices_list / O_indices_list: per-option token indices\n",
    "    returns: scores (num_options, L, H)  (lower is better)\n",
    "    \"\"\"\n",
    "    num_options = len(attn_case)\n",
    "    assert num_options == len(S_indices_list) == len(O_indices_list)\n",
    "    L, H = attn_case[0].shape[:2]\n",
    "    scores = np.zeros((num_options, L, H), dtype=np.float32)\n",
    "\n",
    "    for i, A_lh in enumerate(attn_case):\n",
    "        S_idx = np.array(S_indices_list[i], dtype=np.int64)\n",
    "        O_idx = np.array(O_indices_list[i], dtype=np.int64)\n",
    "\n",
    "        for l in range(L):\n",
    "            for h in range(H):\n",
    "                U = _symmetrize(A_lh[l, h])                 # (T,T)\n",
    "                if use_multihop:\n",
    "                    Uk = katz_flow(U, beta=multihop_beta, K=multihop_K)\n",
    "                    lam = float(blend_multihop)\n",
    "                    U = lam * U + (1.0 - lam) * Uk\n",
    "                U_so = U[np.ix_(S_idx, O_idx)]              # (|S|, |O|)\n",
    "\n",
    "                if use_rank_invariant:\n",
    "                    s = _score_option_rank_invariant(U_so, normalize)\n",
    "                else:\n",
    "                    if agg_mode == \"max\":\n",
    "                        s = _score_option_metric(U_so, eps, normalize)\n",
    "                    elif agg_mode == \"topk\":\n",
    "                        u = topk_pool(U_so, k=topk_k)\n",
    "                        s = -np.log(np.clip(u, eps, 1.0)).sum()\n",
    "                        if normalize and U_so.shape[1] > 0:\n",
    "                            s /= float(U_so.shape[1])\n",
    "                    elif agg_mode == \"softmax\":\n",
    "                        u = softmax_pool(U_so, tau=softmax_tau)\n",
    "                        s = -np.log(np.clip(u, eps, 1.0)).sum()\n",
    "                        if normalize and U_so.shape[1] > 0:\n",
    "                            s /= float(U_so.shape[1])\n",
    "                    else:\n",
    "                        raise ValueError(\"agg_mode must be one of {'max','topk','softmax'}\")\n",
    "\n",
    "                scores[i, l, h] = s\n",
    "\n",
    "    return scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65be511b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_all_indices_multi_hop(\n",
    "    cases_attn: List[List[Array]],\n",
    "    S_indices_cases: List[List[Sequence[int]]],\n",
    "    O_indices_cases: List[List[Sequence[int]]],\n",
    "    gold_labels: Sequence[int],\n",
    "    use_rank_invariant: bool = False,\n",
    "    agg_mode: str = 'topk',\n",
    "    topk_k: int = 3, \n",
    "    multihop_K: int = 3,\n",
    "    blend_multihop: float = 0.7,\n",
    "    multihop_beta: float = 0.6,\n",
    "    eps: float = 1e-6,\n",
    "    normalize: bool = True,\n",
    "    use_multihop: bool = True,\n",
    "):\n",
    "    \"\"\"\n",
    "    Parameters mirror the per-case lists:\n",
    "      cases_attn[c][i]          -> (L,H,T,T) for case c, option i\n",
    "      S_indices_cases[c][i]     -> list[int] scenario token indices\n",
    "      O_indices_cases[c][i]     -> list[int] option token indices\n",
    "    Returns:\n",
    "      all_scores: (C, O, L, H)\n",
    "      full_matrix: (C, L, H)  # 1 if argmin(option) == gold, else 0\n",
    "      acc_map: (L, H)         # mean over cases\n",
    "    \"\"\"\n",
    "    C = len(cases_attn)\n",
    "    O = len(cases_attn[0])\n",
    "    L, H = cases_attn[0][0].shape[:2]\n",
    "\n",
    "    all_scores = np.zeros((C, O, L, H), dtype=np.float32)\n",
    "    full_matrix = np.zeros((C, L, H), dtype=np.float32)\n",
    "\n",
    "    for c in range(C):\n",
    "        scores_case = compute_scores_for_case_indices(\n",
    "            cases_attn[c],\n",
    "            S_indices_cases[c],\n",
    "            O_indices_cases[c],\n",
    "            use_rank_invariant=use_rank_invariant,\n",
    "            agg_mode=agg_mode, topk_k=topk_k,\n",
    "            use_multihop=use_multihop, multihop_beta=multihop_beta, multihop_K=multihop_K, blend_multihop=blend_multihop,\n",
    "        )  # (O,L,H)\n",
    "        all_scores[c] = scores_case\n",
    "        winner = scores_case.argmin(axis=0)              # (L,H)\n",
    "        full_matrix[c] = (winner == int(gold_labels[c])).astype(np.float32)\n",
    "\n",
    "    acc_map = full_matrix.mean(axis=0)                   # (L,H)\n",
    "    return all_scores, full_matrix, acc_map\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "814710e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _filter_kwargs(func, kwargs):\n",
    "    sig = inspect.signature(func)\n",
    "    return {k: v for k, v in kwargs.items() if k in sig.parameters}\n",
    "\n",
    "def _evaluate_cfg(\n",
    "    cfg,\n",
    "    cases_attn, S_indices_cases, O_indices_cases, gold_labels,\n",
    "    budget=\"coarse\",   # {\"coarse\",\"mid\",\"final\"}\n",
    "):\n",
    "    budget_kwargs = {\n",
    "        \"coarse\": {\"sample_ratio\": 0.3, \"head_stride\": 2},\n",
    "        \"mid\":    {\"sample_ratio\": 0.6, \"head_stride\": 1},\n",
    "        \"final\":  {\"sample_ratio\": 1.0, \"head_stride\": 1},\n",
    "    }[budget]\n",
    "\n",
    "    call_kwargs = {\n",
    "        \"agg_mode\": \"topk\",\n",
    "\n",
    "        \"topk_k\": cfg.get(\"topk_k\", 3),\n",
    "        \"use_multihop\": True,\n",
    "        \"multihop_beta\": cfg.get(\"multihop_beta\", 0.6),\n",
    "        \"multihop_K\": cfg.get(\"multihop_K\", 3),\n",
    "        \"blend_multihop\": cfg.get(\"blend_multihop\", 0.7),\n",
    "    }\n",
    "    call_kwargs.update(budget_kwargs)\n",
    "    call_kwargs = _filter_kwargs(compute_all_indices_multi_hop, call_kwargs)\n",
    "\n",
    "    all_scores, full_matrix, acc_map = compute_all_indices_multi_hop(\n",
    "        cases_attn, S_indices_cases, O_indices_cases, gold_labels, **call_kwargs\n",
    "    )\n",
    "    acc_top1 = float(acc_map.max())\n",
    "\n",
    "    flat = acc_map.reshape(-1)\n",
    "    k = min(5, flat.size)\n",
    "    acc_top5_mean = float(np.sort(flat)[-k:].mean()) if k > 0 else acc_top1\n",
    "\n",
    "    simplicity = -0.01 * max(cfg.get(\"K\", 3) - 3, 0) - 0.005 * max(cfg.get(\"beta\", 0.6) - 0.7, 0)\n",
    "\n",
    "    score = 0.7 * acc_top1 + 0.3 * acc_top5_mean + simplicity\n",
    "    return {\"score\": score, \"acc_top1\": acc_top1, \"acc_top5_mean\": acc_top5_mean, \"acc_map\": acc_map}\n",
    "\n",
    "K_space    = [i+1 for i in range(7)]\n",
    "beta_space = [0.3, 0.4 ,0.5, 0.6, 0.7, 0.8, 0.9]\n",
    "lam_space  = [0.3, 0.4 ,0.5, 0.6, 0.7, 0.8, 0.9]\n",
    "topk_space = [i+1 for i in range(7)]\n",
    "\n",
    "def _rand_cfg():\n",
    "    return {\n",
    "        \"multihop_K\": random.choice(K_space),\n",
    "        \"multihop_beta\": random.choice(beta_space),\n",
    "        \"blend_multihop\": random.choice(lam_space),\n",
    "        \"topk_k\": random.choice(topk_space),\n",
    "    }\n",
    "\n",
    "def tune_hparams(\n",
    "    cases_attn, S_indices_cases, O_indices_cases, gold_labels,\n",
    "    n_coarse=60, n_mid_keep=12, n_final_keep=5, seed=42\n",
    "):\n",
    "    random.seed(seed)\n",
    "    coarse = []\n",
    "    for _ in range(n_coarse):\n",
    "        cfg = _rand_cfg()\n",
    "        if cfg[\"multihop_K\"] >= 4 and cfg[\"multihop_beta\"] >= 0.8:\n",
    "            continue\n",
    "        res = _evaluate_cfg(cfg, cases_attn, S_indices_cases, O_indices_cases, gold_labels, budget=\"coarse\")\n",
    "        coarse.append((res[\"score\"], cfg, res))\n",
    "    coarse.sort(key=lambda x: x[0], reverse=True)\n",
    "    survivors = coarse[:max(3, min(n_mid_keep, len(coarse)))]\n",
    "\n",
    "    mid = []\n",
    "    for _, cfg, _ in survivors:\n",
    "        res = _evaluate_cfg(cfg, cases_attn, S_indices_cases, O_indices_cases, gold_labels, budget=\"mid\")\n",
    "        mid.append((res[\"score\"], cfg, res))\n",
    "    mid.sort(key=lambda x: x[0], reverse=True)\n",
    "    finalists = mid[:max(3, min(n_final_keep, len(mid)))]\n",
    "\n",
    "    final = []\n",
    "    for _, cfg, _ in finalists:\n",
    "        res = _evaluate_cfg(cfg, cases_attn, S_indices_cases, O_indices_cases, gold_labels, budget=\"final\")\n",
    "        final.append((res[\"score\"], cfg, res))\n",
    "    final.sort(key=lambda x: x[0], reverse=True)\n",
    "\n",
    "    best_score, best_cfg, best_res = final[0]\n",
    "    return {\n",
    "        \"best_cfg\": best_cfg,\n",
    "        \"best_metrics\": {\n",
    "            \"acc_top1\": best_res[\"acc_top1\"],\n",
    "            \"acc_top5_mean\": best_res[\"acc_top5_mean\"],\n",
    "            \"score\": best_score,\n",
    "        },\n",
    "        \"finalists\": final,\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8020d5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "MAX_LEN = 256\n",
    "model_path = tokenizer_path = \"roberta-large-mnli\" # microsoft/deberta-v3-large roberta-large-mnli\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "config = AutoConfig.from_pretrained(model_path, output_hidden_states=False, output_attentions=True)\n",
    "model = AutoModelForSequenceClassification.from_pretrained(model_path, config=config).to(device)\n",
    "tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, do_lower_case=False)\n",
    "\n",
    "num_layers = getattr(config, \"num_hidden_layers\", None)\n",
    "num_heads  = getattr(config, \"num_attention_heads\", None)\n",
    "\n",
    "pad_id = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "941c3336",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "idx = 0\n",
    "dataset_list = ['Irony', 'Metaphor']\n",
    "\n",
    "df = pd.read_csv(f'./data/pragmatic/{dataset_list[idx]}.csv')\n",
    "df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "df['answer_keys'] = [ast.literal_eval(data) for data in df['answer_keys']]  # Rename to avoid conflict\n",
    "df['answer'] = [data.index('correct') for data in df['answer_keys']]\n",
    "\n",
    "print(f\"Dataset loaded: {len(df)} rows\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4848c8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "idx = 0\n",
    "dataset_list = ['simile', 'metaphor', 'idiom']\n",
    "\n",
    "df = pd.read_csv(f'./data/pragmatic/{dataset_list[idx]}_test_for_new_method.csv')\n",
    "df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "\n",
    "print(f\"Dataset loaded: {len(df)} rows\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c91d1a82",
   "metadata": {},
   "outputs": [],
   "source": [
    "cases_attn = []\n",
    "S_indices_cases = []\n",
    "O_indices_cases = []\n",
    "gold_labels = []\n",
    "\n",
    "for c in df.itertuples():\n",
    "    S = c.scenarios\n",
    "    opts = c.options\n",
    "    attn_case, S_idxs, O_idxs = [], [], []\n",
    "    for opt in opts:\n",
    "        A, S_idx, O_idx = attn_and_indices(model, tokenizer, S, opt)\n",
    "        attn_case.append(A) ; S_idxs.append(S_idx) ; O_idxs.append(O_idx)\n",
    "    cases_attn.append(attn_case)\n",
    "    S_indices_cases.append(S_idxs)\n",
    "    O_indices_cases.append(O_idxs)\n",
    "    gold_labels.append(c.answer)\n",
    "    \n",
    "result = tune_hparams(cases_attn, S_indices_cases, O_indices_cases, gold_labels)\n",
    "print(\"BEST:\", result[\"best_cfg\"], result[\"best_metrics\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea46426c",
   "metadata": {},
   "outputs": [],
   "source": [
    "multihop_K = result[\"best_cfg\"]['multihop_K']\n",
    "multihop_beta = result[\"best_cfg\"]['multihop_beta']\n",
    "blend_multihop = result[\"best_cfg\"]['blend_multihop']\n",
    "topk_k = result[\"best_cfg\"]['topk_k']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d929992f",
   "metadata": {},
   "outputs": [],
   "source": [
    "start = time.time()\n",
    "all_scores, full_matrix, acc_map = compute_all_indices_multi_hop(\n",
    "    cases_attn, S_indices_cases, O_indices_cases, gold_labels,\n",
    "    use_rank_invariant=False, eps=1e-6, normalize=True, agg_mode='topk', # for ablation study, change agg_mode = {topk, max, softmax}\n",
    "    multihop_K= multihop_K, multihop_beta = multihop_beta, blend_multihop = blend_multihop, topk_k = topk_k, use_multihop=True) # for ablation study, use_multihop = True/False\n",
    "end = time.time()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a52d0f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "acc_flat, best_L, best_H, best_score = flatten_and_best(acc_map)\n",
    "print(f\"[Default] Best layer={best_L}, head={best_H}, acc={best_score:.4f}\")\n",
    "print(f'Latency : {end - start}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b1505609",
   "metadata": {},
   "source": [
    "# Total Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a06c609",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_list = ['simile', 'metaphor', 'idiom', 'Irony', 'Metaphor']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3c73bc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "for didx, dataset in enumerate(dataset_list):\n",
    "    print('Ablation Study')\n",
    "    print(dataset)\n",
    "\n",
    "    if dataset in ['Irony', 'Metaphor']:\n",
    "        df = pd.read_csv(f'./data/pragmatic/{dataset}.csv')\n",
    "        df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "        df['answer_keys'] = [ast.literal_eval(data) for data in df['answer_keys']]  # Rename to avoid conflict\n",
    "        df['answer'] = [data.index('correct') for data in df['answer_keys']]\n",
    "\n",
    "    else:\n",
    "        df = pd.read_csv(f'./data/pragmatic/{dataset}_test_for_new_method.csv')\n",
    "        df['options'] = [ast.literal_eval(data) for data in df['options']]\n",
    "\n",
    "    cases_attn = []\n",
    "    S_indices_cases = []\n",
    "    O_indices_cases = []\n",
    "    gold_labels = []\n",
    "\n",
    "    for c in df.itertuples():\n",
    "        S = c.scenarios\n",
    "        opts = c.options\n",
    "        attn_case, S_idxs, O_idxs = [], [], []\n",
    "        for opt in opts:\n",
    "            A, S_idx, O_idx = attn_and_indices(model, tokenizer, S, opt)\n",
    "            attn_case.append(A) ; S_idxs.append(S_idx) ; O_idxs.append(O_idx)\n",
    "        cases_attn.append(attn_case)\n",
    "        S_indices_cases.append(S_idxs)\n",
    "        O_indices_cases.append(O_idxs)\n",
    "        gold_labels.append(c.answer)\n",
    "    \n",
    "    result = tune_hparams(cases_attn, S_indices_cases, O_indices_cases, gold_labels)\n",
    "    multihop_K = result[\"best_cfg\"]['multihop_K']\n",
    "    multihop_beta = result[\"best_cfg\"]['multihop_beta']\n",
    "    blend_multihop = result[\"best_cfg\"]['blend_multihop']\n",
    "    topk_k = result[\"best_cfg\"]['topk_k']\n",
    "\n",
    "    start = time.time()\n",
    "    all_scores, full_matrix, acc_map = compute_all_indices_multi_hop(\n",
    "        cases_attn, S_indices_cases, O_indices_cases, gold_labels,\n",
    "        use_rank_invariant=False, eps=1e-6, normalize=True, agg_mode='topk',\n",
    "        multihop_K= multihop_K, multihop_beta = multihop_beta, blend_multihop = blend_multihop, topk_k = topk_k, use_multihop=True)\n",
    "    end = time.time()\n",
    "    acc_flat, best_L, best_H, best_score = flatten_and_best(acc_map)\n",
    "    print(f\"[Default] Best layer={best_L}, head={best_H}, acc={best_score:.4f}\")\n",
    "    for l, h, s in topk_layers_heads(acc_map, k=4):\n",
    "        print(f\"layer={l}, head={h}, acc={s:.3f}\")\n",
    "        if l==best_L and h==best_H:\n",
    "            continue\n",
    "    print(f'Latency : {end - start}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
