{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import os\n",
                "import torch\n",
                "from rdkit import Chem, RDLogger\n",
                "from rdkit.Chem import Draw\n",
                "from src.model.flow_vae import FlowMAGNet\n",
                "from src.model.load_utils import load_model_from_id\n",
                "from src.utils import ROOT_DIR, WB_COLLECTION, smiles_from_file, DATA_PATH\n",
                "\n",
                "RDLogger.DisableLog(\"rdApp.*\")\n",
                "os.chdir(ROOT_DIR)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "model = load_model_from_id(collection=WB_COLLECTION, run_id=\"2op8w2pw\", model_class=FlowMAGNet)\n",
                "dm = model.trainer.datamodule\n",
                "train_dl = dm.train_dataloader()"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "gt_smiles = smiles_from_file(DATA_PATH / \"zinc\" / \"val.txt\")[:10]"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "output_smiles = model.reconstruct_from_smiles(gt_smiles)\n",
                "for gt, out in zip(gt_smiles, output_smiles):\n",
                "    display(Draw.MolsToGridImage([Chem.MolFromSmiles(gt), Chem.MolFromSmiles(out)]))"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "torch.manual_seed(0)\n",
                "smiles  = model.sample_molecules(10)\n",
                "Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], molsPerRow=5)"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": []
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "molgen-old",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.7.13"
        },
        "orig_nbformat": 4,
        "vscode": {
            "interpreter": {
                "hash": "0f6652a2d5cdbf98f2836c9e6456d23091f31ef28960ab0e66bfaade4a9ae0bb"
            }
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
