{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import paddle\n",
    "import paddle_quantum\n",
    "from paddle_quantum.ansatz import Circuit\n",
    "from paddle_quantum.state import zero_state\n",
    "from numpy import pi as PI\n",
    "from paddle import matmul, transpose, reshape\n",
    "from paddle_quantum.qinfo import pauli_str_to_matrix\n",
    "from paddle_quantum.linalg import dagger\n",
    "from paddle_quantum.dataset import *\n",
    "from paddle_quantum.loss import ExpecVal\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Observable(n):\n",
    "    r\"\"\"\n",
    "    :param n: bit number\n",
    "    :return: local observable: Z \\otimes I \\otimes ...\\otimes I\n",
    "    \"\"\"\n",
    "    Ob = pauli_str_to_matrix([[1.0, 'Z0']], n)\n",
    "\n",
    "    return Ob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Opt_Classifier(paddle_quantum.Operator):\n",
    "    def __init__(self, n, depth,  seed_paras=1):\n",
    "        super(Opt_Classifier, self).__init__()\n",
    "        self.n = n\n",
    "        self.depth = depth\n",
    "        paddle.seed(seed_paras)\n",
    "        \n",
    "        self.para = self.create_parameter(\n",
    "            shape=[(depth+3)*n],\n",
    "            default_initializer=paddle.nn.initializer.Uniform(0, 2*np.pi),\n",
    "            dtype='float32')        \n",
    "\n",
    "\n",
    "    # forward func\n",
    "    def forward(self, data, label):\n",
    "        \"\"\"\n",
    "        input: data: input data unitary, shape: [BATCH, 1, 2^n]\n",
    "               label: shape: [BATCH, 1]\n",
    "        \"\"\"\n",
    "        Ob = paddle.to_tensor(Observable(self.n))\n",
    "        label_pp = reshape(paddle.to_tensor(label), [-1, 1])\n",
    "        All_data = paddle.concat(data, axis=0)\n",
    "        state_in = reshape(zero_state(num_qubits=self.n).data, (-1, 1, 2**self.n))\n",
    "        state_out = state_in\n",
    "        count = 0\n",
    "\n",
    "        # set up circuit\n",
    "        for _ in range(self.depth):\n",
    "            circuit = Circuit(self.n)    \n",
    "            circuit.ry([k for k in range(0, self.n)], param=self.para[count: count + self.n])                       \n",
    "            circuit.cnot()    \n",
    "            count += self.n\n",
    "            state_out = matmul(state_out, matmul(circuit.unitary_matrix().unsqueeze(0), All_data))\n",
    "\n",
    "        E_Z = matmul(matmul(state_out, Ob), transpose(paddle.conj(state_out), perm=[0, 2, 1]))      \n",
    "        state_predict = paddle.real(E_Z)[:, 0] * 0.5 + 0.5\n",
    "\n",
    "        loss = paddle.mean((state_predict - label_pp) ** 2)\n",
    "        is_correct = (paddle.abs(state_predict - label_pp) < 0.5).nonzero().shape[0]\n",
    "        acc = is_correct / label.shape[0]\n",
    "\n",
    "        return loss, acc, state_predict.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def QClassifier(quantum_train_x, train_y, quantum_test_x, test_y, N, DEPTH, EPOCH, LR, seed):\n",
    "\n",
    "    net = Opt_Classifier(n=N, depth=DEPTH, seed_paras=seed)\n",
    "    summary_iter, summary_test_acc, summary_train_acc = [], [], []\n",
    "    summary_iter, summary_test_loss, summary_train_loss = [], [], []\n",
    " \n",
    "    # SGD \n",
    "    opt = paddle.optimizer.SGD(learning_rate=LR, parameters=net.parameters())\n",
    "    \n",
    "    # optimize\n",
    "    for ep in range(EPOCH):        \n",
    "\n",
    "        loss, train_acc, state_predict_useless= net(data=quantum_train_x ,label=train_y)\n",
    "        loss_useless, test_acc, state_predict_useless = net(data=quantum_test_x ,label=test_y)\n",
    "        loss.backward()\n",
    "        opt.minimize(loss)\n",
    "        opt.clear_grad()\n",
    "\n",
    "        if ep % 50 == 0:\n",
    "            print(\"epoch:\", ep,\n",
    "                    \"loss: %.4f\" % loss.numpy(),\n",
    "                    \"train acc: %.4f\" % train_acc,\n",
    "                    \"test acc: %.4f\" % test_acc)\n",
    "        \n",
    "        summary_train_loss.append(loss[0].item())  \n",
    "        summary_test_loss.append(loss_useless[0].item())\n",
    "        \n",
    "        summary_train_acc.append(train_acc)  \n",
    "        summary_test_acc.append(test_acc)        \n",
    "\n",
    "\n",
    "    return summary_test_acc, summary_train_acc, summary_test_loss, summary_train_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[52434 11598 92753 17886 15039 65757 50175 47641 27634 31686]\n"
     ]
    }
   ],
   "source": [
    "v = 10\n",
    "ep = 1000\n",
    "D = 16\n",
    "eta = 0.1\n",
    "seed = np.random.randint(0, high=1e5, size=[v], dtype=int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_data(testing_data_num, classes):\n",
    "    num_qubit = 4\n",
    "\n",
    "    val_dataset = FashionMNIST(mode='test', encoding='angle_encoding', num_qubits=num_qubit, classes=classes,\n",
    "                        data_num=testing_data_num,\n",
    "                        downscaling_method='resize', target_dimension=16, return_state=True, seed=788)\n",
    "\n",
    "    quantum_test_x, test_circuit, test_y = val_dataset.quantum_image_states, val_dataset.quantum_image_circuits,val_dataset.labels\n",
    "    \n",
    "    return quantum_test_x, test_circuit, test_y\n",
    "\n",
    "\n",
    "def train_data(training_data_num, classes):\n",
    "    num_qubit = 4\n",
    "\n",
    "\n",
    "    train_dataset = FashionMNIST(mode='train', encoding='angle_encoding', num_qubits=num_qubit, classes=classes,\n",
    "                        data_num=training_data_num,\n",
    "                        downscaling_method='resize', target_dimension=16, return_state=True, seed=6)\n",
    "\n",
    "    quantum_train_x, train_circuit, train_y = train_dataset.quantum_image_states, train_dataset.quantum_image_circuits, train_dataset.labels\n",
    "\n",
    "    \n",
    "    return quantum_train_x, train_circuit, train_y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "quantum_test_x, test_circuit, test_y = test_data(2000, [0, 1]) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_x = []\n",
    "for i in range(len(test_y)):\n",
    "    cir_test = Circuit(4)\n",
    "    cir_test.extend(test_circuit[i][0])\n",
    "    test_x.append(cir_test.unitary_matrix().unsqueeze(0))  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch: 0 loss: 0.3327 train acc: 0.4600 test acc: 0.4840\n",
      "epoch: 50 loss: 0.2344 train acc: 0.6500 test acc: 0.5030\n",
      "epoch: 100 loss: 0.2009 train acc: 0.6900 test acc: 0.5255\n",
      "epoch: 150 loss: 0.1877 train acc: 0.7000 test acc: 0.5190\n",
      "epoch: 200 loss: 0.1799 train acc: 0.7300 test acc: 0.5220\n",
      "epoch: 250 loss: 0.1742 train acc: 0.7400 test acc: 0.5220\n",
      "epoch: 300 loss: 0.1697 train acc: 0.7300 test acc: 0.5190\n",
      "epoch: 350 loss: 0.1660 train acc: 0.7400 test acc: 0.5200\n",
      "epoch: 400 loss: 0.1629 train acc: 0.7300 test acc: 0.5285\n",
      "epoch: 450 loss: 0.1604 train acc: 0.7300 test acc: 0.5315\n",
      "epoch: 500 loss: 0.1583 train acc: 0.7300 test acc: 0.5350\n",
      "epoch: 550 loss: 0.1566 train acc: 0.7400 test acc: 0.5355\n",
      "epoch: 600 loss: 0.1551 train acc: 0.7600 test acc: 0.5380\n",
      "epoch: 650 loss: 0.1539 train acc: 0.7600 test acc: 0.5410\n",
      "epoch: 700 loss: 0.1528 train acc: 0.7700 test acc: 0.5400\n",
      "epoch: 750 loss: 0.1518 train acc: 0.7700 test acc: 0.5440\n",
      "epoch: 800 loss: 0.1510 train acc: 0.7800 test acc: 0.5450\n",
      "epoch: 850 loss: 0.1502 train acc: 0.7800 test acc: 0.5460\n",
      "epoch: 900 loss: 0.1495 train acc: 0.7800 test acc: 0.5460\n",
      "epoch: 950 loss: 0.1488 train acc: 0.7900 test acc: 0.5465\n",
      "epoch: 0 loss: 0.3165 train acc: 0.4460 test acc: 0.4840\n",
      "epoch: 50 loss: 0.2719 train acc: 0.5420 test acc: 0.5325\n",
      "epoch: 100 loss: 0.2498 train acc: 0.6020 test acc: 0.5610\n",
      "epoch: 150 loss: 0.2381 train acc: 0.6280 test acc: 0.5880\n",
      "epoch: 200 loss: 0.2295 train acc: 0.6380 test acc: 0.5945\n",
      "epoch: 250 loss: 0.2218 train acc: 0.6560 test acc: 0.6115\n",
      "epoch: 300 loss: 0.2149 train acc: 0.6580 test acc: 0.6185\n",
      "epoch: 350 loss: 0.2093 train acc: 0.6640 test acc: 0.6255\n",
      "epoch: 400 loss: 0.2048 train acc: 0.6840 test acc: 0.6325\n",
      "epoch: 450 loss: 0.2014 train acc: 0.6920 test acc: 0.6415\n",
      "epoch: 500 loss: 0.1986 train acc: 0.7020 test acc: 0.6420\n",
      "epoch: 550 loss: 0.1964 train acc: 0.7180 test acc: 0.6410\n",
      "epoch: 600 loss: 0.1945 train acc: 0.7240 test acc: 0.6400\n",
      "epoch: 650 loss: 0.1928 train acc: 0.7140 test acc: 0.6365\n",
      "epoch: 700 loss: 0.1913 train acc: 0.7060 test acc: 0.6375\n",
      "epoch: 750 loss: 0.1899 train acc: 0.7000 test acc: 0.6390\n",
      "epoch: 800 loss: 0.1886 train acc: 0.7080 test acc: 0.6410\n",
      "epoch: 850 loss: 0.1874 train acc: 0.7020 test acc: 0.6455\n",
      "epoch: 900 loss: 0.1863 train acc: 0.6960 test acc: 0.6420\n",
      "epoch: 950 loss: 0.1853 train acc: 0.6960 test acc: 0.6475\n",
      "epoch: 0 loss: 0.3144 train acc: 0.4630 test acc: 0.4840\n",
      "epoch: 50 loss: 0.2728 train acc: 0.5400 test acc: 0.5325\n",
      "epoch: 100 loss: 0.2505 train acc: 0.5840 test acc: 0.5675\n",
      "epoch: 150 loss: 0.2390 train acc: 0.5950 test acc: 0.5860\n",
      "epoch: 200 loss: 0.2320 train acc: 0.6170 test acc: 0.6000\n",
      "epoch: 250 loss: 0.2268 train acc: 0.6280 test acc: 0.6065\n",
      "epoch: 300 loss: 0.2226 train acc: 0.6440 test acc: 0.6140\n",
      "epoch: 350 loss: 0.2190 train acc: 0.6520 test acc: 0.6205\n",
      "epoch: 400 loss: 0.2159 train acc: 0.6550 test acc: 0.6250\n",
      "epoch: 450 loss: 0.2132 train acc: 0.6650 test acc: 0.6310\n",
      "epoch: 500 loss: 0.2108 train acc: 0.6670 test acc: 0.6340\n",
      "epoch: 550 loss: 0.2087 train acc: 0.6770 test acc: 0.6400\n",
      "epoch: 600 loss: 0.2070 train acc: 0.6830 test acc: 0.6430\n",
      "epoch: 650 loss: 0.2054 train acc: 0.6920 test acc: 0.6450\n",
      "epoch: 700 loss: 0.2041 train acc: 0.6940 test acc: 0.6530\n",
      "epoch: 750 loss: 0.2030 train acc: 0.6940 test acc: 0.6555\n",
      "epoch: 800 loss: 0.2019 train acc: 0.6990 test acc: 0.6615\n",
      "epoch: 850 loss: 0.2010 train acc: 0.6970 test acc: 0.6615\n",
      "epoch: 900 loss: 0.2002 train acc: 0.6960 test acc: 0.6625\n",
      "epoch: 950 loss: 0.1995 train acc: 0.6950 test acc: 0.6650\n",
      "epoch: 0 loss: 0.3438 train acc: 0.4000 test acc: 0.4725\n",
      "epoch: 50 loss: 0.2842 train acc: 0.4900 test acc: 0.5165\n",
      "epoch: 100 loss: 0.2439 train acc: 0.5800 test acc: 0.5410\n",
      "epoch: 150 loss: 0.2214 train acc: 0.6400 test acc: 0.5465\n",
      "epoch: 200 loss: 0.2065 train acc: 0.6700 test acc: 0.5590\n",
      "epoch: 250 loss: 0.1958 train acc: 0.7400 test acc: 0.5535\n",
      "epoch: 300 loss: 0.1879 train acc: 0.7400 test acc: 0.5550\n",
      "epoch: 350 loss: 0.1817 train acc: 0.7400 test acc: 0.5520\n",
      "epoch: 400 loss: 0.1768 train acc: 0.7500 test acc: 0.5500\n",
      "epoch: 450 loss: 0.1726 train acc: 0.7400 test acc: 0.5510\n",
      "epoch: 500 loss: 0.1689 train acc: 0.7500 test acc: 0.5560\n",
      "epoch: 550 loss: 0.1657 train acc: 0.7600 test acc: 0.5555\n",
      "epoch: 600 loss: 0.1629 train acc: 0.7800 test acc: 0.5545\n",
      "epoch: 650 loss: 0.1605 train acc: 0.7700 test acc: 0.5560\n",
      "epoch: 700 loss: 0.1584 train acc: 0.7900 test acc: 0.5530\n",
      "epoch: 750 loss: 0.1566 train acc: 0.8100 test acc: 0.5540\n",
      "epoch: 800 loss: 0.1550 train acc: 0.8200 test acc: 0.5555\n",
      "epoch: 850 loss: 0.1536 train acc: 0.8300 test acc: 0.5550\n",
      "epoch: 900 loss: 0.1524 train acc: 0.8200 test acc: 0.5530\n",
      "epoch: 950 loss: 0.1514 train acc: 0.8200 test acc: 0.5560\n",
      "epoch: 0 loss: 0.2939 train acc: 0.4980 test acc: 0.4725\n",
      "epoch: 50 loss: 0.2609 train acc: 0.5700 test acc: 0.5300\n",
      "epoch: 100 loss: 0.2451 train acc: 0.5780 test acc: 0.5460\n",
      "epoch: 150 loss: 0.2343 train acc: 0.6100 test acc: 0.5540\n",
      "epoch: 200 loss: 0.2263 train acc: 0.6360 test acc: 0.5635\n",
      "epoch: 250 loss: 0.2200 train acc: 0.6460 test acc: 0.5670\n",
      "epoch: 300 loss: 0.2148 train acc: 0.6500 test acc: 0.5715\n",
      "epoch: 350 loss: 0.2102 train acc: 0.6660 test acc: 0.5760\n",
      "epoch: 400 loss: 0.2062 train acc: 0.6600 test acc: 0.5860\n",
      "epoch: 450 loss: 0.2027 train acc: 0.6720 test acc: 0.5965\n",
      "epoch: 500 loss: 0.1998 train acc: 0.6920 test acc: 0.5985\n",
      "epoch: 550 loss: 0.1973 train acc: 0.6960 test acc: 0.6055\n",
      "epoch: 600 loss: 0.1953 train acc: 0.6960 test acc: 0.6120\n",
      "epoch: 650 loss: 0.1937 train acc: 0.7040 test acc: 0.6130\n",
      "epoch: 700 loss: 0.1924 train acc: 0.7080 test acc: 0.6145\n",
      "epoch: 750 loss: 0.1914 train acc: 0.7120 test acc: 0.6165\n",
      "epoch: 800 loss: 0.1906 train acc: 0.7020 test acc: 0.6160\n",
      "epoch: 850 loss: 0.1900 train acc: 0.7040 test acc: 0.6150\n",
      "epoch: 900 loss: 0.1894 train acc: 0.7020 test acc: 0.6160\n",
      "epoch: 950 loss: 0.1889 train acc: 0.7020 test acc: 0.6175\n",
      "epoch: 0 loss: 0.2894 train acc: 0.5010 test acc: 0.4725\n",
      "epoch: 50 loss: 0.2633 train acc: 0.5610 test acc: 0.5230\n",
      "epoch: 100 loss: 0.2495 train acc: 0.5900 test acc: 0.5530\n",
      "epoch: 150 loss: 0.2416 train acc: 0.6100 test acc: 0.5545\n",
      "epoch: 200 loss: 0.2369 train acc: 0.6170 test acc: 0.5610\n",
      "epoch: 250 loss: 0.2338 train acc: 0.6300 test acc: 0.5655\n",
      "epoch: 300 loss: 0.2313 train acc: 0.6270 test acc: 0.5695\n",
      "epoch: 350 loss: 0.2292 train acc: 0.6360 test acc: 0.5755\n",
      "epoch: 400 loss: 0.2272 train acc: 0.6440 test acc: 0.5725\n",
      "epoch: 450 loss: 0.2252 train acc: 0.6530 test acc: 0.5760\n",
      "epoch: 500 loss: 0.2232 train acc: 0.6550 test acc: 0.5785\n",
      "epoch: 550 loss: 0.2213 train acc: 0.6530 test acc: 0.5780\n",
      "epoch: 600 loss: 0.2196 train acc: 0.6520 test acc: 0.5800\n",
      "epoch: 650 loss: 0.2179 train acc: 0.6540 test acc: 0.5855\n",
      "epoch: 700 loss: 0.2164 train acc: 0.6560 test acc: 0.5865\n",
      "epoch: 750 loss: 0.2150 train acc: 0.6590 test acc: 0.5855\n",
      "epoch: 800 loss: 0.2137 train acc: 0.6640 test acc: 0.5920\n",
      "epoch: 850 loss: 0.2124 train acc: 0.6690 test acc: 0.6005\n",
      "epoch: 900 loss: 0.2113 train acc: 0.6750 test acc: 0.6045\n",
      "epoch: 950 loss: 0.2101 train acc: 0.6820 test acc: 0.6085\n",
      "epoch: 0 loss: 0.3139 train acc: 0.4600 test acc: 0.5190\n",
      "epoch: 50 loss: 0.2390 train acc: 0.6100 test acc: 0.5600\n",
      "epoch: 100 loss: 0.2058 train acc: 0.6600 test acc: 0.5795\n",
      "epoch: 150 loss: 0.1876 train acc: 0.6800 test acc: 0.6090\n",
      "epoch: 200 loss: 0.1769 train acc: 0.7100 test acc: 0.6120\n",
      "epoch: 250 loss: 0.1700 train acc: 0.7500 test acc: 0.6110\n",
      "epoch: 300 loss: 0.1652 train acc: 0.7700 test acc: 0.6070\n",
      "epoch: 350 loss: 0.1615 train acc: 0.7700 test acc: 0.6090\n",
      "epoch: 400 loss: 0.1586 train acc: 0.7800 test acc: 0.6055\n",
      "epoch: 450 loss: 0.1561 train acc: 0.7700 test acc: 0.6000\n",
      "epoch: 500 loss: 0.1538 train acc: 0.7700 test acc: 0.5985\n",
      "epoch: 550 loss: 0.1517 train acc: 0.7700 test acc: 0.5930\n",
      "epoch: 600 loss: 0.1497 train acc: 0.7800 test acc: 0.5940\n",
      "epoch: 650 loss: 0.1478 train acc: 0.7900 test acc: 0.5885\n",
      "epoch: 700 loss: 0.1459 train acc: 0.7900 test acc: 0.5845\n",
      "epoch: 750 loss: 0.1442 train acc: 0.7900 test acc: 0.5835\n",
      "epoch: 800 loss: 0.1425 train acc: 0.8100 test acc: 0.5825\n",
      "epoch: 850 loss: 0.1410 train acc: 0.8200 test acc: 0.5780\n",
      "epoch: 900 loss: 0.1396 train acc: 0.8300 test acc: 0.5755\n",
      "epoch: 950 loss: 0.1383 train acc: 0.8200 test acc: 0.5745\n",
      "epoch: 0 loss: 0.3142 train acc: 0.4660 test acc: 0.5190\n",
      "epoch: 50 loss: 0.2582 train acc: 0.5640 test acc: 0.5620\n",
      "epoch: 100 loss: 0.2314 train acc: 0.6040 test acc: 0.5930\n",
      "epoch: 150 loss: 0.2192 train acc: 0.6280 test acc: 0.6035\n",
      "epoch: 200 loss: 0.2120 train acc: 0.6520 test acc: 0.6120\n",
      "epoch: 250 loss: 0.2069 train acc: 0.6640 test acc: 0.6160\n",
      "epoch: 300 loss: 0.2029 train acc: 0.6740 test acc: 0.6190\n",
      "epoch: 350 loss: 0.1995 train acc: 0.6840 test acc: 0.6265\n",
      "epoch: 400 loss: 0.1967 train acc: 0.6880 test acc: 0.6325\n",
      "epoch: 450 loss: 0.1942 train acc: 0.6980 test acc: 0.6320\n",
      "epoch: 500 loss: 0.1920 train acc: 0.6980 test acc: 0.6360\n",
      "epoch: 550 loss: 0.1900 train acc: 0.6960 test acc: 0.6380\n",
      "epoch: 600 loss: 0.1883 train acc: 0.6880 test acc: 0.6415\n",
      "epoch: 650 loss: 0.1868 train acc: 0.6920 test acc: 0.6435\n",
      "epoch: 700 loss: 0.1855 train acc: 0.6940 test acc: 0.6435\n",
      "epoch: 750 loss: 0.1844 train acc: 0.7020 test acc: 0.6415\n",
      "epoch: 800 loss: 0.1834 train acc: 0.7120 test acc: 0.6430\n",
      "epoch: 850 loss: 0.1825 train acc: 0.7200 test acc: 0.6455\n",
      "epoch: 900 loss: 0.1817 train acc: 0.7260 test acc: 0.6465\n",
      "epoch: 950 loss: 0.1810 train acc: 0.7240 test acc: 0.6455\n",
      "epoch: 0 loss: 0.2963 train acc: 0.4950 test acc: 0.5190\n",
      "epoch: 50 loss: 0.2629 train acc: 0.5400 test acc: 0.5450\n",
      "epoch: 100 loss: 0.2434 train acc: 0.5990 test acc: 0.5915\n",
      "epoch: 150 loss: 0.2321 train acc: 0.6150 test acc: 0.6025\n",
      "epoch: 200 loss: 0.2251 train acc: 0.6360 test acc: 0.6130\n",
      "epoch: 250 loss: 0.2202 train acc: 0.6440 test acc: 0.6130\n",
      "epoch: 300 loss: 0.2165 train acc: 0.6450 test acc: 0.6200\n",
      "epoch: 350 loss: 0.2136 train acc: 0.6620 test acc: 0.6235\n",
      "epoch: 400 loss: 0.2111 train acc: 0.6750 test acc: 0.6330\n",
      "epoch: 450 loss: 0.2090 train acc: 0.6750 test acc: 0.6335\n",
      "epoch: 500 loss: 0.2070 train acc: 0.6750 test acc: 0.6340\n",
      "epoch: 550 loss: 0.2053 train acc: 0.6830 test acc: 0.6345\n",
      "epoch: 600 loss: 0.2036 train acc: 0.6870 test acc: 0.6300\n",
      "epoch: 650 loss: 0.2022 train acc: 0.6930 test acc: 0.6320\n",
      "epoch: 700 loss: 0.2008 train acc: 0.6950 test acc: 0.6370\n",
      "epoch: 750 loss: 0.1996 train acc: 0.6940 test acc: 0.6390\n",
      "epoch: 800 loss: 0.1985 train acc: 0.6980 test acc: 0.6390\n",
      "epoch: 850 loss: 0.1975 train acc: 0.6990 test acc: 0.6410\n",
      "epoch: 900 loss: 0.1966 train acc: 0.7030 test acc: 0.6400\n",
      "epoch: 950 loss: 0.1957 train acc: 0.7070 test acc: 0.6460\n",
      "epoch: 0 loss: 0.3173 train acc: 0.5200 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2601 train acc: 0.5800 test acc: 0.5185\n",
      "epoch: 100 loss: 0.2144 train acc: 0.6400 test acc: 0.5085\n",
      "epoch: 150 loss: 0.1909 train acc: 0.6900 test acc: 0.5040\n",
      "epoch: 200 loss: 0.1791 train acc: 0.7500 test acc: 0.5035\n",
      "epoch: 250 loss: 0.1722 train acc: 0.7600 test acc: 0.5065\n",
      "epoch: 300 loss: 0.1675 train acc: 0.7600 test acc: 0.5125\n",
      "epoch: 350 loss: 0.1640 train acc: 0.7700 test acc: 0.5180\n",
      "epoch: 400 loss: 0.1612 train acc: 0.7800 test acc: 0.5230\n",
      "epoch: 450 loss: 0.1588 train acc: 0.7900 test acc: 0.5215\n",
      "epoch: 500 loss: 0.1567 train acc: 0.7900 test acc: 0.5280\n",
      "epoch: 550 loss: 0.1548 train acc: 0.8000 test acc: 0.5315\n",
      "epoch: 600 loss: 0.1531 train acc: 0.8100 test acc: 0.5335\n",
      "epoch: 650 loss: 0.1515 train acc: 0.8100 test acc: 0.5350\n",
      "epoch: 700 loss: 0.1500 train acc: 0.8100 test acc: 0.5395\n",
      "epoch: 750 loss: 0.1486 train acc: 0.8100 test acc: 0.5410\n",
      "epoch: 800 loss: 0.1472 train acc: 0.8100 test acc: 0.5485\n",
      "epoch: 850 loss: 0.1459 train acc: 0.8100 test acc: 0.5480\n",
      "epoch: 900 loss: 0.1447 train acc: 0.8100 test acc: 0.5485\n",
      "epoch: 950 loss: 0.1436 train acc: 0.8100 test acc: 0.5515\n",
      "epoch: 0 loss: 0.2936 train acc: 0.5380 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2740 train acc: 0.5660 test acc: 0.5425\n",
      "epoch: 100 loss: 0.2581 train acc: 0.5880 test acc: 0.5665\n",
      "epoch: 150 loss: 0.2461 train acc: 0.6220 test acc: 0.5810\n",
      "epoch: 200 loss: 0.2374 train acc: 0.6340 test acc: 0.5910\n",
      "epoch: 250 loss: 0.2312 train acc: 0.6480 test acc: 0.6000\n",
      "epoch: 300 loss: 0.2264 train acc: 0.6560 test acc: 0.6050\n",
      "epoch: 350 loss: 0.2225 train acc: 0.6600 test acc: 0.6040\n",
      "epoch: 400 loss: 0.2190 train acc: 0.6680 test acc: 0.6040\n",
      "epoch: 450 loss: 0.2158 train acc: 0.6680 test acc: 0.6030\n",
      "epoch: 500 loss: 0.2128 train acc: 0.6840 test acc: 0.6015\n",
      "epoch: 550 loss: 0.2100 train acc: 0.6820 test acc: 0.5995\n",
      "epoch: 600 loss: 0.2074 train acc: 0.6660 test acc: 0.5970\n",
      "epoch: 650 loss: 0.2050 train acc: 0.6700 test acc: 0.5960\n",
      "epoch: 700 loss: 0.2029 train acc: 0.6700 test acc: 0.5945\n",
      "epoch: 750 loss: 0.2009 train acc: 0.6740 test acc: 0.5975\n",
      "epoch: 800 loss: 0.1990 train acc: 0.6820 test acc: 0.5990\n",
      "epoch: 850 loss: 0.1972 train acc: 0.6820 test acc: 0.6035\n",
      "epoch: 900 loss: 0.1955 train acc: 0.6900 test acc: 0.6070\n",
      "epoch: 950 loss: 0.1939 train acc: 0.6900 test acc: 0.6130\n",
      "epoch: 0 loss: 0.3052 train acc: 0.5130 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2805 train acc: 0.5390 test acc: 0.5420\n",
      "epoch: 100 loss: 0.2612 train acc: 0.5720 test acc: 0.5675\n",
      "epoch: 150 loss: 0.2498 train acc: 0.5990 test acc: 0.5910\n",
      "epoch: 200 loss: 0.2424 train acc: 0.6120 test acc: 0.5995\n",
      "epoch: 250 loss: 0.2371 train acc: 0.6190 test acc: 0.5975\n",
      "epoch: 300 loss: 0.2330 train acc: 0.6230 test acc: 0.5965\n",
      "epoch: 350 loss: 0.2297 train acc: 0.6310 test acc: 0.6015\n",
      "epoch: 400 loss: 0.2269 train acc: 0.6380 test acc: 0.6070\n",
      "epoch: 450 loss: 0.2246 train acc: 0.6480 test acc: 0.6095\n",
      "epoch: 500 loss: 0.2225 train acc: 0.6540 test acc: 0.6135\n",
      "epoch: 550 loss: 0.2206 train acc: 0.6570 test acc: 0.6155\n",
      "epoch: 600 loss: 0.2188 train acc: 0.6630 test acc: 0.6175\n",
      "epoch: 650 loss: 0.2171 train acc: 0.6650 test acc: 0.6165\n",
      "epoch: 700 loss: 0.2155 train acc: 0.6670 test acc: 0.6190\n",
      "epoch: 750 loss: 0.2138 train acc: 0.6730 test acc: 0.6205\n",
      "epoch: 800 loss: 0.2122 train acc: 0.6740 test acc: 0.6260\n",
      "epoch: 850 loss: 0.2106 train acc: 0.6790 test acc: 0.6275\n",
      "epoch: 900 loss: 0.2091 train acc: 0.6790 test acc: 0.6290\n",
      "epoch: 950 loss: 0.2077 train acc: 0.6760 test acc: 0.6370\n",
      "epoch: 0 loss: 0.2959 train acc: 0.5000 test acc: 0.4835\n",
      "epoch: 50 loss: 0.2593 train acc: 0.5900 test acc: 0.4925\n",
      "epoch: 100 loss: 0.2291 train acc: 0.6200 test acc: 0.4905\n",
      "epoch: 150 loss: 0.2064 train acc: 0.6800 test acc: 0.5125\n",
      "epoch: 200 loss: 0.1918 train acc: 0.7200 test acc: 0.5165\n",
      "epoch: 250 loss: 0.1831 train acc: 0.7500 test acc: 0.5185\n",
      "epoch: 300 loss: 0.1776 train acc: 0.7600 test acc: 0.5220\n",
      "epoch: 350 loss: 0.1738 train acc: 0.7600 test acc: 0.5205\n",
      "epoch: 400 loss: 0.1709 train acc: 0.7600 test acc: 0.5240\n",
      "epoch: 450 loss: 0.1684 train acc: 0.7600 test acc: 0.5230\n",
      "epoch: 500 loss: 0.1661 train acc: 0.7800 test acc: 0.5265\n",
      "epoch: 550 loss: 0.1640 train acc: 0.7800 test acc: 0.5250\n",
      "epoch: 600 loss: 0.1620 train acc: 0.8000 test acc: 0.5295\n",
      "epoch: 650 loss: 0.1601 train acc: 0.8200 test acc: 0.5325\n",
      "epoch: 700 loss: 0.1583 train acc: 0.8200 test acc: 0.5380\n",
      "epoch: 750 loss: 0.1567 train acc: 0.8300 test acc: 0.5375\n",
      "epoch: 800 loss: 0.1552 train acc: 0.8300 test acc: 0.5390\n",
      "epoch: 850 loss: 0.1538 train acc: 0.8200 test acc: 0.5395\n",
      "epoch: 900 loss: 0.1525 train acc: 0.8200 test acc: 0.5410\n",
      "epoch: 950 loss: 0.1514 train acc: 0.8100 test acc: 0.5420\n",
      "epoch: 0 loss: 0.2931 train acc: 0.5060 test acc: 0.4835\n",
      "epoch: 50 loss: 0.2706 train acc: 0.5340 test acc: 0.5110\n",
      "epoch: 100 loss: 0.2534 train acc: 0.5860 test acc: 0.5330\n",
      "epoch: 150 loss: 0.2401 train acc: 0.6000 test acc: 0.5360\n",
      "epoch: 200 loss: 0.2306 train acc: 0.6220 test acc: 0.5445\n",
      "epoch: 250 loss: 0.2238 train acc: 0.6440 test acc: 0.5475\n",
      "epoch: 300 loss: 0.2188 train acc: 0.6440 test acc: 0.5585\n",
      "epoch: 350 loss: 0.2149 train acc: 0.6520 test acc: 0.5625\n",
      "epoch: 400 loss: 0.2119 train acc: 0.6640 test acc: 0.5715\n",
      "epoch: 450 loss: 0.2095 train acc: 0.6740 test acc: 0.5700\n",
      "epoch: 500 loss: 0.2076 train acc: 0.6680 test acc: 0.5760\n",
      "epoch: 550 loss: 0.2061 train acc: 0.6680 test acc: 0.5725\n",
      "epoch: 600 loss: 0.2047 train acc: 0.6660 test acc: 0.5715\n",
      "epoch: 650 loss: 0.2035 train acc: 0.6660 test acc: 0.5735\n",
      "epoch: 700 loss: 0.2023 train acc: 0.6740 test acc: 0.5715\n",
      "epoch: 750 loss: 0.2011 train acc: 0.6780 test acc: 0.5765\n",
      "epoch: 800 loss: 0.1999 train acc: 0.6860 test acc: 0.5735\n",
      "epoch: 850 loss: 0.1987 train acc: 0.6920 test acc: 0.5745\n",
      "epoch: 900 loss: 0.1974 train acc: 0.6900 test acc: 0.5735\n",
      "epoch: 950 loss: 0.1962 train acc: 0.6860 test acc: 0.5725\n",
      "epoch: 0 loss: 0.2895 train acc: 0.5100 test acc: 0.4835\n",
      "epoch: 50 loss: 0.2739 train acc: 0.5350 test acc: 0.5110\n",
      "epoch: 100 loss: 0.2644 train acc: 0.5620 test acc: 0.5215\n",
      "epoch: 150 loss: 0.2564 train acc: 0.5800 test acc: 0.5330\n",
      "epoch: 200 loss: 0.2495 train acc: 0.5840 test acc: 0.5465\n",
      "epoch: 250 loss: 0.2437 train acc: 0.5950 test acc: 0.5530\n",
      "epoch: 300 loss: 0.2386 train acc: 0.6040 test acc: 0.5555\n",
      "epoch: 350 loss: 0.2340 train acc: 0.6150 test acc: 0.5590\n",
      "epoch: 400 loss: 0.2299 train acc: 0.6220 test acc: 0.5620\n",
      "epoch: 450 loss: 0.2262 train acc: 0.6330 test acc: 0.5715\n",
      "epoch: 500 loss: 0.2228 train acc: 0.6440 test acc: 0.5810\n",
      "epoch: 550 loss: 0.2195 train acc: 0.6500 test acc: 0.5905\n",
      "epoch: 600 loss: 0.2161 train acc: 0.6620 test acc: 0.5965\n",
      "epoch: 650 loss: 0.2127 train acc: 0.6680 test acc: 0.5985\n",
      "epoch: 700 loss: 0.2093 train acc: 0.6750 test acc: 0.6100\n",
      "epoch: 750 loss: 0.2062 train acc: 0.6810 test acc: 0.6180\n",
      "epoch: 800 loss: 0.2034 train acc: 0.6810 test acc: 0.6240\n",
      "epoch: 850 loss: 0.2011 train acc: 0.6820 test acc: 0.6275\n",
      "epoch: 900 loss: 0.1992 train acc: 0.6860 test acc: 0.6290\n",
      "epoch: 950 loss: 0.1977 train acc: 0.6890 test acc: 0.6355\n",
      "epoch: 0 loss: 0.2934 train acc: 0.4800 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2355 train acc: 0.5800 test acc: 0.5200\n",
      "epoch: 100 loss: 0.2082 train acc: 0.6500 test acc: 0.5180\n",
      "epoch: 150 loss: 0.1945 train acc: 0.7100 test acc: 0.5025\n",
      "epoch: 200 loss: 0.1853 train acc: 0.7600 test acc: 0.5020\n",
      "epoch: 250 loss: 0.1780 train acc: 0.7500 test acc: 0.4945\n",
      "epoch: 300 loss: 0.1716 train acc: 0.7400 test acc: 0.4940\n",
      "epoch: 350 loss: 0.1662 train acc: 0.7400 test acc: 0.4970\n",
      "epoch: 400 loss: 0.1616 train acc: 0.7600 test acc: 0.4950\n",
      "epoch: 450 loss: 0.1576 train acc: 0.7800 test acc: 0.4965\n",
      "epoch: 500 loss: 0.1543 train acc: 0.7900 test acc: 0.4945\n",
      "epoch: 550 loss: 0.1514 train acc: 0.8100 test acc: 0.4960\n",
      "epoch: 600 loss: 0.1489 train acc: 0.8000 test acc: 0.4905\n",
      "epoch: 650 loss: 0.1468 train acc: 0.7800 test acc: 0.4880\n",
      "epoch: 700 loss: 0.1450 train acc: 0.7900 test acc: 0.4865\n",
      "epoch: 750 loss: 0.1435 train acc: 0.8000 test acc: 0.4865\n",
      "epoch: 800 loss: 0.1422 train acc: 0.8100 test acc: 0.4890\n",
      "epoch: 850 loss: 0.1410 train acc: 0.8300 test acc: 0.4940\n",
      "epoch: 900 loss: 0.1401 train acc: 0.8400 test acc: 0.4930\n",
      "epoch: 950 loss: 0.1393 train acc: 0.8400 test acc: 0.4915\n",
      "epoch: 0 loss: 0.2893 train acc: 0.5260 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2621 train acc: 0.5600 test acc: 0.5500\n",
      "epoch: 100 loss: 0.2460 train acc: 0.6000 test acc: 0.5635\n",
      "epoch: 150 loss: 0.2355 train acc: 0.6120 test acc: 0.5750\n",
      "epoch: 200 loss: 0.2279 train acc: 0.6280 test acc: 0.5845\n",
      "epoch: 250 loss: 0.2219 train acc: 0.6380 test acc: 0.5930\n",
      "epoch: 300 loss: 0.2172 train acc: 0.6400 test acc: 0.5930\n",
      "epoch: 350 loss: 0.2135 train acc: 0.6400 test acc: 0.5955\n",
      "epoch: 400 loss: 0.2105 train acc: 0.6540 test acc: 0.6030\n",
      "epoch: 450 loss: 0.2081 train acc: 0.6640 test acc: 0.6070\n",
      "epoch: 500 loss: 0.2061 train acc: 0.6700 test acc: 0.6055\n",
      "epoch: 550 loss: 0.2044 train acc: 0.6780 test acc: 0.6090\n",
      "epoch: 600 loss: 0.2028 train acc: 0.6800 test acc: 0.6055\n",
      "epoch: 650 loss: 0.2014 train acc: 0.6860 test acc: 0.6050\n",
      "epoch: 700 loss: 0.2001 train acc: 0.6940 test acc: 0.6045\n",
      "epoch: 750 loss: 0.1989 train acc: 0.6920 test acc: 0.6040\n",
      "epoch: 800 loss: 0.1978 train acc: 0.6980 test acc: 0.6030\n",
      "epoch: 850 loss: 0.1966 train acc: 0.6980 test acc: 0.6050\n",
      "epoch: 900 loss: 0.1956 train acc: 0.7060 test acc: 0.6050\n",
      "epoch: 950 loss: 0.1945 train acc: 0.7160 test acc: 0.6055\n",
      "epoch: 0 loss: 0.2865 train acc: 0.5250 test acc: 0.5125\n",
      "epoch: 50 loss: 0.2666 train acc: 0.5570 test acc: 0.5415\n",
      "epoch: 100 loss: 0.2523 train acc: 0.5710 test acc: 0.5665\n",
      "epoch: 150 loss: 0.2415 train acc: 0.5870 test acc: 0.5825\n",
      "epoch: 200 loss: 0.2339 train acc: 0.6010 test acc: 0.6010\n",
      "epoch: 250 loss: 0.2283 train acc: 0.6100 test acc: 0.6130\n",
      "epoch: 300 loss: 0.2242 train acc: 0.6250 test acc: 0.6210\n",
      "epoch: 350 loss: 0.2210 train acc: 0.6310 test acc: 0.6290\n",
      "epoch: 400 loss: 0.2183 train acc: 0.6370 test acc: 0.6355\n",
      "epoch: 450 loss: 0.2161 train acc: 0.6420 test acc: 0.6365\n",
      "epoch: 500 loss: 0.2141 train acc: 0.6510 test acc: 0.6355\n",
      "epoch: 550 loss: 0.2123 train acc: 0.6580 test acc: 0.6360\n",
      "epoch: 600 loss: 0.2108 train acc: 0.6680 test acc: 0.6390\n",
      "epoch: 650 loss: 0.2093 train acc: 0.6690 test acc: 0.6410\n",
      "epoch: 700 loss: 0.2080 train acc: 0.6700 test acc: 0.6445\n",
      "epoch: 750 loss: 0.2069 train acc: 0.6760 test acc: 0.6460\n",
      "epoch: 800 loss: 0.2058 train acc: 0.6820 test acc: 0.6490\n",
      "epoch: 850 loss: 0.2048 train acc: 0.6830 test acc: 0.6500\n",
      "epoch: 900 loss: 0.2039 train acc: 0.6860 test acc: 0.6505\n",
      "epoch: 950 loss: 0.2030 train acc: 0.6870 test acc: 0.6515\n",
      "epoch: 0 loss: 0.3244 train acc: 0.4600 test acc: 0.4965\n",
      "epoch: 50 loss: 0.2558 train acc: 0.5300 test acc: 0.5200\n",
      "epoch: 100 loss: 0.2170 train acc: 0.6400 test acc: 0.5465\n",
      "epoch: 150 loss: 0.1973 train acc: 0.7000 test acc: 0.5475\n",
      "epoch: 200 loss: 0.1865 train acc: 0.7500 test acc: 0.5420\n",
      "epoch: 250 loss: 0.1799 train acc: 0.7600 test acc: 0.5360\n",
      "epoch: 300 loss: 0.1754 train acc: 0.7600 test acc: 0.5290\n",
      "epoch: 350 loss: 0.1719 train acc: 0.7700 test acc: 0.5275\n",
      "epoch: 400 loss: 0.1690 train acc: 0.7800 test acc: 0.5310\n",
      "epoch: 450 loss: 0.1666 train acc: 0.7800 test acc: 0.5315\n",
      "epoch: 500 loss: 0.1645 train acc: 0.7800 test acc: 0.5355\n",
      "epoch: 550 loss: 0.1626 train acc: 0.7700 test acc: 0.5365\n",
      "epoch: 600 loss: 0.1608 train acc: 0.7800 test acc: 0.5375\n",
      "epoch: 650 loss: 0.1592 train acc: 0.7900 test acc: 0.5425\n",
      "epoch: 700 loss: 0.1577 train acc: 0.8000 test acc: 0.5420\n",
      "epoch: 750 loss: 0.1563 train acc: 0.8100 test acc: 0.5440\n",
      "epoch: 800 loss: 0.1550 train acc: 0.8000 test acc: 0.5450\n",
      "epoch: 850 loss: 0.1538 train acc: 0.8000 test acc: 0.5420\n",
      "epoch: 900 loss: 0.1528 train acc: 0.7900 test acc: 0.5445\n",
      "epoch: 950 loss: 0.1518 train acc: 0.7900 test acc: 0.5485\n",
      "epoch: 0 loss: 0.3034 train acc: 0.4920 test acc: 0.4965\n",
      "epoch: 50 loss: 0.2832 train acc: 0.5020 test acc: 0.5140\n",
      "epoch: 100 loss: 0.2619 train acc: 0.5580 test acc: 0.5385\n",
      "epoch: 150 loss: 0.2475 train acc: 0.5940 test acc: 0.5605\n",
      "epoch: 200 loss: 0.2369 train acc: 0.6260 test acc: 0.5735\n",
      "epoch: 250 loss: 0.2278 train acc: 0.6600 test acc: 0.5820\n",
      "epoch: 300 loss: 0.2198 train acc: 0.6800 test acc: 0.5920\n",
      "epoch: 350 loss: 0.2128 train acc: 0.6940 test acc: 0.6025\n",
      "epoch: 400 loss: 0.2071 train acc: 0.7120 test acc: 0.6085\n",
      "epoch: 450 loss: 0.2026 train acc: 0.7180 test acc: 0.6135\n",
      "epoch: 500 loss: 0.1989 train acc: 0.7160 test acc: 0.6155\n",
      "epoch: 550 loss: 0.1959 train acc: 0.7060 test acc: 0.6125\n",
      "epoch: 600 loss: 0.1935 train acc: 0.7020 test acc: 0.6130\n",
      "epoch: 650 loss: 0.1915 train acc: 0.7120 test acc: 0.6180\n",
      "epoch: 700 loss: 0.1899 train acc: 0.7100 test acc: 0.6185\n",
      "epoch: 750 loss: 0.1885 train acc: 0.7220 test acc: 0.6195\n",
      "epoch: 800 loss: 0.1874 train acc: 0.7220 test acc: 0.6225\n",
      "epoch: 850 loss: 0.1864 train acc: 0.7240 test acc: 0.6270\n",
      "epoch: 900 loss: 0.1855 train acc: 0.7360 test acc: 0.6265\n",
      "epoch: 950 loss: 0.1848 train acc: 0.7360 test acc: 0.6295\n",
      "epoch: 0 loss: 0.2973 train acc: 0.5070 test acc: 0.4965\n",
      "epoch: 50 loss: 0.2800 train acc: 0.5230 test acc: 0.5265\n",
      "epoch: 100 loss: 0.2616 train acc: 0.5510 test acc: 0.5330\n",
      "epoch: 150 loss: 0.2505 train acc: 0.5780 test acc: 0.5550\n",
      "epoch: 200 loss: 0.2442 train acc: 0.6050 test acc: 0.5680\n",
      "epoch: 250 loss: 0.2397 train acc: 0.6260 test acc: 0.5765\n",
      "epoch: 300 loss: 0.2358 train acc: 0.6340 test acc: 0.5835\n",
      "epoch: 350 loss: 0.2318 train acc: 0.6430 test acc: 0.5870\n",
      "epoch: 400 loss: 0.2277 train acc: 0.6480 test acc: 0.5920\n",
      "epoch: 450 loss: 0.2234 train acc: 0.6590 test acc: 0.6015\n",
      "epoch: 500 loss: 0.2191 train acc: 0.6650 test acc: 0.6130\n",
      "epoch: 550 loss: 0.2150 train acc: 0.6680 test acc: 0.6205\n",
      "epoch: 600 loss: 0.2115 train acc: 0.6770 test acc: 0.6185\n",
      "epoch: 650 loss: 0.2086 train acc: 0.6890 test acc: 0.6200\n",
      "epoch: 700 loss: 0.2064 train acc: 0.6860 test acc: 0.6235\n",
      "epoch: 750 loss: 0.2048 train acc: 0.6910 test acc: 0.6250\n",
      "epoch: 800 loss: 0.2036 train acc: 0.6940 test acc: 0.6250\n",
      "epoch: 850 loss: 0.2027 train acc: 0.6990 test acc: 0.6305\n",
      "epoch: 900 loss: 0.2019 train acc: 0.6970 test acc: 0.6315\n",
      "epoch: 950 loss: 0.2012 train acc: 0.7050 test acc: 0.6345\n",
      "epoch: 0 loss: 0.3478 train acc: 0.4300 test acc: 0.4645\n",
      "epoch: 50 loss: 0.2904 train acc: 0.5100 test acc: 0.4765\n",
      "epoch: 100 loss: 0.2300 train acc: 0.6300 test acc: 0.5155\n",
      "epoch: 150 loss: 0.1990 train acc: 0.6900 test acc: 0.5315\n",
      "epoch: 200 loss: 0.1840 train acc: 0.7300 test acc: 0.5330\n",
      "epoch: 250 loss: 0.1763 train acc: 0.7800 test acc: 0.5310\n",
      "epoch: 300 loss: 0.1718 train acc: 0.8100 test acc: 0.5345\n",
      "epoch: 350 loss: 0.1689 train acc: 0.8100 test acc: 0.5350\n",
      "epoch: 400 loss: 0.1668 train acc: 0.8100 test acc: 0.5350\n",
      "epoch: 450 loss: 0.1651 train acc: 0.8100 test acc: 0.5370\n",
      "epoch: 500 loss: 0.1636 train acc: 0.8300 test acc: 0.5340\n",
      "epoch: 550 loss: 0.1622 train acc: 0.8300 test acc: 0.5330\n",
      "epoch: 600 loss: 0.1609 train acc: 0.8000 test acc: 0.5340\n",
      "epoch: 650 loss: 0.1595 train acc: 0.8000 test acc: 0.5340\n",
      "epoch: 700 loss: 0.1582 train acc: 0.8100 test acc: 0.5370\n",
      "epoch: 750 loss: 0.1567 train acc: 0.8100 test acc: 0.5405\n",
      "epoch: 800 loss: 0.1553 train acc: 0.8100 test acc: 0.5370\n",
      "epoch: 850 loss: 0.1538 train acc: 0.8200 test acc: 0.5360\n",
      "epoch: 900 loss: 0.1524 train acc: 0.8200 test acc: 0.5390\n",
      "epoch: 950 loss: 0.1509 train acc: 0.8200 test acc: 0.5415\n",
      "epoch: 0 loss: 0.3209 train acc: 0.4360 test acc: 0.4645\n",
      "epoch: 50 loss: 0.2829 train acc: 0.5280 test acc: 0.4925\n",
      "epoch: 100 loss: 0.2578 train acc: 0.5680 test acc: 0.5260\n",
      "epoch: 150 loss: 0.2395 train acc: 0.5940 test acc: 0.5495\n",
      "epoch: 200 loss: 0.2272 train acc: 0.6500 test acc: 0.5625\n",
      "epoch: 250 loss: 0.2195 train acc: 0.6640 test acc: 0.5650\n",
      "epoch: 300 loss: 0.2142 train acc: 0.6800 test acc: 0.5685\n",
      "epoch: 350 loss: 0.2099 train acc: 0.6820 test acc: 0.5735\n",
      "epoch: 400 loss: 0.2062 train acc: 0.6840 test acc: 0.5825\n",
      "epoch: 450 loss: 0.2030 train acc: 0.6860 test acc: 0.5850\n",
      "epoch: 500 loss: 0.2003 train acc: 0.6940 test acc: 0.5895\n",
      "epoch: 550 loss: 0.1979 train acc: 0.7020 test acc: 0.5915\n",
      "epoch: 600 loss: 0.1959 train acc: 0.7020 test acc: 0.5945\n",
      "epoch: 650 loss: 0.1942 train acc: 0.7100 test acc: 0.5960\n",
      "epoch: 700 loss: 0.1926 train acc: 0.7120 test acc: 0.5980\n",
      "epoch: 750 loss: 0.1913 train acc: 0.7140 test acc: 0.6005\n",
      "epoch: 800 loss: 0.1900 train acc: 0.7140 test acc: 0.6035\n",
      "epoch: 850 loss: 0.1889 train acc: 0.7180 test acc: 0.6025\n",
      "epoch: 900 loss: 0.1879 train acc: 0.7100 test acc: 0.6025\n",
      "epoch: 950 loss: 0.1869 train acc: 0.7140 test acc: 0.6020\n",
      "epoch: 0 loss: 0.3295 train acc: 0.4370 test acc: 0.4645\n",
      "epoch: 50 loss: 0.3004 train acc: 0.4830 test acc: 0.4900\n",
      "epoch: 100 loss: 0.2802 train acc: 0.5270 test acc: 0.5310\n",
      "epoch: 150 loss: 0.2650 train acc: 0.5540 test acc: 0.5375\n",
      "epoch: 200 loss: 0.2526 train acc: 0.5800 test acc: 0.5580\n",
      "epoch: 250 loss: 0.2427 train acc: 0.5930 test acc: 0.5780\n",
      "epoch: 300 loss: 0.2352 train acc: 0.6100 test acc: 0.5840\n",
      "epoch: 350 loss: 0.2299 train acc: 0.6190 test acc: 0.5905\n",
      "epoch: 400 loss: 0.2262 train acc: 0.6250 test acc: 0.5915\n",
      "epoch: 450 loss: 0.2235 train acc: 0.6360 test acc: 0.5955\n",
      "epoch: 500 loss: 0.2213 train acc: 0.6480 test acc: 0.5945\n",
      "epoch: 550 loss: 0.2195 train acc: 0.6530 test acc: 0.5950\n",
      "epoch: 600 loss: 0.2178 train acc: 0.6550 test acc: 0.5960\n",
      "epoch: 650 loss: 0.2162 train acc: 0.6610 test acc: 0.5975\n",
      "epoch: 700 loss: 0.2148 train acc: 0.6560 test acc: 0.6025\n",
      "epoch: 750 loss: 0.2134 train acc: 0.6570 test acc: 0.5975\n",
      "epoch: 800 loss: 0.2122 train acc: 0.6620 test acc: 0.6030\n",
      "epoch: 850 loss: 0.2112 train acc: 0.6690 test acc: 0.6060\n",
      "epoch: 900 loss: 0.2102 train acc: 0.6710 test acc: 0.6075\n",
      "epoch: 950 loss: 0.2094 train acc: 0.6740 test acc: 0.6080\n",
      "epoch: 0 loss: 0.2778 train acc: 0.4800 test acc: 0.5220\n",
      "epoch: 50 loss: 0.2376 train acc: 0.5600 test acc: 0.5305\n",
      "epoch: 100 loss: 0.2179 train acc: 0.6400 test acc: 0.5220\n",
      "epoch: 150 loss: 0.2048 train acc: 0.6800 test acc: 0.5245\n",
      "epoch: 200 loss: 0.1960 train acc: 0.6800 test acc: 0.5325\n",
      "epoch: 250 loss: 0.1898 train acc: 0.6900 test acc: 0.5320\n",
      "epoch: 300 loss: 0.1851 train acc: 0.7000 test acc: 0.5345\n",
      "epoch: 350 loss: 0.1813 train acc: 0.7100 test acc: 0.5370\n",
      "epoch: 400 loss: 0.1782 train acc: 0.7000 test acc: 0.5385\n",
      "epoch: 450 loss: 0.1755 train acc: 0.7100 test acc: 0.5380\n",
      "epoch: 500 loss: 0.1731 train acc: 0.7200 test acc: 0.5370\n",
      "epoch: 550 loss: 0.1710 train acc: 0.7200 test acc: 0.5365\n",
      "epoch: 600 loss: 0.1692 train acc: 0.7500 test acc: 0.5375\n",
      "epoch: 650 loss: 0.1676 train acc: 0.7500 test acc: 0.5395\n",
      "epoch: 700 loss: 0.1662 train acc: 0.7600 test acc: 0.5395\n",
      "epoch: 750 loss: 0.1650 train acc: 0.7800 test acc: 0.5355\n",
      "epoch: 800 loss: 0.1639 train acc: 0.7800 test acc: 0.5345\n",
      "epoch: 850 loss: 0.1629 train acc: 0.7800 test acc: 0.5355\n",
      "epoch: 900 loss: 0.1620 train acc: 0.7900 test acc: 0.5340\n",
      "epoch: 950 loss: 0.1611 train acc: 0.8100 test acc: 0.5380\n",
      "epoch: 0 loss: 0.2752 train acc: 0.5320 test acc: 0.5220\n",
      "epoch: 50 loss: 0.2528 train acc: 0.5760 test acc: 0.5560\n",
      "epoch: 100 loss: 0.2415 train acc: 0.6080 test acc: 0.5660\n",
      "epoch: 150 loss: 0.2333 train acc: 0.6360 test acc: 0.5695\n",
      "epoch: 200 loss: 0.2264 train acc: 0.6540 test acc: 0.5745\n",
      "epoch: 250 loss: 0.2202 train acc: 0.6640 test acc: 0.5845\n",
      "epoch: 300 loss: 0.2147 train acc: 0.6620 test acc: 0.5845\n",
      "epoch: 350 loss: 0.2098 train acc: 0.6620 test acc: 0.5945\n",
      "epoch: 400 loss: 0.2056 train acc: 0.6720 test acc: 0.6070\n",
      "epoch: 450 loss: 0.2020 train acc: 0.6760 test acc: 0.6165\n",
      "epoch: 500 loss: 0.1991 train acc: 0.6880 test acc: 0.6105\n",
      "epoch: 550 loss: 0.1966 train acc: 0.6980 test acc: 0.6110\n",
      "epoch: 600 loss: 0.1945 train acc: 0.6940 test acc: 0.6130\n",
      "epoch: 650 loss: 0.1927 train acc: 0.6920 test acc: 0.6125\n",
      "epoch: 700 loss: 0.1912 train acc: 0.6980 test acc: 0.6195\n",
      "epoch: 750 loss: 0.1898 train acc: 0.6960 test acc: 0.6200\n",
      "epoch: 800 loss: 0.1887 train acc: 0.7040 test acc: 0.6195\n",
      "epoch: 850 loss: 0.1876 train acc: 0.7160 test acc: 0.6200\n",
      "epoch: 900 loss: 0.1867 train acc: 0.7220 test acc: 0.6200\n",
      "epoch: 950 loss: 0.1859 train acc: 0.7240 test acc: 0.6225\n",
      "epoch: 0 loss: 0.2841 train acc: 0.5200 test acc: 0.5220\n",
      "epoch: 50 loss: 0.2672 train acc: 0.5560 test acc: 0.5420\n",
      "epoch: 100 loss: 0.2543 train acc: 0.5770 test acc: 0.5545\n",
      "epoch: 150 loss: 0.2431 train acc: 0.5940 test acc: 0.5615\n",
      "epoch: 200 loss: 0.2350 train acc: 0.6170 test acc: 0.5760\n",
      "epoch: 250 loss: 0.2292 train acc: 0.6360 test acc: 0.5860\n",
      "epoch: 300 loss: 0.2248 train acc: 0.6460 test acc: 0.5930\n",
      "epoch: 350 loss: 0.2212 train acc: 0.6510 test acc: 0.5965\n",
      "epoch: 400 loss: 0.2184 train acc: 0.6570 test acc: 0.5995\n",
      "epoch: 450 loss: 0.2162 train acc: 0.6620 test acc: 0.6035\n",
      "epoch: 500 loss: 0.2143 train acc: 0.6720 test acc: 0.5985\n",
      "epoch: 550 loss: 0.2128 train acc: 0.6750 test acc: 0.5935\n",
      "epoch: 600 loss: 0.2115 train acc: 0.6800 test acc: 0.5930\n",
      "epoch: 650 loss: 0.2104 train acc: 0.6830 test acc: 0.5900\n",
      "epoch: 700 loss: 0.2093 train acc: 0.6830 test acc: 0.5920\n",
      "epoch: 750 loss: 0.2084 train acc: 0.6830 test acc: 0.5915\n",
      "epoch: 800 loss: 0.2075 train acc: 0.6790 test acc: 0.5930\n",
      "epoch: 850 loss: 0.2066 train acc: 0.6670 test acc: 0.5975\n",
      "epoch: 900 loss: 0.2058 train acc: 0.6670 test acc: 0.6000\n",
      "epoch: 950 loss: 0.2051 train acc: 0.6690 test acc: 0.6010\n",
      "epoch: 0 loss: 0.2965 train acc: 0.5500 test acc: 0.5300\n",
      "epoch: 50 loss: 0.2544 train acc: 0.5700 test acc: 0.5310\n",
      "epoch: 100 loss: 0.2252 train acc: 0.5700 test acc: 0.5220\n",
      "epoch: 150 loss: 0.2070 train acc: 0.6400 test acc: 0.5180\n",
      "epoch: 200 loss: 0.1936 train acc: 0.6800 test acc: 0.5175\n",
      "epoch: 250 loss: 0.1819 train acc: 0.7300 test acc: 0.5185\n",
      "epoch: 300 loss: 0.1724 train acc: 0.7400 test acc: 0.5245\n",
      "epoch: 350 loss: 0.1651 train acc: 0.7600 test acc: 0.5355\n",
      "epoch: 400 loss: 0.1597 train acc: 0.7600 test acc: 0.5425\n",
      "epoch: 450 loss: 0.1558 train acc: 0.7800 test acc: 0.5415\n",
      "epoch: 500 loss: 0.1528 train acc: 0.7700 test acc: 0.5430\n",
      "epoch: 550 loss: 0.1506 train acc: 0.7800 test acc: 0.5475\n",
      "epoch: 600 loss: 0.1489 train acc: 0.7800 test acc: 0.5490\n",
      "epoch: 650 loss: 0.1475 train acc: 0.8000 test acc: 0.5480\n",
      "epoch: 700 loss: 0.1464 train acc: 0.8000 test acc: 0.5490\n",
      "epoch: 750 loss: 0.1455 train acc: 0.8000 test acc: 0.5500\n",
      "epoch: 800 loss: 0.1447 train acc: 0.8000 test acc: 0.5505\n",
      "epoch: 850 loss: 0.1440 train acc: 0.8100 test acc: 0.5485\n",
      "epoch: 900 loss: 0.1434 train acc: 0.8200 test acc: 0.5470\n",
      "epoch: 950 loss: 0.1429 train acc: 0.8200 test acc: 0.5455\n",
      "epoch: 0 loss: 0.3055 train acc: 0.5040 test acc: 0.5300\n",
      "epoch: 50 loss: 0.2783 train acc: 0.5320 test acc: 0.5570\n",
      "epoch: 100 loss: 0.2562 train acc: 0.5680 test acc: 0.5700\n",
      "epoch: 150 loss: 0.2415 train acc: 0.6020 test acc: 0.5885\n",
      "epoch: 200 loss: 0.2315 train acc: 0.6160 test acc: 0.5910\n",
      "epoch: 250 loss: 0.2243 train acc: 0.6400 test acc: 0.5955\n",
      "epoch: 300 loss: 0.2190 train acc: 0.6620 test acc: 0.5970\n",
      "epoch: 350 loss: 0.2148 train acc: 0.6740 test acc: 0.5990\n",
      "epoch: 400 loss: 0.2115 train acc: 0.6700 test acc: 0.6020\n",
      "epoch: 450 loss: 0.2088 train acc: 0.6760 test acc: 0.6015\n",
      "epoch: 500 loss: 0.2065 train acc: 0.6860 test acc: 0.6015\n",
      "epoch: 550 loss: 0.2045 train acc: 0.6880 test acc: 0.6025\n",
      "epoch: 600 loss: 0.2027 train acc: 0.6920 test acc: 0.6025\n",
      "epoch: 650 loss: 0.2010 train acc: 0.6920 test acc: 0.6010\n",
      "epoch: 700 loss: 0.1995 train acc: 0.6980 test acc: 0.6000\n",
      "epoch: 750 loss: 0.1981 train acc: 0.7040 test acc: 0.5995\n",
      "epoch: 800 loss: 0.1969 train acc: 0.7040 test acc: 0.5995\n",
      "epoch: 850 loss: 0.1958 train acc: 0.7040 test acc: 0.6000\n",
      "epoch: 900 loss: 0.1948 train acc: 0.7020 test acc: 0.6000\n",
      "epoch: 950 loss: 0.1939 train acc: 0.7020 test acc: 0.6015\n",
      "epoch: 0 loss: 0.3054 train acc: 0.5050 test acc: 0.5300\n",
      "epoch: 50 loss: 0.2808 train acc: 0.5300 test acc: 0.5560\n",
      "epoch: 100 loss: 0.2624 train acc: 0.5700 test acc: 0.5675\n",
      "epoch: 150 loss: 0.2510 train acc: 0.5880 test acc: 0.5955\n",
      "epoch: 200 loss: 0.2431 train acc: 0.6050 test acc: 0.6010\n",
      "epoch: 250 loss: 0.2370 train acc: 0.6180 test acc: 0.6025\n",
      "epoch: 300 loss: 0.2321 train acc: 0.6330 test acc: 0.6080\n",
      "epoch: 350 loss: 0.2279 train acc: 0.6470 test acc: 0.6120\n",
      "epoch: 400 loss: 0.2244 train acc: 0.6430 test acc: 0.6165\n",
      "epoch: 450 loss: 0.2213 train acc: 0.6420 test acc: 0.6150\n",
      "epoch: 500 loss: 0.2186 train acc: 0.6430 test acc: 0.6210\n",
      "epoch: 550 loss: 0.2161 train acc: 0.6470 test acc: 0.6225\n",
      "epoch: 600 loss: 0.2138 train acc: 0.6440 test acc: 0.6245\n",
      "epoch: 650 loss: 0.2118 train acc: 0.6490 test acc: 0.6255\n",
      "epoch: 700 loss: 0.2099 train acc: 0.6570 test acc: 0.6275\n",
      "epoch: 750 loss: 0.2082 train acc: 0.6680 test acc: 0.6315\n",
      "epoch: 800 loss: 0.2066 train acc: 0.6720 test acc: 0.6370\n",
      "epoch: 850 loss: 0.2052 train acc: 0.6770 test acc: 0.6350\n",
      "epoch: 900 loss: 0.2040 train acc: 0.6840 test acc: 0.6380\n",
      "epoch: 950 loss: 0.2030 train acc: 0.6830 test acc: 0.6400\n"
     ]
    }
   ],
   "source": [
    "train_data_m100, test_data_m100 = [], []\n",
    "train_data_m500, test_data_m500 = [], []\n",
    "train_data_m1000, test_data_m1000 = [], []\n",
    "\n",
    "train_data_m100_acc, test_data_m100_acc = [], []\n",
    "train_data_m500_acc, test_data_m500_acc = [], []\n",
    "train_data_m1000_acc, test_data_m1000_acc = [], []\n",
    "\n",
    "for trial in range(v):\n",
    "    quantum_train_x, train_circuit, train_y = train_data(100, [0, 1]) \n",
    "    train_x = []\n",
    "    for i in range(0, len(train_y)):\n",
    "        cir_train = Circuit(4)\n",
    "        cir_train.extend(train_circuit[i][0])\n",
    "        train_x.append(cir_train.unitary_matrix().unsqueeze(0))\n",
    "    \n",
    "    \n",
    "    te_acc, tr_acc, te_loss, tr_loss = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = D,\n",
    "      EPOCH = ep,\n",
    "      LR = eta,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "\n",
    "    train_data_m100.append(tr_loss)\n",
    "    test_data_m100.append(te_loss)\n",
    "    train_data_m100_acc.append(tr_acc)\n",
    "    test_data_m100_acc.append(te_acc)    \n",
    "    \n",
    "    quantum_train_x, train_circuit, train_y = train_data(500, [0, 1]) \n",
    "    train_x = []\n",
    "    for i in range(0, len(train_y)):\n",
    "        cir_train = Circuit(4)\n",
    "        cir_train.extend(train_circuit[i][0])\n",
    "        train_x.append(cir_train.unitary_matrix().unsqueeze(0))\n",
    "      \n",
    "\n",
    "    te_acc, tr_acc, te_loss, tr_loss = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = D,\n",
    "      EPOCH = ep,\n",
    "      LR = eta,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "    train_data_m500.append(tr_loss)\n",
    "    test_data_m500.append(te_loss)\n",
    "    train_data_m500_acc.append(tr_acc)\n",
    "    test_data_m500_acc.append(te_acc)     \n",
    "    \n",
    "    quantum_train_x, train_circuit, train_y = train_data(1000, [0, 1]) \n",
    "    train_x = []\n",
    "    for i in range(0, len(train_y)):\n",
    "        cir_train = Circuit(4)\n",
    "        cir_train.extend(train_circuit[i][0])\n",
    "        train_x.append(cir_train.unitary_matrix().unsqueeze(0))\n",
    "    \n",
    "    te_acc, tr_acc, te_loss, tr_loss = QClassifier(\n",
    "      train_x,\n",
    "      train_y,\n",
    "      test_x,\n",
    "      test_y,\n",
    "      N = 4,\n",
    "      DEPTH = D,\n",
    "      EPOCH = ep,\n",
    "      LR = eta,\n",
    "      seed=seed[trial]\n",
    "    )\n",
    "\n",
    "    train_data_m1000.append(tr_loss)\n",
    "    test_data_m1000.append(te_loss)\n",
    "    train_data_m1000_acc.append(tr_acc)\n",
    "    test_data_m1000_acc.append(te_acc)    \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "paddle_quantum_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
