{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "import omegaconf\n",
    "import json\n",
    "import time\n",
    "\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml \n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset\n",
    "from models.state_generators import AM_StateGenerator\n",
    "from utils import resize_and_norm, visual_compare, seed_everything, resize, norm_image\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AM_encoder\n",
      "device: cpu\n",
      "seed: 42\n",
      "state_generator:\n",
      "  am_encoder:\n",
      "    q_device: default.qubit\n",
      "    n_qubits: 4\n",
      "    noisy: false\n",
      "    AmplitudeDamping: 0.0\n",
      "    DepolarizingChannel: 0.0\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "dataloader:\n",
      "  batch_size: 1\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "config_path = rf\"../configs/AM_encoder.yaml\"\n",
    "n_qubits_to_eval = [4, 6, 8, 10, 12]\n",
    "config = omegaconf.OmegaConf.load(config_path)\n",
    "print(omegaconf.OmegaConf.to_yaml(config))\n",
    "\n",
    "seed_everything(config.seed)\n",
    "\n",
    "# get a MNIST sample\n",
    "mnist_dir = r\"../mnist/processed\"\n",
    "test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, \"mnist_test.pt\"))\n",
    "test_loader = DataLoader(test_ds, batch_size=1, shuffle=True)\n",
    "\n",
    "samples = next(iter(test_loader))\n",
    "\n",
    "\n",
    "def eval_ae_depth(target_state, ae_config):\n",
    "    am_state_generator = AM_StateGenerator(ae_config)\n",
    "    start = time.perf_counter()\n",
    "    specs = qml.specs(am_state_generator.am_encoder)(target_state)\n",
    "    duration = time.perf_counter() - start\n",
    "    resources = specs[\"resources\"]\n",
    "    num_gates = resources.num_gates\n",
    "    depth = resources.depth\n",
    "\n",
    "    return depth, num_gates, duration\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/miniconda3/envs/qenc/lib/python3.10/site-packages/torchvision/transforms/functional.py:1603: UserWarning: The default value of the antialias parameter of all the resizing transforms (Resize(), RandomResizedCrop(), etc.) will change from None to True in v0.17, in order to be consistent across the PIL and Tensor backends. To suppress this warning, directly pass antialias=True (recommended, future default), antialias=None (current default, which means False for Tensors and True for PIL), or antialias=False (only works on Tensors - PIL will still use antialiasing). This also applies if you are using the inference transforms from the models weights: update the call to weights.transforms(antialias=True).\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for n_qubits in n_qubits_to_eval:\n",
    "    config.state_generator.am_encoder.n_qubits = n_qubits \n",
    "    target_state = resize_and_norm(samples[\"images\"], config.state_generator.am_encoder.n_qubits).to(config.device)\n",
    "    assert target_state.shape[-1] == 2**n_qubits\n",
    "    # print(target_state.shape)\n",
    "    depth, num_gates, duration = eval_ae_depth(target_state, config)\n",
    "    results.append(\n",
    "        {\n",
    "            \"n_qubits\": n_qubits,\n",
    "            \"depth\": depth,\n",
    "            \"num_gates\": num_gates,\n",
    "            \"duration\": duration\n",
    "        }\n",
    "    )\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Results save to ../logs/eval/ae_depth.json\n"
     ]
    }
   ],
   "source": [
    "log_dir = r\"../logs/eval/\"\n",
    "file_path = os.path.join(log_dir, f\"ae_depth.json\")\n",
    "os.makedirs(log_dir, exist_ok=True)\n",
    "with open(file_path, \"w\") as f:\n",
    "    json.dump(results, f, indent=4)\n",
    "\n",
    "print(f\"Results save to {file_path}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qenc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
