{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "10bfd766",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pennylane as qml\n",
    "from pennylane import numpy as np\n",
    "from pennylane.templates import RandomLayers\n",
    "from pennylane.transforms import commute_controlled, cancel_inverses"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "df880d38",
   "metadata": {},
   "outputs": [],
   "source": [
    "class qcenv():\n",
    "    def __init__(self, top_type, data, epsilon, weight, n_qubits):\n",
    "        self.gates = []\n",
    "        self.top_type = top_type\n",
    "        self.data = data/np.linalg.norm(data)\n",
    "        self._pro_data()\n",
    "        self._build_state_embedding()\n",
    "        self.cur_ind = 0\n",
    "        self.W = weight\n",
    "        self.N = n_qubits\n",
    "        self.epsilon = epsilon #最小误差\n",
    "        \n",
    "        \n",
    "    def _pro_data(self):\n",
    "        self.mean = np.mean(self.data)\n",
    "        self.std = np.std(self.data)\n",
    "        \n",
    "    def _gen_gates(self,actions):\n",
    "        '''\n",
    "        actions = [[0.2,0.1,0.9,0.5],[0.2,0.1,0.9,0.5]]\n",
    "        [gates_action,angle_action]\n",
    "        L: layer\n",
    "        '''\n",
    "        ctr_act = actions[0]\n",
    "        rot_act = actions[1]\n",
    "        assert len(ctr_act)==self.N\n",
    "        assert len(rot_act) == self.N\n",
    "        this_ctr_layer = [0 for i in ctr_act]\n",
    "        this_rot_layer = [0 for i in rot_act]\n",
    "        \n",
    "        if self.cur_ind == 0:\n",
    "            for i, a in enumerate(rot_act):\n",
    "                angle = (np.arcsin(a) * self.std + self.mean) * self.W[i]\n",
    "                self.gates.append(qml.RY(angle,wires=i))\n",
    "                this_rot_layer[i] = angle\n",
    "        else:\n",
    "            n_qubits = len(ctr_act)\n",
    "            ctr_bits = []\n",
    "            rot_bits = []\n",
    "            ctr_a = []\n",
    "            ctr_gates = []\n",
    "#             for wires in self.top_type:\n",
    "#                 '''top_type={[0,2],[0,3],[1,3]}\n",
    "#                  top_type={[0,3],[1,3],[2,3]}'''\n",
    "#                 ctr_bits.append(ctr_act[wires[0]])\n",
    "#                 rot_bits.append\n",
    "#             ctr_gate = ctr_bits.index(ctr_max)\n",
    "#             for wires in top_type:\n",
    "#                 if wires[0] == ctr_gate:\n",
    "#                     ctr_gates.append(wires)\n",
    "#             g = np.random.choice(len(ctr_gates))#问题 如果top门是[0,2],[0,3]应该如何选择？\n",
    "#             gate = ctr_gates[g]\n",
    "            max_gate = max(ctr_act)\n",
    "            temp_gate = ctr_act.index(max_gate)\n",
    "            for wires in self.top_type:\n",
    "                if wires[0] == temp_gate or wires[1] == temp_gate:\n",
    "                    gate = wires\n",
    "            weight = rot_act[gate[1]]\n",
    "            angle = (np.arcsin(weight) * self.std + self.mean) * self.W[self.cur_ind + self.N - 1]\n",
    "            self.gates.append(qml.CNOT(wires=gate))\n",
    "            self.gates.append(qml.RY(angle,wires=gate[1]))\n",
    "            this_ctr_layer[gate[0]] = 1\n",
    "            this_ctr_layer[gate[1]] = 1\n",
    "            this_rot_layer[gate[1]] = angle\n",
    "        return this_ctr_layer, this_rot_layer\n",
    " \n",
    "    def _circuit(self):\n",
    "        for op in self.gates:\n",
    "            qml.apply(op)\n",
    "        return qml.state()\n",
    "            \n",
    "    def step(self,dev,actions):\n",
    "        reward = 0\n",
    "        ctr_layer,rot_layer = self._gen_gates(actions)\n",
    "        qcircuit = qml.QNode(self._circuit,dev)\n",
    "        res = qcircuit()\n",
    "        dist = self._norm(data, res)\n",
    "        self.layer_embedding.append([ctr_layer, rot_layer])\n",
    "        obs = self.layer_embedding\n",
    "        if self._is_final_layer(dist):\n",
    "            reward = -np.mean(np.abs(self.data-res)/self.data)\n",
    "            done = True\n",
    "            return qcircuit, reward, done, obs\n",
    "        \n",
    "        self.cur_ind += 1\n",
    "        done = False\n",
    "        return qcircuit, reward, done, obs\n",
    "    \n",
    "#     def reset(self):\n",
    "        \n",
    "    \n",
    "    def _norm(self,org, res, norm=2):\n",
    "        return np.linalg.norm(org-res,ord=norm)\n",
    "    \n",
    "    def _is_final_layer(self, dist):\n",
    "        return (self.cur_ind == 2**self.N-1) or (dist<=self.epsilon)\n",
    "    \n",
    "    def _build_state_embedding(self):\n",
    "        layer_embedding = []\n",
    "        layer_embedding.append(self.data)\n",
    "        layer_embedding.append(self.mean)\n",
    "        layer_embedding.append(self.std)\n",
    "        \n",
    "        self.layer_embedding = layer_embedding\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e8d5ef7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-1.1667090137981118\n"
     ]
    }
   ],
   "source": [
    "n_qubits = 4\n",
    "top_type = [[0,2],[0,3],[1,3]]\n",
    "data = np.random.rand(16)\n",
    "weight = np.random.rand(2**n_qubits+n_qubits-1)\n",
    "epsilon = 10e-5\n",
    "env = qcenv(top_type,data,epsilon,weight,n_qubits)\n",
    "actions = [[0.2,0.1,0.9,0.5],[0.2,0.1,0.9,0.5]]\n",
    "dev = qml.device('default.qubit', wires=n_qubits)\n",
    "done = False\n",
    "while not done:\n",
    "    qc,res, done, obs = env.step(dev,actions)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "93c80d7a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0: ──RY(0.20)─╭●───────────╭●───────────╭●───────────╭●───────────╭●───────────╭●───────────╭●\n",
      "1: ──RY(0.06)─│────────────│────────────│────────────│────────────│────────────│────────────│─\n",
      "2: ──RY(0.17)─╰X──RY(0.01)─╰X──RY(0.25)─╰X──RY(0.26)─╰X──RY(0.04)─╰X──RY(0.09)─╰X──RY(0.28)─╰X\n",
      "3: ──RY(0.02)─────────────────────────────────────────────────────────────────────────────────\n",
      "\n",
      "────────────╭●───────────╭●───────────╭●───────────╭●───────────╭●───────────╭●───────────╭●\n",
      "────────────│────────────│────────────│────────────│────────────│────────────│────────────│─\n",
      "───RY(0.25)─╰X──RY(0.17)─╰X──RY(0.23)─╰X──RY(0.16)─╰X──RY(0.19)─╰X──RY(0.07)─╰X──RY(0.19)─╰X\n",
      "────────────────────────────────────────────────────────────────────────────────────────────\n",
      "\n",
      "────────────╭●───────────┤  State\n",
      "────────────│────────────┤  State\n",
      "───RY(0.07)─╰X──RY(0.20)─┤  State\n",
      "─────────────────────────┤  State\n"
     ]
    }
   ],
   "source": [
    "print(qml.draw(qc)())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0584cb32",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(num_episode,anget,env,output):\n",
    "    observation = None\n",
    "    T = []\n",
    "    while episode < num_episode:\n",
    "        if observation is None:\n",
    "            observation = deepcopy(env.reset())\n",
    "            agent.reset(observation)\n",
    "        if episode <= warmup:\n",
    "            ctr_action = agent.random_action()\n",
    "            rot_action = agent.random_action()\n",
    "        else:\n",
    "            ctr_action = agent.select_action(observation,episode= episode)\n",
    "            ctr_action = agent.select_action(observation,episode= episode)\n",
    "        actions = [ctr_action,rot_action]\n",
    "        \n",
    "        observation2, reward, done =env.step(actions)\n",
    "        obeservation2 = deepcopy(observation2)\n",
    "        ...\n",
    "        if done:\n",
    "            ..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "33035061",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4.96447561e-02+0.j 8.40503679e-01+0.j 6.22682466e-03+0.j\n",
      " 5.27699893e-01+0.j 2.48847462e-03+0.j 4.21307755e-02+0.j\n",
      " 3.12123502e-04+0.j 2.64512890e-02+0.j 5.01513880e-03+0.j\n",
      " 8.49081140e-02+0.j 6.29037030e-04+0.j 5.33085146e-02+0.j\n",
      " 2.51386987e-04+0.j 4.25607262e-03+0.j 3.15308768e-05+0.j\n",
      " 2.67212282e-03+0.j]\n"
     ]
    }
   ],
   "source": [
    "top_type = [[0,3],[1,3],[2,3]]\n",
    "env = qcenv(top_type)\n",
    "action = [0.2,0.1,0.9,0.5]\n",
    "dev = qml.device('default.qubit', wires=len(action))\n",
    "for L in range(6):\n",
    "    qc,res = env.step(dev,action,L)\n",
    "print(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "a88b1d2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0: ──RY(0.20)──────────────────────────────────────────────────────────────────┤  State\n",
      "1: ──RY(0.10)──────────────────────────────────────────────────────────────────┤  State\n",
      "2: ──RY(1.12)─╭●───────────╭●───────────╭●───────────╭●───────────╭●───────────┤  State\n",
      "3: ──RY(0.52)─╰X──RY(0.50)─╰X──RY(0.50)─╰X──RY(0.50)─╰X──RY(0.50)─╰X──RY(0.50)─┤  State\n"
     ]
    }
   ],
   "source": [
    "print(qml.draw(qc)())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03b2871e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
