{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a4e812",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support, adjusted_rand_score, mutual_info_score\n",
    "from collections import Counter\n",
    "from pathlib import Path\n",
    "os.chdir(Path.cwd().parents[0])\n",
    "from bin.learn import *\n",
    "import torch\n",
    "import json\n",
    "from types import SimpleNamespace\n",
    "from foldingdiff.potential_model import *\n",
    "from foldingdiff.modelling import *\n",
    "from foldingdiff.tokenizer import Tokenizer\n",
    "from foldingdiff.datasets import *\n",
    "import torch\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.reset_peak_memory_stats()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdcca647",
   "metadata": {},
   "outputs": [],
   "source": [
    "# d = '1749783864.896021'\n",
    "# bad_key = '3lca_A'\n",
    "# fnames = glob(os.path.join(REPEAT_DIR, \"*.pdb\"))\n",
    "# j = next((i for i, v in enumerate(fnames) if bad_key in v), None)\n",
    "# i = j//100\n",
    "# path = f\"{d}/feats_100_{i}.pkl\"\n",
    "# print(f\"loading {i}\")\n",
    "# stuff = pickle.load(open(path, \"rb\"))\n",
    "# # if bad_key in stuff:\n",
    "# #     print(f\"dumping {bad_key}\")\n",
    "# #     pickle.dump(stuff[bad_key], open(f'ckpts/{d}/{bad_key}.pkl', 'wb+'))\n",
    "# stuff.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37123804",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_dir = \"ckpts/1749783864.878101\"\n",
    "args_path = os.path.join(ckpt_dir, \"args.json\")\n",
    "with open(args_path, \"r\") as f:\n",
    "    arg_dict = json.load(f)\n",
    "# 2. Turn it into an object whose keys become attributes\n",
    "args = SimpleNamespace(**arg_dict)\n",
    "raw_ds = FullCathCanonicalCoordsDataset(\n",
    "    args.data_dir, use_cache=False, debug=args.debug,\n",
    "    zero_center=False, toy=30, pad=args.pad, secondary=False,\n",
    "    trim_strategy=\"discard\"\n",
    ")\n",
    "dataset = [Tokenizer(x) for x in raw_ds.structures]\n",
    "# maximum token length for Transformer positional encodings\n",
    "max_len = max([3 * (3 * t.n - 1) - 2 for t in dataset])\n",
    "device = args.cuda\n",
    "# ---------------- build model ------------------------------------\n",
    "model = get_model(args, device, max_len=max_len)           # returns SemiCRFModel\n",
    "model.to(args.cuda)\n",
    "if args.config:\n",
    "    config = json.load(open(args.config))\n",
    "# compute feats in batches            \n",
    "dataset = FeatDataset(dataset, args.save_dir)\n",
    "checkpoint_path, epoch = find_latest_checkpoint(args.save_dir)    \n",
    "print(checkpoint_path)\n",
    "ckpt = torch.load(checkpoint_path, map_location=model.device)\n",
    "if 'model_state' in ckpt:\n",
    "    model.load_state_dict(ckpt['model_state'])\n",
    "else:\n",
    "    model.load_state_dict(ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1411cd0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx in [25]:\n",
    "    (_, t, feats) = dataset[idx]\n",
    "    N = t.n\n",
    "    assert t.n == len(t.aa), \"number of residues != length of amino acid sequence\"\n",
    "    coords = t.compute_coords()\n",
    "    out, attn_scores = model.precompute(\n",
    "        feats         = feats,\n",
    "        aa_seq        = t.aa,              # Tokenizer stores AA sequence\n",
    "        coords_tensor = coords\n",
    "    )                                       # out[i][l] ready for DP\n",
    "    log_a, map_a, best_lens = semi_crf_dp_and_map(out, N, gamma=args.gamma)\n",
    "    best_seg = backtrace_map_segmentation(best_lens, N)\n",
    "    attn_stack = torch.stack([attn_scores[start][end-start] for start, end in best_seg], axis=0)\n",
    "    attn_agg = attn_stack.mean(axis=0)    \n",
    "    t.bond_to_token = {3*start: (3*start, 3*seg_idx, min(3*(end-start), 3*t.n-1-3*start))\n",
    "                    for seg_idx, (start, end) in enumerate(best_seg)} \n",
    "    loss   = -log_a[N]                       # negative log‑partition\n",
    "    prob = torch.exp(map_a[N] - log_a[N]).item()\n",
    "    print(idx, prob)\n",
    "    epoch = -1\n",
    "    path = Path(os.path.join(args.plot_dir, f\"epoch={epoch}_idx={idx}_p={prob:.3f}.png\"))\n",
    "    attn_path = path.with_name(path.stem + \"_attn\" + path.suffix)\n",
    "    t.visualize(path, vis_dihedral=False)\n",
    "    plot_feature_importance(attn_agg.detach().cpu().numpy(), model.aggregator.per_res_labels, attn_path)\n",
    "    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f\"epoch={epoch}_idx={idx}_p={prob:.3f}_tree.png\"), horizontal_gap=1.0, font_size=6)\n",
    "    # # start building hierarchy (down)\n",
    "    for _ in range(20):\n",
    "        vals = list(map(lambda l: l.value, t.bond_to_token.tree.leaves.values()))\n",
    "        max_bt = max([val[1] for val in vals])\n",
    "        best_i, best = (-1, -1), float(\"-inf\")\n",
    "        for i, (i1, _, l1) in enumerate(vals):\n",
    "            for j in range(1, l1-1): # split here\n",
    "                if j%3: continue\n",
    "                assert (l1-j+1)%3 != 2\n",
    "                expr = out[i1//3][j//3] + out[(i1+j)//3][(l1-j+1)//3]\n",
    "                if expr > best:\n",
    "                    best = expr\n",
    "                    best_i = (i, j)\n",
    "        (i1, b1, l1) = vals[best_i[0]]\n",
    "        t.bond_to_token.tree.split((i1, b1, l1), (i1, max_bt+1, j), (i1+j, max_bt+2, l1-j))\n",
    "        max_bt += 2\n",
    "    # # start building hierarchy (up)\n",
    "    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f\"epoch={epoch}_idx={idx}_p={prob:.3f}_down.png\"), horizontal_gap=0.5, font_size=6)\n",
    "    for _ in range(20):\n",
    "        vals = list(t.bond_to_token.values())\n",
    "        max_bt = max([val[1] for val in vals])\n",
    "        best_i, best = -1, float(\"-inf\")\n",
    "        for i, (i1, _, l1) in enumerate(vals):\n",
    "            if i < len(t.bond_to_token)-1:\n",
    "                (i2, _, l2) = vals[i+1]\n",
    "                assert i1+l1 == i2\n",
    "                try:\n",
    "                    if out[i1//3][(l1+l2)//3] > best:\n",
    "                        best = out[i1//3][(l1+l2)//3]\n",
    "                        best_i = i\n",
    "                except:\n",
    "                    print(i1, l1, l2)\n",
    "                    raise\n",
    "        if best_i < 0:\n",
    "            break\n",
    "        (i1, _, l1) = vals[best_i]\n",
    "        (i2, _, l2) = vals[best_i+1]\n",
    "        t.bond_to_token.pop(i2)\n",
    "        t.bond_to_token[i1] = (i1, max_bt+1, l1+l2)\n",
    "        max_bt += 1\n",
    "    \n",
    "    t.bond_to_token.tree.visualize(os.path.join(args.plot_dir, f\"epoch={epoch}_idx={idx}_p={prob:.3f}_up.png\"), horizontal_gap=0.5, font_size=6)\n",
    "    \n",
    "    \n",
    "\n",
    "    # epoch = -1\n",
    "    path = Path(os.path.join(args.plot_dir, f\"epoch={epoch}_idx={idx}_p={prob:.3f}_after.png\"))\n",
    "    attn_path = path.with_name(path.stem + \"_attn\" + path.suffix)\n",
    "    t.visualize(path, vis_dihedral=False)\n",
    "    plot_feature_importance(attn_agg.detach().cpu().numpy(), model.aggregator.per_res_labels, attn_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81bd3e97",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
