{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "wound-radar",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from physvae import utils\n",
    "from physvae.galaxy.model import VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "choice-jersey",
   "metadata": {},
   "outputs": [],
   "source": [
    "# setting\n",
    "datadir = './data/galaxy'\n",
    "dataname = 'test'\n",
    "modeldir = './out_galaxy/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acoustic-sacramento",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load test data\n",
    "data_all = np.load('{}/data_all.npy'.format(datadir))\n",
    "idx_test = np.loadtxt('{}/idx_{}.txt'.format(modeldir, dataname)).astype(np.intp)\n",
    "data_test = data_all[idx_test]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "serious-graham",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cpu\" #torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# set and load model (aux only)\n",
    "with open('{}/args.json'.format(modeldir), 'r') as f:\n",
    "    args_tr_dict = json.load(f)\n",
    "model = VAE(args_tr_dict).to(device)\n",
    "model.load_state_dict(torch.load('{}/model.pt'.format(modeldir), map_location=device))\n",
    "model.eval()\n",
    "print(\"model loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "designed-manner",
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference & reconstruction on test data\n",
    "data_tensor = torch.Tensor(data_test).to(device).view(-1,3,69,69)\n",
    "with torch.no_grad():\n",
    "    # aux only\n",
    "    z_phy_stat, z_aux2_stat, unmixed = model.encode(data_tensor)\n",
    "    z_phy, z_aux2 = model.draw(z_phy_stat, z_aux2_stat, hard_z=False)\n",
    "    x_full, _, _, _ = model.decode(z_phy, z_aux2, full=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "above-omaha",
   "metadata": {},
   "outputs": [],
   "source": [
    "# random generation\n",
    "torch.manual_seed(123456789)\n",
    "n=100\n",
    "with torch.no_grad():\n",
    "    # phy and aux\n",
    "    rand_z_aux2 = torch.randn(n, model.dim_z_aux2)*10.0\n",
    "    rand_I0 = torch.rand(n,1)*0.25+0.5\n",
    "    rand_A = torch.rand(n,1)*0.25+0.35\n",
    "    rand_e = torch.rand(n,1)*0.15+0.7\n",
    "    rand_theta = torch.rand(n,1)*2.3+0.5\n",
    "    rand_z_phy = torch.cat([rand_I0, rand_A, rand_e, rand_theta], dim=1)\n",
    "    rand_x_full, _, _, _ = model.decode(rand_z_phy, rand_z_aux2, full=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "serious-scroll",
   "metadata": {},
   "outputs": [],
   "source": [
    "# show / save\n",
    "H=4; W=8\n",
    "k=10 # margin [px]\n",
    "\n",
    "# phy and aux\n",
    "bigimage = torch.empty(0, 69*W+k*(W-1), 3)\n",
    "c=0\n",
    "for i in range(H):\n",
    "    tmp = torch.empty(69,0,3)\n",
    "    for j in range(W):\n",
    "        tmp = torch.cat([tmp, rand_x_full[c].cpu().permute(1,2,0)], dim=1); c+=1\n",
    "        if j<W-1:\n",
    "            tmp = torch.cat([tmp, torch.ones(69,k,3)], dim=1)\n",
    "    bigimage = torch.cat([bigimage, tmp.clone()], dim=0)\n",
    "    if i<H-1:\n",
    "        bigimage = torch.cat([bigimage, torch.ones(k, 69*W+k*(W-1), 3)], dim=0)\n",
    "\n",
    "plt.imshow(bigimage); plt.show()\n",
    "# plt.imsave('galaxy_randgen.png', bigimage.numpy())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
