{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c69d5731",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython.display import Image, display\n",
    "import matplotlib.image as mpimg\n",
    "from tqdm import tqdm\n",
    "import math\n",
    "import os\n",
    "import json\n",
    "from types import SimpleNamespace\n",
    "from foldingdiff.tokenizer import Tokenizer\n",
    "from collections import defaultdict\n",
    "from foldingdiff.datasets import *\n",
    "from foldingdiff.utils import *\n",
    "from foldingdiff.plotting import plot\n",
    "os.chdir(Path.cwd().parents[0])\n",
    "from bin.encode import BPE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "810c2acc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def modified(t):\n",
    "    mod = []\n",
    "    for k, v in t.bond_to_token.items():\n",
    "        if isinstance(v[1], tuple):\n",
    "            mod.append(k)\n",
    "    return mod\n",
    "\n",
    "\n",
    "def compare(t1, t2):\n",
    "    return compute_rmsd(t1.compute_coords(), t2.compute_coords())\n",
    "\n",
    "\n",
    "def vis_images(*paths, scale=4):\n",
    "    \"\"\"\n",
    "    Display an arbitrary number of images in a square-ish grid layout.\n",
    "\n",
    "    Parameters:\n",
    "    *paths: variable number of file paths to images\n",
    "    \"\"\"\n",
    "    n = len(paths)\n",
    "    if n == 0:\n",
    "        print(\"No images to display.\")\n",
    "        return\n",
    "\n",
    "    # Determine grid size (close to square)\n",
    "    n_cols = math.ceil(math.sqrt(n))\n",
    "    n_rows = math.ceil(n / n_cols)\n",
    "\n",
    "    # Create subplots\n",
    "    fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * scale, n_rows * scale))\n",
    "\n",
    "    # Flatten axes array for easy iteration\n",
    "    if isinstance(axes, plt.Axes):\n",
    "        axes = [axes]\n",
    "    else:\n",
    "        axes = axes.flatten()\n",
    "\n",
    "    # Display each image\n",
    "    for ax, path in zip(axes, paths):\n",
    "        img = mpimg.imread(path)\n",
    "        ax.imshow(img)\n",
    "        ax.set_title(path.split(\"/\")[-1])\n",
    "        ax.axis('off')\n",
    "\n",
    "    # Hide any unused subplots\n",
    "    for ax in axes[len(paths):]:\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f365b12e",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_images('ckpts/1752521366.505088/run_iter=50.png', scale=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab5c5f95",
   "metadata": {},
   "outputs": [],
   "source": [
    "d = \"1752566951.7877476\"\n",
    "# d = \"1752336941.5962725\"\n",
    "# d = \"1752336941.5956001\"\n",
    "# d = \"1752293573.2959695\"\n",
    "\n",
    "args_path = f\"./ckpts/{d}/args.txt\"\n",
    "args = load_args_from_txt(args_path)\n",
    "args.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b36201be",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "# folders = [\"./ckpts/1752293573.2959695\",           \n",
    "#            \"./ckpts/1752293573.2966762\", \n",
    "#            \"./ckpts/1752336941.5962725\", \n",
    "#            \"./ckpts/1752336941.5956001\"]\n",
    "folders = [\"./ckpts/1752566951.7877476\"]\n",
    "paths = []\n",
    "for folder in folders:\n",
    "    p = Path(folder)\n",
    "    # find all run_iter PNGs\n",
    "    pngs = list(p.glob(\"run_iter=*.png\"))\n",
    "    if not pngs:\n",
    "        print(f\"No run_iter PNGs in {folder!r}, skipping.\")\n",
    "        continue\n",
    "\n",
    "    # helper to extract the integer after run_iter=\n",
    "    def iter_num(fp: Path):\n",
    "        m = re.search(r\"run_iter=(\\d+)\\.png$\", fp.name)\n",
    "        return int(m.group(1)) if m else -1\n",
    "\n",
    "    # pick the file with the max iteration\n",
    "    latest = max(pngs, key=iter_num)\n",
    "    ref_coords = np.load(os.path.join(folder, \"ref_coords.npy\"), allow_pickle=True)\n",
    "    run_path = os.path.join(folder, f\"run_iter={latest}.png\")            \n",
    "    plot(ref_coords, p.name, run_path, no_iters=iter_num(latest), step_iter=args.save_every, ratio=None)\n",
    "    print(f\"Latest in {folder!r} → {latest}\")\n",
    "    paths.append(str(latest))\n",
    "\n",
    "# now display them\n",
    "vis_images(*paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e9f5b87",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = FullCathCanonicalCoordsDataset(args.data_dir, \n",
    "                                         use_cache=False, \n",
    "                                         debug=False, \n",
    "                                         zero_center=False, \n",
    "                                         toy=args.toy, \n",
    "                                         pad=args.pad, \n",
    "                                         secondary=args.sec)     \n",
    "cleaned_structures = []\n",
    "for i, struc in enumerate(dataset.structures):\n",
    "    if (struc['angles']['psi']==struc['angles']['psi']).sum() < len(struc['angles']['psi'])-1:\n",
    "        print(f\"skipping {i}, {struc['fname']} because of missing dihedrals\")\n",
    "    else:\n",
    "        cleaned_structures.append(struc)\n",
    "dataset.structures = cleaned_structures\n",
    "ref = BPE(dataset.structures, \n",
    "            bins=args.bins, \n",
    "            bin_strategy=args.bin_strategy, \n",
    "            save_dir=f'./ckpts/{d}',\n",
    "            rmsd_partition_min_size=args.p_min_size,\n",
    "            num_partitions=args.num_p,\n",
    "            compute_sec_structs=args.sec, \n",
    "            plot_iou_with_sec_structs=args.sec_eval,                  \n",
    "            res_init=args.res_init)\n",
    "ref.initialize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47d8dd4d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a8a3ff4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# len(bpe.tokenizers), len(ref.tokenizers)\n",
    "# len(pickle.load(open(\"./ckpts/1751936564.1540673/bpe_iter=100.pkl\", \"rb\")).tokenizers)\n",
    "len(cleaned_structures)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2cdffad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = 0\n",
    "path = f'./ckpts/{d}/bpe_iter={t}.pkl'\n",
    "bpe = pickle.load(open(path, 'rb'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25056d88",
   "metadata": {},
   "outputs": [],
   "source": [
    "for index in range(len(bpe.tokenizers)):\n",
    "    t = bpe.tokenizers[index]\n",
    "    for k in modified(t):\n",
    "        print(index, k, t.bond_to_token[k])\n",
    "\n",
    "        \n",
    "index = 6\n",
    "t = bpe.tokenizers[index] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d47e4d83",
   "metadata": {},
   "outputs": [],
   "source": [
    "start, length = 69, 6\n",
    "occur = (30, 8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9297b2b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.abspath('../test.png')\n",
    "ref_path = os.path.abspath('../ref.png')\n",
    "bond_path = os.path.abspath('../test_bonds.png')\n",
    "ref_bond_path = os.path.abspath('../ref_bonds.png)\n",
    "t.visualize(patref.tokenizers[index].visualize(ref_path)\n",
    "t.visualize_bonds(start, length, bond_path)\n",
    "ref.tokenizers[index].visualize_bonds(start, length, ref_bond_path)\n",
    "vis_images(ref_bond_path, bond_path)\n",
    "# vis_images(*([bond_path] + [f'./ckpts/{d}/key_iter=0_{i}.png' for i in range(10)]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f66a3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "vis_images(ref_path, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73ee0398",
   "metadata": {},
   "outputs": [],
   "source": [
    "t.token_geo(start, length), bpe._tokens[occur]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3704ea5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ref.tokenizers[index].fname, t.fname"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70bf7c4e",
   "metadata": {},
   "outputs": [],
   "source": [
    "full = t.token_geo(0, 3*t.n-1)\n",
    "tokenized = t.tokenize()\n",
    "repl = bpe.recover(tokenized)\n",
    "assert full == repl\n",
    "bpe.quantize(tokenized)\n",
    "tokenized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f873429a",
   "metadata": {},
   "outputs": [],
   "source": [
    "struc = cleaned_structures[0]['angles']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80f2b4f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "bpe.tokenizers[0].n, len(repl[\"0C:1N\"]), len(repl[\"N:CA\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "074202de",
   "metadata": {},
   "outputs": [],
   "source": [
    "len(cleaned_structures), len(ref.tokenizers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02ecbf1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "ref.tokenizers[0]._angles_and_dists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2433019e",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_structures[0]['angles']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fabed0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "ref.tokenizers[0].n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc85f8a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# cleaned_structures[0]\n",
    "t.angles_and_dists[\"0C:1N\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb6bd866",
   "metadata": {},
   "outputs": [],
   "source": [
    "t._angles_and_dists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21c7c6d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0f774bf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "864fc27f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
