{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/pan/SuperEncoder\n"
     ]
    }
   ],
   "source": [
    "%cd ../..\n",
    "\n",
    "import os\n",
    "from omegaconf import OmegaConf\n",
    "import numpy as np\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml \n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset, FractalDB_Dataset\n",
    "from models.state_generators import StateGenerator\n",
    "from utils import resize_and_norm, visual_compare, seed_everything, resize, norm_image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: v0.0.7\n",
      "device: cpu\n",
      "seed: 42\n",
      "n_epochs: 10\n",
      "noise_factor: 0\n",
      "noisy_probability: 0\n",
      "state_generator:\n",
      "  loss: DotProd\n",
      "  aae_encoder:\n",
      "    q_device: default.qubit\n",
      "    n_qubits: 4\n",
      "    n_encoder_layers: 8\n",
      "  super_encoder:\n",
      "    arch: MLP\n",
      "    input_size: ${eval:\"int(int(2**${..aae_encoder.n_qubits})**0.5)\"}\n",
      "    in_dim: ${eval:\"${.input_size}*${.input_size}\"}\n",
      "    out_dim: ${eval:\"${..aae_encoder.n_qubits} * ${..aae_encoder.n_encoder_layers}\"}\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "checkpoint:\n",
      "  logs: ./logs/superencoder/${version}\n",
      "  save_path: ./trained_models/superencoder_${version}_${state_generator.loss}.pt\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "optimizer:\n",
      "  lr: 0.003\n",
      "  weight_decay: 1.0e-05\n",
      "\n"
     ]
    }
   ],
   "source": [
    "version = \"v0.0.7.yaml\"\n",
    "config_dir = r\"./configs/\"\n",
    "\n",
    "try:\n",
    "    OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "except ValueError:\n",
    "    pass\n",
    "config = OmegaConf.load(os.path.join(config_dir, version))\n",
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST Data\n",
    "\n",
    "# If you don't have the inputs, uncomment these lines below\n",
    "# import math\n",
    "# import numpy as np\n",
    "# import torchvision\n",
    "\n",
    "# n_samples = 10  # Number of samples to load\n",
    "# mnist_dir = r\"./mnist/processed\"\n",
    "# test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, \"mnist_test.pt\"))\n",
    "# test_loader = DataLoader(test_ds, batch_size=n_samples)\n",
    "# samples = next(iter(test_loader))\n",
    "\n",
    "# target_state = resize_and_norm(samples[\"images\"].to(torch.float64), config.state_generator.aae_encoder.n_qubits).to(config.device)\n",
    "# target_state = target_state.numpy()\n",
    "# np.save(\"test_samles_for_real_device.npy\", target_state)\n",
    "\n",
    "# Load the inputs from the saved files\n",
    "target_state = np.load(\"scripts/eval_on_real_devices/test_samles_for_real_device.npy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SuperEncoder on simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "superEncoder = StateGenerator(\n",
    "    config=config\n",
    ").to(config.device)\n",
    "\n",
    "superEncoder.load(\"trained_models/superencoder_v0.0.7_state.pt\", config.device, strict=False)\n",
    "\n",
    "weights = []\n",
    "fidelity = []\n",
    "for i in range(len(target_state)):\n",
    "    encoder_params = superEncoder(torch.tensor([target_state[i]],dtype=torch.float64).float()).clone()\n",
    "    result_state = superEncoder.qc(encoder_params)\n",
    "    fidelity.append(qml.math.fidelity_statevector(result_state, [target_state[i]]).detach().numpy())\n",
    "    weights.append(encoder_params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9747084901341632"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(fidelity).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers\n",
    "n_qubits = config.state_generator.aae_encoder.n_qubits\n",
    "\n",
    "@qml.qnode(\n",
    "    qml.device(\n",
    "        'default.qubit',\n",
    "        wires=config.state_generator.aae_encoder.n_qubits,\n",
    "    ),\n",
    "    interface=\"torch\",\n",
    ")\n",
    "@qml.simplify\n",
    "def aae_encoder(inputs):\n",
    "    aae_encoder_for_train(\n",
    "        inputs,\n",
    "        n_encoder_layers,\n",
    "        n_qubits,\n",
    "    )\n",
    "    return qml.probs(wires=range(n_qubits))\n",
    "    # return qml.state()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[0.0036, 0.0006, 0.0117, 0.0013, 0.0109, 0.1567, 0.1504, 0.0042, 0.0018,\n",
       "          0.3152, 0.3156, 0.0044, 0.0014, 0.0065, 0.0026, 0.0132]],\n",
       "        dtype=torch.float64, grad_fn=<ViewBackward0>),\n",
       " tensor([[5.2699e-03, 5.0977e-05, 1.0225e-03, 7.8186e-03, 1.5440e-02, 2.9733e-01,\n",
       "          2.4351e-01, 8.5500e-03, 2.2028e-03, 2.2828e-01, 1.6771e-01, 2.7273e-04,\n",
       "          4.1257e-03, 1.2680e-04, 1.3095e-03, 1.6980e-02]], dtype=torch.float64,\n",
       "        grad_fn=<ViewBackward0>),\n",
       " tensor([[9.6804e-05, 8.2020e-03, 6.6553e-04, 5.9737e-03, 9.2412e-04, 1.8384e-01,\n",
       "          2.8206e-01, 3.4214e-03, 7.0129e-03, 5.9205e-02, 4.1102e-01, 6.2853e-03,\n",
       "          4.6801e-03, 1.1461e-02, 1.2752e-02, 2.4066e-03]], dtype=torch.float64,\n",
       "        grad_fn=<ViewBackward0>)]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "superEncoder_simulator = []\n",
    "for i in range(len(target_state)):\n",
    "    superEncoder_simulator.append(aae_encoder(weights[i]))\n",
    "superEncoder_simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.00794577, 0.00121687, 0.01618678, 0.00794577, 0.0074761 ,\n",
       "         0.15285872, 0.09997803, 0.00279222, 0.00655146, 0.34300753,\n",
       "         0.33285421, 0.00254862, 0.00762789, 0.00228051, 0.00078376,\n",
       "         0.00794577], requires_grad=True),\n",
       " tensor([0.01244814, 0.00045991, 0.00337941, 0.01244814, 0.01120051,\n",
       "         0.26640262, 0.17361377, 0.00602436, 0.0019064 , 0.30350881,\n",
       "         0.16920141, 0.01169471, 0.01244814, 0.00096245, 0.00185308,\n",
       "         0.01244814], requires_grad=True),\n",
       " tensor([0.00980394, 0.00725754, 0.00183695, 0.00754967, 0.00253792,\n",
       "         0.20402254, 0.28285545, 0.00610869, 0.00980394, 0.09545952,\n",
       "         0.32101587, 0.00210707, 0.00980394, 0.00893337, 0.02120743,\n",
       "         0.00969618], requires_grad=True)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ideal results by AmplitudeEmbedding\n",
    "@qml.qnode(qml.device('default.qubit', wires=n_qubits))\n",
    "def ideal_circuit(f=None):\n",
    "    qml.AmplitudeEmbedding(features=f, wires=range(n_qubits))\n",
    "    return qml.probs(wires=range(n_qubits))\n",
    "\n",
    "ame_simulator = []\n",
    "for i in range(len(target_state)):\n",
    "    ame_simulator.append(ideal_circuit(f=target_state[i]))\n",
    "ame_simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.06084239432185898"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy import stats\n",
    "entropy = []\n",
    "for i in range(len(target_state)):\n",
    "    entropy.append(stats.entropy(ame_simulator[i], superEncoder_simulator[i][0].detach()))\n",
    "np.array(entropy).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SuperEncoder on Real Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers\n",
    "n_qubits = config.state_generator.aae_encoder.n_qubits\n",
    "ibmq_dev = qml.device(name='qiskit.ibmq', backend='ibm_osaka', wires=n_qubits)\n",
    "@qml.qnode(ibmq_dev)\n",
    "@qml.simplify\n",
    "def aae_encoder(inputs):\n",
    "    aae_encoder_for_train(\n",
    "        inputs,\n",
    "        n_encoder_layers,\n",
    "        n_qubits,\n",
    "    )\n",
    "    return qml.probs(wires=range(n_qubits))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "json_decoder.target_from_server_data:WARNING:2024-05-11 13:14:01,897: Definition of instruction switch_case is not found in the Qiskit namespace and GateConfig is not provided by the BackendConfiguration payload. Qiskit Gate model cannot be instantiated for this instruction and this instruction is silently excluded from the Target. Please add new gate class to Qiskit or provide GateConfig for this name.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[tensor([0.0225, 0.0361, 0.0312, 0.0293, 0.0508, 0.1094, 0.1201, 0.0215, 0.0273,\n",
       "         0.2510, 0.1836, 0.0303, 0.0195, 0.0254, 0.0205, 0.0215],\n",
       "        dtype=torch.float64, grad_fn=<SqueezeBackward0>),\n",
       " tensor([0.0615, 0.0439, 0.0352, 0.0518, 0.0596, 0.1328, 0.1270, 0.0527, 0.0498,\n",
       "         0.1367, 0.0850, 0.0459, 0.0244, 0.0244, 0.0303, 0.0391],\n",
       "        dtype=torch.float64, grad_fn=<SqueezeBackward0>),\n",
       " tensor([0.0723, 0.0479, 0.0596, 0.0469, 0.0908, 0.1025, 0.1152, 0.0566, 0.0469,\n",
       "         0.0479, 0.1182, 0.0498, 0.0439, 0.0430, 0.0303, 0.0283],\n",
       "        dtype=torch.float64, grad_fn=<SqueezeBackward0>)]"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "superEncoder_ibmq = []\n",
    "for i in range(len(target_state)):\n",
    "    superEncoder_ibmq.append(aae_encoder(weights[i]))\n",
    "superEncoder_ibmq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.44930970830751216"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entropy = []\n",
    "for i in range(len(target_state)):\n",
    "    entropy.append(stats.entropy(ame_simulator[i], superEncoder_ibmq[i].detach().numpy()))\n",
    "np.array(entropy).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pennylane",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
