{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ..\n",
    "\n",
    "import argparse\n",
    "import os\n",
    "import math\n",
    "\n",
    "import pennylane as qml\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "from data_utils.aae_dataset import MNIST_AAE_Dataset\n",
    "from data_utils.plot import plot_2d\n",
    "from loss import dot_product_loss, fidelity_loss\n",
    "from models.batch_encoders import batch_single_layer_aae_raw_circuit, aae_encoder_for_train\n",
    "from models.encoders import get_aae_encoder\n",
    "from models.state_generators import MLPStateGenerator\n",
    "from models.superencoders import (MLP, MLPV2, AutoChainMLP, ConvEncDec,\n",
    "                                  EncoderDecoderRNN, ManualChainMLP, mat_fn)\n",
    "from utils import append_log, resize_and_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def parse_args():\n",
    "    parser = argparse.ArgumentParser()\n",
    "    \n",
    "    # training args\n",
    "    parser.add_argument(\"dataset_path\", type=str)\n",
    "    parser.add_argument(\"--mode\", type=str, default=\"train\", help=\"train | test\")\n",
    "    parser.add_argument(\"--bsz\", type=int, default=32, help=\"Batch size\")\n",
    "    parser.add_argument(\"--epoch\", type=int, default=10, help=\"Number of epoches\")\n",
    "    parser.add_argument(\n",
    "        \"--loss\",\n",
    "        type=str,\n",
    "        default=\"mse\",\n",
    "        help=\"Type of loss function: mse | fidelity\",\n",
    "    )\n",
    "    parser.add_argument(\"--learning_rate\", type=float, default=3e-3)\n",
    "    parser.add_argument(\"--weight_decay\", type=float, default=1e-5)\n",
    "\n",
    "    # model args\n",
    "    parser.add_argument(\"--n_qubits\", type=int, default=4, help=\"Number of qubits\")\n",
    "    parser.add_argument(\"--n_encoder_layers\", type=int, default=8, help=\"Number of layers in AAE encoder\")\n",
    "    parser.add_argument(\"--n_ansatz_layers\", type=int, default=5, help=\"Number of layers in ansatz\")\n",
    "    parser.add_argument(\n",
    "        \"--model\", type=str, default=\"mlp\", help=\"Model type, mlp | conv | rnn | state\"\n",
    "    )\n",
    "    parser.add_argument(\"--image_channels\", type=int, default=1, help=\"Number of channels of the image, 1 for mnist, 3 for RGB image\")\n",
    "\n",
    "\n",
    "\n",
    "    # checkpoint args\n",
    "    parser.add_argument(\n",
    "        \"--save-path\",\n",
    "        type=str,\n",
    "        default=\"mnist/models/superencoder.pt\",\n",
    "        help=\"File path for saving trained model\",\n",
    "    )\n",
    "    parser.add_argument(\"--logs\", type=str, default=\"logs/superencoder/v0.0.1\")\n",
    "    parser.add_argument(\n",
    "        \"--save-figures\",\n",
    "        type=int,\n",
    "        default=0,\n",
    "        help=\"when using `--mode test` this option dertermines whether to save some figures for comparison\",\n",
    "    )\n",
    "\n",
    "    args = parser.parse_args()\n",
    "\n",
    "    # additional args\n",
    "    args.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    args.q_device = qml.device(\"default.qubit\", wires=args.n_qubits)\n",
    "\n",
    "\n",
    "    return args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## utils "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def seed_everything(seed):\n",
    "    import random\n",
    "    import numpy as np\n",
    "    import torch\n",
    "\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "\n",
    "\n",
    "def get_super_encoder(args):\n",
    "    \"\"\"return super encoder instance according to args\"\"\"\n",
    "\n",
    "    image_size = math.floor((2**args.n_qubits) ** 0.5)\n",
    "    image_channels = args.image_channels\n",
    "    \n",
    "    \n",
    "    if args.model == \"mlp\":\n",
    "        superencoder = MLP(image_size*image_size*image_channels, args.n_qubits*args.n_encoder_layers).to(args.device)\n",
    "    elif args.model == \"mlpv2\":\n",
    "        superencoder = MLPV2(image_size*image_size*image_channels, args.n_qubits*args.n_encoder_layers).to(args.device)\n",
    "    # elif args.model == \"conv\":  \n",
    "    #     superencoder = ConvEncDec().to(DEVICE)  # class itself is fixed, refactor this class later\n",
    "    elif args.model.startswith(\"chain\"):\n",
    "        superencoder = _init_chain_model(args.model)\n",
    "    elif args.model == \"rnn\":\n",
    "        superencoder = EncoderDecoderRNN(image_size*image_size*image_channels, hidden_size=256, num_aae_layers=args.n_encoder_layers).to(args.device)\n",
    "    elif args.model == \"state\":\n",
    "        superencoder = _init_state_model(args)\n",
    "    else:\n",
    "        raise NotImplementedError(f\"Unsupported model: {args.model}\")\n",
    "    \n",
    "    def _init_chain_model(model_name):\n",
    "        if model_name == \"chain\":\n",
    "            return ManualChainMLP(2**args.n_qubits, batch_single_layer_aae_raw_circuit, args.n_encoder_layers).to(args.device)\n",
    "        if model_name == \"chain-non-uniform\":\n",
    "            return ManualChainMLP(\n",
    "                2**args.n_qubits, batch_single_layer_aae_raw_circuit, args.n_encoder_layers, is_uniform=False\n",
    "            ).to(args.device)\n",
    "\n",
    "        raise ValueError(f\"Invalid model name: {model_name}\")\n",
    "\n",
    "\n",
    "    def _init_state_model(args):\n",
    "        from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "        @qml.qnode(args.q_device, interface=\"torch\", diff_method=\"backprop\")\n",
    "        @qml.simplify\n",
    "        def batch_aae_encoder(weights):\n",
    "            aae_encoder_for_train(weights, args.n_encoder_layers, args.n_qubits)\n",
    "            return qml.state()\n",
    "\n",
    "        superencoder = MLPStateGenerator(in_dim=image_size*image_size*image_channels, out_dim=args.n_qubits*args.n_encoder_layers, quantum_circuit=batch_aae_encoder).to(args.device)\n",
    "        return superencoder\n",
    "    \n",
    "    return superencoder\n",
    "\n",
    "\n",
    "def get_loss_function(args):\n",
    "    \"\"\"return a callable to be used as loss function\"\"\"\n",
    "\n",
    "    if args.model == \"state\":\n",
    "        loss_fn = dot_product_loss\n",
    "    else:\n",
    "        from models.batch_encoders import aae_encoder_for_train\n",
    "\n",
    "        @qml.qnode(args.q_device, interface=\"torch\", diff_method=\"backprop\")\n",
    "        @qml.simplify\n",
    "        def batch_aae_encoder(weights):\n",
    "            aae_encoder_for_train(weights, args.n_encoder_layers, args.n_qubits)\n",
    "            return qml.state()\n",
    "        \n",
    "        if args.loss == \"mse\":\n",
    "            loss_fn = torch.nn.MSELoss()\n",
    "        elif args.loss == \"fidelity\":\n",
    "            loss_fn = fidelity_loss(batch_aae_encoder, \"matrix\")\n",
    "        # elif args.loss == \"hybrid\":     # support this later\n",
    "        #     loss = loss_fn(pred, encod_params) + F_loss(\n",
    "        #         images, encod_params, pred\n",
    "        #     )\n",
    "        else:\n",
    "            raise NotImplementedError(f\"Unsupported loss: {args.loss}\")\n",
    "        \n",
    "    return loss_fn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(args, loader: DataLoader):\n",
    "    aae_encoder = get_aae_encoder(args.n_qubits)\n",
    "\n",
    "    loss_path = os.path.join(args.logs, \"loss.txt\")\n",
    "    # FIXME: rewrite log if exist, too ugly\n",
    "    \n",
    "    if os.path.exists(loss_path):\n",
    "        f = open(loss_path, \"w\")\n",
    "        f.close()\n",
    "\n",
    "    superencoder = get_super_encoder(args)\n",
    "\n",
    "    loss_fn = get_loss_function(args)\n",
    "    optimizer = torch.optim.Adam(superencoder.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)\n",
    "    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.n_epochs)\n",
    "\n",
    "    for epoch in range(args.epoch):\n",
    "        batch_idx = 0\n",
    "        epoch_loss_sum = 0\n",
    "        for batch in tqdm(loader, leave=False):\n",
    "            images = batch[\"images\"]\n",
    "            images = resize_and_norm(images, args.n_qubits).to(args.device)\n",
    "            # plot_2d(images[0], figname=f\"input{batch_idx}.pdf\")\n",
    "\n",
    "            try:  # in some setting encoder params doesn't exist \n",
    "                encod_params = batch[\"encoder_params\"][\"weights\"].to(args.device)\n",
    "                encod_params = encod_params.reshape(encod_params.size(0), -1)\n",
    "                # plot_2d(encod_params[0], figname=f\"encod_parames{batch_idx}.pdf\")\n",
    "            except:\n",
    "                pass\n",
    "\n",
    "            if args.model == \"conv\":\n",
    "                images = images.view(-1, 1, 4, 4)\n",
    "                encod_params = encod_params.view(-1, 1, 4, 8)\n",
    "\n",
    "            pred = superencoder(images)\n",
    "\n",
    "            # fidelity loss need 3 args, better unify input formate of loss functions\n",
    "            if args.model == \"state\":\n",
    "                loss = loss_fn(pred[0], images)\n",
    "            else:\n",
    "                # TODO: support other loss, unify the interface of loss functions\n",
    "                raise NotImplementedError(f\"Unsupported loss function: {args.loss}\")\n",
    "                \n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            batch_idx += 1\n",
    "            epoch_loss_sum += loss.item()\n",
    "\n",
    "            append_log(loss_path, loss.item())\n",
    "        \n",
    "        scheduler.step()\n",
    "        \n",
    "        print(\n",
    "            f\"Epoch [{epoch+1}/{args.epoch}], Loss: {epoch_loss_sum/(batch_idx+1):.4f}\"\n",
    "        )\n",
    "\n",
    "    return superencoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(args, model, loader: DataLoader, dataset: MNIST_AAE_Dataset):\n",
    "    batch_idx = 0\n",
    "\n",
    "    def _save_fig(\n",
    "        save_figures: bool, fig, name_prefix, batch_idx, format_suffix: str = \"pdf\"\n",
    "    ):\n",
    "        if save_figures and batch_idx < 10:\n",
    "            plot_2d(fig, figname=f\"{name_prefix}{batch_idx}.{format_suffix}\")\n",
    "\n",
    "    for batch in tqdm(loader, leave=False):\n",
    "        images = batch[\"images\"]\n",
    "        _save_fig(args.save_figures, images[0], \"image\", batch_idx, format_suffix=\"svg\")\n",
    "        images = resize_and_norm(images).to(args.device)\n",
    "        _save_fig(args.save_figures, images[0], \"input\", batch_idx)\n",
    "\n",
    "        try:  # in e2e training, encoder_params doesn't exist\n",
    "            encod_params = batch[\"encoder_params\"][\"weights\"].to(args.device)\n",
    "            encod_params = encod_params.reshape(encod_params.size(0), -1)\n",
    "        except:\n",
    "            pass  # it should be fine to pass, no need for encod_params to pred\n",
    "\n",
    "        # TODO: make these shapes configurable\n",
    "        if args.model == \"conv\":\n",
    "            images = images.view(-1, 1, 4, 4)\n",
    "            encod_params = encod_params.view(-1, 1, 4, 8)\n",
    "\n",
    "        pred = model(images)\n",
    "        ######## uncomment this line for comparison ############\n",
    "        # loss = dot_product_loss(pred, images)\n",
    "        # print(loss)\n",
    "        ######## uncomment this line for comparison ############\n",
    "\n",
    "        # Currently the state model outputs both state vector and parameters\n",
    "        sample_pred = pred[0][0] if len(pred) == 2 else pred[0]\n",
    "        _save_fig(args.save_figures, sample_pred, \"pred\", batch_idx)\n",
    "\n",
    "        # assert len(pred) == len(batch[\"index\"])\n",
    "        for i, idx in enumerate(batch[\"index\"]):\n",
    "            if args.model == \"state\":\n",
    "                assert len(pred) == 2\n",
    "                p = pred[1][i].view(32)\n",
    "            else:\n",
    "                p = pred[i].view(32)\n",
    "            dataset.save_pred_params(idx, p)\n",
    "\n",
    "        # plot_2d(pred[0], figname=f\"pred{batch_idx}.pdf\")\n",
    "        batch_idx += 1\n",
    "\n",
    "    dataset.save_dataset_to_disk()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    args = parse_args()\n",
    "\n",
    "    dataset = MNIST_AAE_Dataset(args.dataset_path)\n",
    "    loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True)\n",
    "\n",
    "    if not os.path.exists(args.logs):\n",
    "        os.makedirs(args.logs)\n",
    "\n",
    "    if args.mode == \"train\":\n",
    "        model = train(args, loader)\n",
    "        if args.model.startswith(\"chain\") or args.model == \"state\":\n",
    "            # Fix:\n",
    "            #  Can't pickle <function batch_single_layer_aae_raw_circuit at 0x7f19cf2f3250>:\n",
    "            #  it's not the same object as models.batch_encoders.batch_single_layer_aae_raw_circuit\n",
    "            # torch.save(model.state_dict, args.save_path)\n",
    "            model.save(args.save_path)\n",
    "        else:\n",
    "            torch.save(model, args.save_path)\n",
    "    elif args.mode == \"test\":\n",
    "        model = get_super_encoder(args)\n",
    "        if args.model.startswith(\"chain\"):\n",
    "            model.load(args.save_path)\n",
    "        elif args.model == \"state\":\n",
    "            model.load(args.save_path)\n",
    "        else:\n",
    "            model = torch.load(args.save_path)\n",
    "        test(args, model, loader, dataset)\n",
    "    else:\n",
    "        raise NotImplementedError(\"Please set --mode <train | test>\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "qenc",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
