{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import necessary modules\n",
    "\n",
    "import sys\n",
    "import os\n",
    "\n",
    "# Set root folder to project root\n",
    "os.chdir(os.path.dirname(os.getcwd()))\n",
    "\n",
    "# Add root folder to path\n",
    "sys.path.append(os.getcwd())\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "\n",
    "from src.utils.loading import get_mazes, load_model\n",
    "from src.utils.tda import get_diagram, get_betti_nums\n",
    "from src.utils.plotting import plot_mazes, plot_residuals, plot_latents, plot_diagram"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model and mazes\n",
    "\n",
    "model = load_model('dt_net')\n",
    "maze_size = 19\n",
    "maze_idx = 25\n",
    "start_idx = 3001\n",
    "end_idx = 3400\n",
    "iters = list(range(start_idx, end_idx+1))\n",
    "\n",
    "inputs, solutions = get_mazes(\n",
    "    dataset='maze-dataset', \n",
    "    maze_size=maze_size, \n",
    "    num_mazes=100,\n",
    "    percolation=0.0,\n",
    "    deadend_start=True)\n",
    "\n",
    "# Get latent series\n",
    "input = inputs[maze_idx:maze_idx+1]\n",
    "solution = solutions[maze_idx:maze_idx+1]\n",
    "latents = model.input_to_latent(input)\n",
    "latents = model.latent_forward(latents, input, iters=iters)\n",
    "output = model.latent_to_output(latents[-1])\n",
    "prediction = model.output_to_prediction(output, input)\n",
    "\n",
    "plot_mazes(input, prediction, solution)\n",
    "\n",
    "print(f'{latents.shape=}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_residuals(latents, start_idx=start_idx, fig_size=(6,6), font_size=16, file_name=f'outputs/residuals/{model.name()}_size-{maze_size}_idx-{maze_idx}');\n",
    "# latents.shape = (400, 1, 128, 44, 44)\n",
    "max_norm = torch.max(torch.norm(latents.reshape(latents.shape[0], -1), dim=1)).item()\n",
    "max_entry = torch.max(torch.abs(latents)).item()\n",
    "print(f'{max_norm = :.2f}')\n",
    "print(f'{max_entry = :.2f}')\n",
    "print(f'{latents.reshape(latents.shape[0], -1)[0,:10] = }')\n",
    "print(f'{latents.reshape(latents.shape[0], -1)[-1,:10] = }')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "#plot_latents(latents, animate=True, duration=10, skip_frames=8, fig_size=(8, 8), file_name=f'{model.name()}_size-{maze_size}_idx-{maze_idx}')\n",
    "plot_latents(latents, animate=False, fig_size=(8, 8), font_size=12, file_name=f'outputs/pca/{model.name()}_size-{maze_size}_idx-{maze_idx}')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "threshold = 0.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diagram, _ = get_diagram(latents,\n",
    "                         embed_dim=0,\n",
    "                         max_homo=1,\n",
    "                         delay=1)\n",
    "betti_nums = get_betti_nums(diagram, threshold=threshold)\n",
    "print(f'{betti_nums = }')\n",
    "plot_diagram(diagram, threshold=threshold)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "diagram, _ = get_diagram(latents,\n",
    "                         dtype=np.float32)\n",
    "betti_nums = get_betti_nums(diagram, threshold=threshold)\n",
    "print(f'{betti_nums = }')\n",
    "plot_diagram(diagram, threshold=threshold)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
