{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "296b01d4",
   "metadata": {},
   "source": [
    "# Results from Shaydulin and Wild (2021)\n",
    "\n",
    "#### Code taken from: https://github.com/rsln-s/Importance-of-Kernel-Bandwidth-in-Quantum-Machine-Learning/\n",
    "\n",
    "#### Requires downloading the reduced_data.npz file (~8.6 GB) (Derived from Shaydulin and Wild (2022)):\n",
    "\n",
    "https://drive.google.com/file/d/1a-rDDY-sPI8cigccL44PsfoelN6hyhJL/view?usp=sharing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e619dc57",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "rc('text', usetex=False)\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "axis_label_fontsize=36\n",
    "axis_tick_fontsize=22\n",
    "axis_legend_fontsize=24\n",
    "\n",
    "import utils\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2879088b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_kernel(df, c):\n",
    "    ## Assumes a df for a specific model, dataset and num_qubits\n",
    "    \n",
    "    df_c = df[(df[c_key] == c)]\n",
    "    assert len(df_c) == 1, \"The dataset has more than one item!\"\n",
    "    \n",
    "    test_score_svm = df_c.test_score.to_numpy()[0]   ## Test score for the SVM on the full kernel matrices\n",
    "    K_tr_tr = df_c.qkern_matrix_train.to_numpy()[0]  ## 800 x 800 kernel Gram matrix on full training set\n",
    "    K_tr_test = df_c.qkern_matrix_test.to_numpy()[0] ## 200 x 800 kernel Gram matrix on full train-test set\n",
    "    \n",
    "    ## Perform eigendecomposition on the full training kernel Gram matrix\n",
    "    P = len(K_tr_tr)\n",
    "    eigs_tr, vecs_tr = np.linalg.eigh(K_tr_tr / P)\n",
    "    eigs_tr = eigs_tr[::-1]\n",
    "    vecs_tr = np.sqrt(P) * vecs_tr[:,::-1]\n",
    "    \n",
    "    return {'c': c,\n",
    "            'test_score_svm': test_score_svm,\n",
    "            'K_tr_tr': K_tr_tr,\n",
    "            'K_tr_test': K_tr_test,\n",
    "            'eigs_tr': eigs_tr,\n",
    "            'vecs_tr': vecs_tr,\n",
    "           }\n",
    "\n",
    "def kt_alignment(K, y):\n",
    "    y = y.reshape(len(y), 1)\n",
    "    K_yy = y @ y.T\n",
    "    \n",
    "    return np.trace(K@K_yy)/ np.sqrt(np.trace(K@K)*np.trace(K_yy@K_yy))\n",
    "\n",
    "def get_dataset(ds, num_qubits, P_tr=800, P_test=200):\n",
    "    \n",
    "    return utils.get_dataset(ds, num_qubits, P_tr, P_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5af4404e",
   "metadata": {},
   "source": [
    "# Process all quantum simulation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "83d789d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Lists for the available models and datasets.\n",
    "model_list = ['IQP', 'EVO']\n",
    "ds_list = ['fmnist', 'kmnist', 'plasticc']\n",
    "rho = np.arange(1, 801, 1)\n",
    "\n",
    "try:\n",
    "    print('Retreiving the Tables...')\n",
    "    all_data = np.load('./reduced_data.npz', allow_pickle=True)\n",
    "    data_table = all_data['data_table'].tolist()\n",
    "    test_table = all_data['test_table'].tolist()\n",
    "    print('Done!')\n",
    "except:\n",
    "    print('Please download reduced_data.npz')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4e959e2",
   "metadata": {},
   "source": [
    "# Create the Table 2 in Main Text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e41ed93a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'IQP'\n",
    "num_qubits = 22\n",
    "\n",
    "for ds in ds_list:\n",
    "    model_ds = model + '_' + ds\n",
    "    \n",
    "    c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "    test_list = test_table[model_ds][num_qubits]['test_list']\n",
    "    \n",
    "    plt.loglog(c_list, test_list)\n",
    "    print(ds, 'MIN - MAX', test_table[model_ds][num_qubits]['max'], test_table[model_ds][num_qubits]['min'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c977db4",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'EVO'\n",
    "num_qubits = 23 ## EVO model requires dataset_dim + 1 qubits\n",
    "\n",
    "for ds in ds_list:\n",
    "    model_ds = model + '_' + ds\n",
    "    \n",
    "    c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "    test_list = test_table[model_ds][num_qubits]['test_list']\n",
    "    \n",
    "    plt.loglog(c_list, test_list)\n",
    "    print(ds, 'MIN - MAX', test_table[model_ds][num_qubits]['max'], test_table[model_ds][num_qubits]['min'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "681371d9",
   "metadata": {},
   "source": [
    "# Figure 4 B, C - Kernel Eigenvalues and Cumulative Power for IQP Model with KMNIST Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44344933",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'IQP'\n",
    "ds = 'fmnist'\n",
    "dataset_dim = 22\n",
    "\n",
    "## Get the dataset\n",
    "x_train, x_test, y_train, y_test = get_dataset(ds, dataset_dim, 800, 200)\n",
    "\n",
    "## Get the model data\n",
    "model_ds = model + '_' + ds\n",
    "num_qubits = dataset_dim if model == 'IQP' else dataset_dim+1\n",
    "\n",
    "data = data_table[model_ds][num_qubits]\n",
    "c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "\n",
    "print(c_list)\n",
    "\n",
    "c_list_plot = [0.001, 0.01,  0.05,  0.1,   0.5,   1.,]\n",
    "\n",
    "colors = iter(sns.color_palette('viridis', len(c_list_plot)))\n",
    "\n",
    "fig, axs = plt.subplots(1,2, figsize=(14,5))\n",
    "\n",
    "largest_eig = []\n",
    "for data_c in data:\n",
    "    c = data_c['c']\n",
    "    \n",
    "    if c in c_list_plot:\n",
    "        K_tr = data_c['K_tr_tr']\n",
    "        eigs_tr = data_c['eigs_tr']\n",
    "        vecs_tr = data_c['vecs_tr']\n",
    "        \n",
    "        weight_sq = (vecs_tr.T @ y_train / len(y_train))**2\n",
    "        cum_power = np.cumsum(weight_sq)/np.sum(weight_sq)\n",
    "        \n",
    "        largest_eig += [eigs_tr[0]]\n",
    "        \n",
    "        color = next(colors)\n",
    "        axs[0].loglog(rho, eigs_tr, label=f'$c = {c}$', linewidth=3, color=color)\n",
    "        axs[0].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "        axs[0].set_ylabel('$\\eta_k$', fontsize=axis_label_fontsize)\n",
    "        axs[1].semilogx(rho, cum_power, '-',  linewidth=3, color=color)\n",
    "        axs[1].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "        axs[1].set_ylabel('$C(k)$', fontsize=axis_label_fontsize)\n",
    "        \n",
    "axs[0].legend(loc='best', fontsize=axis_legend_fontsize)\n",
    "axs[0].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "axs[1].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "plt.figure(figsize=(7,5))\n",
    "plt.title('Largest Eigenvalue vs. $c$ (Bandwidth)')\n",
    "plt.plot(c_list_plot, largest_eig, linewidth=3, label=f'$n = {num_qubits}$')\n",
    "plt.legend(fontsize=axis_legend_fontsize)\n",
    "plt.xlabel('$c$', fontsize=axis_label_fontsize)\n",
    "plt.ylabel('$\\eta_{max}$',fontsize=axis_label_fontsize)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f439c3f0",
   "metadata": {},
   "source": [
    "## IQP Scaling of the Largest Eigenvalue with Number of Qubits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2116afe3",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'IQP'\n",
    "ds = 'fmnist'\n",
    "\n",
    "## Get the dataset\n",
    "x_train, x_test, y_train, y_test = get_dataset(ds, dataset_dim, 800, 200)\n",
    "\n",
    "## Get the model data\n",
    "model_ds = model + '_' + ds\n",
    "num_qubits = dataset_dim if model == 'IQP' else dataset_dim+1\n",
    "\n",
    "data = data_table[model_ds]\n",
    "dim_list = data_table[model_ds]['num_qubits_list']\n",
    "c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "\n",
    "dim_list_plot = dim_list[1:-10] ## Exclude dim = 23\n",
    "c_list_plot = [0.001, 0.005, 0.01, 0.05, 0.1]#, 0.5, 1.]\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1,2, figsize=(14,5))\n",
    "colors = iter(sns.color_palette('viridis', len(dim_list_plot)))\n",
    "\n",
    "largest_eig = []\n",
    "for dim, data_dim in zip(dim_list_plot, list(data_table[model_ds].values())[1:]):\n",
    "\n",
    "    colors = iter(sns.color_palette('viridis', len(dim_list_plot)))\n",
    "    largest_eig_c = []\n",
    "    for data_c in data_dim:\n",
    "\n",
    "        c = data_c['c']\n",
    "\n",
    "        if c in c_list_plot:\n",
    "\n",
    "            K_tr = data_c['K_tr_tr']\n",
    "            eigs_tr = data_c['eigs_tr']\n",
    "            vecs_tr = data_c['vecs_tr']\n",
    "\n",
    "            weight_sq = (vecs_tr.T @ y_train / len(y_train))**2\n",
    "            cum_power = np.cumsum(weight_sq)/np.sum(weight_sq)\n",
    "\n",
    "            largest_eig_c += [eigs_tr[0]]\n",
    "\n",
    "            color = next(colors)\n",
    "            axs[0].loglog(rho, eigs_tr, label=f'$n = {dim}$', linewidth=3, color=color)\n",
    "            axs[0].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "            axs[0].set_ylabel('$\\eta_k$', fontsize=axis_label_fontsize)\n",
    "            axs[1].semilogx(rho, cum_power, '-',  linewidth=3, color=color)\n",
    "            axs[1].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "            axs[1].set_ylabel('$C(k)$', fontsize=axis_label_fontsize)\n",
    "\n",
    "    largest_eig += [np.array(largest_eig_c)]\n",
    "largest_eig = np.array(largest_eig).T\n",
    "        \n",
    "# axs[0].legend(loc='best', fontsize=axis_legend_fontsize)\n",
    "axs[0].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "axs[1].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e861de9a",
   "metadata": {},
   "source": [
    "Note the eigenvalues scale exponentially with $n$: $\\eta_{max}(n) = a  s^{-n}$. We devide each with $\\eta_{max}(4)$, which is the element dim_list_plot[0].\n",
    "\n",
    "Hence in the fit we use $\\tilde \\eta_{max}(n) = s^{-n+ 4}$.\n",
    "\n",
    "We fit an exponential on $\\log \\tilde\\eta_{max}(n) =  (-n + 4) \\log s$. Therefore:\n",
    "\n",
    "$$\n",
    "\\log s = params[1]/4 \\; \\text{and} \\; \\log s = -params[0]\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52c124e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "largest_eig = largest_eig/np.max(largest_eig, axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "dim_list_fit = np.logspace(.5, 7.0, 300000)\n",
    "\n",
    "# dim_list_fit = dim_list_plot\n",
    "s_list = []\n",
    "largest_eig_fit = []\n",
    "largest_eig_fit_anlyt = []\n",
    "largest_eig_fit_params = []\n",
    "for eig in largest_eig:\n",
    "    x = (dim_list_plot)\n",
    "    y = np.log(eig)\n",
    "    \n",
    "    params = np.polyfit(x, y, 1)\n",
    "    eta_max_fit = np.poly1d(params)((dim_list_fit))\n",
    "    \n",
    "    s1 = np.exp(params[1]/4)\n",
    "    s2 = np.exp(-params[0])\n",
    "    \n",
    "    s_list += [s1]\n",
    "    largest_eig_fit_anlyt += [s1**(-dim_list_fit+4)]\n",
    "    \n",
    "    \n",
    "    largest_eig_fit += [np.exp(eta_max_fit)]\n",
    "    largest_eig_fit_params += [params]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf6078ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "largest_eig = largest_eig/np.max(largest_eig, axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "dim_list_fit = np.logspace(.5, 5.6, 300000)\n",
    "\n",
    "# dim_list_fit = dim_list_plot\n",
    "largest_eig_fit = []\n",
    "largest_eig_fit_params = []\n",
    "for eig in largest_eig:\n",
    "    x = (dim_list_plot)\n",
    "    y = np.log(eig)\n",
    "    \n",
    "    params = np.polyfit(x, y, 1)\n",
    "    eta_max_fit = np.poly1d(params)((dim_list_fit))\n",
    "    \n",
    "    largest_eig_fit += [np.exp(eta_max_fit)]\n",
    "    largest_eig_fit_params += [params]\n",
    "\n",
    "\n",
    "colors = iter(sns.color_palette('viridis', len(c_list_plot)))\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "for c, eta_max, eta_max_fit in zip(c_list_plot, largest_eig, largest_eig_fit):\n",
    "\n",
    "    color = next(colors)\n",
    "    plt.loglog(dim_list_plot, eta_max, 'o', linewidth=3, label=f'$c = {c}$', color=color)\n",
    "    plt.loglog(dim_list_fit, eta_max_fit, '--', linewidth=3, color=color)\n",
    "\n",
    "plt.title('Largest Eigenvalue vs. $n$ (# Qubits)')\n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize)\n",
    "plt.ylabel('$\\\\tilde\\eta_{max}(n)$',fontsize=axis_label_fontsize-7)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "\n",
    "h_line = 9.9e-1\n",
    "plt.title('$\\eta_0 = %.2f$'%h_line, fontsize=25)\n",
    "\n",
    "plt.axhline(h_line, linestyle='--', color='k', linewidth=2)\n",
    "\n",
    "plt.ylim([h_line - .1 if h_line != .1 else h_line -.01, 1.03])\n",
    "\n",
    "plt.legend(loc='upper right', fontsize=axis_legend_fontsize-6)   \n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c81dcb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = iter(sns.color_palette('viridis', len(c_list_plot)))\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "opt_dim = []\n",
    "for i, c in enumerate(c_list_plot):\n",
    "\n",
    "    plt.loglog(dim_list_fit, abs(largest_eig_fit[i] - h_line), linewidth=3, label=f'c={c}', color=next(colors))\n",
    "    opt_dim += [dim_list_fit[np.argmin(abs(largest_eig_fit[i] - h_line))]]\n",
    "    \n",
    "plt.legend(fontsize=axis_legend_fontsize-7, loc='upper center') \n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize-7)\n",
    "# plt.ylabel('$\\eta_{max}(n)/\\eta_{max}(n_0)$',fontsize=axis_label_fontsize-14)\n",
    "plt.ylabel('$\\\\tilde\\eta_{max}(n) - \\eta_0$', fontsize=axis_label_fontsize-7)\n",
    "plt.title('$\\eta_0 = %.2f$'%h_line, fontsize=25)\n",
    "plt.tight_layout()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8613e90f",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_params = np.polyfit(np.log(opt_dim), np.log(c_list_plot), 1)\n",
    "fits = np.exp(np.poly1d(fit_params)(np.log(opt_dim)))\n",
    "\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "plt.loglog(opt_dim, fits, '--', linewidth=3, color = 'k', \n",
    "           label='($\\eta_0 = %.2f$) $c^* \\propto n^{%.3f}$'%(h_line, fit_params[0]))\n",
    "plt.loglog(opt_dim, c_list_plot, 'o', linewidth=3, markersize=12)\n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize-7)\n",
    "plt.ylabel('$c^*$',fontsize=axis_label_fontsize-7)\n",
    "plt.legend(fontsize=axis_legend_fontsize-5)\n",
    "plt.tight_layout()\n",
    "\n",
    "print(fit_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b1517db",
   "metadata": {},
   "source": [
    "## EVO Scaling of the Largest Eigenvalue with Number of Qubits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e3098c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = 'EVO'\n",
    "ds = 'fmnist'\n",
    "\n",
    "## Get the dataset\n",
    "x_train, x_test, y_train, y_test = get_dataset(ds, dataset_dim, 800, 200)\n",
    "\n",
    "## Get the model data\n",
    "model_ds = model + '_' + ds\n",
    "num_qubits = dataset_dim if model == 'IQP' else dataset_dim+1\n",
    "\n",
    "data = data_table[model_ds]\n",
    "dim_list = data_table[model_ds]['num_qubits_list']\n",
    "c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "\n",
    "dim_list_plot = dim_list[:] ## Exclude dim = 23\n",
    "c_list_plot = [0.001, 0.01, 0.05, 0.1]#, 0.5]#, 1.e+0, 5.e+0]\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(1,2, figsize=(14,5))\n",
    "colors = iter(sns.color_palette('viridis', len(dim_list_plot)))\n",
    "\n",
    "largest_eig = []\n",
    "for dim, data_dim in zip(dim_list_plot, list(data_table[model_ds].values())[1:]):\n",
    "\n",
    "    colors = iter(sns.color_palette('viridis', len(data_dim)))\n",
    "    largest_eig_c = []\n",
    "    for data_c in data_dim:\n",
    "\n",
    "        c = data_c['c']\n",
    "\n",
    "        if c in c_list_plot:\n",
    "\n",
    "            K_tr = data_c['K_tr_tr']\n",
    "            eigs_tr = data_c['eigs_tr']\n",
    "            vecs_tr = data_c['vecs_tr']\n",
    "\n",
    "            weight_sq = (vecs_tr.T @ y_train / len(y_train))**2\n",
    "            cum_power = np.cumsum(weight_sq)/np.sum(weight_sq)\n",
    "\n",
    "            largest_eig_c += [eigs_tr[0]]\n",
    "\n",
    "            color = next(colors)\n",
    "            axs[0].loglog(rho, eigs_tr, label=f'$n = {dim}$', linewidth=3, color=color)\n",
    "            axs[0].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "            axs[0].set_ylabel('$\\eta_k$', fontsize=axis_label_fontsize)\n",
    "            axs[1].semilogx(rho, cum_power, '-',  linewidth=3, color=color)\n",
    "            axs[1].set_xlabel('$k$', fontsize=axis_label_fontsize)\n",
    "            axs[1].set_ylabel('$C(k)$', fontsize=axis_label_fontsize)\n",
    "\n",
    "    largest_eig += [np.array(largest_eig_c)]\n",
    "largest_eig = np.array(largest_eig).T\n",
    "        \n",
    "# axs[0].legend(loc='best', fontsize=axis_legend_fontsize)\n",
    "axs[0].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "axs[1].tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fd0afb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "largest_eig = largest_eig/np.max(largest_eig, axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "dim_list_fit = np.logspace(.5, 7.0, 300000)\n",
    "\n",
    "# dim_list_fit = dim_list_plot\n",
    "s_list = []\n",
    "largest_eig_fit = []\n",
    "largest_eig_fit_anlyt = []\n",
    "largest_eig_fit_params = []\n",
    "for eig in largest_eig:\n",
    "    x = (dim_list_plot)\n",
    "    y = np.log(eig)\n",
    "    \n",
    "    params = np.polyfit(x, y, 1)\n",
    "    eta_max_fit = np.poly1d(params)((dim_list_fit))\n",
    "    \n",
    "    s1 = np.exp(params[1]/4)\n",
    "    s2 = np.exp(-params[0])\n",
    "    \n",
    "    s_list += [s1]\n",
    "    \n",
    "    largest_eig_fit_anlyt += [s1**(-dim_list_fit+4)]\n",
    "    \n",
    "    \n",
    "    largest_eig_fit += [np.exp(eta_max_fit)]\n",
    "    largest_eig_fit_params += [params]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7581d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "largest_eig = largest_eig/np.max(largest_eig, axis=1, keepdims=True)\n",
    "\n",
    "\n",
    "dim_list_fit = np.logspace(.5, 5.6, 300000)\n",
    "\n",
    "# dim_list_fit = dim_list_plot\n",
    "largest_eig_fit = []\n",
    "largest_eig_fit_params = []\n",
    "for eig in largest_eig:\n",
    "    x = (dim_list_plot)\n",
    "    y = np.log(eig)\n",
    "    \n",
    "    params = np.polyfit(x, y, 1)\n",
    "    eta_max_fit = np.poly1d(params)((dim_list_fit))\n",
    "    \n",
    "    largest_eig_fit += [np.exp(eta_max_fit)]\n",
    "    largest_eig_fit_params += [params]\n",
    "\n",
    "\n",
    "colors = iter(sns.color_palette('viridis', len(c_list_plot)))\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "for c, eta_max, eta_max_fit in zip(c_list_plot, largest_eig, largest_eig_fit):\n",
    "\n",
    "    color = next(colors)\n",
    "    plt.loglog(dim_list_plot, eta_max, 'o', linewidth=3, label=f'$c = {c}$', color=color)\n",
    "    plt.loglog(dim_list_fit, eta_max_fit, '--', linewidth=3, color=color)\n",
    "\n",
    "plt.title('Largest Eigenvalue vs. $n$ (# Qubits)')\n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize)\n",
    "plt.ylabel('$\\\\tilde\\eta_{max}(n)$',fontsize=axis_label_fontsize-7)\n",
    "plt.gca().tick_params(axis='both', which='major', labelsize=axis_tick_fontsize)\n",
    "\n",
    "h_line = 9e-1\n",
    "plt.title('$\\eta_0 = %.2f$'%h_line, fontsize=25)\n",
    "\n",
    "plt.axhline(h_line, linestyle='--', color='k', linewidth=2)\n",
    "\n",
    "plt.ylim([h_line - .1 if h_line != .1 else h_line -.01, 1.03])\n",
    "\n",
    "plt.legend(loc='upper right', fontsize=axis_legend_fontsize-6)   \n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7995d89b",
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = iter(sns.color_palette('viridis', len(c_list_plot)))\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "opt_dim = []\n",
    "for i, c in enumerate(c_list_plot):\n",
    "\n",
    "    plt.loglog(dim_list_fit, abs(largest_eig_fit[i] - h_line), linewidth=3, label=f'c={c}', color=next(colors))\n",
    "    opt_dim += [dim_list_fit[np.argmin(abs(largest_eig_fit[i] - h_line))]]\n",
    "    \n",
    "plt.legend(fontsize=axis_legend_fontsize-7, loc='upper center') \n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize-7)\n",
    "# plt.ylabel('$\\eta_{max}(n)/\\eta_{max}(n_0)$',fontsize=axis_label_fontsize-14)\n",
    "plt.ylabel('$\\\\tilde\\eta_{max}(n) - \\eta_0$', fontsize=axis_label_fontsize-7)\n",
    "plt.title('$\\eta_0 = %.2f$'%h_line, fontsize=25)\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6765e280",
   "metadata": {},
   "outputs": [],
   "source": [
    "fit_params = np.polyfit(np.log(opt_dim), np.log(c_list_plot), 1)\n",
    "fits = np.exp(np.poly1d(fit_params)(np.log(opt_dim)))\n",
    "\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "plt.loglog(opt_dim, fits, '--', linewidth=3, color = 'k', \n",
    "           label='($\\eta_0 = %.2f$) $c^* \\propto n^{%.3f}$'%(h_line, fit_params[0]))\n",
    "plt.loglog(opt_dim, c_list_plot, 'o', linewidth=3, markersize=12)\n",
    "plt.xlabel('$n$', fontsize=axis_label_fontsize-7)\n",
    "plt.ylabel('$c^*$',fontsize=axis_label_fontsize-7)\n",
    "plt.legend(fontsize=axis_legend_fontsize-5)\n",
    "plt.tight_layout()\n",
    "\n",
    "print(fit_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1f8b314",
   "metadata": {},
   "source": [
    "## SVM with Cross-Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96eef3ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.svm import SVC\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "## Get the dataset\n",
    "dataset_dim = 22\n",
    "\n",
    "svm_results = {}\n",
    "for j, ds in enumerate(ds_list):\n",
    "    x_train, x_test, y_train, y_test = get_dataset(ds, dataset_dim, 800, 200)\n",
    "    \n",
    "    for i, model in enumerate(model_list):\n",
    "        \n",
    "        model_ds = model + '_' + ds\n",
    "        num_qubits = dataset_dim if model == 'IQP' else dataset_dim+1\n",
    "        \n",
    "        data = data_table[model_ds][num_qubits]\n",
    "        c_list = test_table[model_ds][num_qubits]['c_list']\n",
    "        \n",
    "        svm_results[model_ds] = {'c_list': c_list}\n",
    "        \n",
    "        cv_scores = []\n",
    "        test_scores = []\n",
    "        eigs = []\n",
    "        weights = []\n",
    "        cum_powers = []\n",
    "        for c, data_c in zip(c_list, data):\n",
    "            assert c == data_c['c']\n",
    "            \n",
    "            K_tr_tr = data_c['K_tr_tr']\n",
    "            K_tr_test = data_c['K_tr_test']\n",
    "            eigs_tr = data_c['eigs_tr']\n",
    "            vecs_tr = data_c['vecs_tr']\n",
    "            \n",
    "            weight_sq = (vecs_tr.T @ y_train / len(y_train))**2\n",
    "            cum_power = np.cumsum(weight_sq)/np.sum(weight_sq)\n",
    "            \n",
    "            qsvc = SVC(kernel='precomputed', random_state=41)\n",
    "            qsvc.fit(K_tr_tr, y_train)\n",
    "            \n",
    "            cv_score_2 = (cross_val_score(qsvc, K_tr_tr, y_train, cv=2)).mean()\n",
    "            cv_score_5 = (cross_val_score(qsvc, K_tr_tr, y_train, cv=5)).mean()\n",
    "            cv_score_10 = (cross_val_score(qsvc, K_tr_tr, y_train, cv=10)).mean()\n",
    "            \n",
    "            cv_scores += [[cv_score_2, cv_score_5, cv_score_10]]\n",
    "            test_scores += [qsvc.score(K_tr_test, y_test)]\n",
    "            eigs += [eigs_tr]\n",
    "            weights += [weight_sq]\n",
    "            cum_powers += [cum_power]\n",
    "        \n",
    "        cv_scores = np.array(cv_scores).T\n",
    "        eigs = np.array(eigs)\n",
    "        weights = np.array(weights)\n",
    "        cum_powers = np.array(cum_powers)\n",
    "        \n",
    "        svm_results[model_ds]['cv_scores'] = cv_scores\n",
    "        svm_results[model_ds]['test_scores'] = test_scores\n",
    "        svm_results[model_ds]['eigs'] = eigs\n",
    "        svm_results[model_ds]['weights'] = weights\n",
    "        svm_results[model_ds]['cum_powers'] = cum_powers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ab19b15",
   "metadata": {},
   "outputs": [],
   "source": [
    "c_list_plot = [0.01, 0.05, 0.1, 0.5, 1.]\n",
    "\n",
    "fig, axs = plt.subplots(len(model_list), len(ds_list), figsize=(19,9))\n",
    "for j, ds in enumerate(ds_list): \n",
    "    for i, model in enumerate(model_list):\n",
    "        model_ds = model + '_' + ds\n",
    "        \n",
    "        (c_list, cv_scores, test_scores,\n",
    "         eigs, weights, cum_powers) = svm_results[model_ds].values()\n",
    "        \n",
    "        axs[i,j].loglog(c_list, cv_scores[0], linewidth=3, label='2-fold-cv')\n",
    "        axs[i,j].loglog(c_list, cv_scores[1], linewidth=3, label='5-fold-cv')\n",
    "        axs[i,j].loglog(c_list, cv_scores[2], linewidth=3, label='10-fold-cv')\n",
    "        axs[i,j].loglog(c_list, test_scores, linewidth=3, label='test_scores')\n",
    "        axs[i,j].set_xlabel('Bandwidth $c$', fontsize=25)\n",
    "        axs[i,j].set_ylabel('Test Scores', fontsize=25)\n",
    "        axs[i,j].set_title(f'{model_ds}')\n",
    "        \n",
    "        axs[i,j].legend(loc='lower left')\n",
    "        \n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "\n",
    "fig, axs = plt.subplots(len(model_list), len(ds_list), figsize=(19,9))\n",
    "for j, ds in enumerate(ds_list): \n",
    "    for i, model in enumerate(model_list):\n",
    "        model_ds = model + '_' + ds\n",
    "        \n",
    "        (c_list, cv_scores, test_scores,\n",
    "         eigs, weights, cum_powers) = svm_results[model_ds].values()\n",
    "        \n",
    "        for c, eig in zip(c_list, eigs):\n",
    "            if c in c_list_plot:\n",
    "                axs[i,j].loglog(rho, eig, linewidth=3, label=f'$c={c}$')\n",
    "                axs[i,j].set_xlabel('$k$', fontsize=25)\n",
    "                axs[i,j].set_ylabel('Eigenvalue $\\eta_k$', fontsize=25)\n",
    "                axs[i,j].set_title(f'{model_ds}')\n",
    "        \n",
    "        axs[i,j].legend(loc='lower left')\n",
    "            \n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "fig, axs = plt.subplots(len(model_list), len(ds_list), figsize=(19,9))\n",
    "for j, ds in enumerate(ds_list): \n",
    "    for i, model in enumerate(model_list):\n",
    "        model_ds = model + '_' + ds\n",
    "        \n",
    "        (c_list, cv_scores, test_scores,\n",
    "         eigs, weights, cum_powers) = svm_results[model_ds].values()\n",
    "        \n",
    "        for c, cum_power in zip(c_list, cum_powers):\n",
    "            if c in c_list_plot:\n",
    "                axs[i,j].semilogx(rho, cum_power, linewidth=3, label=f'$c={c}$')\n",
    "                axs[i,j].set_xlabel('$k$', fontsize=25)\n",
    "                axs[i,j].set_ylabel('$C(k)$', fontsize=25)\n",
    "                axs[i,j].set_title(f'{model_ds}')\n",
    "        \n",
    "        axs[i,j].legend(loc='lower right')\n",
    "            \n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a013e34-6438-4094-aaa1-f0f7da06e676",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0c43f8-f6c6-4ba6-8696-21ac23eef92e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95ddcfcf-5d85-43c5-9219-0ff4daacf46d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ffcv)",
   "language": "python",
   "name": "ffcv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
