{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73295118",
   "metadata": {
    "is_executing": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import scipy.sparse as sparse\n",
    "import scipy.sparse.linalg as arp\n",
    "import warnings\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34ea445e",
   "metadata": {},
   "source": [
    "## Observables"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f52655a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "Pauli_I = np.array([[1,0],[0,1]],dtype=np.complex64)\n",
    "Pauli_x = np.array([[0,1],[1,0]],dtype=np.complex64)\n",
    "Pauli_y = np.array([[0,-1j],[1j,0]],dtype=np.complex64)\n",
    "Pauli_z = np.array([[1,0],[0,-1]],dtype=np.complex64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "56e48bd2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_basis_observables(num_bit):\n",
    "    permutation_list = list(itertools.product(list(range(1,4)),repeat=num_bit))\n",
    "    # float_observables = []\n",
    "    # observables = []\n",
    "    print(len(permutation_list))\n",
    "    for k in tqdm(range(0,len(permutation_list))):\n",
    "        index_list = []\n",
    "        if permutation_list[k][0] == 0:\n",
    "            tmp = Pauli_I\n",
    "        elif permutation_list[k][0] == 1:\n",
    "            tmp = Pauli_x\n",
    "            index_list += [1, 0, 0]\n",
    "        elif permutation_list[k][0] == 2:\n",
    "            tmp = Pauli_y\n",
    "            index_list += [0, 1, 0]\n",
    "        elif permutation_list[k][0] == 3:\n",
    "            tmp = Pauli_z\n",
    "            index_list += [0, 0, 1]\n",
    "        for i in range(1,num_bit):\n",
    "            if permutation_list[k][i] == 0:\n",
    "                tmp = np.kron(tmp,Pauli_I)\n",
    "            elif permutation_list[k][i] == 1:\n",
    "                tmp = np.kron(tmp,Pauli_x)\n",
    "                index_list += [1, 0, 0]\n",
    "            elif permutation_list[k][i] == 2:\n",
    "                tmp = np.kron(tmp,Pauli_y)\n",
    "                index_list += [0, 1, 0]\n",
    "            elif permutation_list[k][i] == 3:\n",
    "                tmp = np.kron(tmp,Pauli_z)\n",
    "                index_list += [0, 0, 1]\n",
    "        observable = tmp\n",
    "        np.save(str(num_bit) + 'qubit/observable' + str(num_bit) + str(k), observable)\n",
    "        # observables.append(observable)\n",
    "\n",
    "        observable = observable.reshape(observable.shape[0]*observable.shape[1])\n",
    "        observable_real = observable.real\n",
    "        observable_imag = observable.imag\n",
    "        observable = np.concatenate((observable_real,observable_imag))\n",
    "        np.save(str(num_bit) + 'qubit/float_observable' + str(num_bit) + str(k), observable)\n",
    "        # float_observables.append(observable)\n",
    "        # np.save(str(num_bit) + 'qubit/observable_index'+ str(num_bit) + str(k), np.array(index_list))\n",
    "\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "90949bca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "729\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 729/729 [00:01<00:00, 461.21it/s]\n"
     ]
    }
   ],
   "source": [
    "num_bit = 6\n",
    "_ = generate_basis_observables(num_bit)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fd3f93cd",
   "metadata": {},
   "source": [
    "## State Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "64d88916",
   "metadata": {},
   "outputs": [],
   "source": [
    "sxx = Pauli_x\n",
    "syy = Pauli_y\n",
    "szz = Pauli_z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "88f125f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exact_E_rand_Js(L, Js, h):\n",
    "    \"\"\"For comparison: obtain ground state energy from exact diagonalization.\n",
    "    Exponentially expensive in L, only works for small enough `L` <~ 20.\n",
    "    \"\"\"\n",
    "    if L >= 20:\n",
    "        warnings.warn(\"Large L: Exact diagonalization might take a long time!\")\n",
    "    # get single site operaors\n",
    "    sx = sparse.csr_matrix(np.array([[0., 1.], [1., 0.]]))\n",
    "    sz = sparse.csr_matrix(np.array([[1., 0.], [0., -1.]]))\n",
    "    id = sparse.csr_matrix(np.eye(2))\n",
    "    sx_list = []  # sx_list[i] = kron([id, id, ..., id, sx, id, .... id])\n",
    "    sz_list = []\n",
    "    for i_site in range(L):\n",
    "        x_ops = [id] * L\n",
    "        z_ops = [id] * L\n",
    "        x_ops[i_site] = sx\n",
    "        z_ops[i_site] = sz\n",
    "        X = x_ops[0]\n",
    "        Z = z_ops[0]\n",
    "        for j in range(1, L):\n",
    "            X = sparse.kron(X, x_ops[j], 'csr')\n",
    "            Z = sparse.kron(Z, z_ops[j], 'csr')\n",
    "        sx_list.append(X)\n",
    "        sz_list.append(Z)\n",
    "    H_x = sparse.csr_matrix((2**L, 2**L))\n",
    "    H_zz = sparse.csr_matrix((2**L, 2**L))\n",
    "    for i in range(L - 1):\n",
    "        rand_J = Js[i]\n",
    "        H_zz = H_zz + rand_J*sz_list[i] * sz_list[(i + 1) % L]\n",
    "    for i in range(L):\n",
    "        H_x = H_x + sx_list[i]\n",
    "    H = -H_zz - h * H_x\n",
    "    E, V = arp.eigsh(H, k=2, which='SA', return_eigenvectors=True, ncv=20)\n",
    "    return V[:,0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a05e683f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e846028c",
   "metadata": {},
   "source": [
    "## Probability Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "db999657",
   "metadata": {},
   "outputs": [],
   "source": [
    "vs = []\n",
    "for j in range(0,3**num_bit):\n",
    "    observable = np.load(str(num_bit)+'qubit/observable'+str(num_bit)+str(j) + '.npy')\n",
    "    vs.append(np.linalg.eig(observable)[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "091eb44c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def exact_E_rand_Js_probs(state,L,vs):\n",
    "    values = []\n",
    "    for j in range(0,3**num_bit):\n",
    "        tmp = []\n",
    "        for k in range(0,2**num_bit):\n",
    "            tmp.append(np.abs(np.inner(state.conj().T,vs[j][:,k]))**2)\n",
    "        values.append(tmp)\n",
    "    return values"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "645b2196",
   "metadata": {},
   "source": [
    "## Dataset Generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "83027f0d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [09:37<00:00,  3.46it/s]\n"
     ]
    }
   ],
   "source": [
    "L = num_bit\n",
    "h = 1\n",
    "num_states= 2000\n",
    "for i in tqdm(range(0,num_states)):\n",
    "    Js = np.random.random(L)*2-1\n",
    "    state = exact_E_rand_Js(L, Js, h)\n",
    "    probs = exact_E_rand_Js_probs(state,L,vs)\n",
    "    np.save(str(L)+'qubit/Ising_ground_state_'+str(L)+'qubit_probs_random_'+ str(i)+'.npy',probs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
