{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../..\")\n",
    "import torch\n",
    "from draw import make_html\n",
    "import py3Dmol\n",
    "import numpy as np\n",
    "from metrics import check_validity_smiles\n",
    "\n",
    "samples = torch.load(\"../../samples/gen/ds_qm9_original_rep0/gen.pt\", weights_only=False)\n",
    "res = 560\n",
    "\n",
    "for chunk in range(4):\n",
    "    view = py3Dmol.view(width=res*4, height=res*5, viewergrid=(5, 4))\n",
    "    for i in range(20):\n",
    "        b_idx = chunk * 20 + i + 160\n",
    "        row = i % 5\n",
    "        col = i // 5\n",
    "\n",
    "        M = samples[0][0][b_idx].clone()\n",
    "        for i, a in enumerate(M.atoms):\n",
    "            if a == 0:\n",
    "                break\n",
    "        M.atoms = M.atoms[:i]\n",
    "        M.coords = M.coords[:i]\n",
    "        M.coords = M.coords - M.coords.mean(dim=0, keepdim=True)\n",
    "\n",
    "        # PCA for visualization\n",
    "        U, _, _ = np.linalg.svd(M.coords.T.numpy())\n",
    "        if np.linalg.det(U) < 0:\n",
    "            U[:, -1] *= -1\n",
    "        U = torch.tensor(U)\n",
    "        M.coords = M.coords @ U\n",
    "\n",
    "        view = M.show(view=view, viewer=(row, col))\n",
    "\n",
    "        yes, _ = check_validity_smiles(M)\n",
    "        if yes:\n",
    "            bg = \"#cffbcf\"\n",
    "        else:\n",
    "            bg = \"#f4ccca\"\n",
    "        view.setBackgroundColor(bg, viewer=(row, col))\n",
    "    \n",
    "    # view.show()\n",
    "    path = f\"qm9_chunk{chunk}.html\"\n",
    "    make_html(view, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = torch.load(\"../../samples/gen/ds_geom_original_test201_rep0/gen.pt\", weights_only=False)\n",
    "res = 560\n",
    "\n",
    "for chunk in range(4):\n",
    "    view = py3Dmol.view(width=res*4, height=res*5, viewergrid=(5, 4))\n",
    "    for i in range(20):\n",
    "        b_idx = chunk * 20 + i\n",
    "        row = i % 5\n",
    "        col = i // 5\n",
    "\n",
    "        M = samples[0][0][b_idx].clone()\n",
    "        for i, a in enumerate(M.atoms):\n",
    "            if a == 0:\n",
    "                break\n",
    "        M.atoms = M.atoms[:i]\n",
    "        M.coords = M.coords[:i]\n",
    "        M.coords = M.coords - M.coords.mean(dim=0, keepdim=True)\n",
    "\n",
    "        # PCA for visualization\n",
    "        U, _, _ = np.linalg.svd(M.coords.T.numpy())\n",
    "        if np.linalg.det(U) < 0:\n",
    "            U[:, -1] *= -1\n",
    "        U = torch.tensor(U)\n",
    "        M.coords = M.coords @ U\n",
    "\n",
    "        view = M.show(view=view, viewer=(row, col))\n",
    "\n",
    "        yes, _ = check_validity_smiles(M)\n",
    "        if yes:\n",
    "            bg = \"#cffbcf\"\n",
    "        else:\n",
    "            bg = \"#f4ccca\"\n",
    "        view.setBackgroundColor(bg, viewer=(row, col))\n",
    "    \n",
    "    view.zoom(0.94)\n",
    "    # view.show()\n",
    "    path = f\"geom_chunk{chunk}.html\"\n",
    "    make_html(view, path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dar2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
