{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/home/pan/SuperEncoder\n"
     ]
    }
   ],
   "source": [
    "%cd ../..\n",
    "\n",
    "import os\n",
    "from omegaconf import OmegaConf\n",
    "import numpy as np\n",
    "\n",
    "import torch \n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "import pennylane as qml \n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset, FractalDB_Dataset\n",
    "from models.state_generators import StateGenerator\n",
    "from utils import resize_and_norm, visual_compare, seed_everything, resize, norm_image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "version: AAE_encoder\n",
      "device: cpu\n",
      "seed: 42\n",
      "n_epochs: 10\n",
      "noise_factor: 0\n",
      "noisy_probability: 0\n",
      "state_generator:\n",
      "  loss: DotProd\n",
      "  n_train_step: 100\n",
      "  aae_encoder:\n",
      "    q_device: default.qubit\n",
      "    n_qubits: 4\n",
      "    n_encoder_layers: 8\n",
      "    noisy: false\n",
      "    AmplitudeDamping: 0\n",
      "    DepolarizingChannel: 0\n",
      "  optimizer:\n",
      "    name: Adam\n",
      "    args:\n",
      "      lr: 0.01\n",
      "dataset:\n",
      "  root: ./FractalDB/fractaldb_cat60_ins1000\n",
      "  transform: ToTensor\n",
      "checkpoint:\n",
      "  logs: ./logs/superencoder/${version}\n",
      "  save_path: ./trained_models/superencoder_${version}_${state_generator.loss}.pt\n",
      "dataloader:\n",
      "  batch_size: 32\n",
      "  num_workers: 0\n",
      "  pin_memory: false\n",
      "\n"
     ]
    }
   ],
   "source": [
    "version = \"AAE_encoder.yaml\"\n",
    "config_dir = r\"./configs/\"\n",
    "\n",
    "try:\n",
    "    OmegaConf.register_new_resolver(\"eval\", eval)\n",
    "except ValueError:\n",
    "    pass\n",
    "config = OmegaConf.load(os.path.join(config_dir, version))\n",
    "print(OmegaConf.to_yaml(config))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST Data\n",
    "\n",
    "# If you don't have the inputs, uncomment these lines below\n",
    "# import math\n",
    "# import numpy as np\n",
    "# import torchvision\n",
    "\n",
    "# n_samples = 10  # Number of samples to load\n",
    "# mnist_dir = r\"./mnist/processed\"\n",
    "# test_ds = MNIST_AAE_Dataset(os.path.join(mnist_dir, \"mnist_test.pt\"))\n",
    "# test_loader = DataLoader(test_ds, batch_size=n_samples)\n",
    "# samples = next(iter(test_loader))\n",
    "\n",
    "# target_state = resize_and_norm(samples[\"images\"].to(torch.float64), config.state_generator.aae_encoder.n_qubits).to(config.device)\n",
    "# target_state = target_state.numpy()\n",
    "# np.save(\"test_samles_for_real_device.npy\", target_state)\n",
    "\n",
    "# Load the inputs from the saved files\n",
    "target_state = np.load(\"scripts/eval_on_real_devices/test_samles_for_real_device.npy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AAE on simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/pan/anaconda3/envs/pennylane/lib/python3.9/site-packages/pennylane/qnn/torch.py:432: UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /opt/conda/conda-bld/pytorch_1695392035629/work/aten/src/ATen/native/Copy.cpp:299.)\n",
      "  return res.type(x.dtype)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0.9999231364280262"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Compute fidelity on test samples\n",
    "from models.state_generators import AAE_StateGenerator\n",
    "\n",
    "aae_state_generator = AAE_StateGenerator(config)\n",
    "fidelity = []\n",
    "\n",
    "for i in range(len(target_state)):\n",
    "    result_state = aae_state_generator(torch.tensor(target_state[i].reshape(1,16)).float())\n",
    "    fidelity.append(qml.math.fidelity_statevector(result_state[0].detach(), target_state[i]).numpy())\n",
    "np.array(fidelity).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from loss import FidLossDotProd, FidLossDotProdAAE\n",
    "from models.batch_encoders import aae_encoder_for_train\n",
    "from models.superencoders import MLP, mat_fn\n",
    "\n",
    "class AAE_StateGenerator_Adjusted(torch.nn.Module):\n",
    "    def __init__(self, config) -> None:\n",
    "        self.n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers\n",
    "        self.n_qubits = config.state_generator.aae_encoder.n_qubits\n",
    "        amplitude_damping_prob = config.state_generator.aae_encoder.AmplitudeDamping\n",
    "        depolarizing_prob = config.state_generator.aae_encoder.DepolarizingChannel\n",
    "\n",
    "        @qml.qnode(\n",
    "            qml.device(\n",
    "                config.state_generator.aae_encoder.q_device,\n",
    "                wires=config.state_generator.aae_encoder.n_qubits,\n",
    "            ),\n",
    "            interface=\"torch\",\n",
    "            diff_method=\"backprop\",\n",
    "        )\n",
    "        @qml.simplify\n",
    "        def aae_encoder(inputs, weights):\n",
    "            aae_encoder_for_train(\n",
    "                weights,\n",
    "                self.n_encoder_layers,\n",
    "                self.n_qubits,\n",
    "            )\n",
    "            return qml.state()\n",
    "\n",
    "        @qml.qnode(\n",
    "            qml.device(\n",
    "                config.state_generator.aae_encoder.q_device,\n",
    "                wires=config.state_generator.aae_encoder.n_qubits,\n",
    "            ),\n",
    "            interface=\"torch\",\n",
    "            diff_method=\"backprop\",\n",
    "        )\n",
    "        @qml.simplify\n",
    "        @qml.transforms.insert(qml.AmplitudeDamping, amplitude_damping_prob, position=\"all\")\n",
    "        @qml.transforms.insert(qml.DepolarizingChannel, depolarizing_prob, position=\"all\")\n",
    "        def aae_encoder_noisy(inputs, weights):\n",
    "            aae_encoder_for_train(\n",
    "                weights,\n",
    "                self.n_encoder_layers,\n",
    "                self.n_qubits,\n",
    "            )\n",
    "            return qml.state()\n",
    "\n",
    "        super().__init__()\n",
    "        self.config = config\n",
    "        self.is_noisy = config.get(\"state_generator\").get(\"aae_encoder\").get(\"noisy\", False)\n",
    "\n",
    "        weight_shapes = {\"weights\": (1, self.n_encoder_layers, self.n_qubits)}\n",
    "        self.criterion = self.get_criterion(is_noisy=self.is_noisy)\n",
    "\n",
    "        # FIXME: is matrix_fn still in use? maybe delete it\n",
    "        if self.is_noisy:\n",
    "            self.matrix_fn = mat_fn(aae_encoder_noisy)\n",
    "            self.aae_encoder = qml.qnn.TorchLayer(\n",
    "                aae_encoder_noisy, weight_shapes, init_method=nn.init.uniform_\n",
    "            ).to(self.config.device)\n",
    "        else:\n",
    "            self.matrix_fn = mat_fn(aae_encoder)\n",
    "            self.aae_encoder = qml.qnn.TorchLayer(\n",
    "                aae_encoder, weight_shapes, init_method=nn.init.uniform_\n",
    "            ).to(self.config.device)\n",
    "\n",
    "\n",
    "    def forward(self, target_state, verbose=False):\n",
    "        self.train_for_state(target_state, verbose=verbose)\n",
    "        return self.compute_state()\n",
    "\n",
    "    def compute_state(self):\n",
    "        # _ = torch.zeros((1,), dtype=torch.float32).to(\n",
    "        #     device=self.config.device\n",
    "        # )  # inputs doesn't matter, But TorchLayer need it\n",
    "        # return self.aae_encoder(_)\n",
    "        return self.aae_encoder.state_dict()\n",
    "\n",
    "    def get_criterion(self, is_noisy=False):\n",
    "        if \"DotProd\" == self.config.state_generator.loss:\n",
    "            criterion = FidLossDotProdAAE(is_noisy)\n",
    "        elif \"MSE\" == self.config.state_generator.loss:\n",
    "            criterion = nn.MSELoss()\n",
    "        return criterion\n",
    "\n",
    "    def get_optimizer(self):\n",
    "        optimizer_cls = getattr(torch.optim, self.config.state_generator.optimizer.name)\n",
    "        optimizer = optimizer_cls(\n",
    "            self.aae_encoder.parameters(),\n",
    "            **self.config.state_generator.optimizer.args,\n",
    "        )\n",
    "\n",
    "        return optimizer\n",
    "\n",
    "    def train_for_state(self, target_state, verbose=False):\n",
    "        # expecte target_state.shape == (N, )\n",
    "        assert target_state.shape == (\n",
    "            1,\n",
    "            2**self.n_qubits,\n",
    "        ), f\"{target_state.shape} != {torch.Size((1, 2**self.n_qubits))}\"\n",
    "        n_step = self.config.state_generator.n_train_step\n",
    "        target_state = target_state.to(device=self.config.device)\n",
    "        _ = torch.zeros((1,), dtype=torch.float32).to(\n",
    "            device=self.config.device\n",
    "        )  # inputs doesn't matter, But TorchLayer need it\n",
    "\n",
    "        optimizer = self.get_optimizer()\n",
    "        # scheduler = CosineAnnealingLR(optimizer, T_max=n_step)\n",
    "\n",
    "        for t in range(n_step):\n",
    "            result_state = self.aae_encoder(_)\n",
    "            loss = self.criterion(result_state, target_state)\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "\n",
    "            optimizer.step()\n",
    "            # scheduler.step()\n",
    "\n",
    "            if verbose:\n",
    "                if t % 10 == 0:\n",
    "                    print(f\"loss: {loss.item()}\", end=\"\\r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Adjust AAE_StateGenerator to return weights only\n",
    "aae_state_generator = AAE_StateGenerator_Adjusted(config)\n",
    "weights = []\n",
    "\n",
    "for i in range(len(target_state)):\n",
    "    result_state = aae_state_generator(torch.tensor(target_state[i].reshape(1,16)).float())['weights'].clone()\n",
    "    weights.append(result_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[array([0.00715834, 0.00129875, 0.0165496 , 0.0074651 , 0.00785723,\n",
       "        0.15574598, 0.10029541, 0.0029763 , 0.0061979 , 0.3423506 ,\n",
       "        0.32948362, 0.00249901, 0.00858439, 0.00279158, 0.00087184,\n",
       "        0.00787475]),\n",
       " array([0.01242017, 0.00046098, 0.00340079, 0.01253895, 0.01129234,\n",
       "        0.26683494, 0.17390626, 0.00605853, 0.00192419, 0.30329919,\n",
       "        0.16859539, 0.01165452, 0.01246599, 0.0009506 , 0.00180619,\n",
       "        0.01239101]),\n",
       " array([0.00979425, 0.00724828, 0.00185094, 0.00754049, 0.00250771,\n",
       "        0.20381874, 0.28313112, 0.00614196, 0.00991925, 0.09474538,\n",
       "        0.32150681, 0.00209351, 0.00980056, 0.00915428, 0.02103416,\n",
       "        0.0097129 ])]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers\n",
    "n_qubits = config.state_generator.aae_encoder.n_qubits\n",
    "\n",
    "@qml.qnode(\n",
    "    qml.device(\n",
    "        'default.qubit',\n",
    "        wires=config.state_generator.aae_encoder.n_qubits,\n",
    "    ),\n",
    "    interface=\"torch\",\n",
    ")\n",
    "@qml.simplify\n",
    "def aae_encoder(inputs):\n",
    "    aae_encoder_for_train(\n",
    "        inputs,\n",
    "        n_encoder_layers,\n",
    "        n_qubits,\n",
    "    )\n",
    "    return qml.probs(wires=range(n_qubits))\n",
    "\n",
    "aae_simulator = []\n",
    "for i in range(len(target_state)):\n",
    "    aae_simulator.append(aae_encoder(weights[i])[0].numpy())\n",
    "aae_simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([0.00794577, 0.00121687, 0.01618678, 0.00794577, 0.0074761 ,\n",
       "         0.15285872, 0.09997803, 0.00279222, 0.00655146, 0.34300753,\n",
       "         0.33285421, 0.00254862, 0.00762789, 0.00228051, 0.00078376,\n",
       "         0.00794577], requires_grad=True),\n",
       " tensor([0.01244814, 0.00045991, 0.00337941, 0.01244814, 0.01120051,\n",
       "         0.26640262, 0.17361377, 0.00602436, 0.0019064 , 0.30350881,\n",
       "         0.16920141, 0.01169471, 0.01244814, 0.00096245, 0.00185308,\n",
       "         0.01244814], requires_grad=True),\n",
       " tensor([0.00980394, 0.00725754, 0.00183695, 0.00754967, 0.00253792,\n",
       "         0.20402254, 0.28285545, 0.00610869, 0.00980394, 0.09545952,\n",
       "         0.32101587, 0.00210707, 0.00980394, 0.00893337, 0.02120743,\n",
       "         0.00969618], requires_grad=True)]"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Ideal results by AmplitudeEmbedding\n",
    "ideal_dev = qml.device('default.qubit', wires=n_qubits)\n",
    "\n",
    "@qml.qnode(ideal_dev)\n",
    "def ideal_circuit(f=None):\n",
    "    qml.AmplitudeEmbedding(features=f, wires=range(n_qubits))\n",
    "    return qml.probs(wires=range(n_qubits))\n",
    "\n",
    "ame_simulator = []\n",
    "for i in range(len(target_state)):\n",
    "    ame_simulator.append(ideal_circuit(f=target_state[i]))\n",
    "ame_simulator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8.536040554366184e-05"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from scipy import stats\n",
    "\n",
    "entropy = []\n",
    "for i in range(len(target_state)):\n",
    "    entropy.append(stats.entropy(ame_simulator[i], aae_simulator[i]))\n",
    "np.array(entropy).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AAE on Real Device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "n_encoder_layers = config.state_generator.aae_encoder.n_encoder_layers\n",
    "n_qubits = config.state_generator.aae_encoder.n_qubits\n",
    "\n",
    "ibmq_dev = qml.device(name='qiskit.ibmq', backend='ibm_osaka', wires=n_qubits)\n",
    "@qml.qnode(ibmq_dev)\n",
    "@qml.simplify\n",
    "def aae_encoder(inputs):\n",
    "    aae_encoder_for_train(\n",
    "        inputs,\n",
    "        n_encoder_layers,\n",
    "        n_qubits,\n",
    "    )\n",
    "    return qml.probs(wires=range(n_qubits))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "json_decoder.target_from_server_data:WARNING:2024-05-11 13:23:30,398: Definition of instruction switch_case is not found in the Qiskit namespace and GateConfig is not provided by the BackendConfiguration payload. Qiskit Gate model cannot be instantiated for this instruction and this instruction is silently excluded from the Target. Please add new gate class to Qiskit or provide GateConfig for this name.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[tensor([0.0283, 0.0469, 0.0439, 0.0234, 0.0234, 0.0889, 0.0723, 0.0488, 0.0352,\n",
       "         0.2578, 0.2207, 0.0439, 0.0244, 0.0156, 0.0107, 0.0156],\n",
       "        dtype=torch.float64),\n",
       " tensor([0.0303, 0.0664, 0.0508, 0.0303, 0.0508, 0.1152, 0.1045, 0.0850, 0.0361,\n",
       "         0.1797, 0.0762, 0.0312, 0.0410, 0.0400, 0.0273, 0.0352],\n",
       "        dtype=torch.float64),\n",
       " tensor([0.0352, 0.0547, 0.0645, 0.0518, 0.0410, 0.0830, 0.1338, 0.0566, 0.0488,\n",
       "         0.0791, 0.1543, 0.0400, 0.0430, 0.0391, 0.0439, 0.0312],\n",
       "        dtype=torch.float64)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "aae_ibmq = []\n",
    "for i in range(len(target_state)):\n",
    "    aae_ibmq.append(aae_encoder(weights[i]))\n",
    "aae_ibmq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.4146552901265916"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entropy = []\n",
    "for i in range(3):\n",
    "    entropy.append(stats.entropy(ame_simulator[i], aae_ibmq[i]))\n",
    "np.array(entropy).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pennylane",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
