{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import jax.numpy as jnp\n",
    "import matplotlib as mpl\n",
    "\n",
    "import sklearn\n",
    "from sklearn import svm\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "import seaborn as sns\n",
    "import itertools\n",
    "\n",
    "rc('text', usetex=False)\n",
    "plt.rcParams.update({'font.size': 18})\n",
    "axis_label_fontsize=26\n",
    "axis_tick_fontsize=22\n",
    "axis_legend_fontsize=24"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def quantum_kernel_matrix(X, Xp=None, c=1):\n",
    "    \n",
    "    if Xp is not None:\n",
    "        diff = X[:,np.newaxis] - Xp[np.newaxis,:]\n",
    "    else:\n",
    "        diff = X[:,np.newaxis] - X[np.newaxis,:]\n",
    "    \n",
    "    K = 1\n",
    "    for i in range(diff.shape[-1]):\n",
    "        K *= np.cos(c*diff[:,:,i]/2.)**2\n",
    "    \n",
    "    return K \n",
    "\n",
    "    \n",
    "def kt_alignment(K, y):\n",
    "    y = y.reshape(len(y), 1)\n",
    "    K_yy = y @ y.T\n",
    "    \n",
    "    return np.trace(K@K_yy)/ np.sqrt(np.trace(K@K)*np.trace(K_yy@K_yy))\n",
    "\n",
    "\n",
    "# Plot some kernel cross-sections\n",
    "plt.close('all')\n",
    "n_qubits = 100\n",
    "delta = np.pi * np.ones(n_qubits)# this gives a much wider picture of nonlocalness of kernel\n",
    "x0 = np.zeros(n_qubits)\n",
    "N = 1000\n",
    "xspread = np.array([-np.pi/2 + delta * x for x in np.linspace(0, 1, N)])\n",
    "fig, axes = plt.subplots(2,1, figsize=(4, 6), constrained_layout=True, sharex=True, sharey=True)\n",
    "\n",
    "\n",
    "cvals = [1, 0.25]\n",
    "\n",
    "colors = ['orangered', 'forestgreen']\n",
    "for i_c, c in enumerate(cvals):\n",
    "    K = quantum_kernel_matrix(xspread, xspread, c=c)[499]\n",
    "    axes[i_c].plot(xspread, K, c=colors[i_c], lw=2)\n",
    "    xxx = np.linspace(-np.pi/2, np.pi/2, N)\n",
    "    axes[i_c].fill_between(xxx, np.zeros_like(xxx), K, color=colors[i_c], alpha=0.1)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment with $\\bar{f}(x) = \\cos(c\\, x^{(50)})$ on $x \\in \\{0, \\pi\\}^n$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def label(x, c):\n",
    "    \"\"\"Function to assign a label to any member of the dataset.\"\"\"\n",
    "    return np.cos(c * x[:, -1])\n",
    "\n",
    "n_qubits = 50\n",
    "n = 2 ** n_qubits # max num training points\n",
    "n_data = 1000\n",
    "\n",
    "X = np.vstack((np.random.uniform(0, np.pi, size=(n_qubits-1, n_data)), np.random.choice([0, np.pi], n_data))).T\n",
    "y = label(X, 1)\n",
    "\n",
    "test_fraction = 0.25\n",
    "lam = 1e-5\n",
    "cvals = [1, 0.25]\n",
    "\n",
    "\n",
    "x_continuous = np.vstack((np.random.uniform(0, np.pi, size=(n_qubits-1, 10)), np.random.choice([0, np.pi], 10))).T\n",
    "y_continuous = label(x_continuous, 1)\n",
    "x_continuous=x_continuous[:,-1]\n",
    "\n",
    "\n",
    "all_eigs = []\n",
    "all_weights = []\n",
    "all_cum_power = []\n",
    "all_kt = []\n",
    "\n",
    "fig, axes = plt.subplots(2,1, figsize=(4, 6), constrained_layout=True, sharex=True, sharey=True)\n",
    "colors = ['orangered', 'forestgreen']\n",
    "for j, c in enumerate(cvals):\n",
    "   \n",
    "    ax = axes[j]\n",
    "\n",
    "    # generate some random train/test data\n",
    "    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_fraction, random_state=123)\n",
    "    X_train, X_test = X_train, X_test\n",
    "\n",
    "    K_train = quantum_kernel_matrix(X_train, c=c)\n",
    "    K_test = quantum_kernel_matrix(X_train, X_test, c=c)\n",
    "\n",
    "    # Note that the train predictions interpolate the data if lambda=0\n",
    "    # Note that this is a regression problem; the predictions are floats not binary\n",
    "    lam = 0\n",
    "    K_inv = jnp.linalg.inv(K_train + lam * np.eye(len(X_train)))\n",
    "    train_preds = y_train.dot(K_inv @ K_train)\n",
    "    test_preds = y_train.dot(K_inv @ K_test)\n",
    "\n",
    "    err = np.linalg.norm(test_preds - y_test, ord=2) ** 2 / n_data\n",
    "    \n",
    "    lab = f\"Accuracy = {1 - err}\"\n",
    "    \n",
    "    ax.scatter(X_test[:,-1], test_preds, c=colors[j], s=40, alpha=1, label=r'$f^*(x)$')\n",
    "    ax.plot(x_continuous, y_continuous, marker='o', linestyle='dashed', c='k', lw=1, label=r'$\\bar f(x)$')\n",
    "    ax.legend()\n",
    "    ax.set_xlim(-.1, np.pi+.1)\n",
    "    ax.legend()\n",
    "    \n",
    "    ax.set_xticks([0, np.pi])\n",
    "    ax.set_xticklabels([\"0\", \"$\\pi$\"], fontsize=axis_tick_fontsize)\n",
    "    \n",
    "    print(f\"dataset err: {err}\")\n",
    "    \n",
    "    eigs, vecs = np.linalg.eigh(K_train/len(K_train))\n",
    "    eigs = eigs[::-1]\n",
    "    vecs = np.sqrt(len(K_train)) * vecs[:, ::-1]\n",
    "    weight_sq = (vecs.T @ y_train / len(K_train))**2\n",
    "    cum_power = np.cumsum(weight_sq)/np.sum(weight_sq)\n",
    "        \n",
    "    all_eigs += [eigs]\n",
    "    all_weights += [weight_sq]\n",
    "    all_cum_power += [cum_power]\n",
    "    all_kt += [kt_alignment(K_train, y_train)]\n",
    "    \n",
    "ax.set_xlabel(\"$x^{(50)}$\", fontsize=axis_label_fontsize)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "all_eigs = np.array(all_eigs)\n",
    "all_weights = np.array(all_weights)\n",
    "all_cum_power = np.array(all_cum_power)\n",
    "all_kt = np.array(all_kt)\n",
    "\n",
    "rho = np.arange(1, len(all_eigs[0]) +1, 1)\n",
    "\n",
    "fig, axs = plt.subplots(1, 2, figsize=(13,5))\n",
    "\n",
    "colors = iter(sns.color_palette('viridis', len(cvals)))\n",
    "\n",
    "for i, c, eig, weight, cum_power, kt in zip(range(len(cvals)), cvals, all_eigs, all_weights, all_cum_power, all_kt):\n",
    "    \n",
    "    color = next(colors)\n",
    "    axs[0].loglog(rho, eig, label=f'c = ${c}$', linewidth=3, color=color)\n",
    "    axs[1].plot(rho, cum_power, label=f'c = ${c}$', linewidth=3, color=color)\n",
    "    \n",
    "axs[0].set_xlabel('$k$', fontsize = axis_label_fontsize)\n",
    "axs[0].set_ylabel('$\\eta_k$', fontsize = axis_label_fontsize)\n",
    "axs[1].set_xlabel('$k$', fontsize = axis_label_fontsize)\n",
    "axs[1].set_ylabel('$C(k)$', fontsize = axis_label_fontsize)\n",
    "axs[0].legend()\n",
    "axs[1].legend()\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q git+https://github.com/Pehlevan-Group/kernel-generalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from functools import partial\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "import seaborn as sns\n",
    "\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from jax import jit, vmap, config\n",
    "\n",
    "from kernel_generalization import kernel_simulation as ker_sim\n",
    "\n",
    "rc('text', usetex=False)\n",
    "plt.rcParams.update({'font.size': 16})\n",
    "axis_label_fontsize=36\n",
    "axis_tick_fontsize=22\n",
    "axis_legend_fontsize=22\n",
    "\n",
    "gpu_id = 0\n",
    "config.update(\"jax_enable_x64\", True)\n",
    "\n",
    "def quantum_kernel(X, X_p=None, sigma=1):\n",
    "    \n",
    "    K = 1\n",
    "    for i, x in enumerate(X.T):\n",
    "        if X_p == None:\n",
    "            diff = jax.device_put(x[:, jnp.newaxis] - x[jnp.newaxis, :], jax.devices()[gpu_id])\n",
    "        else:\n",
    "            diff = jax.device_put(x[:, jnp.newaxis] - X_p.T[i][jnp.newaxis, :], jax.devices()[gpu_id])\n",
    "        K *= jax.device_put((jnp.cos(sigma*diff/2.)**2), jax.devices()[gpu_id])\n",
    "        \n",
    "    del diff\n",
    "        \n",
    "    return K\n",
    "\n",
    "@jit\n",
    "def kernel_spectra_eigvalsh(K, y):\n",
    "    \n",
    "    P = len(K)\n",
    "    eigs = jax.device_put(jnp.linalg.eigvalsh(K/P), jax.devices()[gpu_id])\n",
    "    eigs = eigs[::-1]\n",
    "\n",
    "    return eigs, 0\n",
    "\n",
    "@jit\n",
    "def kernel_spectra(K, y):\n",
    "    \n",
    "    P = len(K)\n",
    "    eigs, vecs = jax.device_put(jnp.linalg.eigh(K/P), jax.devices()[gpu_id])\n",
    "    eigs = jax.device_put(eigs[::-1], jax.devices()[gpu_id])\n",
    "    vecs = jax.device_put(vecs[:,::-1]*np.sqrt(P), jax.devices()[gpu_id])\n",
    "\n",
    "    return eigs, vecs\n",
    "\n",
    "\n",
    "def degens(D, k):\n",
    "    \n",
    "    if k == 1:\n",
    "        degen = D\n",
    "    elif k == 2:\n",
    "        degen = D*(D+1)/2\n",
    "    elif k == 3:\n",
    "        degen = D*(D-1)*(D-2)/6 + D*(D-1)\n",
    "    else:\n",
    "        raise Exception('Degeneracy not implemented - Just multinomial coefficients')\n",
    "        \n",
    "    return int(degen)\n",
    "\n",
    "def compute_gen_err(pvals, reg, eigs, weights):\n",
    "    eigs = np.array(eigs)\n",
    "    weights = np.array(weights)\n",
    "    \n",
    "    kappa_vals = ker_sim.solve_kappa(pvals, reg, eigs)\n",
    "    gamma_vals = np.array([np.sum(ker_sim.gamma_fn(pvals[i], kappa_vals[i], eigs[:])) for i in range(len(pvals))])\n",
    "    \n",
    "    ## Calculate generalization error\n",
    "    noiseless = np.zeros(len(np.array(pvals)))\n",
    "    prefactor = kappa_vals ** 2 / (1 - gamma_vals)\n",
    "\n",
    "    for i, p in enumerate(pvals):\n",
    "        noiseless[i] = prefactor[i] * np.sum(weights[:] / (p * eigs[:] + kappa_vals[i]) ** 2)\n",
    "        \n",
    "    return noiseless\n",
    "\n",
    "def perform_ker_regression(pvals, reg, K, y):\n",
    "    \n",
    "    errs = []\n",
    "    for p in pvals:\n",
    "        K = np.array(K)\n",
    "        K_tr = jax.device_put(K[:,:p][:p,:] + reg*np.eye(p), jax.devices()[gpu_id])\n",
    "        K_test = jax.device_put(K[:,:p], jax.devices()[gpu_id])\n",
    "        y_tr = jax.device_put(np.array(y)[:p], jax.devices()[gpu_id])\n",
    "        \n",
    "        alpha = jnp.linalg.inv(K_tr) @ y_tr\n",
    "        y_pred = jnp.dot(K_test, alpha)\n",
    "        \n",
    "        errs += [np.mean((y_pred -  y)**2)]\n",
    "        \n",
    "        del alpha, y_pred, K_tr, K_test, y_tr\n",
    "    \n",
    "    return np.array(errs)\n",
    "    \n",
    "\n",
    "def compute_quantum_kernel(sigma, X, Xp=None):\n",
    "    \n",
    "    if Xp == None:\n",
    "        diff = jax.device_put(X[:,jnp.newaxis] - X[jnp.newaxis,:], jax.devices()[gpu_id])\n",
    "    else:\n",
    "        diff = jax.device_put(X[:,jnp.newaxis] - Xp[jnp.newaxis,:], jax.devices()[gpu_id])\n",
    "    \n",
    "    K = 1\n",
    "    for i in range(diff.shape[-1]):\n",
    "        K *= jax.device_put(jnp.cos(sigma*diff[:,:,i]/2.)**2, jax.devices()[gpu_id])\n",
    "    \n",
    "    return K\n",
    "\n",
    "def get_target_and_weights(X, vecs, target='exp', target_mode=2):\n",
    "    \n",
    "    P, D = X.shape\n",
    "    \n",
    "    if target=='exp':\n",
    "        y = jax.device_put(jnp.exp(-X**2), jax.devices()[gpu_id]).mean(1)\n",
    "        weights = (vecs.T @ y / (P))**2\n",
    "    if target=='gaussian':\n",
    "        y = jax.device_put(jnp.exp(-(X**2).sum(axis=1)/(D)**2), jax.devices()[gpu_id])\n",
    "        weights = (vecs.T @ y / (P))**2\n",
    "    elif target=='flat':\n",
    "        N = degens(D, target_mode)\n",
    "        weights = np.zeros(P)\n",
    "        weights[:N] = 1/N ## Target power adds up to 1\n",
    "        y = vecs @ jax.device_put(weights, jax.devices()[gpu_id])\n",
    "    elif target=='random_binary':\n",
    "        y = jax.device_put(np.random.choice([0,1], size=(P,)), jax.devices()[gpu_id])  ## The target here is badly aligned\n",
    "        weights = (vecs.T @ y / P)**2\n",
    "\n",
    "    return jax.device_put(y, jax.devices()[gpu_id]), jax.device_put(weights, jax.devices()[gpu_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ker_type = 'quantum'\n",
    "target = 'gaussian'\n",
    "\n",
    "P = 5000\n",
    "N_tr = P\n",
    "pvals_th = np.logspace(0.5, np.log10(N_tr) - 0.2, 100).astype('int')\n",
    "pvals_exp = np.logspace(0.5, np.log10(N_tr) - 0.2, 10).astype('int')\n",
    "reg = 1e-10\n",
    "target_mode = 2\n",
    "\n",
    "d_vals = np.array([20, 40, 80, 200])\n",
    "c_vals = np.append([1], 2/d_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_eigs = []\n",
    "all_vecs = []\n",
    "all_weights = []\n",
    "all_gen_err_th = []\n",
    "all_gen_err_exp = []\n",
    "\n",
    "for D in d_vals:\n",
    "    \n",
    "    X = jax.device_put(np.random.uniform(-np.pi, np.pi, size=(P, D)), jax.devices()[gpu_id])\n",
    "    \n",
    "    c_eigs = []\n",
    "    c_vecs = []\n",
    "    c_gen_err_th = []\n",
    "    c_gen_err_exp = []\n",
    "    c_weights = []\n",
    "    for c in c_vals:\n",
    "      \n",
    "        K = jax.device_put(quantum_kernel(X, sigma = c), jax.devices()[gpu_id])\n",
    "        print(f'D = {D} -- c = {c:.3f} ---------', end=\"\\r\")\n",
    "\n",
    "        eigs, vecs = kernel_spectra(K, 0)\n",
    "        y, weights = get_target_and_weights(X, vecs, target=target, target_mode=target_mode)\n",
    "        del vecs\n",
    "        eg_th = compute_gen_err(pvals_th, reg, eigs, weights)\n",
    "        eg_exp = perform_ker_regression(pvals_exp, reg, K, y)\n",
    "\n",
    "        c_gen_err_th += [np.array(eg_th)]\n",
    "        c_gen_err_exp += [np.array(eg_exp)]\n",
    "        c_eigs += [np.array(eigs)]\n",
    "        c_vecs += []\n",
    "        c_weights += [np.array([weights])]\n",
    "        \n",
    "        del K, eigs\n",
    "\n",
    "    all_eigs += [c_eigs]\n",
    "    all_vecs += [c_vecs]\n",
    "    all_weights += [c_weights]\n",
    "    all_gen_err_th += [c_gen_err_th]\n",
    "    all_gen_err_exp += [c_gen_err_exp]\n",
    "\n",
    "all_eigs = np.array(all_eigs)\n",
    "all_vecs = np.array(all_vecs)\n",
    "all_weights = np.array(all_weights).squeeze()\n",
    "all_gen_err_th = np.array(all_gen_err_th)\n",
    "all_gen_err_exp = np.array(all_gen_err_exp)\n",
    "\n",
    "data = dict(P=P, pvals_th=pvals_th, pvals_exp=pvals_exp, reg=reg, \n",
    "            d_vals=d_vals, c_vals=c_vals, target_mode=target_mode, target = target,\n",
    "            all_eigs=all_eigs, all_vecs=all_vecs, all_weights=all_weights, \n",
    "            all_gen_err_th=all_gen_err_th, all_gen_err_exp=all_gen_err_exp)\n",
    "\n",
    "np.savez(f'synthetic_exp_target_{target}.npz', data=data) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rcParams.update({'font.size': 12})\n",
    "\n",
    "fig, axs = plt.subplots(2, 2, figsize=(10,8))\n",
    "\n",
    "for k, D, gen_err_th, gen_err_exp in zip(range(4), d_vals, all_gen_err_th, all_gen_err_exp):\n",
    "    \n",
    "    N = degens(D, target_mode)\n",
    "\n",
    "    ax = axs[int(k>1), k%2]\n",
    "\n",
    "    for i, c, eg_th, eg_exp in zip(range(len(c_vals)), c_vals, gen_err_th, gen_err_exp):\n",
    "        ax.loglog(pvals_th, eg_th, label='c = %.0e'%c, color=f'C{i}')\n",
    "        ax.loglog(pvals_exp, eg_exp, 'o', color=f'C{i}')\n",
    "\n",
    "    ax.set_xlabel('P', fontsize=18)\n",
    "    ax.set_ylabel('$E_g(P)$', fontsize=18)\n",
    "    ax.legend()\n",
    "\n",
    "    ax.set_title('n = %d'%D)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig('gen_err_sim_appendix.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = sns.color_palette('viridis',len(c_vals))\n",
    "\n",
    "## For D = 40\n",
    "\n",
    "D = 40\n",
    "gen_err_th = all_gen_err_th[1]\n",
    "gen_err_exp = all_gen_err_exp[1]\n",
    "\n",
    "plt.figure(figsize=(7,5))\n",
    "\n",
    "for i, c, eg_th, eg_exp, color in zip(range(len(c_vals)), c_vals, gen_err_th, gen_err_exp, colors):\n",
    "    if c == 0.05:\n",
    "        label='$c^* = %.2f$'%c\n",
    "    else:\n",
    "        label='$c = %.2f$'%c\n",
    "    \n",
    "    plt.loglog(pvals_th, eg_th, label=label, color=f'C{i}', linewidth=3)\n",
    "    plt.loglog(pvals_exp, eg_exp, 'o', color=f'C{i}')\n",
    "\n",
    "plt.xlabel('$P$', fontsize=axis_label_fontsize)\n",
    "plt.ylabel('$E_g(P)$', fontsize=axis_label_fontsize)\n",
    "plt.xticks(fontsize=axis_tick_fontsize)\n",
    "plt.yticks(fontsize=axis_tick_fontsize)\n",
    "plt.legend(loc='best', fontsize=axis_legend_fontsize)\n",
    "plt.tight_layout()\n",
    "plt.savefig('sim_quantum_kernel_gen_err_sim.pdf')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Figure 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import reduce\n",
    "from qiskit.quantum_info import DensityMatrix\n",
    "from qiskit.circuit.library import ZZFeatureMap\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "def self_product(x: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Example from:\n",
    "    https://qiskit.org/documentation/_modules/qiskit/circuit/library/data_preparation/pauli_feature_map.html#PauliFeatureMap\n",
    "     \n",
    "    Define a function map from R^n to R.\n",
    "     \n",
    "    Args:\n",
    "        x: data\n",
    "     \n",
    "    Returns:\n",
    "        float: the mapped value\n",
    "    \"\"\"\n",
    "    coeff = x[0] if len(x) == 1 else reduce(lambda m, n: m * n, x)\n",
    "    return coeff\n",
    "\n",
    "def IQPStyleFeatureMap(dataset_dim):\n",
    "    \n",
    "    return ZZFeatureMap(dataset_dim, reps=2, data_map_func=self_product)\n",
    "\n",
    "def get_quantum_kernel(FeatureMap, simulation_method='statevector', shots=1, batch_size=500):                                  \n",
    "    \"\"\"Builds Qiskit QuantumKernel object \n",
    "    with parameters passed directly to HamiltonianEvolutionFeatureMap\n",
    "    \"\"\"\n",
    "    from qiskit.providers.aer import AerSimulator\n",
    "    from qiskit.utils import QuantumInstance\n",
    "    from qiskit_machine_learning.kernels import QuantumKernel\n",
    "    if simulation_method == 'statevector' and shots != 1:\n",
    "        raise ValueError(f'With simulation method {simulation_method} no shots are allowed')\n",
    "    quantum_instance_sv = QuantumInstance(AerSimulator(method=simulation_method, shots=shots))\n",
    "    return QuantumKernel(feature_map=FeatureMap, quantum_instance=quantum_instance_sv, batch_size=batch_size)\n",
    "\n",
    "N = 301 # Number\n",
    "d_vals = [2, 3, 5, 10] # check how dimension influences how the kernel \"looks\"\n",
    "c_vals = [0.01, 0.1, 0.3, .5, 1] # bandwidths to check\n",
    "sigma = 1\n",
    "\n",
    "K_all = []\n",
    "for i_d, d in enumerate(d_vals):\n",
    "    x0 = np.random.normal(size=d)\n",
    "    delta = np.ones(d)# this gives a much wider picture of nonlocalness of kernel\n",
    "    \n",
    "    X = np.linspace(x0+delta, x0-delta, N)\n",
    "    \n",
    "    K_c = []\n",
    "    for i_c, c in enumerate(c_vals):\n",
    "        print(f\"...d={d}, c={c}...\", end='\\r')\n",
    "        FeatureMap = IQPStyleFeatureMap(d)\n",
    "        qkern = get_quantum_kernel(FeatureMap)                                                                                     \n",
    "        K = qkern.evaluate(x_vec = c*X, y_vec = np.atleast_2d(c*x0))\n",
    "        K_c += [K]\n",
    "    K_all += [K_c]\n",
    "\n",
    "K_all = np.array(K_all).squeeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "colors = sns.color_palette('viridis', len(c_vals))[::-1]\n",
    "\n",
    "x = np.linspace(-1, 1, N)\n",
    "for c, color, K in zip(c_vals, colors, K_all[-1]):\n",
    "    plt.plot(x, K, color=color, label=f'$c={c}$', linewidth=3)\n",
    "    \n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python (ffcv)",
   "language": "python",
   "name": "ffcv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
