{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "\n",
    "import os\n",
    "from omegaconf import OmegaConf\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml \n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset, FractalDB_Dataset\n",
    "from models.state_generators import StateGenerator, AAE_StateGenerator, AM_StateGenerator\n",
    "from utils import resize_and_norm, visual_compare, seed_everything, resize, norm_image\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Eval AAE fidelity on noisy simulator first to determine how noisy it is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AAE_encoder_noisy\n",
      "device: cpu\n",
      "seed: 42\n",
      "n_epochs: 10\n",
      "noise_factor: 0\n",
      "noisy_probability: 0\n",
      "state_generator:\n",
      "  loss: DotProd\n",
      "  n_train_step: 100\n",
      "  loss_args:\n",
      "    avg: true\n",
      "    noisy: true\n",
      "  aae_encoder:\n",
      "    q_device: default.mixed\n",
      "    n_qubits: 4\n",
      "    n_encoder_layers: 16\n",
      "    noisy: true\n",
      "    AmplitudeDamping: 0.005\n",
      "    DepolarizingChannel: 0.005\n",
      "  optimizer:\n",
      "    name: Adam\n",
      "    args:\n",
      "      lr: 0.01\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "checkpoint:\n",
      "  logs: ./logs/superencoder/${version}\n",
      "  save_path: ./trained_models/superencoder_${version}_${state_generator.loss}.pt\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "folder_path = \"../eval_datasets/\"\n",
    "version = \"AAE_encoder_noisy.yaml\"\n",
    "config_dir = r\"../configs/\"\n",
    "try:\n",
    "    OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "except ValueError:\n",
    "    pass\n",
    "config = OmegaConf.load(os.path.join(config_dir, version))\n",
    "print(OmegaConf.to_yaml(config))\n",
    "\n",
    "seed_everything(config.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'beta_a1b1': <torch.utils.data.dataloader.DataLoader at 0x74fad04f81c0>,\n",
       " 'exponential_rate1': <torch.utils.data.dataloader.DataLoader at 0x74facbe01450>,\n",
       " 'uniform_low-1high1': <torch.utils.data.dataloader.DataLoader at 0x74facb44e050>,\n",
       " 'uniform_low0high1': <torch.utils.data.dataloader.DataLoader at 0x74facb44e4a0>,\n",
       " 'normal_mean0.3std0.5': <torch.utils.data.dataloader.DataLoader at 0x74facb44de10>,\n",
       " 'lognormal_mean0std1': <torch.utils.data.dataloader.DataLoader at 0x74facb44f1c0>}"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_file_names_in_folder(folder_path):\n",
    "    file_names = []\n",
    "    for root, dirs, files in os.walk(folder_path):\n",
    "        for file in files:\n",
    "            file_names.append(file)\n",
    "    return file_names\n",
    "\n",
    "def norm_data(images):\n",
    "    images = images.reshape(images.shape[0], -1)\n",
    "    norms = torch.norm(images, p=2, dim=1, keepdim=True)\n",
    "    images = images / norms\n",
    "    return images\n",
    "\n",
    "file_names = get_file_names_in_folder(folder_path)\n",
    "distribution_files = {}\n",
    "\n",
    "for file_name in file_names:\n",
    "    distribution_name = file_name.split(\".pt\")[-2]\n",
    "    distribution_files[distribution_name] = folder_path +file_name\n",
    "\n",
    "\n",
    "n_samples = 1\n",
    "# normed_data = np.zeros((n_samples*16,len(distribution_files)))\n",
    "test_loaders = {}\n",
    "for idx, name in enumerate(distribution_files.keys()):\n",
    "    test_dataset = MNIST_AAE_Dataset(distribution_files[name])\n",
    "    test_loader = DataLoader(test_dataset, shuffle=True, batch_size=n_samples, num_workers=0, pin_memory=True)\n",
    "    test_loaders[name] = test_loader\n",
    "\n",
    "test_loaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 1, 4, 4])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "next(iter(test_loaders[\"beta_a1b1\"]))[\"images\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AAE_StateGenerator(\n",
      "  (criterion): FidLossDotProdAAE()\n",
      "  (aae_encoder): <Quantum Torch Layer: func=aae_encoder_noisy>\n",
      ")\n",
      "Testing Dataset: beta_a1b1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.49s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: exponential_rate1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.10s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low-1high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.03s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low0high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.35s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: normal_mean0.3std0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.69s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: lognormal_mean0std1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:57<00:00, 57.13s/it]\n"
     ]
    }
   ],
   "source": [
    "aae_state_generator = AAE_StateGenerator(config)\n",
    "print(aae_state_generator)\n",
    "results = {}\n",
    "targets = {}\n",
    "for key, loader in test_loaders.items():\n",
    "    results[key] = []\n",
    "    targets[key] = []\n",
    "    print(f\"Testing Dataset: {key}\")\n",
    "    samples = next(iter(loader))\n",
    "    target_states = samples[\"images\"]\n",
    "    for target_state in tqdm(target_states):\n",
    "        target_state = norm_data(target_state).to(config.device)\n",
    "        result_state = aae_state_generator(target_state)\n",
    "        results[key].append(result_state.detach().numpy())\n",
    "        targets[key].append(target_state.detach().numpy())\n",
    "    if aae_state_generator.is_noisy:\n",
    "        results[key] = np.stack(results[key])\n",
    "    else:\n",
    "        results[key] = np.vstack(results[key])\n",
    "    targets[key] = np.vstack(targets[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'beta_a1b1': 0.41255587228208584, 'exponential_rate1': 0.46088000080203395, 'uniform_low-1high1': 0.39403896817987794, 'uniform_low0high1': 0.4256582863028516, 'normal_mean0.3std0.5': 0.4168016048307988, 'lognormal_mean0std1': 0.45828239570339485}\n",
      "0.42803618801684057\n"
     ]
    }
   ],
   "source": [
    "fidelities = {}\n",
    "for key in results.keys():\n",
    "    if aae_state_generator.is_noisy:\n",
    "        target_mat = qml.math.dm_from_state_vector(targets[key])\n",
    "        fidelities[key] = qml.math.fidelity(results[key], target_mat).mean()\n",
    "    else:\n",
    "        fidelities[key] = qml.math.fidelity_statevector(results[key], targets[key]).mean()\n",
    "\n",
    "print(fidelities)\n",
    "print(np.mean([fidelity for key, fidelity in fidelities.items()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### log: AAE\n",
    "\n",
    "prob=0.05: fidelity ~= 0.15\n",
    "\n",
    "prob=0.02: fidelity ~= 0.3\n",
    "\n",
    "prob=0.01: fidelity ~= 0.45\n",
    "\n",
    "prob=0.005: fidelity ~= 0.64\n",
    "\n",
    "prob=0.001: fidelity ~= 0.89\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## (Optional) Eval AM fidelity in noisy condition\n",
    "\n",
    "<!-- **!!! not working, pennglane can't directly insert into StatePrep !!!** -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AM_encoder\n",
      "device: cpu\n",
      "seed: 42\n",
      "state_generator:\n",
      "  am_encoder:\n",
      "    q_device: default.mixed\n",
      "    n_qubits: 4\n",
      "    noisy: true\n",
      "    AmplitudeDamping: 0.001\n",
      "    DepolarizingChannel: 0.001\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "folder_path = \"../eval_datasets/\"\n",
    "version = \"AM_encoder.yaml\"\n",
    "config_dir = r\"../configs/\"\n",
    "try:\n",
    "    OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "except ValueError:\n",
    "    pass\n",
    "config = OmegaConf.load(os.path.join(config_dir, version))\n",
    "print(OmegaConf.to_yaml(config))\n",
    "\n",
    "seed_everything(config.seed)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AM_StateGenerator()\n",
      "Testing Dataset: beta_a1b1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  7.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: exponential_rate1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  9.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low-1high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: uniform_low0high1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  9.71it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: normal_mean0.3std0.5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  5.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing Dataset: lognormal_mean0std1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00,  9.77it/s]\n"
     ]
    }
   ],
   "source": [
    "am_state_generator = AM_StateGenerator(config)\n",
    "print(am_state_generator)\n",
    "results = {}\n",
    "targets = {}\n",
    "for key, loader in test_loaders.items():\n",
    "    results[key] = []\n",
    "    targets[key] = []\n",
    "    print(f\"Testing Dataset: {key}\")\n",
    "    samples = next(iter(loader))\n",
    "    target_states = samples[\"images\"]\n",
    "    for target_state in tqdm(target_states):\n",
    "        target_state = norm_data(target_state).to(config.device)\n",
    "        result_state = am_state_generator(target_state.view(-1))\n",
    "        results[key].append(result_state.detach().numpy())\n",
    "        targets[key].append(target_state.detach().numpy())\n",
    "    if am_state_generator.is_noisy:\n",
    "        results[key] = np.stack(results[key])\n",
    "    else:\n",
    "        results[key] = np.vstack(results[key])\n",
    "    targets[key] = np.vstack(targets[key])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'beta_a1b1': 0.9431416025148134, 'exponential_rate1': 0.9466213886518494, 'uniform_low-1high1': 0.8881742211677074, 'uniform_low0high1': 0.9456365519120868, 'normal_mean0.3std0.5': 0.8899342402302953, 'lognormal_mean0std1': 0.9443397672878207}\n",
      "0.9263079619607621\n"
     ]
    }
   ],
   "source": [
    "fidelities = {}\n",
    "for key in results.keys():\n",
    "    if am_state_generator.is_noisy:\n",
    "        target_mat = qml.math.dm_from_state_vector(targets[key])\n",
    "        fidelities[key] = qml.math.fidelity(results[key], target_mat).mean()\n",
    "    else:\n",
    "        fidelities[key] = qml.math.fidelity_statevector(results[key], targets[key]).mean()\n",
    "\n",
    "print(fidelities)\n",
    "print(np.mean([fidelity for key, fidelity in fidelities.items()]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### log: AM\n",
    "\n",
    "prob=0.05: fidelity ~= 0.09\n",
    "\n",
    "prob=0.02: fidelity ~= 0.29\n",
    "\n",
    "prob=0.01: fidelity ~= 0.50\n",
    "\n",
    "prob=0.005: fidelity ~= 0.69\n",
    "\n",
    "prob=0.001: fidelity ~= 0.92\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qenc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
