{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import argparse\n",
    "\n",
    "import torchquantum as tq\n",
    "\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "# data is cos(theta)|000> + e^(j * phi)sin(theta) |111>\n",
    "\n",
    "from torchpack.datasets.dataset import Dataset\n",
    "from torchquantum.plugins import (\n",
    "    tq2qiskit_initialize,\n",
    "    tq2qiskit,\n",
    "    tq2qiskit_measurement,\n",
    "    qiskit_assemble_circs,\n",
    ")\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_data(L, N):\n",
    "    omega_0 = np.zeros([2**L], dtype=\"complex_\")\n",
    "    omega_0[0] = 1 + 0j\n",
    "\n",
    "    omega_1 = np.zeros([2**L], dtype=\"complex_\")\n",
    "    omega_1[-1] = 1 + 0j\n",
    "\n",
    "    states = np.zeros([N, 2**L], dtype=\"complex_\")\n",
    "\n",
    "    thetas = 2 * np.pi * np.random.rand(N)\n",
    "    phis = 2 * np.pi * np.random.rand(N)\n",
    "\n",
    "    for i in range(N):\n",
    "        states[i] = (\n",
    "            np.cos(thetas[i]) * omega_0\n",
    "            + np.exp(1j * phis[i]) * np.sin(thetas[i]) * omega_1\n",
    "        )\n",
    "\n",
    "    X = np.sin(2 * thetas) * np.cos(phis)\n",
    "    # X = np.sin(thetas + phis)\n",
    "    # X = 2* thetas + phis\n",
    "\n",
    "    return states, X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RegressionDataset:\n",
    "    def __init__(self, split, n_samples, n_wires):\n",
    "        self.split = split\n",
    "        self.n_samples = n_samples\n",
    "        self.n_wires = n_wires\n",
    "\n",
    "        self.states, self.Xlabel = gen_data(self.n_wires, self.n_samples)\n",
    "\n",
    "    def __getitem__(self, index: int):\n",
    "        instance = {\"states\": self.states[index], \"Xlabel\": self.Xlabel[index]}\n",
    "        return instance\n",
    "\n",
    "    def __len__(self) -> int:\n",
    "        return self.n_samples\n",
    "\n",
    "\n",
    "class Regression(Dataset):\n",
    "    def __init__(self, n_train, n_valid, n_wires):\n",
    "        n_samples_dict = {\"train\": n_train, \"valid\": n_valid}\n",
    "        super().__init__(\n",
    "            {\n",
    "                split: RegressionDataset(\n",
    "                    split=split, n_samples=n_samples_dict[split], n_wires=n_wires\n",
    "                )\n",
    "                for split in [\"train\", \"valid\"]\n",
    "            }\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SuperLayer(tq.QuantumModule):\n",
    "    def __init__(self, n_wires, n_layers):\n",
    "        super(SuperLayer, self).__init__()\n",
    "        self.n_wires = n_wires\n",
    "        self.n_layers = n_layers\n",
    "        self.queue = tq.QuantumModuleList()\n",
    "        self.ind = []\n",
    "        for i in range(self.n_layers):\n",
    "            for j in range(self.n_wires):\n",
    "                self.ind.append([j])\n",
    "            for j in range(self.n_wires):\n",
    "                self.ind.append([j, (j+1)%self.n_wires])\n",
    "            self.queue += self.init_qnn()\n",
    "    \n",
    "    def forward(self, q_device, fault_dict=dict()):\n",
    "        for i in range(len(self.queue)):\n",
    "            op = self.queue[i]\n",
    "            op(q_device, self.ind[i])\n",
    "            if i in fault_dict.keys():\n",
    "                fop, wire =  fault_dict[i]\n",
    "                fop(q_device, wire)\n",
    "    \n",
    "    def init_qnn(self):\n",
    "        queue = tq.QuantumModuleList()\n",
    "        for _ in range(self.n_wires):\n",
    "            queue.append(tq.U3(has_params=True, trainable=True))\n",
    "        for _ in range(self.n_wires):\n",
    "            queue.append(tq.CU3(has_params=True, trainable=True))\n",
    "        return queue\n",
    "\n",
    "class QModel(tq.QuantumModule):\n",
    "    def __init__(self, n_wires, n_blocks, qnn=None):\n",
    "        super().__init__()\n",
    "        # inside one block, we have one u3 layer one each qubit and one layer\n",
    "        # cu3 layer with ring connection\n",
    "        self.n_wires = n_wires\n",
    "        self.n_blocks = n_blocks\n",
    "        if qnn:\n",
    "            self.qnn = qnn\n",
    "        else:\n",
    "            self.qnn = SuperLayer(self.n_wires, self.n_blocks)\n",
    "        self.measure = tq.MeasureAll(tq.PauliZ)\n",
    "\n",
    "    def forward(self, input_states, fault_dict=dict()):\n",
    "        qdev = tq.QuantumDevice(\n",
    "            n_wires=self.n_wires, bsz=input_states.shape[0], device=input_states.device\n",
    "        )\n",
    "        # firstly set the qdev states\n",
    "        qdev.set_states(input_states)\n",
    "        self.qnn(qdev, fault_dict)\n",
    "\n",
    "        res = self.measure(qdev)\n",
    "        res = res.sum(dim=-1)\n",
    "        return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 0\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)\n",
    "\n",
    "class ARG():\n",
    "    def __init__(self):\n",
    "        pass\n",
    "\n",
    "args = ARG()\n",
    "args.n_train = 256\n",
    "args.n_valid = 100\n",
    "args.n_wires = 2\n",
    "args.n_layers = 3\n",
    "\n",
    "args.bsz = 32\n",
    "\n",
    "args.epochs = 50"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataflow, model, device, optimizer):\n",
    "    for feed_dict in dataflow[\"train\"]:\n",
    "        inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "        targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "        outputs = model(inputs)\n",
    "\n",
    "        loss = F.mse_loss(outputs, targets)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print(f\"loss: {loss.item()}\")\n",
    "\n",
    "\n",
    "def valid_test(dataflow, split, model, device):\n",
    "    target_all = []\n",
    "    output_all = []\n",
    "    with torch.no_grad():\n",
    "        for feed_dict in dataflow[split]:\n",
    "            inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "            targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            target_all.append(targets)\n",
    "            output_all.append(outputs)\n",
    "        target_all = torch.cat(target_all, dim=0)\n",
    "        output_all = torch.cat(output_all, dim=0)\n",
    "\n",
    "    loss = F.mse_loss(output_all, target_all)\n",
    "\n",
    "    print(f\"{split} set loss: {loss}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, RL: 0.005\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 0.21301651000976562\n",
      "valid set loss: 0.11517993360757828\n",
      "Epoch 2, RL: 0.004995066821070679\n",
      "loss: 0.02429831773042679\n",
      "valid set loss: 0.02202836237847805\n",
      "Epoch 3, RL: 0.004980286753286194\n",
      "loss: 0.006668418180197477\n",
      "valid set loss: 0.006509586237370968\n",
      "Epoch 4, RL: 0.004955718126821721\n",
      "loss: 0.008516826666891575\n",
      "valid set loss: 0.008134914562106133\n",
      "Epoch 5, RL: 0.004921457902821577\n",
      "loss: 0.007079693488776684\n",
      "valid set loss: 0.006443488411605358\n",
      "Epoch 6, RL: 0.004877641290737884\n",
      "loss: 0.001857138006016612\n",
      "valid set loss: 0.002263097558170557\n",
      "Epoch 7, RL: 0.004824441214720628\n",
      "loss: 0.00020700700406450778\n",
      "valid set loss: 0.00027339838561601937\n",
      "Epoch 8, RL: 0.004762067631165049\n",
      "loss: 0.00029745115898549557\n",
      "valid set loss: 0.00030080866417847574\n",
      "Epoch 9, RL: 0.0046907667001096585\n",
      "loss: 0.0003364051226526499\n",
      "valid set loss: 0.00017937282973434776\n",
      "Epoch 10, RL: 0.004610819813755038\n",
      "loss: 5.174580655875616e-05\n",
      "valid set loss: 2.8714559448417276e-05\n",
      "Epoch 11, RL: 0.0045225424859373685\n",
      "loss: 4.316771992307622e-06\n",
      "valid set loss: 5.426956704468466e-06\n",
      "Epoch 12, RL: 0.004426283106939473\n",
      "loss: 1.1959653420490213e-05\n",
      "valid set loss: 1.1345224265824072e-05\n",
      "Epoch 13, RL: 0.004322421568553529\n",
      "loss: 7.368143087660428e-06\n",
      "valid set loss: 7.10365657141665e-06\n",
      "Epoch 14, RL: 0.004211367764821722\n",
      "loss: 9.22184824503347e-07\n",
      "valid set loss: 6.102184215706075e-07\n",
      "Epoch 15, RL: 0.004093559974371725\n",
      "loss: 5.809403091916465e-07\n",
      "valid set loss: 5.351951131160604e-07\n",
      "Epoch 16, RL: 0.003969463130731184\n",
      "loss: 3.0121606187094585e-07\n",
      "valid set loss: 3.494255906844046e-07\n",
      "Epoch 17, RL: 0.003839566987447492\n",
      "loss: 6.716312839216698e-08\n",
      "valid set loss: 7.792342415768871e-08\n",
      "Epoch 18, RL: 0.0037043841852542887\n",
      "loss: 1.2186197295704915e-07\n",
      "valid set loss: 1.33090679810266e-07\n",
      "Epoch 19, RL: 0.0035644482289126822\n",
      "loss: 1.9595472622313537e-07\n",
      "valid set loss: 2.3727797326955624e-07\n",
      "Epoch 20, RL: 0.0034203113817116958\n",
      "loss: 1.1970325886068167e-07\n",
      "valid set loss: 1.8314091221327544e-07\n",
      "Epoch 21, RL: 0.003272542485937369\n",
      "loss: 1.0653639748170463e-07\n",
      "valid set loss: 1.276723793353085e-07\n",
      "Epoch 22, RL: 0.003121724717912137\n",
      "loss: 7.260409518039523e-08\n",
      "valid set loss: 1.1725681048346814e-07\n",
      "Epoch 23, RL: 0.0029684532864643126\n",
      "loss: 9.578177895264162e-08\n",
      "valid set loss: 1.1267536592640681e-07\n",
      "Epoch 24, RL: 0.002813333083910762\n",
      "loss: 9.314275217775503e-08\n",
      "valid set loss: 1.0827797325418942e-07\n",
      "Epoch 25, RL: 0.0026569762988232844\n",
      "loss: 6.365583260503627e-08\n",
      "valid set loss: 1.0911400494251211e-07\n",
      "Epoch 26, RL: 0.002500000000000001\n",
      "loss: 8.350335889417693e-08\n",
      "valid set loss: 1.0736378897036047e-07\n",
      "Epoch 27, RL: 0.0023430237011767175\n",
      "loss: 1.0498930436142473e-07\n",
      "valid set loss: 1.0714678921885934e-07\n",
      "Epoch 28, RL: 0.0021866669160892404\n",
      "loss: 9.725977179186884e-08\n",
      "valid set loss: 1.071429238663768e-07\n",
      "Epoch 29, RL: 0.0020315467135356893\n",
      "loss: 1.011315404753077e-07\n",
      "valid set loss: 1.0485103985047317e-07\n",
      "Epoch 30, RL: 0.0018782752820878637\n",
      "loss: 1.2005016003513447e-07\n",
      "valid set loss: 1.0576934528216952e-07\n",
      "Epoch 31, RL: 0.0017274575140626331\n",
      "loss: 8.15454370695079e-08\n",
      "valid set loss: 1.0655359261591002e-07\n",
      "Epoch 32, RL: 0.0015796886182883067\n",
      "loss: 6.99946554050257e-08\n",
      "valid set loss: 1.0380959736266959e-07\n",
      "Epoch 33, RL: 0.0014355517710873192\n",
      "loss: 1.0501969427423319e-07\n",
      "valid set loss: 1.0633490177269778e-07\n",
      "Epoch 34, RL: 0.001295615814745712\n",
      "loss: 1.2218342249070702e-07\n",
      "valid set loss: 1.0588555454660309e-07\n",
      "Epoch 35, RL: 0.0011604330125525083\n",
      "loss: 1.0206284173364111e-07\n",
      "valid set loss: 1.0568604835725637e-07\n",
      "Epoch 36, RL: 0.0010305368692688178\n",
      "loss: 1.250998025170702e-07\n",
      "valid set loss: 1.0612613010607674e-07\n",
      "Epoch 37, RL: 0.000906440025628276\n",
      "loss: 1.1294706325770676e-07\n",
      "valid set loss: 1.0578631304269948e-07\n",
      "Epoch 38, RL: 0.0007886322351782785\n",
      "loss: 9.335940376331564e-08\n",
      "valid set loss: 1.0539930883624038e-07\n",
      "Epoch 39, RL: 0.0006775784314464719\n",
      "loss: 9.897124897406684e-08\n",
      "valid set loss: 1.0599053013038429e-07\n",
      "Epoch 40, RL: 0.0005737168930605274\n",
      "loss: 9.814363011173555e-08\n",
      "valid set loss: 1.0590466814619504e-07\n",
      "Epoch 41, RL: 0.0004774575140626318\n",
      "loss: 1.0702699171361019e-07\n",
      "valid set loss: 1.0661906912901031e-07\n",
      "Epoch 42, RL: 0.000389180186244963\n",
      "loss: 6.840924982043362e-08\n",
      "valid set loss: 1.0581361209460738e-07\n",
      "Epoch 43, RL: 0.00030923329989034114\n",
      "loss: 9.859267890988122e-08\n",
      "valid set loss: 1.0573776165756499e-07\n",
      "Epoch 44, RL: 0.0002379323688349517\n",
      "loss: 5.651473600210011e-08\n",
      "valid set loss: 1.0539780248564057e-07\n",
      "Epoch 45, RL: 0.0001755587852793717\n",
      "loss: 1.0163601871227002e-07\n",
      "valid set loss: 1.0580340870092186e-07\n",
      "Epoch 46, RL: 0.00012235870926211623\n",
      "loss: 1.1877786931790979e-07\n",
      "valid set loss: 1.0582743925624527e-07\n",
      "Epoch 47, RL: 7.854209717842261e-05\n",
      "loss: 8.847102606068802e-08\n",
      "valid set loss: 1.0613774747980642e-07\n",
      "Epoch 48, RL: 4.428187317827821e-05\n",
      "loss: 9.538926803998038e-08\n",
      "valid set loss: 1.059143102111193e-07\n",
      "Epoch 49, RL: 1.9713246713805593e-05\n",
      "loss: 9.49906819869284e-08\n",
      "valid set loss: 1.0584840737237755e-07\n",
      "Epoch 50, RL: 4.933178929321104e-06\n",
      "loss: 1.1751624384714887e-07\n",
      "valid set loss: 1.0595640986821309e-07\n",
      "valid set loss: 1.0595640986821309e-07\n"
     ]
    }
   ],
   "source": [
    "dataflow = torch.load('dataflow_sin2x1cosx2.pt')\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks).to(device)\n",
    "\n",
    "n_epochs = args.epochs\n",
    "optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    # train\n",
    "    print(f\"Epoch {epoch}, RL: {optimizer.param_groups[0]['lr']}\")\n",
    "    train(dataflow, model, device, optimizer)\n",
    "\n",
    "    # valid\n",
    "    valid_test(dataflow, \"valid\", model, device)\n",
    "    scheduler.step()\n",
    "\n",
    "# final valid\n",
    "valid_test(dataflow, \"valid\", model, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.qnn, 'sincosRegression_baseline.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NAT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataflow, model, device, optimizer, errorRate=0):\n",
    "    def randomSample(model, errorRate):\n",
    "        fault_dict = dict()\n",
    "        fault = np.random.choice([None, tq.PauliX(), tq.PauliY(), tq.PauliZ()], size=(len(model.qnn.ind), ), p=[1-errorRate*3, *([errorRate]*3)])\n",
    "        for i in range(len(fault)):\n",
    "            if fault[i]:\n",
    "                fault_dict.update({i: [fault[i], [model.qnn.ind[i][-1]]]})\n",
    "        return fault_dict\n",
    "    \n",
    "    for feed_dict in dataflow[\"train\"]:\n",
    "        inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "        targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "        fd = randomSample(model, errorRate)\n",
    "        outputs = model(inputs, fd)\n",
    "\n",
    "        loss = F.mse_loss(outputs, targets)\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    print(f\"loss: {loss.item()}\")\n",
    "\n",
    "\n",
    "def valid_test(dataflow, split, model, device):\n",
    "    target_all = []\n",
    "    output_all = []\n",
    "    with torch.no_grad():\n",
    "        for feed_dict in dataflow[split]:\n",
    "            inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "            targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            target_all.append(targets)\n",
    "            output_all.append(outputs)\n",
    "        target_all = torch.cat(target_all, dim=0)\n",
    "        output_all = torch.cat(output_all, dim=0)\n",
    "\n",
    "    loss = F.mse_loss(output_all, target_all)\n",
    "\n",
    "    print(f\"{split} set loss: {loss}\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, RL: 0.005\n",
      "loss: 0.3434562087059021\n",
      "valid set loss: 0.3838525712490082\n",
      "Epoch 2, RL: 0.004995066821070679\n",
      "loss: 0.345772385597229\n",
      "valid set loss: 0.3082601726055145\n",
      "Epoch 3, RL: 0.004980286753286194\n",
      "loss: 0.3059511184692383\n",
      "valid set loss: 0.254973828792572\n",
      "Epoch 4, RL: 0.004955718126821721\n",
      "loss: 0.1378580778837204\n",
      "valid set loss: 0.23403722047805786\n",
      "Epoch 5, RL: 0.004921457902821577\n",
      "loss: 0.1685188114643097\n",
      "valid set loss: 0.20811191201210022\n",
      "Epoch 6, RL: 0.004877641290737884\n",
      "loss: 0.18736401200294495\n",
      "valid set loss: 0.16812658309936523\n",
      "Epoch 7, RL: 0.004824441214720628\n",
      "loss: 0.39771580696105957\n",
      "valid set loss: 0.12492452561855316\n",
      "Epoch 8, RL: 0.004762067631165049\n",
      "loss: 0.0820598155260086\n",
      "valid set loss: 0.08983724564313889\n",
      "Epoch 9, RL: 0.0046907667001096585\n",
      "loss: 0.06750781089067459\n",
      "valid set loss: 0.06243018060922623\n",
      "Epoch 10, RL: 0.004610819813755038\n",
      "loss: 0.03464501351118088\n",
      "valid set loss: 0.04423852637410164\n",
      "Epoch 11, RL: 0.0045225424859373685\n",
      "loss: 0.022733183577656746\n",
      "valid set loss: 0.03045322559773922\n",
      "Epoch 12, RL: 0.004426283106939473\n",
      "loss: 0.018247516825795174\n",
      "valid set loss: 0.021372636780142784\n",
      "Epoch 13, RL: 0.004322421568553529\n",
      "loss: 0.01996535062789917\n",
      "valid set loss: 0.01830790378153324\n",
      "Epoch 14, RL: 0.004211367764821722\n",
      "loss: 0.00935937650501728\n",
      "valid set loss: 0.014457300305366516\n",
      "Epoch 15, RL: 0.004093559974371725\n",
      "loss: 0.005455560516566038\n",
      "valid set loss: 0.009347736835479736\n",
      "Epoch 16, RL: 0.003969463130731184\n",
      "loss: 0.0059296428225934505\n",
      "valid set loss: 0.00607415521517396\n",
      "Epoch 17, RL: 0.003839566987447492\n",
      "loss: 0.0053586894646286964\n",
      "valid set loss: 0.005429355893284082\n",
      "Epoch 18, RL: 0.0037043841852542887\n",
      "loss: 0.004275184590369463\n",
      "valid set loss: 0.0034476660657674074\n",
      "Epoch 19, RL: 0.0035644482289126822\n",
      "loss: 0.15444159507751465\n",
      "valid set loss: 0.002790957922115922\n",
      "Epoch 20, RL: 0.0034203113817116958\n",
      "loss: 0.003328327089548111\n",
      "valid set loss: 0.0030881697311997414\n",
      "Epoch 21, RL: 0.003272542485937369\n",
      "loss: 0.2578026056289673\n",
      "valid set loss: 0.0023822789080441\n",
      "Epoch 22, RL: 0.003121724717912137\n",
      "loss: 0.9306357502937317\n",
      "valid set loss: 0.0015445708995684981\n",
      "Epoch 23, RL: 0.0029684532864643126\n",
      "loss: 0.0014330395497381687\n",
      "valid set loss: 0.0011054638307541609\n",
      "Epoch 24, RL: 0.002813333083910762\n",
      "loss: 0.0016059590270742774\n",
      "valid set loss: 0.0013607027940452099\n",
      "Epoch 25, RL: 0.0026569762988232844\n",
      "loss: 0.0011624512262642384\n",
      "valid set loss: 0.0011560601415112615\n",
      "Epoch 26, RL: 0.002500000000000001\n",
      "loss: 0.00296210334636271\n",
      "valid set loss: 0.0023334412835538387\n",
      "Epoch 27, RL: 0.0023430237011767175\n",
      "loss: 0.003383681643754244\n",
      "valid set loss: 0.0023993276990950108\n",
      "Epoch 28, RL: 0.0021866669160892404\n",
      "loss: 0.0011799097992479801\n",
      "valid set loss: 0.0013040225021541119\n",
      "Epoch 29, RL: 0.0020315467135356893\n",
      "loss: 0.0009745023562572896\n",
      "valid set loss: 0.0008011300815269351\n",
      "Epoch 30, RL: 0.0018782752820878637\n",
      "loss: 0.001587160862982273\n",
      "valid set loss: 0.0011483661364763975\n",
      "Epoch 31, RL: 0.0017274575140626331\n",
      "loss: 0.0014084625290706754\n",
      "valid set loss: 0.0010184973943978548\n",
      "Epoch 32, RL: 0.0015796886182883067\n",
      "loss: 0.0009331104811280966\n",
      "valid set loss: 0.0009916425915434957\n",
      "Epoch 33, RL: 0.0014355517710873192\n",
      "loss: 0.0013524118112400174\n",
      "valid set loss: 0.0010686872992664576\n",
      "Epoch 34, RL: 0.001295615814745712\n",
      "loss: 0.0010673736687749624\n",
      "valid set loss: 0.0010302108712494373\n",
      "Epoch 35, RL: 0.0011604330125525083\n",
      "loss: 0.0004959983052685857\n",
      "valid set loss: 0.0008601943263784051\n",
      "Epoch 36, RL: 0.0010305368692688178\n",
      "loss: 0.0006264373078010976\n",
      "valid set loss: 0.0008267730008810759\n",
      "Epoch 37, RL: 0.000906440025628276\n",
      "loss: 0.8019968271255493\n",
      "valid set loss: 0.0008156037074513733\n",
      "Epoch 38, RL: 0.0007886322351782785\n",
      "loss: 0.0010790766682475805\n",
      "valid set loss: 0.0008929299074225128\n",
      "Epoch 39, RL: 0.0006775784314464719\n",
      "loss: 0.0014636617852374911\n",
      "valid set loss: 0.0010030993726104498\n",
      "Epoch 40, RL: 0.0005737168930605274\n",
      "loss: 0.0012684534303843975\n",
      "valid set loss: 0.001031074090860784\n",
      "Epoch 41, RL: 0.0004774575140626318\n",
      "loss: 0.7711633443832397\n",
      "valid set loss: 0.0010021731723099947\n",
      "Epoch 42, RL: 0.000389180186244963\n",
      "loss: 0.0013197604566812515\n",
      "valid set loss: 0.0011971049243584275\n",
      "Epoch 43, RL: 0.00030923329989034114\n",
      "loss: 0.4076334834098816\n",
      "valid set loss: 0.0014120331034064293\n",
      "Epoch 44, RL: 0.0002379323688349517\n",
      "loss: 0.001931183971464634\n",
      "valid set loss: 0.0015430503990501165\n",
      "Epoch 45, RL: 0.0001755587852793717\n",
      "loss: 0.8634570837020874\n",
      "valid set loss: 0.0015600189799442887\n",
      "Epoch 46, RL: 0.00012235870926211623\n",
      "loss: 0.002282551256939769\n",
      "valid set loss: 0.0015809324104338884\n",
      "Epoch 47, RL: 7.854209717842261e-05\n",
      "loss: 1.5309031009674072\n",
      "valid set loss: 0.001581291202455759\n",
      "Epoch 48, RL: 4.428187317827821e-05\n",
      "loss: 0.0023176628164947033\n",
      "valid set loss: 0.0015687741106376052\n",
      "Epoch 49, RL: 1.9713246713805593e-05\n",
      "loss: 0.001644133124500513\n",
      "valid set loss: 0.0015696340706199408\n",
      "Epoch 50, RL: 4.933178929321104e-06\n",
      "loss: 0.0023471431341022253\n",
      "valid set loss: 0.0015677615301683545\n",
      "valid set loss: 0.0015677615301683545\n"
     ]
    }
   ],
   "source": [
    "er = 0.01\n",
    "dataflow = torch.load('dataflow_sin2x1cosx2.pt')\n",
    "\n",
    "use_cuda = torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "model = QModel(n_wires=args.n_wires, n_blocks=args.n_blocks).to(device)\n",
    "\n",
    "n_epochs = args.epochs\n",
    "optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "for epoch in range(1, n_epochs + 1):\n",
    "    # train\n",
    "    print(f\"Epoch {epoch}, RL: {optimizer.param_groups[0]['lr']}\")\n",
    "    train(dataflow, model, device, optimizer, er)\n",
    "\n",
    "    # valid\n",
    "    valid_test(dataflow, \"valid\", model, device)\n",
    "    scheduler.step()\n",
    "\n",
    "# final valid\n",
    "valid_test(dataflow, \"valid\", model, device)\n",
    "\n",
    "torch.save(model.qnn, 'sincosRegression_NATH.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ours"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from ea_task import OneGen_task\n",
    "import geatpy as ea\n",
    "from new_gates import PauliX, PauliY, PauliZ\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "ERROR_DICT = {1: PauliX().to(device), 2: PauliY().to(device), 3: PauliZ().to(device)}\n",
    "\n",
    "def gen_fd(var, model):\n",
    "    fd = {}\n",
    "    for i, v in enumerate(var):\n",
    "        if v != 0:\n",
    "            fd[i] = [ERROR_DICT[v], [model.qnn.ind[i][-1]]]\n",
    "    return fd\n",
    "\n",
    "def aim(var, model, inputs, targets):\n",
    "    fd = {}\n",
    "    for i, v in enumerate(var):\n",
    "        if v != 0:\n",
    "            fd[i] = [ERROR_DICT[v], [model.qnn.ind[i][-1]]]\n",
    "    with torch.no_grad():\n",
    "        outputs = model(inputs, fd)\n",
    "        loss = F.mse_loss(outputs, targets).item()\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataflow, model, device, optimizer, args):\n",
    "    loss1_acc = 0\n",
    "    loss2_acc = 0\n",
    "    for feed_dict in dataflow[\"train\"]:\n",
    "        inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "        targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "        outputs1 = model(inputs)\n",
    "        loss1 = F.mse_loss(outputs1, targets)\n",
    "        loss1_acc += loss1.item()\n",
    "\n",
    "        if args.EA:\n",
    "            args.Chrom, args.best_pop = OneGen_task(args.N, args.M, args.K, args.NIND, args.selS, args.recS, args.mutS, args.FieldD, \\\n",
    "                                                    model, inputs, targets, args.aim, args.Chrom, args.pc, args.Encoding)\n",
    "            fd = gen_fd(args.best_pop, model)\n",
    "            outputs2 = model(inputs, fd)\n",
    "            loss2 = F.mse_loss(outputs2, targets)\n",
    "            loss2_acc += loss2.item()\n",
    "            loss = loss1 + args.lambda_ * loss2\n",
    "        else:\n",
    "            loss = loss1\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    return loss1_acc, loss2_acc\n",
    "\n",
    "def valid_test(dataflow, split, model, device):\n",
    "    target_all = []\n",
    "    output_all = []\n",
    "    with torch.no_grad():\n",
    "        for feed_dict in dataflow[split]:\n",
    "            inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "            targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            target_all.append(targets)\n",
    "            output_all.append(outputs)\n",
    "        target_all = torch.cat(target_all, dim=0)\n",
    "        output_all = torch.cat(output_all, dim=0)\n",
    "\n",
    "    loss = F.mse_loss(output_all, target_all)\n",
    "\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "args.N = args.n_layers * args.n_wires * 2\n",
    "args.M = 1\n",
    "args.NIND = 20\n",
    "args.K = 0.1 * args.NIND    # 10% of the population for elite selection\n",
    "args.selS = 'etour'\n",
    "args.recS = 'xovdp' \n",
    "args.mutS = 'mutbin' \n",
    "args.Encoding = 'BG' \n",
    "args.pc = 0.8   \n",
    "args.EA = False\n",
    "args.lambda_ = 0.5\n",
    "\n",
    "ranges = np.array([[0, 3]] * args.N).T\n",
    "borders = np.ones_like(ranges)\n",
    "varTypes = np.array([1]*args.N) \n",
    "codes = [0] * args.N \n",
    "precisions =[0] * args.N\n",
    "scales = [0] * args.N \n",
    "\n",
    "args.FieldD = ea.crtfld(args.Encoding,varTypes,ranges,borders,precisions,codes,scales)\n",
    "args.aim = aim\n",
    "args.Chrom = ea.crtpc(args.Encoding, args.NIND, args.FieldD)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "505e8515a343457ab489c88c35ba6cfd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss1 2.3747442215681076, loss2 0, loss 0.21184714138507843\n",
      "epoch 2, loss1 1.0769352614879608, loss2 0, loss 0.07579193264245987\n",
      "epoch 3, loss1 0.41179318726062775, loss2 0, loss 0.014887610450387001\n",
      "epoch 4, loss1 0.05327105172909796, loss2 0, loss 0.017805995419621468\n",
      "epoch 5, loss1 0.08584151044487953, loss2 0, loss 0.0074418517760932446\n",
      "epoch 6, loss1 0.039262036327272654, loss2 0, loss 0.0012174543226137757\n",
      "epoch 7, loss1 0.00654327601660043, loss2 0, loss 0.0020704239141196012\n",
      "epoch 8, loss1 0.010762696503661573, loss2 0, loss 0.0006027935887686908\n",
      "epoch 9, loss1 0.0032446702534798533, loss2 0, loss 7.493961311411113e-05\n",
      "epoch 10, loss1 0.04001012273511151, loss2 10.50772511959076, loss 0.027960048988461494\n",
      "epoch 11, loss1 0.27977791521698236, loss2 9.569510340690613, loss 0.061398591846227646\n",
      "epoch 12, loss1 0.3856155499815941, loss2 7.971769332885742, loss 0.051147621124982834\n",
      "epoch 13, loss1 0.3584023080766201, loss2 6.741701364517212, loss 0.017986062914133072\n",
      "epoch 14, loss1 0.25945943407714367, loss2 7.6559595465660095, loss 0.028646286576986313\n",
      "epoch 15, loss1 0.48505886644124985, loss2 6.2096768617630005, loss 0.07318605482578278\n",
      "epoch 16, loss1 0.5072643458843231, loss2 6.169585347175598, loss 0.06346048414707184\n",
      "epoch 17, loss1 0.3531945012509823, loss2 7.130469441413879, loss 0.05356709286570549\n",
      "epoch 18, loss1 0.4965226501226425, loss2 6.659513354301453, loss 0.06721600890159607\n",
      "epoch 19, loss1 0.7809989601373672, loss2 5.59862494468689, loss 0.06786634027957916\n",
      "epoch 20, loss1 0.7206755317747593, loss2 5.244321286678314, loss 0.0710878074169159\n",
      "epoch 21, loss1 0.5343837291002274, loss2 5.993255615234375, loss 0.0428304523229599\n",
      "epoch 22, loss1 0.3809167183935642, loss2 5.548148453235626, loss 0.0426313579082489\n",
      "epoch 23, loss1 0.5558761991560459, loss2 5.350879192352295, loss 0.06852774322032928\n",
      "epoch 24, loss1 0.6620575487613678, loss2 4.490971088409424, loss 0.06274360418319702\n",
      "epoch 25, loss1 0.49922261014580727, loss2 4.70625227689743, loss 0.050466712564229965\n",
      "epoch 26, loss1 0.4592488408088684, loss2 4.634773910045624, loss 0.05763396993279457\n",
      "epoch 27, loss1 0.4671628922224045, loss2 4.809020608663559, loss 0.04776061326265335\n",
      "epoch 28, loss1 0.45179997757077217, loss2 4.670274555683136, loss 0.04310467094182968\n",
      "epoch 29, loss1 0.37543728575110435, loss2 4.995038866996765, loss 0.035497572273015976\n",
      "epoch 30, loss1 0.3583251163363457, loss2 4.837058514356613, loss 0.04442350193858147\n",
      "epoch 31, loss1 0.40092405676841736, loss2 5.207985579967499, loss 0.03756032511591911\n",
      "epoch 32, loss1 0.3260419890284538, loss2 5.055136561393738, loss 0.03774084150791168\n",
      "epoch 33, loss1 0.3217012919485569, loss2 4.820952951908112, loss 0.039066266268491745\n",
      "epoch 34, loss1 0.33290633745491505, loss2 4.583374381065369, loss 0.038892701268196106\n",
      "epoch 35, loss1 0.3340195491909981, loss2 4.581292629241943, loss 0.03930818289518356\n",
      "epoch 36, loss1 0.3349405638873577, loss2 4.717834770679474, loss 0.038310859352350235\n",
      "epoch 37, loss1 0.31881493143737316, loss2 4.723408818244934, loss 0.035707004368305206\n",
      "epoch 38, loss1 0.30825841799378395, loss2 4.4431600868701935, loss 0.03462604805827141\n",
      "epoch 39, loss1 0.32062174938619137, loss2 4.5507771372795105, loss 0.03586353361606598\n",
      "epoch 40, loss1 0.3284459263086319, loss2 4.2570092380046844, loss 0.03788779303431511\n",
      "epoch 41, loss1 0.3556923232972622, loss2 4.201562464237213, loss 0.03965554013848305\n",
      "epoch 42, loss1 0.3624202013015747, loss2 4.3685241639614105, loss 0.03906961902976036\n",
      "epoch 43, loss1 0.34579141810536385, loss2 4.2742369174957275, loss 0.038467131555080414\n",
      "epoch 44, loss1 0.3456166833639145, loss2 4.1315324902534485, loss 0.03855607286095619\n",
      "epoch 45, loss1 0.3576799277216196, loss2 4.302948534488678, loss 0.038678739219903946\n",
      "epoch 46, loss1 0.36519773304462433, loss2 4.604855448007584, loss 0.0389089472591877\n",
      "epoch 47, loss1 0.36394644156098366, loss2 4.446034163236618, loss 0.038980063050985336\n",
      "epoch 48, loss1 0.36152147874236107, loss2 4.4170438051223755, loss 0.038981206715106964\n",
      "epoch 49, loss1 0.35959088057279587, loss2 4.341779351234436, loss 0.03900817781686783\n",
      "epoch 50, loss1 0.35872136801481247, loss2 4.464401841163635, loss 0.038998108357191086\n"
     ]
    }
   ],
   "source": [
    "lambda_ = 0.3\n",
    "dataflow = torch.load('dataflow_sin2x1cosx2.pt')\n",
    "\n",
    "model = QModel(n_wires=args.n_wires, n_blocks=args.n_layers).to(device)\n",
    "\n",
    "n_epochs = 50\n",
    "optimizer = optim.Adam(model.parameters(), lr=3e-2, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)\n",
    "\n",
    "args.EA = False\n",
    "loss1_trace = []\n",
    "loss2_trace = []\n",
    "itr = range(1, n_epochs + 1)\n",
    "for epoch in tqdm(itr):\n",
    "    if epoch == 10:\n",
    "        args.EA = True\n",
    "    l1, l2 = train(dataflow, model, device, optimizer, args)\n",
    "    loss1_trace.append(l1)\n",
    "    loss2_trace.append(l2)\n",
    "    scheduler.step()\n",
    "    print(f\"epoch {epoch}, loss1 {l1}, loss2 {l2}, loss {valid_test(dataflow, 'valid', model, device)}\")\n",
    "\n",
    "# torch.save(model.qnn, 'sincosRegression_ours.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(model.qnn, 'models/sincosRegression_ours.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# One gate error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid_test(dataflow, split, model, device, fs=dict()):\n",
    "    target_all = []\n",
    "    output_all = []\n",
    "    with torch.no_grad():\n",
    "        for feed_dict in dataflow[split]:\n",
    "            inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "            targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "            outputs = model(inputs, fs)\n",
    "\n",
    "            target_all.append(targets)\n",
    "            output_all.append(outputs)\n",
    "        target_all = torch.cat(target_all, dim=0)\n",
    "        output_all = torch.cat(output_all, dim=0)\n",
    "\n",
    "    loss = F.mse_loss(output_all, target_all)\n",
    "\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5428720116615295\n"
     ]
    }
   ],
   "source": [
    "dataflow = torch.load('dataflow_sin2x1cosx2.pt')\n",
    "qnn = torch.load('models/sincosRegression_ours.pt')\n",
    "\n",
    "model = QModel(n_wires=args.n_wires, n_blocks=args.n_layers, qnn=qnn).to(device)\n",
    "\n",
    "result_collection = []\n",
    "for err in [tq.PauliX(), tq.PauliY(), tq.PauliZ()]:\n",
    "    col = []\n",
    "    for i in range(len(model.qnn.queue)):\n",
    "        fault_dict = {i: [err, [model.qnn.ind[i][-1]]]}\n",
    "        col.append(valid_test(dataflow, 'valid', model, device, fault_dict))\n",
    "    result_collection.append(col)\n",
    "result_collection = np.array(result_collection)\n",
    "print(result_collection.max())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Noisy loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid_test(dataflow, split, model, device, IndOnQubits, errorRates):\n",
    "    def randomSample(IndOnQubits, errorRates):\n",
    "        fault_dict = dict()\n",
    "        for i in range(len(IndOnQubits)):\n",
    "            # note that this method is not closed, because the new error occuring on a certain queue index will cover the old one if existed.\n",
    "            subSample = np.random.choice([None, tq.PauliX(), tq.PauliY(), tq.PauliZ()], size=IndOnQubits[i].shape, p=[1-errorRates[i]*3, *([errorRates[i]]*3)])\n",
    "            indices = np.argwhere(subSample!=None)\n",
    "            for ind in indices:\n",
    "                ind = ind[0]\n",
    "                op = subSample[ind]\n",
    "                fault_dict.update({IndOnQubits[i,ind]:[op, [i]]})\n",
    "        return fault_dict\n",
    "    \n",
    "    target_all = []\n",
    "    output_all = []\n",
    "    with torch.no_grad():\n",
    "        for feed_dict in dataflow[split]:\n",
    "            inputs = feed_dict[\"states\"].to(device).to(torch.complex64)\n",
    "            targets = feed_dict[\"Xlabel\"].to(device).to(torch.float)\n",
    "\n",
    "            fs = randomSample(IndOnQubits, errorRates)\n",
    "            outputs = model(inputs, fs)\n",
    "\n",
    "            target_all.append(targets)\n",
    "            output_all.append(outputs)\n",
    "        target_all = torch.cat(target_all, dim=0)\n",
    "        output_all = torch.cat(output_all, dim=0)\n",
    "\n",
    "    loss = F.mse_loss(output_all, target_all)\n",
    "\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Queue indices of each qubit:\n",
      "[ 0  3  4  7  8 11]\n",
      "[ 1  2  5  6  9 10]\n"
     ]
    }
   ],
   "source": [
    "dataflow = torch.load('dataflow_sin2x1cosx2.pt')\n",
    "qnn = torch.load('models/sincosRegression_ours.pt')\n",
    "\n",
    "model = QModel(n_wires=args.n_wires, n_blocks=args.n_layers, qnn=qnn).to(device)\n",
    "\n",
    "# collect the number of gates on each qubit\n",
    "IndOnQubits = [[] for _ in range(2)]\n",
    "for i, q in enumerate(model.qnn.ind):\n",
    "    ind = q[-1]\n",
    "    IndOnQubits[ind].append(i)\n",
    "IndOnQubits = np.array(IndOnQubits)\n",
    "print('Queue indices of each qubit:', *IndOnQubits, sep='\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.1345624104142189 0.08752446743015553\n"
     ]
    }
   ],
   "source": [
    "# ers = [0.001, 0.0005]\n",
    "# ers = [0.005, 0.001]\n",
    "# ers = [0.02, 0.01]\n",
    "\n",
    "# mu, std = 0.001, 0.0005\n",
    "# mu, std = 0.005, 0.002\n",
    "mu, std = 0.01, 0.005\n",
    "\n",
    "result_collection = []\n",
    "for _ in range(10):\n",
    "    p = np.random.randn() * std + mu\n",
    "    p = min(max(p, 0), 1/3)\n",
    "    result_collection.append(valid_test(dataflow, 'valid', model, device, IndOnQubits, [p]*2))\n",
    "result_collection = np.array(result_collection)\n",
    "print(result_collection.mean(), result_collection.std())"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "TorchQuantum",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
