{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "46b9e7ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graphlens imported from: /home/lym/LLM-Research/Attention/Graph_Attention/src/GraphLens/src/graphlens/__init__.py\n",
      "You are an expert chemist, your task is to predict the property of molecule using your experienced chemical property prediction knowledge. \n",
      "Please strictly follow the format, no other information can be provided. Given the SMILES string of a molecule, the task focuses on predicting molecular properties, specifically penetration/non-penetration to the brain-blood barrier, based on the SMILES string representation of each molecule. You will be provided with several examples molecules, each accompanied by a binary label indicating whether it has penetrative property (Yes) or not (No). The task is to predict the binary label for a given molecule, please answer with only Yes or No.\n",
      "\n",
      "SMILES: C1=C(Cl)C=CC2=C1C(=NCC(=O)N2CC(F)(F)F)C3=CC=CC=C3\n",
      "BBBP Penetration:\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e237dcbb114e4fb6b92b9ba4f94a945a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "attn shape: (32, 32, 174, 174)\n",
      "saving to: /home/lym/LLM-Research/Attention/Graph_Attention/src/GraphLens/baselines/molecularNet/pp_plots/BBBP_llama3.1-8b/ROI_smiles\n",
      "done.\n"
     ]
    }
   ],
   "source": [
    "# Minimal \"plot.py-like\" visualization for ONE Property Prediction task (BACE/BBBP/ClinTox/HIV/Tox21)\n",
    "# Uses Llama-3.1-8B-Instruct and produces the same 2x2 figures (full attn + ROI + bin + denoised).\n",
    "\n",
    "import os, re, json\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1,2,3\"\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
    "\n",
    "import sys\n",
    "from pathlib import Path\n",
    "\n",
    "GRAPHLENS_SRC = Path(\"/home/lym/LLM-Research/Attention/Graph_Attention/src/GraphLens/src\")\n",
    "if str(GRAPHLENS_SRC) not in sys.path:\n",
    "    sys.path.insert(0, str(GRAPHLENS_SRC))\n",
    "\n",
    "# quick sanity check\n",
    "import graphlens\n",
    "print(\"graphlens imported from:\", graphlens.__file__)\n",
    "\n",
    "# GraphLens plotting/helpers (same ones used by plot.py)\n",
    "from graphlens.utils import (\n",
    "    build_sawtooth_mask,\n",
    "    extract_roi_from_attn,\n",
    "    preprocess_for_scoring,\n",
    "    normalize_minmax,\n",
    ")\n",
    "from graphlens.viz_utils import (\n",
    "    create_layer_figure,\n",
    "    create_head_figure,\n",
    ")\n",
    "\n",
    "# -----------------------\n",
    "# Config (edit these)\n",
    "# -----------------------\n",
    "TASK = \"BBBP\"  # one of: BACE BBBP ClinTox HIV Tox21\n",
    "DATA_DIR = \"/home/lym/LLM-Research/Attention/Graph_Attention/src/GraphLens/baselines/repos/ChemLLMBench/data/property_prediction\"\n",
    "PROMPT_PATH = os.path.join(DATA_DIR, \"property_prediction_prompt.txt\")\n",
    "\n",
    "MODEL_PATH = \"/home/lym/data1/LLM-model/meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "\n",
    "# Pick which visualization\n",
    "PLOT_MODE = \"layer\"  # \"layer\" or \"head\"\n",
    "LAYER_ID = 10\n",
    "HEAD_ID = 5\n",
    "\n",
    "# scoring/denoise params (same defaults as GraphLens)\n",
    "BINARIZE_METHOD = \"topk\"\n",
    "PRE_THRESHOLD_FRAC = 0.1\n",
    "DPI = 220\n",
    "MAX_SEQ_LEN = 1024\n",
    "\n",
    "# -----------------------\n",
    "# Load base prompt from txt\n",
    "# -----------------------\n",
    "def load_prompt_map(prompt_txt_path: str):\n",
    "    txt = open(prompt_txt_path, \"r\", encoding=\"utf-8\").read()\n",
    "    pat = re.compile(r\"^\\s*([A-Za-z0-9_-]+)\\s*:\\s*prompt\\s*=\\s*\\\"([\\s\\S]*?)\\\"\\s*$\", re.MULTILINE)\n",
    "    out = {}\n",
    "    for m in pat.finditer(txt):\n",
    "        k = m.group(1).strip()\n",
    "        p = m.group(2).replace(\"\\\\n\", \"\\n\").replace(\"\\\\\\\"\", \"\\\"\").strip()\n",
    "        out[k] = p\n",
    "    return out\n",
    "\n",
    "PROMPTS = load_prompt_map(PROMPT_PATH)\n",
    "base_prompt = PROMPTS[TASK]\n",
    "\n",
    "# -----------------------\n",
    "# Load one sample (test preferred; otherwise any csv)\n",
    "# -----------------------\n",
    "def resolve_csv(task: str):\n",
    "    for a, b in [(f\"{task}_test.csv\", f\"{task}_train.csv\"),\n",
    "                 (f\"{task.lower()}_test.csv\", f\"{task.lower()}_train.csv\")]:\n",
    "        tp = os.path.join(DATA_DIR, a)\n",
    "        tr = os.path.join(DATA_DIR, b)\n",
    "        if os.path.exists(tp):\n",
    "            return tp\n",
    "        if os.path.exists(tr):\n",
    "            return tr\n",
    "    raise FileNotFoundError(f\"Cannot find {task}_test.csv or {task}_train.csv under {DATA_DIR}\")\n",
    "\n",
    "csv_path = resolve_csv(TASK)\n",
    "df = pd.read_csv(csv_path)\n",
    "\n",
    "# figure out smiles column for each task (per your headers)\n",
    "SMILES_COL = {\n",
    "    \"BACE\": \"mol\",\n",
    "    \"BBBP\": \"smiles\",\n",
    "    \"ClinTox\": \"smiles\",\n",
    "    \"HIV\": \"smiles\",\n",
    "    \"Tox21\": \"smiles\",\n",
    "}[TASK]\n",
    "\n",
    "# pick first non-empty smiles\n",
    "row = None\n",
    "for i in range(len(df)):\n",
    "    s = str(df.loc[i, SMILES_COL]) if SMILES_COL in df.columns else \"\"\n",
    "    if isinstance(s, str) and s.strip():\n",
    "        row = df.loc[i]\n",
    "        break\n",
    "assert row is not None, \"No valid SMILES found\"\n",
    "\n",
    "smiles = str(row[SMILES_COL]).strip()\n",
    "\n",
    "# -----------------------\n",
    "# Build a simple prompt (zero-shot)\n",
    "# (We only need a stable \"SMILES: ...\\n<field>:\" line for span detection)\n",
    "# -----------------------\n",
    "def build_query(task: str, smiles: str):\n",
    "    if task == \"BACE\":\n",
    "        return f\"SMILES: {smiles}\\nBACE-1 Inhibit:\"\n",
    "    if task == \"BBBP\":\n",
    "        return f\"SMILES: {smiles}\\nBBBP Penetration:\"\n",
    "    if task == \"ClinTox\":\n",
    "        return f\"SMILES: {smiles}\\nClinically-trial-toxic:\"\n",
    "    if task == \"HIV\":\n",
    "        return f\"SMILES: {smiles}\\nHIV Inhibit:\"\n",
    "    if task == \"Tox21\":\n",
    "        # simplest: choose one assay name to visualize prompt structure\n",
    "        return f\"SMILES: {smiles}\\nAssay: NR-AR\\nToxic:\"\n",
    "    return f\"SMILES: {smiles}\\nLabel:\"\n",
    "\n",
    "prompt = base_prompt.strip() + \"\\n\\n\" + build_query(TASK, smiles) + \"\\n\"\n",
    "print(prompt)\n",
    "\n",
    "# -----------------------\n",
    "# Find token span that corresponds to the SMILES substring in the prompt\n",
    "# (This replaces GraphWiz edge-span logic.)\n",
    "# -----------------------\n",
    "def get_smiles_token_span(prompt: str, smiles: str, tokenizer):\n",
    "    # find char span of the first occurrence of smiles\n",
    "    start = prompt.find(smiles)\n",
    "    if start < 0:\n",
    "        raise ValueError(\"SMILES not found in prompt (unexpected).\")\n",
    "    end = start + len(smiles)\n",
    "\n",
    "    enc = tokenizer(prompt, return_tensors=\"pt\", return_offsets_mapping=True, add_special_tokens=True)\n",
    "    offsets = enc[\"offset_mapping\"][0].tolist()  # list[(s,e)]\n",
    "\n",
    "    t_start, t_end = None, None\n",
    "    for ti, (s, e) in enumerate(offsets):\n",
    "        if s == 0 and e == 0:\n",
    "            continue\n",
    "        if e > start:\n",
    "            t_start = ti\n",
    "            break\n",
    "    for ti in range(len(offsets) - 1, -1, -1):\n",
    "        s, e = offsets[ti]\n",
    "        if s == 0 and e == 0:\n",
    "            continue\n",
    "        if s < end:\n",
    "            t_end = ti\n",
    "            break\n",
    "\n",
    "    if t_start is None or t_end is None or t_end < t_start:\n",
    "        raise ValueError(\"Failed to map SMILES char-span to token-span.\")\n",
    "    return t_start, t_end\n",
    "\n",
    "# -----------------------\n",
    "# Load model and run attentions\n",
    "# -----------------------\n",
    "tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, padding_side=\"left\")\n",
    "if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "    MODEL_PATH,\n",
    "    device_map=\"auto\",\n",
    "    torch_dtype=\"auto\",\n",
    "    attn_implementation=\"eager\",\n",
    ")\n",
    "model.eval()\n",
    "\n",
    "enc = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "if enc.input_ids.shape[1] > MAX_SEQ_LEN:\n",
    "    enc.input_ids = enc.input_ids[:, :MAX_SEQ_LEN]\n",
    "    if \"attention_mask\" in enc:\n",
    "        enc.attention_mask = enc.attention_mask[:, :MAX_SEQ_LEN]\n",
    "\n",
    "with torch.inference_mode():\n",
    "    out = model(enc.input_ids, output_attentions=True, use_cache=False)\n",
    "\n",
    "# stack attentions: [L, H, S, S]\n",
    "attn_np = np.stack([a[0].to(torch.float32).detach().cpu().numpy() for a in out.attentions], axis=0)\n",
    "\n",
    "num_layers, num_heads, S, _ = attn_np.shape\n",
    "print(\"attn shape:\", attn_np.shape)\n",
    "\n",
    "\n",
    "# -----------------------\n",
    "# Build ROI around the SMILES token span\n",
    "# -----------------------\n",
    "t_start, t_end = get_smiles_token_span(prompt, smiles, tokenizer)\n",
    "g_start, g_end = t_start, t_end\n",
    "span_len = g_end - g_start + 1\n",
    "\n",
    "# one local span covering the full SMILES area -> sawtooth mask becomes a lower-triangle\n",
    "local_spans = [(0, span_len - 1)]\n",
    "ideal_mask = build_sawtooth_mask(span_len, local_spans)\n",
    "\n",
    "\n",
    "# ---- Save per-layer figures to local ----\n",
    "import os\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "SAVE_DIR = f\"./pp_plots/{TASK}_llama3.1-8b/ROI_smiles\"\n",
    "os.makedirs(SAVE_DIR, exist_ok=True)\n",
    "\n",
    "print(\"saving to:\", os.path.abspath(SAVE_DIR))\n",
    "\n",
    "for l in range(num_layers):\n",
    "    # avg ROI over heads\n",
    "    rois = []\n",
    "    full_heads = []\n",
    "    for h in range(num_heads):\n",
    "        rois.append(extract_roi_from_attn(attn_np, l, h, g_start, g_end))\n",
    "        full_heads.append(attn_np[l, h])\n",
    "\n",
    "    avg_roi = np.mean(np.stack(rois, axis=0), axis=0)\n",
    "\n",
    "    bin_mask, denoised = preprocess_for_scoring(\n",
    "        avg_roi,\n",
    "        binarize_method=BINARIZE_METHOD,\n",
    "        ideal_mask=ideal_mask,\n",
    "        pre_threshold_frac=PRE_THRESHOLD_FRAC,\n",
    "    )\n",
    "\n",
    "    full_attn_layer = np.mean(np.stack(full_heads, axis=0), axis=0)\n",
    "\n",
    "    title = f\"{TASK} | L{l} AvgHeads | ROI=SMILES tokens [{g_start},{g_end}]\"\n",
    "    fig = create_layer_figure(\n",
    "        full_attn_layer=full_attn_layer,\n",
    "        avg_roi=avg_roi,\n",
    "        bin_mask=bin_mask,\n",
    "        blurred_norm=denoised,\n",
    "        title=title,\n",
    "        local_spans=local_spans,\n",
    "        g_start=g_start,\n",
    "        g_end=g_end,\n",
    "    )\n",
    "\n",
    "    out_png = os.path.join(SAVE_DIR, f\"{TASK}_L{l:02d}_avgheads.png\")\n",
    "    fig.savefig(out_png, dpi=DPI, bbox_inches=\"tight\")\n",
    "    plt.close(fig)\n",
    "\n",
    "print(\"done.\")\n",
    "\n",
    "# # -----------------------\n",
    "# # Plot: \"layer\" mode (average heads) or \"head\" mode (single head)\n",
    "# # -----------------------\n",
    "# if PLOT_MODE == \"layer\":\n",
    "#     l = int(min(max(LAYER_ID, 0), num_layers - 1))\n",
    "\n",
    "#     # avg ROI over heads\n",
    "#     rois = []\n",
    "#     full_heads = []\n",
    "#     for h in range(num_heads):\n",
    "#         roi = extract_roi_from_attn(attn_np, l, h, g_start, g_end)\n",
    "#         rois.append(roi)\n",
    "#         full_heads.append(attn_np[l, h])\n",
    "\n",
    "#     avg_roi = np.mean(np.stack(rois, axis=0), axis=0)\n",
    "#     bin_mask, denoised = preprocess_for_scoring(\n",
    "#         avg_roi,\n",
    "#         binarize_method=BINARIZE_METHOD,\n",
    "#         ideal_mask=ideal_mask,\n",
    "#         pre_threshold_frac=PRE_THRESHOLD_FRAC,\n",
    "#     )\n",
    "\n",
    "#     # full attention layer-avg\n",
    "#     full_attn_layer = np.mean(np.stack(full_heads, axis=0), axis=0)\n",
    "\n",
    "#     title = f\"{TASK} | L{l} AvgHeads | ROI=SMILES tokens [{g_start},{g_end}]\"\n",
    "#     fig = create_layer_figure(\n",
    "#         full_attn_layer=full_attn_layer,\n",
    "#         avg_roi=avg_roi,\n",
    "#         bin_mask=bin_mask,\n",
    "#         blurred_norm=denoised,\n",
    "#         title=title,\n",
    "#         local_spans=local_spans,\n",
    "#         g_start=g_start,\n",
    "#         g_end=g_end,\n",
    "#     )\n",
    "#     fig.set_dpi(DPI)\n",
    "#     plt.show()\n",
    "\n",
    "# elif PLOT_MODE == \"head\":\n",
    "#     l = int(min(max(LAYER_ID, 0), num_layers - 1))\n",
    "#     h = int(min(max(HEAD_ID, 0), num_heads - 1))\n",
    "\n",
    "#     roi = extract_roi_from_attn(attn_np, l, h, g_start, g_end)\n",
    "#     bin_mask, denoised = preprocess_for_scoring(\n",
    "#         roi,\n",
    "#         binarize_method=BINARIZE_METHOD,\n",
    "#         ideal_mask=ideal_mask,\n",
    "#         pre_threshold_frac=PRE_THRESHOLD_FRAC,\n",
    "#     )\n",
    "\n",
    "#     title = f\"{TASK} | L{l} H{h} | ROI=SMILES tokens [{g_start},{g_end}]\"\n",
    "#     fig = create_head_figure(\n",
    "#         full_attn_head=attn_np[l, h],\n",
    "#         roi=roi,\n",
    "#         bin_mask=bin_mask,\n",
    "#         blurred_norm=denoised,\n",
    "#         title=title,\n",
    "#         local_spans=local_spans,\n",
    "#         g_start=g_start,\n",
    "#         g_end=g_end,\n",
    "#     )\n",
    "#     fig.set_dpi(DPI)\n",
    "#     plt.show()\n",
    "\n",
    "# else:\n",
    "#     raise ValueError(\"PLOT_MODE must be 'layer' or 'head'\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "graphllm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
