{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "596d192c-de00-40b4-a8b3-25e6b01e1e4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57c6b861-1675-4dda-8682-90dce2bdd6ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import collections\n",
    "import lightning\n",
    "import numpy as np\n",
    "import pyro\n",
    "import torch\n",
    "import tqdm\n",
    "\n",
    "import logger, train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "677b3650-cfa3-4e86-b90d-76fcc76171c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pyro.enable_validation(True)\n",
    "# torch.autograd.set_detect_anomaly(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57f1cbc3-5002-44f4-8b3b-f73a4d9d06a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "config, (data, model, trainer) = train.from_file(\"experiments/dcpc_celeba_config.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5fb836e-6722-4e0e-8938-ccafdea5fb7c",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = config.get_logger('valid')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2da1153-66a1-4f95-99d6-0f30f353edde",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.fit(model, data, ckpt_path=\"saved/models/Heteroskedastic_CelebA_Dcpc/0903_161845/checkpoint_149.ckpt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faa5b900-0121-4cf6-91b6-2c9b0b12716f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph.clear()\n",
    "model.eval()\n",
    "model.cuda()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a66471e-8f5b-449c-ae44-ee647b939df5",
   "metadata": {},
   "outputs": [],
   "source": [
    "for (xs, _, indices) in tqdm.tqdm(data.train_dataloader()):\n",
    "    xs = xs.to(model.device)\n",
    "    model._load_particles(indices, True)\n",
    "    with model.graph.condition(X=xs) as joint:\n",
    "        trace, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)\n",
    "    del trace\n",
    "    del log_weight\n",
    "    del xs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b2fdb21-b49c-44dd-a557-e249e83ea183",
   "metadata": {},
   "outputs": [],
   "source": [
    "xs, _, indices = list(data.val_dataloader())[0]\n",
    "xs = xs.to(model.device)\n",
    "model._load_particles(indices, False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a0ba74ac-4973-49b2-8f08-89364422027a",
   "metadata": {},
   "outputs": [],
   "source": [
    "with model.graph.condition(X=xs) as joint:\n",
    "    for i in range(299):\n",
    "        trace, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)\n",
    "        logger.info(\"Free energy at evaluation %d: %f\" % (i+1, -log_weight.mean()))\n",
    "        del trace\n",
    "        del log_weight\n",
    "    _, log_weight = joint(B=len(xs), lr=1e-3, P=model.num_particles)\n",
    "logger.info(\"Free energy at evaluation 300: %f\" % -log_weight.mean())\n",
    "del log_weight"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61390e08-d200-4dd8-968a-01aee2129baf",
   "metadata": {},
   "outputs": [],
   "source": [
    "with model.graph.condition(z=model.graph.nodes['z']['value']) as predictive:\n",
    "    x_hats = predictive(B=len(xs), mode=\"prior\", P=model.num_particles).mean(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c76257-aef8-40e3-934a-a1e534924256",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f750bfb2-80aa-4fad-885e-752772706a9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=8, ncols=8, sharex=\"all\", sharey=\"all\", layout=\"compressed\")\n",
    "\n",
    "for row in range(8):\n",
    "    for col in range(8):\n",
    "        orgs = data.reverse_transform(xs[row * 8 + col].detach().cpu()).transpose(0, -1)\n",
    "        axes[row, col].imshow(orgs)\n",
    "        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "fig.savefig(\"dcpc_celeba_orgs.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a19e9354-dd82-4066-a828-95b79cb30c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=8, ncols=8, sharex=\"all\", sharey=\"all\", layout=\"compressed\")\n",
    "\n",
    "for row in range(8):\n",
    "    for col in range(8):\n",
    "        estimates = data.reverse_transform(x_hats[row * 8 + col].detach().cpu()).transpose(0, -1).clamp(0, 1)\n",
    "        axes[row, col].imshow(estimates)\n",
    "        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "fig.savefig(\"dcpc_celeba_recons.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34f8050d-f8c0-440a-9bbd-d7fb8a904249",
   "metadata": {},
   "outputs": [],
   "source": [
    "del xs\n",
    "del x_hats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14d0da47-0636-4f1a-9a06-5a104c5e9497",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb8ca57f-dcff-437d-b489-8a66781cd14c",
   "metadata": {},
   "outputs": [],
   "source": [
    "posterior = {k: torch.cat((v.detach(), model.particles[\"valid\"][k].detach()), dim=1)\n",
    "             for k, v in model.particles[\"train\"].items()}\n",
    "x_hats = model.graph.predict(B=64 // model.num_particles, P=model.num_particles, **posterior)\n",
    "x_hats = torch.flatten(x_hats, 0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a956c010-f8c2-46bf-a125-598d4a8c639f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=8, ncols=8, sharex=\"all\", sharey=\"all\", layout=\"compressed\")\n",
    "\n",
    "for row in range(8):\n",
    "    for col in range(8):\n",
    "        estimates = data.reverse_transform(x_hats[row * 8 + col].squeeze().detach().cpu()).transpose(0, -1).clamp(0, 1)\n",
    "        axes[row, col].imshow(estimates)\n",
    "        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "fig.savefig(\"dcpc_celeba_predictive.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58a2ab4d-11c3-489d-afcf-22ce2f7e6162",
   "metadata": {},
   "outputs": [],
   "source": [
    "del x_hats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc0cbde7-57f0-48fb-956e-7ab6016f34f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c48c964-fc69-4cad-bfe5-e456ac545d73",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph(B=64 // model.num_particles, lr=1e-3, mode=\"prior\", P=model.num_particles)\n",
    "for _ in range(299):\n",
    "    model.graph(B=64 // model.num_particles, lr=1e-3, P=model.num_particles)\n",
    "x_hats = model.graph(B=64 // model.num_particles, lr=1e-3, mode=\"prior\", P=model.num_particles)\n",
    "x_hats = x_hats.flatten(0, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8d52c91-326b-469a-954d-ca5b4e4c67c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(nrows=8, ncols=8, sharex=\"all\", sharey=\"all\", layout=\"compressed\")\n",
    "\n",
    "for row in range(8):\n",
    "    for col in range(8):\n",
    "        estimates = data.reverse_transform(x_hats[row * 8 + col].squeeze().detach().cpu()).transpose(0, -1).clamp(0, 1)\n",
    "        axes[row, col].imshow(estimates)\n",
    "        axes[row, col].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])\n",
    "\n",
    "fig.savefig(\"dcpc_celeba_priors.pdf\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a8e8aef-77f3-483a-a4b8-b36b01fdffaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "del x_hats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7657733f-5780-47fc-a10e-a01fbac6b4a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEEDS = [123, 456, 789, 101112, 131415]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94d66fc2-7539-4ef9-ad74-1cc77a0e8306",
   "metadata": {},
   "outputs": [],
   "source": [
    "LOG_LIKELIHOODS = torch.zeros(len(SEEDS), requires_grad=False)\n",
    "MEAN_SQUARED_ERROR = torch.zeros(len(SEEDS), requires_grad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089f3a60-5009-4238-9c1b-f1a2e7797546",
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for (s, SEED) in enumerate(SEEDS):\n",
    "        torch.manual_seed(SEED)\n",
    "        np.random.seed(SEED)\n",
    "        for b, (xs, target, indices) in enumerate(data.val_dataloader()):\n",
    "            xs = xs.to(model.device)\n",
    "            model._load_particles(indices, False)\n",
    "            x_hats = model.graph(B=len(xs), mode=\"prior\", P=model.num_particles)\n",
    "            with model.graph.condition(X=xs) as predictive:\n",
    "                trace, _ = predictive(B=len(xs), P=model.num_particles)\n",
    "            LOG_LIKELIHOODS[s] += trace.nodes['X']['fn'].log_prob(xs).sum().cpu()\n",
    "            MEAN_SQUARED_ERROR[s] += ((xs - x_hats) ** 2).sum(dim=0).mean().cpu()\n",
    "    \n",
    "            del xs\n",
    "            del x_hats\n",
    "            del trace\n",
    "            del target\n",
    "            del indices\n",
    "            logger.info(\"Evaluated likelihood for valid batch %d under seed %s\" % (b, s))\n",
    "    \n",
    "        LOG_LIKELIHOODS[s] /= len(data.val_dataloader().dataset)\n",
    "        MEAN_SQUARED_ERROR[s] /= len(data.val_dataloader().dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23203c5d-e7f0-4f5d-868b-a0ba77d1b9e6",
   "metadata": {},
   "outputs": [],
   "source": [
    "LOG_LIKELIHOODS.mean(), LOG_LIKELIHOODS.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce608df3-f41d-4637-862c-7a825403f72f",
   "metadata": {},
   "outputs": [],
   "source": [
    "MEAN_SQUARED_ERROR.mean(), MEAN_SQUARED_ERROR.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2c7a10b-86aa-4759-a7c6-d1659743b9b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph.clear()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "598506e3-0420-4eba-89e7-1d943467a939",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.graph.likelihood.scale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a2ad342-70ca-4653-9f49-3cbad48fbf61",
   "metadata": {},
   "outputs": [],
   "source": [
    "fids = []\n",
    "metrics = collections.defaultdict(lambda: [])\n",
    "data.setup(\"test\")\n",
    "\n",
    "for f in range(10):\n",
    "    for b, batch in enumerate(tqdm.tqdm(data.test_dataloader(), desc='Test set FIDs')):\n",
    "        ms = model.test_step(batch, b)\n",
    "        for k, v in ms.items():\n",
    "            metrics[k].append(v)\n",
    "    fids.append(model.metrics['fid'].compute())\n",
    "    model.metrics['fid'].reset()\n",
    "    model.graph.gmm = None\n",
    "\n",
    "fids = torch.stack(fids, dim=0)\n",
    "fids.mean(), fids.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5fcf70b-e1de-4b86-9341-e4caf9358934",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k, v in metrics.items():\n",
    "    metrics[k] = torch.tensor(v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7b1bc00-5a6f-4e67-8401-24f33ec6e4ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "{m: v.mean(dim=-1) for m, v in metrics.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48947c28-c034-4b29-99f3-f09729cdd00c",
   "metadata": {},
   "outputs": [],
   "source": [
    "{m: v.std(dim=-1) for m, v in metrics.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35e8a0e3-7313-48f9-b8ee-affc4a355e1e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56169065-be1d-4f9f-ab87-ac44c244c6e6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:ppc] *",
   "language": "python",
   "name": "conda-env-ppc-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
