{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "from omegaconf import OmegaConf\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml \n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset, FractalDB_Dataset\n",
    "from models.state_generators import StateGenerator, AAE_StateGenerator\n",
    "from utils import resize_and_norm, visual_compare, seed_everything, resize, norm_image\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AAE_encoder\n",
      "device: cpu\n",
      "seed: 42\n",
      "n_epochs: 10\n",
      "noise_factor: 0\n",
      "noisy_probability: 0\n",
      "state_generator:\n",
      "  loss: DotProd\n",
      "  n_train_step: 100\n",
      "  aae_encoder:\n",
      "    q_device: default.qubit\n",
      "    n_qubits: 4\n",
      "    n_encoder_layers: 8\n",
      "  optimizer:\n",
      "    name: Adam\n",
      "    args:\n",
      "      lr: 0.01\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "checkpoint:\n",
      "  logs: ./logs/superencoder/${version}\n",
      "  save_path: ./trained_models/superencoder_${version}_${state_generator.loss}.pt\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "folder_path = \"../eval_datasets/\"\n",
    "version = \"AAE_encoder.yaml\"\n",
    "config_dir = r\"../configs/\"\n",
    "try:\n",
    "    OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "except ValueError:\n",
    "    pass\n",
    "config = OmegaConf.load(os.path.join(config_dir, version))\n",
    "print(OmegaConf.to_yaml(config))\n",
    "\n",
    "seed_everything(config.seed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Synthetic Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'beta_a1b1': <torch.utils.data.dataloader.DataLoader at 0x79356fc8f9d0>,\n",
       " 'exponential_rate1': <torch.utils.data.dataloader.DataLoader at 0x79356fcddbd0>,\n",
       " 'uniform_low-1high1': <torch.utils.data.dataloader.DataLoader at 0x79356fcdd720>,\n",
       " 'uniform_low0high1': <torch.utils.data.dataloader.DataLoader at 0x79356fcdf940>,\n",
       " 'normal_mean0.3std0.5': <torch.utils.data.dataloader.DataLoader at 0x79356fcdf790>,\n",
       " 'lognormal_mean0std1': <torch.utils.data.dataloader.DataLoader at 0x79356fcdcfd0>}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_file_names_in_folder(folder_path):\n",
    "    file_names = []\n",
    "    for root, dirs, files in os.walk(folder_path):\n",
    "        for file in files:\n",
    "            file_names.append(file)\n",
    "    return file_names\n",
    "\n",
    "def norm_data(images):\n",
    "    images = images.reshape(images.shape[0], -1)\n",
    "    norms = torch.norm(images, p=2, dim=1, keepdim=True)\n",
    "    images = images / norms\n",
    "    return images\n",
    "\n",
    "file_names = get_file_names_in_folder(folder_path)\n",
    "distribution_files = {}\n",
    "\n",
    "for file_name in file_names:\n",
    "    distribution_name = file_name.split(\".pt\")[-2]\n",
    "    distribution_files[distribution_name] = folder_path +file_name\n",
    "\n",
    "\n",
    "n_samples = 256\n",
    "# normed_data = np.zeros((n_samples*16,len(distribution_files)))\n",
    "test_loaders = {}\n",
    "for idx, name in enumerate(distribution_files.keys()):\n",
    "    test_dataset = MNIST_AAE_Dataset(distribution_files[name])\n",
    "    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=n_samples, num_workers=0, pin_memory=True)\n",
    "    test_loaders[name] = test_loader\n",
    "\n",
    "test_loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([256, 1, 4, 4])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(test_loaders[\"beta_a1b1\"]))[\"images\"].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval AAE fidelity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAE_StateGenerator(\n",
      "  (criterion): FidLossDotProd()\n",
      "  (aae_encoder): <Quantum Torch Layer: func=aae_encoder>\n",
      ")\n",
      "Testing Dataset: beta_a1b1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/256 [00:00<?, ?it/s]/root/miniconda3/envs/qenc/lib/python3.10/site-packages/pennylane/qnn/torch.py:432: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at ../aten/src/ATen/native/Copy.cpp:299.)\n",
      "  return res.type(x.dtype)\n",
      "100%|██████████| 256/256 [21:23<00:00,  5.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: exponential_rate1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 256/256 [21:19<00:00,  5.00s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low-1high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 256/256 [21:22<00:00,  5.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low0high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 256/256 [21:21<00:00,  5.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: normal_mean0.3std0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 256/256 [21:21<00:00,  5.01s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: lognormal_mean0std1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 256/256 [21:21<00:00,  5.01s/it]\n"
     ]
    }
   ],
   "source": [
    "aae_state_generator = AAE_StateGenerator(config)\n",
    "print(aae_state_generator)\n",
    "results = {}\n",
    "targets = {}\n",
    "for key, loader in test_loaders.items():\n",
    "    results[key] = []\n",
    "    targets[key] = []\n",
    "    print(f\"Testing Dataset: {key}\")\n",
    "    samples = next(iter(loader))\n",
    "    target_states = samples[\"images\"]\n",
    "    for target_state in tqdm(target_states):\n",
    "        target_state = norm_data(target_state).to(config.device)\n",
    "        result_state = aae_state_generator(target_state)\n",
    "        results[key].append(result_state.detach().numpy())\n",
    "        targets[key].append(target_state.detach().numpy())\n",
    "    results[key] = np.vstack(results[key])\n",
    "    targets[key] = np.vstack(targets[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'beta_a1b1': 0.9995194295199359,\n",
       " 'exponential_rate1': 0.9995656161350166,\n",
       " 'uniform_low-1high1': 0.999313114479679,\n",
       " 'uniform_low0high1': 0.9995835958027853,\n",
       " 'normal_mean0.3std0.5': 0.9992234132799052,\n",
       " 'lognormal_mean0std1': 0.9993504915472622}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "fidelities = {}\n",
    "for key in results.keys():\n",
    "    fidelities[key] = qml.math.fidelity_statevector(results[key], targets[key]).mean()\n",
    "\n",
    "fidelities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(r\"./eval_result/\")\n",
    "np.save(r\"./eval_result/aae_result.npy\", results)\n",
    "np.save(r\"./eval_result/aae_targets.npy\", targets)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval AAE Runtime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAE_StateGenerator(\n",
      "  (criterion): FidLossDotProdAAE()\n",
      "  (aae_encoder): <Quantum Torch Layer: func=aae_encoder>\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'resources': Resources(num_wires=4, num_gates=56, gate_types=defaultdict(<class 'int'>, {'RY': 32, 'CNOT': 24}), gate_sizes=defaultdict(<class 'int'>, {1: 32, 2: 24}), depth=24, shots=Shots(total_shots=None, shot_vector=())),\n",
       " 'num_observables': 1,\n",
       " 'num_diagonalizing_gates': 0,\n",
       " 'num_trainable_params': 0,\n",
       " 'num_device_wires': 4,\n",
       " 'device_name': 'default.qubit.torch',\n",
       " 'expansion_strategy': 'gradient',\n",
       " 'gradient_options': {},\n",
       " 'interface': 'torch',\n",
       " 'diff_method': 'backprop',\n",
       " 'gradient_fn': 'backprop'}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(aae_state_generator)\n",
    "qml.specs(aae_state_generator.aae_encoder)(samples[\"images\"][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAE_StateGenerator(\n",
      "  (criterion): FidLossDotProdAAE()\n",
      "  (aae_encoder): <Quantum Torch Layer: func=aae_encoder>\n",
      ")\n",
      "Testing Dataset: beta_a1b1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:40<00:00,  5.01s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5.008619200118119\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "aae_state_generator = AAE_StateGenerator(config)\n",
    "print(aae_state_generator)\n",
    "results = {}\n",
    "targets = {}\n",
    "\n",
    "durations = []\n",
    "for key, loader in test_loaders.items():\n",
    "    results[key] = []\n",
    "    targets[key] = []\n",
    "    print(f\"Testing Dataset: {key}\")\n",
    "    samples = next(iter(loader))\n",
    "    target_states = samples[\"images\"]\n",
    "    for target_state in tqdm(target_states[:8]):\n",
    "        target_state = norm_data(target_state).to(config.device)\n",
    "        start_time = time.perf_counter()\n",
    "        result_state = aae_state_generator(target_state)\n",
    "        end_time = time.perf_counter()\n",
    "        durations.append(end_time - start_time)\n",
    "    break\n",
    "print(sum(durations)/len(durations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAE_StateGenerator(\n",
      "  (criterion): FidLossDotProdAAE()\n",
      "  (aae_encoder): <Quantum Torch Layer: func=aae_encoder>\n",
      ")\n",
      "Testing Dataset: beta_a1b1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:00<00:00, 25.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.03894654815667309\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "aae_state_generator = AAE_StateGenerator(config)\n",
    "print(aae_state_generator)\n",
    "results = {}\n",
    "targets = {}\n",
    "\n",
    "durations = []\n",
    "for key, loader in test_loaders.items():\n",
    "    results[key] = []\n",
    "    targets[key] = []\n",
    "    print(f\"Testing Dataset: {key}\")\n",
    "    samples = next(iter(loader))\n",
    "    target_states = samples[\"images\"]\n",
    "    for target_state in tqdm(target_states[:8]):\n",
    "        target_state = norm_data(target_state).to(config.device)\n",
    "        start_time = time.perf_counter()\n",
    "        result_state = aae_state_generator.compute_state()\n",
    "        end_time = time.perf_counter()\n",
    "        durations.append(end_time - start_time)\n",
    "    break\n",
    "print(sum(durations)/len(durations))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qenc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
