{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpts = \"small_w512 small_w1024 small_w1536 base_w512 base_w1024 base_w1536 big_w512 big_w1024 original\".split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 66%|██████▌   | 21/32 [00:02<00:01,  8.79it/s]\n",
      " 66%|██████▌   | 21/32 [00:02<00:01,  9.51it/s]\n",
      " 72%|███████▏  | 23/32 [00:02<00:00,  9.38it/s]\n",
      " 66%|██████▌   | 21/32 [00:02<00:01,  9.33it/s]\n",
      " 66%|██████▌   | 21/32 [00:02<00:01,  9.48it/s]\n",
      " 66%|██████▌   | 21/32 [00:02<00:01,  9.32it/s]\n",
      " 72%|███████▏  | 23/32 [00:02<00:00,  9.33it/s]\n",
      " 66%|██████▌   | 21/32 [00:02<00:01,  9.46it/s]\n",
      " 72%|███████▏  | 23/32 [00:02<00:00,  9.29it/s]\n"
     ]
    }
   ],
   "source": [
    "DIFF_STEPS = 30\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "from train import Config, LitQuetzal\n",
    "\n",
    "def gen(ckpt, bsz=8, seed=0):\n",
    "    ckpt_path = f\"../../checkpoints/{ckpt}.ckpt\"\n",
    "\n",
    "    lit = LitQuetzal.load_from_checkpoint(ckpt_path, map_location=DEVICE)\n",
    "    model = lit.ema.module\n",
    "    model.eval();\n",
    "    kwargs = {\n",
    "        \"device\": DEVICE,\n",
    "        \"num_steps\": DIFF_STEPS,\n",
    "        \"pbar\": True,\n",
    "        \"max_len\": 32,\n",
    "    }\n",
    "    torch.manual_seed(seed)\n",
    "    return model.generate(bsz, **kwargs)\n",
    "\n",
    "outs = [gen(ckpt) for ckpt in ckpts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(outs, \"gen_30_midrotate.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from draw import make_html\n",
    "for b_idx in range(8):\n",
    "\n",
    "    import py3Dmol\n",
    "    ncols = 3\n",
    "    nrows = 3\n",
    "\n",
    "    view = py3Dmol.view(width=2880, height=2880, viewergrid=(nrows, ncols))\n",
    "\n",
    "    ref = outs[-1][0][b_idx]\n",
    "    for i, a in enumerate(ref.atoms):\n",
    "        if a == 0:\n",
    "            break\n",
    "    c = ref.coords[:i]\n",
    "    mean = c.mean(dim=0, keepdim=True)\n",
    "\n",
    "    c = c - mean\n",
    "    U, _, _ = np.linalg.svd(c.T.numpy())\n",
    "    if np.linalg.det(U) < 0:\n",
    "        U[:, -1] *= -1\n",
    "    U = torch.tensor(U)\n",
    "\n",
    "    for i in range(nrows*ncols):\n",
    "        row = i // ncols\n",
    "        col = i % ncols\n",
    "        # print(row, col, ckpts[i])\n",
    "\n",
    "        M = outs[i][0][b_idx]\n",
    "        M.coords = (M.coords - mean) @ U\n",
    "        view = M.show(view=view, viewer=(row, col), zoom=True)\n",
    "\n",
    "    path = f\"seeded_30_{b_idx}.html\"\n",
    "    make_html(view, path)\n",
    "    # view.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dar2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
