{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import libraries\n",
    "import numpy as np\n",
    "import paddle\n",
    "import paddle_quantum\n",
    "from paddle_quantum.ansatz import Circuit\n",
    "from paddle_quantum.state import zero_state\n",
    "from numpy import pi as PI\n",
    "from paddle import matmul, transpose, reshape\n",
    "from paddle_quantum.qinfo import pauli_str_to_matrix\n",
    "from paddle_quantum.linalg import dagger\n",
    "from paddle_quantum.dataset import *\n",
    "from paddle_quantum.loss import ExpecVal\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Observable(n):\n",
    "    r\"\"\"\n",
    "    :param n: bit number\n",
    "    :return: local observable: Z \\otimes I \\otimes ...\\otimes I\n",
    "    \"\"\"\n",
    "    Ob = pauli_str_to_matrix([[1.0, 'z0']], n)\n",
    "\n",
    "    return Ob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Opt_Classifier(paddle_quantum.Operator):\n",
    "    def __init__(self, n, depth,  seed_paras=1):\n",
    "\n",
    "        super(Opt_Classifier, self).__init__()\n",
    "        self.n = n\n",
    "        self.depth = depth\n",
    "        paddle.seed(seed_paras)\n",
    "        \n",
    "        self.para = self.create_parameter(\n",
    "            shape=[(depth)*n*2],\n",
    "            default_initializer=paddle.nn.initializer.Uniform(0, 2*np.pi),\n",
    "            dtype='float32')        \n",
    "\n",
    "\n",
    "    def forward(self, data, label):\n",
    "        \"\"\"\n",
    "        input: data: input data unitary, shape: [BATCH, 1, 2^n]\n",
    "               label: shape: [BATCH, 1]\n",
    "        \"\"\"\n",
    "        Ob = paddle.to_tensor(Observable(self.n))\n",
    "        label_pp = reshape(paddle.to_tensor(label), [-1, 1])\n",
    "        All_data = paddle.concat(data, axis=0)\n",
    "        state_in = reshape(zero_state(num_qubits=self.n).data, (-1, 1, 2**self.n))\n",
    "        state_out = state_in\n",
    "        count = 0\n",
    "        for _ in range(2):\n",
    "            for t in range(self.depth):\n",
    "                circuit = Circuit(self.n)    \n",
    "                circuit.ry([k for k in range(0, self.n)], param=self.para[count: count + self.n])\n",
    "                circuit.cnot()  \n",
    "                count += self.n\n",
    "\n",
    "            state_out = matmul(state_out, matmul(circuit.unitary_matrix().unsqueeze(0), All_data))\n",
    "\n",
    "        E_Z = matmul(matmul(state_out, Ob), transpose(paddle.conj(state_out), perm=[0, 2, 1]))      \n",
    "        state_predict = paddle.real(E_Z)[:, 0] * 0.5 + 0.5\n",
    "\n",
    "        loss = paddle.mean((state_predict - label_pp) ** 2)\n",
    "        is_correct = (paddle.abs(state_predict - label_pp) < 0.5).nonzero().shape[0]\n",
    "        acc = is_correct / label.shape[0]\n",
    "\n",
    "        return loss, acc, state_predict.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "def QClassifier(quantum_train_x, train_y, quantum_test_x, test_y, N, DEPTH, EPOCH, LR, seed):\n",
    "\n",
    "    net = Opt_Classifier(n=N, depth=DEPTH, seed_paras=seed)\n",
    "    summary_iter, summary_test_acc, summary_train_acc = [], [], []\n",
    "    summary_test_loss, summary_train_loss = [], []\n",
    "\n",
    "    # SGD \n",
    "    opt = paddle.optimizer.SGD(learning_rate=LR, parameters=net.parameters())\n",
    "    \n",
    "    # optimize\n",
    "    for ep in range(EPOCH):        \n",
    "\n",
    "        loss, train_acc, state_predict_useless= net(data=quantum_train_x ,label=train_y)\n",
    "        loss_useless, test_acc, state_predict_useless = net(data=quantum_test_x ,label=test_y)\n",
    "        loss.backward()\n",
    "        opt.minimize(loss)\n",
    "        opt.clear_grad()\n",
    "\n",
    "        if ep % 50 == 0:\n",
    "            print(\"epoch:\", ep,\n",
    "                    \"loss: %.4f\" % loss.numpy(),\n",
    "                    \"train acc: %.4f\" % train_acc,\n",
    "                    \"test acc: %.4f\" % test_acc)\n",
    "        \n",
    "        summary_train_loss.append(loss[0].item())  \n",
    "        summary_test_loss.append(loss_useless[0].item())\n",
    "        \n",
    "        summary_train_acc.append(train_acc)  \n",
    "        summary_test_acc.append(test_acc)       \n",
    "\n",
    "    return summary_train_loss, summary_test_loss, summary_train_acc, summary_test_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 179,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[15918 67086 33505 58941 44123]\n"
     ]
    }
   ],
   "source": [
    "v = 5\n",
    "ita = 0.01\n",
    "ep = 1000\n",
    "seed = np.random.randint(0, high=1e5, size=[v], dtype=int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_acc(testing_data_num, classes):\n",
    "    num_qubit = 4\n",
    "\n",
    "    val_dataset = MNIST(mode='test', encoding='angle_encoding', num_qubits=num_qubit, classes=classes,\n",
    "                        data_num=testing_data_num,need_cropping=True,\n",
    "                        downscaling_method='resize', target_dimension=16,return_state=True, seed=0)\n",
    "\n",
    "    quantum_test_x, test_circuit, test_y = val_dataset.quantum_image_states, val_dataset.quantum_image_circuits,val_dataset.labels\n",
    "    \n",
    "    return quantum_test_x, test_circuit, test_y\n",
    "\n",
    "\n",
    "def train_acc(training_data_num, classes):\n",
    "    num_qubit = 4\n",
    "\n",
    "\n",
    "    train_accset = MNIST(mode='train', encoding='angle_encoding', num_qubits=num_qubit, classes=classes,\n",
    "                        data_num=training_data_num,need_cropping=True,\n",
    "                        downscaling_method='resize', target_dimension=16, return_state=True, seed=0)\n",
    "\n",
    "    quantum_train_x, train_circuit, train_y = train_accset.quantum_image_states, train_accset.quantum_image_circuits, train_accset.labels\n",
    "\n",
    "    \n",
    "    return quantum_train_x, train_circuit, train_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [],
   "source": [
    "quantum_test_x, test_circuit, test_y = test_acc(2000, [0, 1]) \n",
    "quantum_train_x, train_circuit, train_y = train_acc(500, [0, 1]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = []\n",
    "for i in range(0, len(train_y)):\n",
    "    cir_train = Circuit(4)\n",
    "    cir_train.extend(train_circuit[i][0])\n",
    "    train_x.append(cir_train.unitary_matrix().unsqueeze(0))\n",
    "\n",
    "test_x = []\n",
    "for i in range(len(test_y)):\n",
    "    cir_test = Circuit(4)\n",
    "    cir_test.extend(test_circuit[i][0])\n",
    "    test_x.append(cir_test.unitary_matrix().unsqueeze(0))  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 loss: 0.2141 train acc: 0.7180 test acc: 0.7170\n",
      "epoch: 50 loss: 0.2092 train acc: 0.7160 test acc: 0.7170\n",
      "epoch: 100 loss: 0.2056 train acc: 0.7120 test acc: 0.7140\n",
      "epoch: 150 loss: 0.2029 train acc: 0.7040 test acc: 0.7115\n",
      "epoch: 200 loss: 0.2009 train acc: 0.7120 test acc: 0.7065\n",
      "epoch: 250 loss: 0.1994 train acc: 0.7140 test acc: 0.7070\n",
      "epoch: 300 loss: 0.1982 train acc: 0.7140 test acc: 0.7100\n",
      "epoch: 350 loss: 0.1972 train acc: 0.7080 test acc: 0.7085\n",
      "epoch: 400 loss: 0.1965 train acc: 0.7080 test acc: 0.7070\n",
      "epoch: 450 loss: 0.1959 train acc: 0.7060 test acc: 0.7070\n",
      "epoch: 500 loss: 0.1954 train acc: 0.7060 test acc: 0.7065\n",
      "epoch: 550 loss: 0.1950 train acc: 0.7060 test acc: 0.7080\n",
      "epoch: 600 loss: 0.1947 train acc: 0.7060 test acc: 0.7070\n",
      "epoch: 650 loss: 0.1944 train acc: 0.7060 test acc: 0.7070\n",
      "epoch: 700 loss: 0.1941 train acc: 0.7060 test acc: 0.7075\n",
      "epoch: 750 loss: 0.1939 train acc: 0.7120 test acc: 0.7075\n",
      "epoch: 800 loss: 0.1937 train acc: 0.7140 test acc: 0.7075\n",
      "epoch: 850 loss: 0.1935 train acc: 0.7140 test acc: 0.7075\n",
      "epoch: 900 loss: 0.1934 train acc: 0.7140 test acc: 0.7060\n",
      "epoch: 950 loss: 0.1932 train acc: 0.7160 test acc: 0.7060\n",
      "epoch: 0 loss: 0.2368 train acc: 0.5680 test acc: 0.5940\n",
      "epoch: 50 loss: 0.2315 train acc: 0.5760 test acc: 0.6060\n",
      "epoch: 100 loss: 0.2268 train acc: 0.5780 test acc: 0.6160\n",
      "epoch: 150 loss: 0.2224 train acc: 0.6380 test acc: 0.6570\n",
      "epoch: 200 loss: 0.2185 train acc: 0.6660 test acc: 0.6595\n",
      "epoch: 250 loss: 0.2151 train acc: 0.6740 test acc: 0.6640\n",
      "epoch: 300 loss: 0.2120 train acc: 0.6740 test acc: 0.6675\n",
      "epoch: 350 loss: 0.2094 train acc: 0.6720 test acc: 0.6725\n",
      "epoch: 400 loss: 0.2071 train acc: 0.6780 test acc: 0.6795\n",
      "epoch: 450 loss: 0.2051 train acc: 0.6900 test acc: 0.6860\n",
      "epoch: 500 loss: 0.2034 train acc: 0.6940 test acc: 0.6935\n",
      "epoch: 550 loss: 0.2020 train acc: 0.7140 test acc: 0.7105\n",
      "epoch: 600 loss: 0.2008 train acc: 0.7180 test acc: 0.7105\n",
      "epoch: 650 loss: 0.1997 train acc: 0.7180 test acc: 0.7105\n",
      "epoch: 700 loss: 0.1988 train acc: 0.7200 test acc: 0.7100\n",
      "epoch: 750 loss: 0.1981 train acc: 0.7220 test acc: 0.7100\n",
      "epoch: 800 loss: 0.1974 train acc: 0.7180 test acc: 0.7105\n",
      "epoch: 850 loss: 0.1968 train acc: 0.7180 test acc: 0.7090\n",
      "epoch: 900 loss: 0.1964 train acc: 0.7200 test acc: 0.7075\n",
      "epoch: 950 loss: 0.1959 train acc: 0.7180 test acc: 0.7075\n",
      "epoch: 0 loss: 0.1923 train acc: 0.7200 test acc: 0.7000\n",
      "epoch: 50 loss: 0.1922 train acc: 0.7220 test acc: 0.7000\n",
      "epoch: 100 loss: 0.1921 train acc: 0.7220 test acc: 0.7000\n",
      "epoch: 150 loss: 0.1920 train acc: 0.7220 test acc: 0.7010\n",
      "epoch: 200 loss: 0.1920 train acc: 0.7220 test acc: 0.7015\n",
      "epoch: 250 loss: 0.1919 train acc: 0.7220 test acc: 0.7015\n",
      "epoch: 300 loss: 0.1918 train acc: 0.7240 test acc: 0.7015\n",
      "epoch: 350 loss: 0.1918 train acc: 0.7240 test acc: 0.7015\n",
      "epoch: 400 loss: 0.1917 train acc: 0.7260 test acc: 0.7015\n",
      "epoch: 450 loss: 0.1917 train acc: 0.7260 test acc: 0.7015\n",
      "epoch: 500 loss: 0.1916 train acc: 0.7260 test acc: 0.7005\n",
      "epoch: 550 loss: 0.1915 train acc: 0.7260 test acc: 0.7005\n",
      "epoch: 600 loss: 0.1915 train acc: 0.7260 test acc: 0.7015\n",
      "epoch: 650 loss: 0.1914 train acc: 0.7260 test acc: 0.7025\n",
      "epoch: 700 loss: 0.1914 train acc: 0.7260 test acc: 0.7025\n",
      "epoch: 750 loss: 0.1913 train acc: 0.7260 test acc: 0.7025\n",
      "epoch: 800 loss: 0.1913 train acc: 0.7260 test acc: 0.7020\n",
      "epoch: 850 loss: 0.1912 train acc: 0.7260 test acc: 0.7020\n",
      "epoch: 900 loss: 0.1912 train acc: 0.7260 test acc: 0.7020\n",
      "epoch: 950 loss: 0.1911 train acc: 0.7260 test acc: 0.7025\n",
      "epoch: 0 loss: 0.3362 train acc: 0.3120 test acc: 0.3185\n",
      "epoch: 50 loss: 0.3100 train acc: 0.3120 test acc: 0.3155\n",
      "epoch: 100 loss: 0.2884 train acc: 0.3560 test acc: 0.3525\n",
      "epoch: 150 loss: 0.2709 train acc: 0.3840 test acc: 0.4070\n",
      "epoch: 200 loss: 0.2568 train acc: 0.5680 test acc: 0.5715\n",
      "epoch: 250 loss: 0.2455 train acc: 0.6480 test acc: 0.6295\n",
      "epoch: 300 loss: 0.2365 train acc: 0.7220 test acc: 0.7135\n",
      "epoch: 350 loss: 0.2293 train acc: 0.7160 test acc: 0.7090\n",
      "epoch: 400 loss: 0.2236 train acc: 0.7200 test acc: 0.7075\n",
      "epoch: 450 loss: 0.2190 train acc: 0.7240 test acc: 0.7025\n",
      "epoch: 500 loss: 0.2154 train acc: 0.7220 test acc: 0.7020\n",
      "epoch: 550 loss: 0.2124 train acc: 0.7160 test acc: 0.7005\n",
      "epoch: 600 loss: 0.2100 train acc: 0.7100 test acc: 0.7000\n",
      "epoch: 650 loss: 0.2081 train acc: 0.7080 test acc: 0.6985\n",
      "epoch: 700 loss: 0.2065 train acc: 0.7080 test acc: 0.7000\n",
      "epoch: 750 loss: 0.2052 train acc: 0.7120 test acc: 0.7005\n",
      "epoch: 800 loss: 0.2041 train acc: 0.7120 test acc: 0.6995\n",
      "epoch: 850 loss: 0.2032 train acc: 0.7120 test acc: 0.7000\n",
      "epoch: 900 loss: 0.2025 train acc: 0.7140 test acc: 0.6990\n",
      "epoch: 950 loss: 0.2018 train acc: 0.7160 test acc: 0.7000\n",
      "epoch: 0 loss: 0.2610 train acc: 0.5000 test acc: 0.4960\n",
      "epoch: 50 loss: 0.2466 train acc: 0.6560 test acc: 0.6780\n",
      "epoch: 100 loss: 0.2353 train acc: 0.7140 test acc: 0.7260\n",
      "epoch: 150 loss: 0.2266 train acc: 0.7160 test acc: 0.7230\n",
      "epoch: 200 loss: 0.2199 train acc: 0.7220 test acc: 0.7215\n",
      "epoch: 250 loss: 0.2147 train acc: 0.7280 test acc: 0.7195\n",
      "epoch: 300 loss: 0.2107 train acc: 0.7320 test acc: 0.7185\n",
      "epoch: 350 loss: 0.2076 train acc: 0.7320 test acc: 0.7155\n",
      "epoch: 400 loss: 0.2052 train acc: 0.7340 test acc: 0.7145\n",
      "epoch: 450 loss: 0.2034 train acc: 0.7360 test acc: 0.7120\n",
      "epoch: 500 loss: 0.2019 train acc: 0.7320 test acc: 0.7110\n",
      "epoch: 550 loss: 0.2007 train acc: 0.7280 test acc: 0.7085\n",
      "epoch: 600 loss: 0.1998 train acc: 0.7160 test acc: 0.7095\n",
      "epoch: 650 loss: 0.1991 train acc: 0.7140 test acc: 0.7090\n",
      "epoch: 700 loss: 0.1985 train acc: 0.7120 test acc: 0.7085\n",
      "epoch: 750 loss: 0.1980 train acc: 0.7120 test acc: 0.7075\n",
      "epoch: 800 loss: 0.1976 train acc: 0.7100 test acc: 0.7070\n",
      "epoch: 850 loss: 0.1973 train acc: 0.7080 test acc: 0.7065\n",
      "epoch: 900 loss: 0.1970 train acc: 0.7060 test acc: 0.7060\n",
      "epoch: 950 loss: 0.1968 train acc: 0.7060 test acc: 0.7045\n",
      "epoch: 0 loss: 0.2494 train acc: 0.5680 test acc: 0.5740\n",
      "epoch: 50 loss: 0.2464 train acc: 0.5820 test acc: 0.5820\n",
      "epoch: 100 loss: 0.2433 train acc: 0.5820 test acc: 0.5875\n",
      "epoch: 150 loss: 0.2402 train acc: 0.5820 test acc: 0.5910\n",
      "epoch: 200 loss: 0.2371 train acc: 0.5960 test acc: 0.5950\n",
      "epoch: 250 loss: 0.2339 train acc: 0.6320 test acc: 0.6250\n",
      "epoch: 300 loss: 0.2308 train acc: 0.6440 test acc: 0.6335\n",
      "epoch: 350 loss: 0.2276 train acc: 0.6520 test acc: 0.6375\n",
      "epoch: 400 loss: 0.2244 train acc: 0.6720 test acc: 0.6450\n",
      "epoch: 450 loss: 0.2213 train acc: 0.6840 test acc: 0.6605\n",
      "epoch: 500 loss: 0.2183 train acc: 0.6880 test acc: 0.6665\n",
      "epoch: 550 loss: 0.2154 train acc: 0.6840 test acc: 0.6755\n",
      "epoch: 600 loss: 0.2127 train acc: 0.6920 test acc: 0.6815\n",
      "epoch: 650 loss: 0.2101 train acc: 0.6940 test acc: 0.6890\n",
      "epoch: 700 loss: 0.2077 train acc: 0.7120 test acc: 0.6975\n",
      "epoch: 750 loss: 0.2054 train acc: 0.7280 test acc: 0.7110\n",
      "epoch: 800 loss: 0.2034 train acc: 0.7240 test acc: 0.7120\n",
      "epoch: 850 loss: 0.2016 train acc: 0.7280 test acc: 0.7145\n",
      "epoch: 900 loss: 0.2000 train acc: 0.7200 test acc: 0.7165\n",
      "epoch: 950 loss: 0.1985 train acc: 0.7280 test acc: 0.7150\n",
      "epoch: 0 loss: 0.2396 train acc: 0.6320 test acc: 0.6420\n",
      "epoch: 50 loss: 0.2352 train acc: 0.6300 test acc: 0.6480\n",
      "epoch: 100 loss: 0.2312 train acc: 0.6340 test acc: 0.6565\n",
      "epoch: 150 loss: 0.2277 train acc: 0.6400 test acc: 0.6620\n",
      "epoch: 200 loss: 0.2245 train acc: 0.6520 test acc: 0.6715\n",
      "epoch: 250 loss: 0.2217 train acc: 0.6600 test acc: 0.6755\n",
      "epoch: 300 loss: 0.2193 train acc: 0.6740 test acc: 0.6825\n",
      "epoch: 350 loss: 0.2171 train acc: 0.6820 test acc: 0.6880\n",
      "epoch: 400 loss: 0.2152 train acc: 0.6880 test acc: 0.6955\n",
      "epoch: 450 loss: 0.2136 train acc: 0.6880 test acc: 0.6960\n",
      "epoch: 500 loss: 0.2121 train acc: 0.6960 test acc: 0.7015\n",
      "epoch: 550 loss: 0.2108 train acc: 0.6960 test acc: 0.6995\n",
      "epoch: 600 loss: 0.2096 train acc: 0.6980 test acc: 0.6985\n",
      "epoch: 650 loss: 0.2085 train acc: 0.6980 test acc: 0.6995\n",
      "epoch: 700 loss: 0.2076 train acc: 0.7000 test acc: 0.7020\n",
      "epoch: 750 loss: 0.2067 train acc: 0.6980 test acc: 0.7020\n",
      "epoch: 800 loss: 0.2059 train acc: 0.7000 test acc: 0.7025\n",
      "epoch: 850 loss: 0.2051 train acc: 0.7000 test acc: 0.7030\n",
      "epoch: 900 loss: 0.2044 train acc: 0.7000 test acc: 0.7025\n",
      "epoch: 950 loss: 0.2038 train acc: 0.7000 test acc: 0.7025\n",
      "epoch: 0 loss: 0.2257 train acc: 0.7000 test acc: 0.7160\n",
      "epoch: 50 loss: 0.2203 train acc: 0.6980 test acc: 0.7175\n",
      "epoch: 100 loss: 0.2161 train acc: 0.7020 test acc: 0.7170\n",
      "epoch: 150 loss: 0.2128 train acc: 0.6980 test acc: 0.7160\n",
      "epoch: 200 loss: 0.2101 train acc: 0.7080 test acc: 0.7150\n",
      "epoch: 250 loss: 0.2080 train acc: 0.7120 test acc: 0.7115\n",
      "epoch: 300 loss: 0.2063 train acc: 0.7100 test acc: 0.7100\n",
      "epoch: 350 loss: 0.2049 train acc: 0.7100 test acc: 0.7105\n",
      "epoch: 400 loss: 0.2037 train acc: 0.7120 test acc: 0.7115\n",
      "epoch: 450 loss: 0.2026 train acc: 0.7120 test acc: 0.7100\n",
      "epoch: 500 loss: 0.2017 train acc: 0.7120 test acc: 0.7085\n",
      "epoch: 550 loss: 0.2010 train acc: 0.7100 test acc: 0.7070\n",
      "epoch: 600 loss: 0.2003 train acc: 0.7100 test acc: 0.7075\n",
      "epoch: 650 loss: 0.1996 train acc: 0.7100 test acc: 0.7065\n",
      "epoch: 700 loss: 0.1991 train acc: 0.7060 test acc: 0.7065\n",
      "epoch: 750 loss: 0.1986 train acc: 0.7080 test acc: 0.7055\n",
      "epoch: 800 loss: 0.1981 train acc: 0.7080 test acc: 0.7055\n",
      "epoch: 850 loss: 0.1977 train acc: 0.7080 test acc: 0.7070\n",
      "epoch: 900 loss: 0.1972 train acc: 0.7100 test acc: 0.7050\n",
      "epoch: 950 loss: 0.1969 train acc: 0.7080 test acc: 0.7040\n",
      "epoch: 0 loss: 0.2234 train acc: 0.7280 test acc: 0.7160\n",
      "epoch: 50 loss: 0.2198 train acc: 0.7220 test acc: 0.7135\n",
      "epoch: 100 loss: 0.2170 train acc: 0.7220 test acc: 0.7135\n",
      "epoch: 150 loss: 0.2149 train acc: 0.7200 test acc: 0.7090\n",
      "epoch: 200 loss: 0.2133 train acc: 0.7220 test acc: 0.7080\n",
      "epoch: 250 loss: 0.2120 train acc: 0.7200 test acc: 0.7080\n",
      "epoch: 300 loss: 0.2110 train acc: 0.7180 test acc: 0.7085\n",
      "epoch: 350 loss: 0.2102 train acc: 0.7200 test acc: 0.7090\n",
      "epoch: 400 loss: 0.2095 train acc: 0.7220 test acc: 0.7075\n",
      "epoch: 450 loss: 0.2090 train acc: 0.7220 test acc: 0.7080\n",
      "epoch: 500 loss: 0.2086 train acc: 0.7200 test acc: 0.7055\n",
      "epoch: 550 loss: 0.2082 train acc: 0.7180 test acc: 0.7060\n",
      "epoch: 600 loss: 0.2079 train acc: 0.7180 test acc: 0.7055\n",
      "epoch: 650 loss: 0.2077 train acc: 0.7200 test acc: 0.7050\n",
      "epoch: 700 loss: 0.2075 train acc: 0.7200 test acc: 0.7055\n",
      "epoch: 750 loss: 0.2073 train acc: 0.7200 test acc: 0.7050\n",
      "epoch: 800 loss: 0.2071 train acc: 0.7200 test acc: 0.7050\n",
      "epoch: 850 loss: 0.2069 train acc: 0.7200 test acc: 0.7045\n",
      "epoch: 900 loss: 0.2068 train acc: 0.7200 test acc: 0.7035\n",
      "epoch: 950 loss: 0.2067 train acc: 0.7180 test acc: 0.7030\n",
      "epoch: 0 loss: 0.5851 train acc: 0.3020 test acc: 0.3150\n",
      "epoch: 50 loss: 0.5671 train acc: 0.2940 test acc: 0.3180\n",
      "epoch: 100 loss: 0.5426 train acc: 0.2900 test acc: 0.3165\n",
      "epoch: 150 loss: 0.5120 train acc: 0.2860 test acc: 0.3170\n",
      "epoch: 200 loss: 0.4768 train acc: 0.2800 test acc: 0.3105\n",
      "epoch: 250 loss: 0.4401 train acc: 0.2840 test acc: 0.3080\n",
      "epoch: 300 loss: 0.4048 train acc: 0.2840 test acc: 0.3045\n",
      "epoch: 350 loss: 0.3732 train acc: 0.2820 test acc: 0.3005\n",
      "epoch: 400 loss: 0.3463 train acc: 0.2800 test acc: 0.2915\n",
      "epoch: 450 loss: 0.3241 train acc: 0.2800 test acc: 0.2885\n",
      "epoch: 500 loss: 0.3060 train acc: 0.2920 test acc: 0.2955\n",
      "epoch: 550 loss: 0.2914 train acc: 0.3840 test acc: 0.4060\n",
      "epoch: 600 loss: 0.2794 train acc: 0.6220 test acc: 0.6280\n",
      "epoch: 650 loss: 0.2697 train acc: 0.6580 test acc: 0.6630\n",
      "epoch: 700 loss: 0.2617 train acc: 0.6640 test acc: 0.6735\n",
      "epoch: 750 loss: 0.2551 train acc: 0.6620 test acc: 0.6785\n",
      "epoch: 800 loss: 0.2495 train acc: 0.6580 test acc: 0.6795\n",
      "epoch: 850 loss: 0.2448 train acc: 0.6620 test acc: 0.6795\n",
      "epoch: 900 loss: 0.2408 train acc: 0.6620 test acc: 0.6775\n",
      "epoch: 950 loss: 0.2374 train acc: 0.6620 test acc: 0.6780\n",
      "epoch: 0 loss: 0.5274 train acc: 0.2880 test acc: 0.3120\n",
      "epoch: 50 loss: 0.4890 train acc: 0.2760 test acc: 0.3110\n",
      "epoch: 100 loss: 0.4449 train acc: 0.2820 test acc: 0.3025\n",
      "epoch: 150 loss: 0.3991 train acc: 0.2920 test acc: 0.2980\n",
      "epoch: 200 loss: 0.3560 train acc: 0.2960 test acc: 0.2895\n",
      "epoch: 250 loss: 0.3186 train acc: 0.2800 test acc: 0.2895\n",
      "epoch: 300 loss: 0.2885 train acc: 0.3040 test acc: 0.3220\n",
      "epoch: 350 loss: 0.2652 train acc: 0.4920 test acc: 0.5140\n",
      "epoch: 400 loss: 0.2479 train acc: 0.6620 test acc: 0.6495\n",
      "epoch: 450 loss: 0.2352 train acc: 0.7000 test acc: 0.6855\n",
      "epoch: 500 loss: 0.2259 train acc: 0.7100 test acc: 0.6970\n",
      "epoch: 550 loss: 0.2191 train acc: 0.7100 test acc: 0.6945\n",
      "epoch: 600 loss: 0.2141 train acc: 0.7160 test acc: 0.6985\n",
      "epoch: 650 loss: 0.2104 train acc: 0.7160 test acc: 0.6945\n",
      "epoch: 700 loss: 0.2076 train acc: 0.7140 test acc: 0.6950\n",
      "epoch: 750 loss: 0.2055 train acc: 0.7140 test acc: 0.6975\n",
      "epoch: 800 loss: 0.2039 train acc: 0.7080 test acc: 0.6970\n",
      "epoch: 850 loss: 0.2026 train acc: 0.7100 test acc: 0.6975\n",
      "epoch: 900 loss: 0.2016 train acc: 0.7100 test acc: 0.6960\n",
      "epoch: 950 loss: 0.2008 train acc: 0.7080 test acc: 0.6960\n",
      "epoch: 0 loss: 0.2915 train acc: 0.4840 test acc: 0.4355\n",
      "epoch: 50 loss: 0.2850 train acc: 0.5020 test acc: 0.4490\n",
      "epoch: 100 loss: 0.2789 train acc: 0.5180 test acc: 0.4675\n",
      "epoch: 150 loss: 0.2731 train acc: 0.5280 test acc: 0.4815\n",
      "epoch: 200 loss: 0.2676 train acc: 0.5460 test acc: 0.4930\n",
      "epoch: 250 loss: 0.2623 train acc: 0.5620 test acc: 0.5085\n",
      "epoch: 300 loss: 0.2572 train acc: 0.5720 test acc: 0.5290\n",
      "epoch: 350 loss: 0.2524 train acc: 0.5820 test acc: 0.5475\n",
      "epoch: 400 loss: 0.2477 train acc: 0.6000 test acc: 0.5680\n",
      "epoch: 450 loss: 0.2434 train acc: 0.6080 test acc: 0.5970\n",
      "epoch: 500 loss: 0.2393 train acc: 0.6360 test acc: 0.6295\n",
      "epoch: 550 loss: 0.2355 train acc: 0.6780 test acc: 0.6610\n",
      "epoch: 600 loss: 0.2321 train acc: 0.6840 test acc: 0.6670\n",
      "epoch: 650 loss: 0.2289 train acc: 0.6900 test acc: 0.6695\n",
      "epoch: 700 loss: 0.2260 train acc: 0.6980 test acc: 0.6730\n",
      "epoch: 750 loss: 0.2235 train acc: 0.7000 test acc: 0.6790\n",
      "epoch: 800 loss: 0.2212 train acc: 0.7000 test acc: 0.6820\n",
      "epoch: 850 loss: 0.2191 train acc: 0.7000 test acc: 0.6825\n",
      "epoch: 900 loss: 0.2173 train acc: 0.7060 test acc: 0.6850\n",
      "epoch: 950 loss: 0.2157 train acc: 0.6980 test acc: 0.6875\n",
      "epoch: 0 loss: 0.3053 train acc: 0.4700 test acc: 0.4470\n",
      "epoch: 50 loss: 0.2822 train acc: 0.4940 test acc: 0.4940\n",
      "epoch: 100 loss: 0.2638 train acc: 0.5320 test acc: 0.5495\n",
      "epoch: 150 loss: 0.2495 train acc: 0.5640 test acc: 0.5895\n",
      "epoch: 200 loss: 0.2386 train acc: 0.6020 test acc: 0.6300\n",
      "epoch: 250 loss: 0.2303 train acc: 0.7060 test acc: 0.7040\n",
      "epoch: 300 loss: 0.2240 train acc: 0.7180 test acc: 0.7095\n",
      "epoch: 350 loss: 0.2193 train acc: 0.7260 test acc: 0.7075\n",
      "epoch: 400 loss: 0.2157 train acc: 0.7280 test acc: 0.7060\n",
      "epoch: 450 loss: 0.2130 train acc: 0.7280 test acc: 0.7065\n",
      "epoch: 500 loss: 0.2109 train acc: 0.7280 test acc: 0.7065\n",
      "epoch: 550 loss: 0.2093 train acc: 0.7300 test acc: 0.7075\n",
      "epoch: 600 loss: 0.2080 train acc: 0.7300 test acc: 0.7065\n",
      "epoch: 650 loss: 0.2070 train acc: 0.7320 test acc: 0.7040\n",
      "epoch: 700 loss: 0.2063 train acc: 0.7320 test acc: 0.7035\n",
      "epoch: 750 loss: 0.2057 train acc: 0.7320 test acc: 0.7020\n",
      "epoch: 800 loss: 0.2052 train acc: 0.7340 test acc: 0.7020\n",
      "epoch: 850 loss: 0.2048 train acc: 0.7300 test acc: 0.7015\n",
      "epoch: 900 loss: 0.2044 train acc: 0.7260 test acc: 0.7015\n",
      "epoch: 950 loss: 0.2041 train acc: 0.7260 test acc: 0.7005\n",
      "epoch: 0 loss: 0.2124 train acc: 0.7180 test acc: 0.7030\n",
      "epoch: 50 loss: 0.2108 train acc: 0.7140 test acc: 0.7005\n",
      "epoch: 100 loss: 0.2095 train acc: 0.7140 test acc: 0.7000\n",
      "epoch: 150 loss: 0.2084 train acc: 0.7100 test acc: 0.7000\n",
      "epoch: 200 loss: 0.2075 train acc: 0.7100 test acc: 0.7000\n",
      "epoch: 250 loss: 0.2067 train acc: 0.7120 test acc: 0.6990\n",
      "epoch: 300 loss: 0.2060 train acc: 0.7120 test acc: 0.6985\n",
      "epoch: 350 loss: 0.2053 train acc: 0.7160 test acc: 0.6990\n",
      "epoch: 400 loss: 0.2048 train acc: 0.7160 test acc: 0.6980\n",
      "epoch: 450 loss: 0.2043 train acc: 0.7160 test acc: 0.6985\n",
      "epoch: 500 loss: 0.2039 train acc: 0.7160 test acc: 0.7000\n",
      "epoch: 550 loss: 0.2035 train acc: 0.7140 test acc: 0.7015\n",
      "epoch: 600 loss: 0.2032 train acc: 0.7140 test acc: 0.7015\n",
      "epoch: 650 loss: 0.2029 train acc: 0.7160 test acc: 0.7015\n",
      "epoch: 700 loss: 0.2026 train acc: 0.7160 test acc: 0.7015\n",
      "epoch: 750 loss: 0.2023 train acc: 0.7160 test acc: 0.7015\n",
      "epoch: 800 loss: 0.2020 train acc: 0.7200 test acc: 0.7020\n",
      "epoch: 850 loss: 0.2018 train acc: 0.7180 test acc: 0.7015\n",
      "epoch: 900 loss: 0.2015 train acc: 0.7180 test acc: 0.7005\n",
      "epoch: 950 loss: 0.2013 train acc: 0.7180 test acc: 0.7005\n",
      "epoch: 0 loss: 0.2235 train acc: 0.6660 test acc: 0.6785\n",
      "epoch: 50 loss: 0.2228 train acc: 0.6680 test acc: 0.6780\n",
      "epoch: 100 loss: 0.2220 train acc: 0.6680 test acc: 0.6800\n",
      "epoch: 150 loss: 0.2213 train acc: 0.6700 test acc: 0.6805\n",
      "epoch: 200 loss: 0.2204 train acc: 0.6740 test acc: 0.6805\n",
      "epoch: 250 loss: 0.2196 train acc: 0.6760 test acc: 0.6815\n",
      "epoch: 300 loss: 0.2188 train acc: 0.6740 test acc: 0.6815\n",
      "epoch: 350 loss: 0.2179 train acc: 0.6700 test acc: 0.6815\n",
      "epoch: 400 loss: 0.2170 train acc: 0.6700 test acc: 0.6810\n",
      "epoch: 450 loss: 0.2162 train acc: 0.6700 test acc: 0.6810\n",
      "epoch: 500 loss: 0.2153 train acc: 0.6740 test acc: 0.6820\n",
      "epoch: 550 loss: 0.2144 train acc: 0.6740 test acc: 0.6845\n",
      "epoch: 600 loss: 0.2136 train acc: 0.6700 test acc: 0.6845\n",
      "epoch: 650 loss: 0.2127 train acc: 0.6700 test acc: 0.6855\n",
      "epoch: 700 loss: 0.2119 train acc: 0.6700 test acc: 0.6845\n",
      "epoch: 750 loss: 0.2110 train acc: 0.6700 test acc: 0.6855\n",
      "epoch: 800 loss: 0.2102 train acc: 0.6720 test acc: 0.6855\n",
      "epoch: 850 loss: 0.2094 train acc: 0.6720 test acc: 0.6855\n",
      "epoch: 900 loss: 0.2087 train acc: 0.6740 test acc: 0.6865\n",
      "epoch: 950 loss: 0.2079 train acc: 0.6740 test acc: 0.6865\n"
     ]
    }
   ],
   "source": [
    "train_acc_K2, test_acc_K2 = [], []\n",
    "train_acc_K8, test_acc_K8 = [], []\n",
    "train_acc_K16, test_acc_K16 = [], []\n",
    "train_acc_K4, test_acc_K4 = [], []\n",
    "\n",
    "train_loss_K2, test_loss_K2 = [], []\n",
    "train_loss_K8, test_loss_K8 = [], []\n",
    "train_loss_K16, test_loss_K16 = [], []\n",
    "train_loss_K4, test_loss_K4 = [], []\n",
    "\n",
    "for trial in range(v):\n",
    "\n",
    "    tr_loss, te_loss, tr_acc, te_acc = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = 2,\n",
    "      EPOCH = ep,\n",
    "      LR = ita,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "\n",
    "    train_acc_K2.append(tr_acc)\n",
    "    test_acc_K2.append(te_acc)\n",
    "    \n",
    "    train_loss_K2.append(tr_loss)\n",
    "    test_loss_K2.append(te_loss)\n",
    "  \n",
    "\n",
    "    tr_loss, te_loss, tr_acc, te_acc = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = 8,\n",
    "      EPOCH = ep,\n",
    "      LR = ita,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "\n",
    "    train_acc_K8.append(tr_acc)\n",
    "    test_acc_K8.append(te_acc)\n",
    "    \n",
    "    train_loss_K8.append(tr_loss)\n",
    "    test_loss_K8.append(te_loss)\n",
    "    \n",
    "    tr_loss, te_loss, tr_acc, te_acc = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = 16,\n",
    "      EPOCH = ep,\n",
    "      LR = ita,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "\n",
    "    train_acc_K16.append(tr_acc)\n",
    "    test_acc_K16.append(te_acc)\n",
    "    \n",
    "    train_loss_K16.append(tr_loss)\n",
    "    test_loss_K16.append(te_loss)\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "paddle_quantum_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
