{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1ff22e8-98da-477f-86dc-29eca6225baf",
   "metadata": {},
   "source": [
    "# gp2Scale III -- MNIST"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3473f81d-4935-4e9c-ab5c-7679e2f29cd8",
   "metadata": {},
   "source": [
    "Make sure you\n",
    "\n",
    "- make a new environment\n",
    "\n",
    "- activate it\n",
    "\n",
    "- pip install ipykernel\n",
    "\n",
    "- python3 -m ipykernel install --user --name env --display-name MyEnvironment\n",
    "\n",
    "- pip install everything_else\n",
    "\n",
    "- make sure the notebook uses the right kernel\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af1e0909-65dd-4c1f-8f13-e6b5aed09fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gpcam import GPOptimizer\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import mnist_kernel\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45f0be8d-8625-4974-9b03-afbc83b0049f",
   "metadata": {},
   "source": [
    "run this in the terminal on Perlmutter\n",
    "\n",
    "#./launch-dask-module.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6ad26f-6d0b-4282-a380-485780a559c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask\n",
    "from dask.distributed import Client\n",
    "import os\n",
    "import time\n",
    "\n",
    "scheduler_file = os.path.join(os.environ[\"SCRATCH\"], \"scheduler_fileMNIST.json\")\n",
    "\n",
    "dask.config.config[\"distributed\"][\"dashboard\"][\"link\"] = \"{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status\" \n",
    "\n",
    "while True:\n",
    "    time.sleep(2)\n",
    "    if os.path.isfile(scheduler_file):\n",
    "        print(\"file found\")\n",
    "        time.sleep(2)\n",
    "        client = Client(scheduler_file=scheduler_file)\n",
    "        break\n",
    "print(\"waiting for workers\")\n",
    "client.wait_for_workers(16)\n",
    "print(client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c501bf-bd24-4e5a-ab22-deea98f32503",
   "metadata": {},
   "outputs": [],
   "source": [
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b09030f-713e-4da0-a372-74f9966c1247",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "with open(\"./data/x_train_MNIST.pkl\", 'rb') as f:\n",
    "    x_train = pickle.load(f)\n",
    "with open(\"./data/y_train_MNIST.pkl\", 'rb') as f:\n",
    "    y_train = pickle.load(f)\n",
    "with open(\"./data/x_test_MNIST.pkl\", 'rb') as f:\n",
    "    x_test = pickle.load(f)\n",
    "with open(\"./data/y_test_MNIST.pkl\", 'rb') as f:\n",
    "    y_test = pickle.load(f)\n",
    "\n",
    "y_train = np.asarray(y_train, dtype=np.float32)\n",
    "y_test  = np.asarray(y_test, dtype=np.float32)\n",
    "\n",
    "#probability of y==5\n",
    "y_train[:] = (y_train == 5.).astype(int)\n",
    "y_test[:] = (y_test == 5.).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c492bded-eb95-486c-887e-12899bca5159",
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2755913",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(x_train[0].shape)\n",
    "print(\" \")\n",
    "print(np.min(y_train), np.max(y_train))\n",
    "print(len(y_train))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a669f550-e7c7-4dba-8d75-4793c0f859ea",
   "metadata": {},
   "source": [
    "### Wendland, no bumps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ca48fd-444c-4aee-856e-a3317c2add73",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps_bounds = np.zeros((4,2))\n",
    "hps_bounds[0] = np.array([0.05, 2.]) #delta radius\n",
    "hps_bounds[1] = np.array([.1,   10.])  # signal var\n",
    "hps_bounds[2] = np.array([0.25, 10.])  #kernel length scale\n",
    "hps_bounds[3] = np.array([10, 700.])  #delta weight\n",
    "\n",
    "from loguru import logger\n",
    "logger.enable(\"fvgp\")\n",
    "\n",
    "init_hps = np.random.uniform(size = len(hps_bounds), low = hps_bounds[:,0], high = hps_bounds[:,1])\n",
    "init_hps = np.array([1.52415759,  3.75230001,  2.16698608, 76.85252806]) #found through training, disable for a fresh run\n",
    "\n",
    "\n",
    "st = time.time()\n",
    "my_gp2S = GPOptimizer(np.asarray(x_train[0:60000]).reshape(len(x_train[0:60000]), 28 * 28),y_train[0:60000],init_hyperparameters=init_hps, gp_kernel_function = mnist_kernel.kernelL1, \n",
    "                      gp2Scale = True, gp2Scale_batch_size = 15000, gp2Scale_dask_client = client, \n",
    "                      compute_device=\"gpu\", noise_variances=np.zeros(y_train.shape) + 0.0001, gp2Scale_linalg_mode=\"sparseMINRES\",\n",
    "                      )\n",
    "print(\"Likelihood: \", my_gp2S.log_likelihood())\n",
    "print(\"exec time: \",time.time() - st)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae32825c-cf68-4d51-b054-56b5cb57f6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def in_bounds(v,bounds):\n",
    "    if any(v<bounds[:,0]) or any(v>bounds[:,1]):\n",
    "        ia = list(np.where(v > bounds[:,1])[0])\n",
    "        ib = list(np.where(v < bounds[:,0])[0])\n",
    "        return False, ia + ib\n",
    "    return True, None\n",
    "\n",
    "\n",
    "def prior_function(theta,args):\n",
    "    bounds = args[\"bounds\"]\n",
    "    d = in_bounds(theta, bounds)\n",
    "    if d[0]:\n",
    "        #prior = 0. #+ #np.sum(np.log(pis[one_ampl_ind])) + np.sum(np.log(1.-pis[zero_ampl_ind]))\n",
    "        prior = -(theta[0]**2)/0.5 #+ #np.sum(np.log(pis[one_ampl_ind])) + np.sum(np.log(1.-pis[zero_ampl_ind]))\n",
    "        #print(\"PRIOR=\", prior, theta,flush = True)\n",
    "        return prior\n",
    "    else:\n",
    "        print(\"                    PRIOR eval out of bounds\", d[1], theta[d[1]], flush = True)\n",
    "        return -np.inf\n",
    "\n",
    "def proposal_distribution_normal(x0, hps, obj):\n",
    "    cov = obj.prop_args[\"prop_Sigma\"]\n",
    "    #print(cov)\n",
    "    proposal_hps = np.zeros((len(x0)))\n",
    "    proposal_hps = np.random.multivariate_normal(\n",
    "        mean = x0, cov = cov, size = 1).reshape(len(x0))\n",
    "    return proposal_hps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad340ebd",
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpcam.gpMCMC import gpMCMC\n",
    "logger.disable(\"fvgp\")\n",
    "def func(hps, args):\n",
    "    np.save(\"last_hps_backup\", hps)\n",
    "    result = my_gp2S.log_likelihood(hyperparameters=hps)\n",
    "    print(result ,\" @ \", hps)\n",
    "    return result\n",
    "\n",
    "def write_results(obj): \n",
    "    np.save(\"current_trace\", obj.trace)\n",
    "\n",
    "from gpcam.gpMCMC import ProposalDistribution\n",
    "lengthscale_ind = [i for i in range(0,4)]\n",
    "\n",
    "\n",
    "#initial proposal Sigma\n",
    "axis_std_lengthscale = (hps_bounds[lengthscale_ind, 1] - hps_bounds[lengthscale_ind,0])/100.\n",
    "init_s_ls = np.diag(axis_std_lengthscale**2)\n",
    "\n",
    "\n",
    "#normal proposal distr. for core, Wendland, bump positions, and radii\n",
    "pd1 = ProposalDistribution(lengthscale_ind, proposal_dist = proposal_distribution_normal,\n",
    "                        init_prop_Sigma = init_s_ls, adapt_callable=\"normal\", K=10, ID = \"core\")\n",
    "\n",
    "my_mcmc = gpMCMC(func, prior_function, [pd1], args={\"bounds\":hps_bounds})\n",
    "print(init_hps)\n",
    "mcmc_result = my_mcmc.run_mcmc(x0=init_hps, info=True, n_updates=100, run_in_every_iteration=write_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a48356d-d97d-4b9e-ac87-c811142da05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import sparse\n",
    "hps1 = np.median(mcmc_result[\"x\"][91:], axis=0)\n",
    "print(\"FINAL HYPERPARAMETERS\")\n",
    "print(hps1)\n",
    "my_gp2S.set_hyperparameters(hps1)\n",
    "print(my_gp2S.log_likelihood())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "477b9a8d-d964-4f34-9f0c-a0a3979e547a",
   "metadata": {},
   "outputs": [],
   "source": [
    "sparsity = float(my_gp2S.prior.K.nnz) / float(my_gp2S.prior.K.shape[0]**2) #would be 0 for full sparsity i.e. all zeros\n",
    "\n",
    "print(\"sparsity: \", sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b482b15-8260-4ca7-90cf-437529a23765",
   "metadata": {},
   "outputs": [],
   "source": [
    "#Brier score\n",
    "def brier_score_gp_regression(y_pred, y_true):\n",
    "    y_pred[y_pred < 0.] = 0.\n",
    "    y_pred[y_pred > 1.] = 1.\n",
    "    return np.mean((y_pred - y_true) ** 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b2603d-6831-4b0f-a065-429e3b9e5975",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = my_gp2S.posterior_mean(np.asarray(x_test).reshape(len(x_test), 28 * 28))[\"f(x)\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "979483ae-9bf2-4a33-a0ac-b491afe5bd67",
   "metadata": {},
   "outputs": [],
   "source": [
    "brier_score_gp_regression(mean, y_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4ba7444-5442-4a52-b410-bd88ecbc2c4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(np.arange(0, np.max(mean)), np.arange(0, np.max(mean)))\n",
    "plt.scatter(mean, y_test)\n",
    "plt.xlabel(\"prediction\")\n",
    "plt.ylabel(\"y test\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "MNISTenv",
   "language": "python",
   "name": "mnistenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
