{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "\n",
    "import os\n",
    "import logging\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "\n",
    "import omegaconf\n",
    "import matplotlib.pylab as plt\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml\n",
    "\n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset\n",
    "from models.state_generators import AAE_StateGenerator\n",
    "from utils import resize_and_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AAE_encoder_10qubits\n",
      "device: cpu\n",
      "seed: 42\n",
      "n_epochs: 10\n",
      "noise_factor: 0\n",
      "noisy_probability: 0\n",
      "state_generator:\n",
      "  loss: DotProd\n",
      "  n_train_step: 100\n",
      "  aae_encoder:\n",
      "    q_device: default.qubit\n",
      "    n_qubits: 10\n",
      "    n_encoder_layers: 60\n",
      "    noisy: false\n",
      "    AmplitudeDamping: 0\n",
      "    DepolarizingChannel: 0\n",
      "  optimizer:\n",
      "    name: Adam\n",
      "    args:\n",
      "      lr: 0.01\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "checkpoint:\n",
      "  logs: ./logs/superencoder/${version}\n",
      "  save_path: ./trained_models/superencoder_${version}_${state_generator.loss}.pt\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "n_qubits = 10\n",
    "encoder_layers_range = [65, 80]\n",
    "\n",
    "config_path = rf\"../configs/AAE_encoder_{n_qubits}qubits.yaml\"\n",
    "\n",
    "config = omegaconf.OmegaConf.load(config_path)\n",
    "\n",
    "print(omegaconf.OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1024])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Text(0.5, 1.0, '0')"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGzCAYAAABpdMNsAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAjqElEQVR4nO3de3SV9Z3v8c9OSDa3ZMcQcpOAAeSiEJxSiTkqQyUlxDMuLOgC9ayC9cBAg2cgdWozx7udFcVzrJdDcXragq6KFxyBo8tiEUwcNGAJMIiXlNBYwuSCMpIdAtkJye/84bg7KaDPL9nJjx3er7WetcjzfPnu78OT8MmTvfPbPmOMEQAAfSzG9QAAgAsTAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAvpIKBTS3XffrczMTA0aNEi5ubnaunWr67EAZwggoI8sWrRIjz/+uG677TY9+eSTio2N1fXXX68dO3a4Hg1wwsdipEDve//995Wbm6vHHntMd911lySptbVVkyZNUmpqqt577z3HEwJ9jzsgoA+88sorio2N1ZIlS8L7Bg4cqDvuuEMVFRWqra11OB3gBgEE9IG9e/dq3LhxSkxM7LJ/2rRpkqR9+/Y5mApwiwAC+kB9fb0yMjLO2P/Vvrq6ur4eCXCOAAL6wKlTp+T3+8/YP3DgwPBx4EJDAAF9YNCgQQqFQmfsb21tDR8HLjQEENAHMjIyVF9ff8b+r/ZlZmb29UiAcwQQ0AeuuOIK/eEPf1AwGOyyf9euXeHjwIWGAAL6wE033aSOjg794he/CO8LhUJau3atcnNzlZWV5XA6wI0BrgcALgS5ubm6+eabVVJSoqNHj2rs2LF69tln9emnn+pXv/qV6/EAJ1gJAegjra2tuvfee/Wb3/xGX3zxhXJycvTwww+roKDA9WiAEwQQAMAJngMCADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMCJ8+4XUTs7O1VXV6eEhAT5fD7X4wAALBlj1NzcrMzMTMXEnPs+57wLoLq6OpYlAYB+oLa2ViNGjDjn8fMugBISEiRJ1+h6DVCc42kAALZOq1079Eb4//Nz6bUAWr16tR577DE1NDRoypQpevrpp8NvP/x1vvqx2wDFaYCPAAKAqPMf6+t809MovfIihJdeeknFxcW6//77tWfPHk2ZMkUFBQU6evRobzwcACAK9UoAPf7441q8eLFuv/12XXbZZXrmmWc0ePBg/frXvz6jNhQKKRgMdtkAAP1fxAOora1NlZWVys/P//ODxMQoPz9fFRUVZ9SXlpYqEAiEN16AAAAXhogH0Oeff66Ojg6lpaV12Z+WlqaGhoYz6ktKStTU1BTeamtrIz0SAOA85PxVcH6/X36/3/UYAIA+FvE7oJSUFMXGxqqxsbHL/sbGRqWnp0f64QAAUSriARQfH6+pU6dq27Zt4X2dnZ3atm2b8vLyIv1wAIAo1Ss/gisuLtbChQv17W9/W9OmTdMTTzyhlpYW3X777b3xcACAKNQrATR//nx99tlnuu+++9TQ0KArrrhCW7ZsOeOFCQCAC5fPGGNcD/GfBYNBBQIBzdAcVkIAgCh02rSrTJvV1NSkxMTEc9bxdgwAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwYoDrAdD/xQwZ4rk2dPVEq96f5cR7rm1NNVa9Tyd2eK4dWGf3pZT8SadVfVLFEc+1HfUNVr3N6dNW9UCkcAcEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOMFSPLDn81mVt+Rf7rm2bn6bVe+/zXnbc23u4ENWvdNjWzzX7mkdYdX7ubo8q/pPx1ziuXbUP/utenf8we7fBYgU7oAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATrAUHazGTx1vV11us7/bsVb+26p0Q4733M5/NsOpdUXeJ59qbs/da9d44brNV/VPDJ3iu3VSdb9U74Y9/8lxrTp+26g18He6AAABORDyAHnjgAfl8vi7bhAnev3sDAFwYeuVHcJdffrneeuutPz/IAH7SBwDoqleSYcCAAUpPT++N1gCAfqJXngM6ePCgMjMzNXr0aN122206fPjwOWtDoZCCwWCXDQDQ/0U8gHJzc7Vu3Tpt2bJFa9asUU1Nja699lo1Nzeftb60tFSBQCC8ZWVlRXokAMB5KOIBVFhYqJtvvlk5OTkqKCjQG2+8oePHj+vll18+a31JSYmamprCW21tbaRHAgCch3r91QFJSUkaN26cqqurz3rc7/fL77d7D3sAQPTr9d8DOnHihA4dOqSMjIzefigAQBSJeADdddddKi8v16effqr33ntP3/ve9xQbG6tbbrkl0g8FAIhiEf8R3JEjR3TLLbfo2LFjGj58uK655hrt3LlTw4cPj/RDwZHDf5NsVb/iitc81w7xtVv1XvbJrZ5rO59Nteqd8dYhz7Uv3TbTqvdNK/dY1RdfdNBz7XOXFFj1DgQSPdd2HPt3q97A14l4AL344ouRbgkA6IdYCw4A4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwotffjgH9kLEr3xMc5bn2f71nt47ZpWtPe671vbvTqnfnAO9fHhe/+ZlV7zcWX25V/98Dn3iuPZHt/d9EksyINO/FrAWHCOIOCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCpXhgbdTqA1b19ZsyPNdeFqyz6t3ReNRzreUKQjKnLZa0+cxuiZqPW7z/m0hSa+KHnmsHDj9l1bv9okGea2OtOgNfjzsgAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBGvBwVpHMGhV7zt50nOt6bRcsa2zw67eQuzw4Z5r6xZcatV76UXrrOoTYuI91/p8lv+GPrtyIFK4AwIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE6wFhx6nTl9utd6+wZ4/xTunHa5Ve+PF3lff21p3lar3tcO/Nyq3u8b7Ln24Zz/Z9X70bsKPNe2T/wvVr0zN/7Rc+3p+gar3oh+3AEBAJywDqB33nlHN9xwgzIzM+Xz+bRp06Yux40xuu+++5SRkaFBgwYpPz9fBw8ejNS8AIB+wjqAWlpaNGXKFK1evfqsx1etWqWnnnpKzzzzjHbt2qUhQ4aooKBAra2tPR4WANB/WD8HVFhYqMLCwrMeM8boiSee0D333KM5c+ZIkp577jmlpaVp06ZNWrBgQc+mBQD0GxF9DqimpkYNDQ3Kz88P7wsEAsrNzVVFRcVZ/04oFFIwGOyyAQD6v4gGUEPDl69iSUtL67I/LS0tfOwvlZaWKhAIhLesrKxIjgQAOE85fxVcSUmJmpqawlttba3rkQAAfSCiAZSeni5Jamxs7LK/sbExfOwv+f1+JSYmdtkAAP1fRAMoOztb6enp2rZtW3hfMBjUrl27lJeXF8mHAgBEOetXwZ04cULV1dXhj2tqarRv3z4lJydr5MiRWrFihX7605/q0ksvVXZ2tu69915lZmbqxhtvjOTcAIAoZx1Au3fv1ne+853wx8XFxZKkhQsXat26dfrxj3+slpYWLVmyRMePH9c111yjLVu2aODAgZGbGvgPvnjvy+V8McH7cjaS9MuZv/BcO3NQh1XvzzuMVf3vTsZ5rr3CX2fV+9nLn/VcW5o826p3ZWCS59pR/zzEqnfHHw5Z1eP8Yx1AM2bMkDHn/uLx+Xx66KGH9NBDD/VoMABA/+b8VXAAgAsTAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcMJ6KR4gWplYu/p24/3L49Fjo616/1PldKv6uH/zvuZdW3q7Ve/CnAOeaxemvmvVO+3mZs+1b8huxfwRpawFF+24AwIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcYCkeRDXTftpzbVJ1yKr3D7cs8lw7+IjdOj8T1lZb1Xcc/cxz7YCLM616vzvvW55rTyzwviSQJP304jc8134+Z6hV7/otl3uuNf/6iVVvdXbY1aNbuAMCADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOsBYcopppb/NcG1u216r3uH/xvr6bOe19TTpJ6s2Vxk4f+Ter+hGbvZ/nrhTv669J0p4FezzXLkvbbtV7/g9/6Ll2QvFgq96dzc1W9ege7oAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJ1iKBxcOY+zKLZfXiVanPz3suTZ7U4JV75Wpt3iurf6bf7Lq/Ur+as+19w25wao3S/H0De6AAABOEEAAACesA+idd97RDTfcoMzMTPl8Pm3atKnL8UWLFsnn83XZZs+eHal5AQD9hHUAtbS0aMqUKVq9+tw/f509e7bq6+vD2wsvvNCjIQEA/Y/1ixAKCwtVWFj4tTV+v1/p6endHgoA0P/1ynNAZWVlSk1N1fjx47Vs2TIdO3bsnLWhUEjBYLDLBgDo/yIeQLNnz9Zzzz2nbdu26dFHH1V5ebkKCwvV0XH294AsLS1VIBAIb1lZWZEeCQBwHor47wEtWLAg/OfJkycrJydHY8aMUVlZmWbOnHlGfUlJiYqLi8MfB4NBQggALgC9/jLs0aNHKyUlRdXV1Wc97vf7lZiY2GUDAPR/vR5AR44c0bFjx5SRkdHbDwUAiCLWP4I7ceJEl7uZmpoa7du3T8nJyUpOTtaDDz6oefPmKT09XYcOHdKPf/xjjR07VgUFBREdHAAQ3awDaPfu3frOd74T/vir528WLlyoNWvWaP/+/Xr22Wd1/PhxZWZmatasWXr44Yfl9/sjNzUAJ2L+WGdVn/jRRM+1h2eftOod5/P+31dHVqpVb9+xL6zqTXubVT2+ZB1AM2bMkPmaRR3ffPPNHg0EALgwsBYcAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4ETE3w8I5wdfXLzn2pYb/sqq95HvnnspprMZN877+mGfvWL3XlCpv6z0XMt6XT1nTp2yqo876f1zpdXYfT8cp07Pte2J3r8eJCk+zu6/Rj63uoc7IACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJluKJErEpw6zqG+eO81ybdsufrHo/evG7VvWx8r4cywND/5tVbxnvy7Gg52LSU63qm0d6r80aYPf9cHW7z3Ot//AXVr07WkNW9ege7oAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATrAUXLYZdZFUemh30XPv06Jetev/vo/lW9Vu3/ZXn2tH/csKqt+nosKpHz7Rl2a1JaMae9Fwb54u16v1hW7r34i+arHqrk8+rvsAdEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAES/FECTMwzqr+yszDnmszYuOtem/74zir+tGvWiyv8/4HVr3Rc7Ep3pfXaZg6yKr39y/f7rn2j+3tVr3/8YNCz7WXtHn/ekDf4Q4IAOAEAQQAcMIqgEpLS3XllVcqISFBqampuvHGG1VVVdWlprW1VUVFRRo2bJiGDh2qefPmqbGxMaJDAwCin1UAlZeXq6ioSDt37tTWrVvV3t6uWbNmqaWlJVyzcuVKvfbaa9qwYYPKy8tVV1enuXPnRnxwAEB0s3oRwpYtW7p8vG7dOqWmpqqyslLTp09XU1OTfvWrX2n9+vW67rrrJElr167VxIkTtXPnTl111VVn9AyFQgqFQuGPg0Hv72MDAIhePXoOqKnpyzd5Sk5OliRVVlaqvb1d+fl/fsOyCRMmaOTIkaqoqDhrj9LSUgUCgfCWlZXVk5EAAFGi2wHU2dmpFStW6Oqrr9akSZMkSQ0NDYqPj1dSUlKX2rS0NDU0NJy1T0lJiZqamsJbbW1td0cCAESRbv8eUFFRkQ4cOKAdO3b0aAC/3y+/39+jHgCA6NOtO6Dly5fr9ddf19tvv60RI0aE96enp6utrU3Hjx/vUt/Y2Kj0dIv3bwcA9HtWAWSM0fLly7Vx40Zt375d2dnZXY5PnTpVcXFx2rZtW3hfVVWVDh8+rLy8vMhMDADoF6x+BFdUVKT169dr8+bNSkhICD+vEwgENGjQIAUCAd1xxx0qLi5WcnKyEhMTdeeddyovL++sr4ADAFy4rAJozZo1kqQZM2Z02b927VotWrRIkvSzn/1MMTExmjdvnkKhkAoKCvTzn/88IsPi/DAy5Qur+pMXZ3quHZqQYNW7s7nZqj5q+XyeSwdcMtKq9efXer8+gwvsfqn8zuS9nmufa5po1Ttpw1DPtZ2nWq16o29YBZAx5htrBg4cqNWrV2v16tXdHgoA0P+xFhwAwAkCCADgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwIluvx0D+pavtd2qvrJhxDcX/Yd/z2yz6n1v9mtW9QsLlniuHdU6war3kA/qPNealpNWva3EeF8qR5J0UcCqvCPF+xJFNd/1vkSNJF0x+2PPtf/z4jesen/cFu+5dvWH0616Z7/xoefazna7z3H0De6AAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAE6wFFyU6P621qk9/1Puaai/83ylWvZckHbCqf6Xg/3iufXjiDVa9P6oY7bk2UGXV2oqx/Er692tDVvU/+vZWz7XzEz6x6h2IGei5trrdWPX+Hx8v8Fx7ySq73p3NzVb1OP9wBwQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIIAAgA4wVI8UcKE7JZuGfBhjefaNdvzrXpf+V+995akqwe2e65dP+Y1q94nRnvv3WbslnrpTUN8dt/7DY3xe66N8w2x6v188zDPtfdsn2fV+9LftHmuNXs+sOqN6McdEADACQIIAOAEAQQAcIIAAgA4QQABAJwggAAAThBAAAAnCCAAgBMEEADACQIIAOAEAQQAcIK14PqpjuZmz7UTH6+36n3XB39rVX98Rqvn2pysI1a956Xt8Vx7W8Ixq9696WSn9zXSJOn2P830XFuxc4JV77T3vddO/Fe7f0NTU+u5trOzw6o3oh93QAAAJ6wCqLS0VFdeeaUSEhKUmpqqG2+8UVVVVV1qZsyYIZ/P12VbunRpRIcGAEQ/qwAqLy9XUVGRdu7cqa1bt6q9vV2zZs1SS0tLl7rFixervr4+vK1atSqiQwMAop/Vc0Bbtmzp8vG6deuUmpqqyspKTZ8+Pbx/8ODBSk9Pj8yEAIB+qUfPATU1NUmSkpOTu+x//vnnlZKSokmTJqmkpEQnT548Z49QKKRgMNhlAwD0f91+FVxnZ6dWrFihq6++WpMmTQrvv/XWWzVq1ChlZmZq//79uvvuu1VVVaVXX331rH1KS0v14IMPdncMAECU6nYAFRUV6cCBA9qxY0eX/UuWLAn/efLkycrIyNDMmTN16NAhjRkz5ow+JSUlKi4uDn8cDAaVlZXV3bEAAFGiWwG0fPlyvf7663rnnXc0YsSIr63Nzc2VJFVXV581gPx+v/x+7+93DwDoH6wCyBijO++8Uxs3blRZWZmys7O/8e/s27dPkpSRkdGtAQEA/ZNVABUVFWn9+vXavHmzEhIS1NDQIEkKBAIaNGiQDh06pPXr1+v666/XsGHDtH//fq1cuVLTp09XTk5Or5wAACA6WQXQmjVrJH35y6b/2dq1a7Vo0SLFx8frrbfe0hNPPKGWlhZlZWVp3rx5uueeeyI2MACgf/AZY4zrIf6zYDCoQCCgGZqjAb441+PgLGKHD7eq77zE+++EtSfGW/UOJXn/HiqU2HsrT/lsv4os64fWt3uuHVjbZNe8/qjn0o7gCbverO92QTpt2lWmzWpqalJiYuI561gLDgDgBAEEAHCCAAIAOEEAAQCcIIAAAE4QQAAAJwggAIATBBAAwAkCCADgBAEEAHCi2+8HhAtXx2ef2f0Fi3rbT0ib+iGWvaMVi98gWnAHBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOWAXQmjVrlJOTo8TERCUmJiovL0+//e1vw8dbW1tVVFSkYcOGaejQoZo3b54aGxsjPjQAIPpZBdCIESP0yCOPqLKyUrt379Z1112nOXPm6MMPP5QkrVy5Uq+99po2bNig8vJy1dXVae7cub0yOAAguvmMMaYnDZKTk/XYY4/ppptu0vDhw7V+/XrddNNNkqRPPvlEEydOVEVFha666ipP/YLBoAKBgGZojgb44noyGgDAgdOmXWXarKamJiUmJp6zrtvPAXV0dOjFF19US0uL8vLyVFlZqfb2duXn54drJkyYoJEjR6qiouKcfUKhkILBYJcNAND/WQfQBx98oKFDh8rv92vp0qXauHGjLrvsMjU0NCg+Pl5JSUld6tPS0tTQ0HDOfqWlpQoEAuEtKyvL+iQAANHHOoDGjx+vffv2adeuXVq2bJkWLlyojz76qNsDlJSUqKmpKbzV1tZ2uxcAIHoMsP0L8fHxGjt2rCRp6tSp+v3vf68nn3xS8+fPV1tbm44fP97lLqixsVHp6enn7Of3++X3++0nBwBEtR7/HlBnZ6dCoZCmTp2quLg4bdu2LXysqqpKhw8fVl5eXk8fBgDQz1jdAZWUlKiwsFAjR45Uc3Oz1q9fr7KyMr355psKBAK64447VFxcrOTkZCUmJurOO+9UXl6e51fAAQAuHFYBdPToUX3/+99XfX29AoGAcnJy9Oabb+q73/2uJOlnP/uZYmJiNG/ePIVCIRUUFOjnP/95rwwOAIhuPf49oEjj94AAILr1+u8BAQDQEwQQAMAJAggA4AQBBABwggACADhBAAEAnCCAAABOEEAAACcIIACAE9arYfe2rxZmOK126bxaowEA4MVptUv68//n53LeBVBzc7MkaYfecDwJAKAnmpubFQgEznn8vFsLrrOzU3V1dUpISJDP5wvvDwaDysrKUm1t7deuLRTtOM/+40I4R4nz7G8icZ7GGDU3NyszM1MxMed+pue8uwOKiYnRiBEjznk8MTGxX1/8r3Ce/ceFcI4S59nf9PQ8v+7O5yu8CAEA4AQBBABwImoCyO/36/7775ff73c9Sq/iPPuPC+EcJc6zv+nL8zzvXoQAALgwRM0dEACgfyGAAABOEEAAACcIIACAEwQQAMCJqAmg1atX65JLLtHAgQOVm5ur999/3/VIEfXAAw/I5/N12SZMmOB6rB555513dMMNNygzM1M+n0+bNm3qctwYo/vuu08ZGRkaNGiQ8vPzdfDgQTfD9sA3neeiRYvOuLazZ892M2w3lZaW6sorr1RCQoJSU1N14403qqqqqktNa2urioqKNGzYMA0dOlTz5s1TY2Ojo4m7x8t5zpgx44zruXTpUkcTd8+aNWuUk5MTXu0gLy9Pv/3tb8PH++paRkUAvfTSSyouLtb999+vPXv2aMqUKSooKNDRo0ddjxZRl19+uerr68Pbjh07XI/UIy0tLZoyZYpWr1591uOrVq3SU089pWeeeUa7du3SkCFDVFBQoNbW1j6etGe+6Twlafbs2V2u7QsvvNCHE/ZceXm5ioqKtHPnTm3dulXt7e2aNWuWWlpawjUrV67Ua6+9pg0bNqi8vFx1dXWaO3euw6nteTlPSVq8eHGX67lq1SpHE3fPiBEj9Mgjj6iyslK7d+/Wddddpzlz5ujDDz+U1IfX0kSBadOmmaKiovDHHR0dJjMz05SWljqcKrLuv/9+M2XKFNdj9BpJZuPGjeGPOzs7TXp6unnsscfC+44fP278fr954YUXHEwYGX95nsYYs3DhQjNnzhwn8/SWo0ePGkmmvLzcGPPltYuLizMbNmwI13z88cdGkqmoqHA1Zo/95XkaY8xf//Vfm7/7u79zN1Qvueiii8wvf/nLPr2W5/0dUFtbmyorK5Wfnx/eFxMTo/z8fFVUVDicLPIOHjyozMxMjR49WrfddpsOHz7seqReU1NTo4aGhi7XNRAIKDc3t99dV0kqKytTamqqxo8fr2XLlunYsWOuR+qRpqYmSVJycrIkqbKyUu3t7V2u54QJEzRy5Miovp5/eZ5fef7555WSkqJJkyappKREJ0+edDFeRHR0dOjFF19US0uL8vLy+vRannerYf+lzz//XB0dHUpLS+uyPy0tTZ988omjqSIvNzdX69at0/jx41VfX68HH3xQ1157rQ4cOKCEhATX40VcQ0ODJJ31un51rL+YPXu25s6dq+zsbB06dEj/8A//oMLCQlVUVCg2Ntb1eNY6Ozu1YsUKXX311Zo0aZKkL69nfHy8kpKSutRG8/U823lK0q233qpRo0YpMzNT+/fv1913362qqiq9+uqrDqe198EHHygvL0+tra0aOnSoNm7cqMsuu0z79u3rs2t53gfQhaKwsDD855ycHOXm5mrUqFF6+eWXdccddzicDD21YMGC8J8nT56snJwcjRkzRmVlZZo5c6bDybqnqKhIBw4ciPrnKL/Juc5zyZIl4T9PnjxZGRkZmjlzpg4dOqQxY8b09ZjdNn78eO3bt09NTU165ZVXtHDhQpWXl/fpDOf9j+BSUlIUGxt7xiswGhsblZ6e7miq3peUlKRx48apurra9Si94qtrd6FdV0kaPXq0UlJSovLaLl++XK+//rrefvvtLu/blZ6erra2Nh0/frxLfbRez3Od59nk5uZKUtRdz/j4eI0dO1ZTp05VaWmppkyZoieffLJPr+V5H0Dx8fGaOnWqtm3bFt7X2dmpbdu2KS8vz+FkvevEiRM6dOiQMjIyXI/SK7Kzs5Went7lugaDQe3atatfX1dJOnLkiI4dOxZV19YYo+XLl2vjxo3avn27srOzuxyfOnWq4uLiulzPqqoqHT58OKqu5zed59ns27dPkqLqep5NZ2enQqFQ317LiL6koZe8+OKLxu/3m3Xr1pmPPvrILFmyxCQlJZmGhgbXo0XMj370I1NWVmZqamrMu+++a/Lz801KSoo5evSo69G6rbm52ezdu9fs3bvXSDKPP/642bt3r/nTn/5kjDHmkUceMUlJSWbz5s1m//79Zs6cOSY7O9ucOnXK8eR2vu48m5ubzV133WUqKipMTU2Neeutt8y3vvUtc+mll5rW1lbXo3u2bNkyEwgETFlZmamvrw9vJ0+eDNcsXbrUjBw50mzfvt3s3r3b5OXlmby8PIdT2/um86yurjYPPfSQ2b17t6mpqTGbN282o0ePNtOnT3c8uZ2f/OQnpry83NTU1Jj9+/ebn/zkJ8bn85nf/e53xpi+u5ZREUDGGPP000+bkSNHmvj4eDNt2jSzc+dO1yNF1Pz5801GRoaJj483F198sZk/f76prq52PVaPvP3220bSGdvChQuNMV++FPvee+81aWlpxu/3m5kzZ5qqqiq3Q3fD153nyZMnzaxZs8zw4cNNXFycGTVqlFm8eHHUffN0tvOTZNauXRuuOXXqlPnhD39oLrroIjN48GDzve99z9TX17sbuhu+6TwPHz5spk+fbpKTk43f7zdjx441f//3f2+amprcDm7pBz/4gRk1apSJj483w4cPNzNnzgyHjzF9dy15PyAAgBPn/XNAAID+iQACADhBAAEAnCCAAABOEEAAACcIIACAEwQQAMAJAggA4AQBBABwggACADhBAAEAnPj/M+wUU49hNyQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# MNIST samples\n",
    "mnist_dir = r\"../mnist/processed\"\n",
    "test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, \"mnist_test.pt\"))\n",
    "test_loader = DataLoader(test_ds, batch_size=1, shuffle=True)\n",
    "\n",
    "samples = next(iter(test_loader))\n",
    "\n",
    "target_state = resize_and_norm(samples[\"images\"], config.state_generator.aae_encoder.n_qubits).to(config.device)\n",
    "print(target_state.shape)\n",
    "assert target_state.shape[-1] == 2**n_qubits\n",
    "\n",
    "size = int((2**n_qubits)**0.5)\n",
    "plt.imshow(target_state.view(size, size))\n",
    "plt.title(samples[\"digits\"].item())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def eval_aae_fidelity(target_state, aae_config):\n",
    "    aae = AAE_StateGenerator(aae_config)\n",
    "    start = time.perf_counter()\n",
    "    result_state = aae(target_state)\n",
    "    duration = time.perf_counter() - start\n",
    "    fidelity = qml.math.fidelity_statevector(result_state, target_state).item()\n",
    "    return fidelity, duration\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [02:34<00:00, 154.68s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'duration': 154.66575022300003,\n",
      "  'fidelity': 0.8934515715476462,\n",
      "  'n_qubits': 10,\n",
      "  'num_encoder_layers': 60}]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for num_encoder_layers in tqdm(range(encoder_layers_range[0], encoder_layers_range[1]+1)):\n",
    "    config.state_generator.aae_encoder.n_encoder_layers = num_encoder_layers\n",
    "    fidelity, duration = eval_aae_fidelity(target_state, config)\n",
    "    results.append(\n",
    "        {\n",
    "            \"n_qubits\": n_qubits,\n",
    "            \"num_encoder_layers\": num_encoder_layers,\n",
    "            \"fidelity\": fidelity,\n",
    "            \"duration\": duration\n",
    "        }\n",
    "    )\n",
    "\n",
    "from pprint import pprint\n",
    "pprint(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "log_dir = r\"../logs/eval/aae_depth/\"\n",
    "file_path = os.path.join(log_dir, f\"{n_qubits}qubits.json\")\n",
    "os.makedirs(log_dir, exist_ok=True)\n",
    "with open(file_path, \"w\") as f:\n",
    "    json.dump(results, f, indent=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qenc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
