{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1ff22e8-98da-477f-86dc-29eca6225baf",
   "metadata": {},
   "source": [
    "# gp2Scale III -- CA Housing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3473f81d-4935-4e9c-ab5c-7679e2f29cd8",
   "metadata": {},
   "source": [
    "Make sure you\n",
    "\n",
    "- make a new environment\n",
    "\n",
    "- activate it\n",
    "\n",
    "- pip install ipykernel\n",
    "\n",
    "- python3 -m ipykernel install --user --name env --display-name MyEnvironment\n",
    "\n",
    "- pip install everything_else\n",
    "\n",
    "- make sure the notebook uses the right kernel\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af1e0909-65dd-4c1f-8f13-e6b5aed09fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gpcam import GPOptimizer\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import housing_kernel\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45f0be8d-8625-4974-9b03-afbc83b0049f",
   "metadata": {},
   "source": [
    "run this in the terminal on Perlmutter\n",
    "\n",
    "#./launch-dask-module.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6ad26f-6d0b-4282-a380-485780a559c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask\n",
    "from dask.distributed import Client\n",
    "import os\n",
    "import time\n",
    "\n",
    "scheduler_file = os.path.join(os.environ[\"SCRATCH\"], \"scheduler_fileCA.json\")\n",
    "\n",
    "dask.config.config[\"distributed\"][\"dashboard\"][\"link\"] = \"{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status\" \n",
    "\n",
    "while True:\n",
    "    time.sleep(2)\n",
    "    if os.path.isfile(scheduler_file):\n",
    "        print(\"file found\")\n",
    "        time.sleep(1)\n",
    "        client = Client(scheduler_file=scheduler_file)\n",
    "        break\n",
    "print(\"waiting for workers\")\n",
    "client.wait_for_workers(16)\n",
    "print(client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c501bf-bd24-4e5a-ab22-deea98f32503",
   "metadata": {},
   "outputs": [],
   "source": [
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b09030f-713e-4da0-a372-74f9966c1247",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = np.genfromtxt(\"./data/x_train_CAhousing.csv\", delimiter=\" \")\n",
    "y_train = np.genfromtxt(\"./data/y_train_CAhousing.csv\", delimiter=\" \")\n",
    "\n",
    "x_test = np.genfromtxt(\"./data/x_test_CAhousing.csv\", delimiter=\" \")\n",
    "y_test = np.genfromtxt(\"./data/y_test_CAhousing.csv\", delimiter=\" \")\n",
    "\n",
    "print(len(x_train))\n",
    "print(len(x_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "369b540a-d0ed-4473-8941-b3210bf8da51",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(np.min(x_train[:,0]), np.max(x_train[:,0]))\n",
    "print(np.min(x_train[:,1]), np.max(x_train[:,1]))\n",
    "print(np.min(x_train[:,2]), np.max(x_train[:,2]))\n",
    "print(np.min(x_train[:,3]), np.max(x_train[:,3]))\n",
    "print(np.min(x_train[:,4]), np.max(x_train[:,4]))\n",
    "print(np.min(x_train[:,5]), np.max(x_train[:,5]))\n",
    "print(np.min(x_train[:,6]), np.max(x_train[:,6]))\n",
    "print(np.min(x_train[:,7]), np.max(x_train[:,7]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ca48fd-444c-4aee-856e-a3317c2add73",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "hps_bounds = np.array([\n",
    "                    #[0.0001,10], ##offset\n",
    "                    #[-2.,2.], ##slope\n",
    "                    #[0.0001,10], ##offset\n",
    "                    #[-2.,2.], ##slope\n",
    "                    [0.1, 50.], ##constant signal var RBF\n",
    "                    [0.1,50.], #Wendland length scale\n",
    "                    [0.1,200.],\n",
    "                    [1.,140.],\n",
    "                    [0.1,36.],\n",
    "                    [100.,20000.],\n",
    "                    [1.,500.],\n",
    "                    [0.05,2.],\n",
    "                    [0.05,2.]\n",
    "                    ])  \n",
    "\n",
    "init_hps = np.random.uniform(size = len(hps_bounds), low = hps_bounds[:,0], high = hps_bounds[:,1])\n",
    "\n",
    "#found through training (disable for fresh run): \n",
    "init_hps = np.array([1.39464662e+01, 3.52326980e+01, 1.20312766e+02, 9.88228186e+01,\n",
    "       3.59970585e+01, 1.67134131e+04, 1.91118884e+02, 1.92620327e+00,\n",
    "       6.62251837e-01])\n",
    "\n",
    "from loguru import logger\n",
    "logger.disable(\"fvgp\")\n",
    "my_gp2S = GPOptimizer(x_train,y_train,init_hyperparameters=init_hps, \n",
    "                      #gp_kernel_function = housing_kernel.kernel,\n",
    "                      gp2Scale = True, gp2Scale_batch_size = 4000, gp2Scale_dask_client = client, \n",
    "                      compute_device=\"gpu\", noise_variances=np.zeros(y_train.shape) + 0.0001, gp2Scale_linalg_mode=\"sparseMINRES\"\n",
    "                      )\n",
    "print(\"Likelihood: \", my_gp2S.log_likelihood())\n",
    "\n",
    "\n",
    "my_gp2S.train(hyperparameter_bounds=hps_bounds, method=\"mcmc\", info = True, max_iter=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdff1510-321c-4cd8-b9e5-ae6943c8422b",
   "metadata": {},
   "outputs": [],
   "source": [
    "st = time.time()\n",
    "print(\"Likelihood: \", my_gp2S.log_likelihood(hyperparameters = init_hps))\n",
    "print(\"exec time: \",time.time() - st)\n",
    "\n",
    "sparsity = float(my_gp2S.prior.K.nnz) / float(my_gp2S.prior.K.shape[0]**2) #would be 0 for full sparsity aka all zeros\n",
    "\n",
    "print(\"sparsity: \", sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acef7ede-f79e-48f8-a9c8-dc421f67c0e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "rmse = my_gp2S.rmse(x_test, y_test)\n",
    "print(rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9da69b7-f75b-453d-99f4-b444c536fff5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "crps = my_gp2S.crps(x_test, y_test) #takes 4 hours\n",
    "print(crps)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "65050794-711e-4d4b-aa7a-9672d1225f7f",
   "metadata": {},
   "source": [
    "## earlier test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae32825c-cf68-4d51-b054-56b5cb57f6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def in_bounds(v,bounds):\n",
    "    if any(v<bounds[:,0]) or any(v>bounds[:,1]):\n",
    "        ia = list(np.where(v > bounds[:,1])[0])\n",
    "        ib = list(np.where(v < bounds[:,0])[0])\n",
    "        return False, ia + ib\n",
    "    return True, None\n",
    "\n",
    "\n",
    "def prior_function(theta,args):\n",
    "    bounds = args[\"bounds\"]\n",
    "    d = in_bounds(theta, bounds)\n",
    "    if d[0]:\n",
    "        prior = 0. #+ #np.sum(np.log(pis[one_ampl_ind])) + np.sum(np.log(1.-pis[zero_ampl_ind]))\n",
    "        #print(\"PRIOR=\", prior, theta,flush = True)\n",
    "        return prior\n",
    "    else:\n",
    "        print(\"                 PRIOR eval out of bounds\", d[1], theta[d[1]], flush = True)\n",
    "        return -np.inf\n",
    "\n",
    "def proposal_distribution_normal(x0, hps, obj):\n",
    "    cov = obj.prop_args[\"prop_Sigma\"]\n",
    "    #print(cov)\n",
    "    proposal_hps = np.zeros((len(x0)))\n",
    "    proposal_hps = np.random.multivariate_normal(\n",
    "        mean = x0, cov = cov, size = 1).reshape(len(x0))\n",
    "    return proposal_hps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad340ebd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from gpcam.gpMCMC import gpMCMC\n",
    "\n",
    "def func(hps, args):\n",
    "    np.save(\"last_hps_backup\", hps)\n",
    "    result = my_gp2S.log_likelihood(hyperparameters=hps)\n",
    "    if isinstance(result, np.ndarray): result = result.item()\n",
    "    print(\"                   f(x): \", result)\n",
    "    return result\n",
    "\n",
    "def write_results(obj): np.save(\"current_trace\", obj.trace)\n",
    "\n",
    "from gpcam.gpMCMC import ProposalDistribution\n",
    "\n",
    "core_ind = [i for i in range(0,28)]\n",
    "\n",
    "axis_std_core = (hps_bounds[core_ind, 1] - hps_bounds[core_ind,0])/10.\n",
    "init_s_core = np.diag(axis_std_core**2)\n",
    "\n",
    "\n",
    "\n",
    "pd1 = ProposalDistribution(core_ind, proposal_dist = proposal_distribution_normal,\n",
    "                        init_prop_Sigma = init_s_core, adapt_callable=\"normal\", K=10, ID = \"core\")\n",
    "\n",
    "my_mcmc = gpMCMC(func, prior_function, [pd1], args={\"bounds\":hps_bounds})\n",
    "print(init_hps)\n",
    "mcmc_result = my_mcmc.run_mcmc(x0=my_gp2S.hyperparameters, info=True, n_updates=1000, run_in_every_iteration=write_results)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78559e6d-6a0f-466d-90bf-77bd4f30d6ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_trace = np.load(\"current_trace.npy\", allow_pickle=True).item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e23fe582-66e0-43e2-8449-ed2a64cc1240",
   "metadata": {},
   "outputs": [],
   "source": [
    "current_trace[\"x\"][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33a96ad1-5cc4-4a6b-9132-f7b194bd4bb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = current_trace[\"x\"][-1]\n",
    "print(hps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a48356d-d97d-4b9e-ac87-c811142da05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import sparse\n",
    "\n",
    "\n",
    "print(\"FINAL HYPERPARAMETERS\")\n",
    "print(hps1)\n",
    "\n",
    "\n",
    "my_gp2S.set_hyperparameters(hps1)\n",
    "print(my_gp2S.log_likelihood(my_gp2S.get_hyperparameters()))\n",
    "np.save(\"full_gp2ScaleHPS\", my_gp2S.get_hyperparameters())\n",
    "#np.save(\"mcmc_result\", mcmc_result)\n",
    "\n",
    "\n",
    "sparsity = float(my_gp2S.prior.K.nnz) / float(my_gp2S.prior.K.shape[0]**2) #would be 0 for full sparsity i.e. all zeros\n",
    "#sparse.save_npz(\"sparse_matrix1Mill\",my_gp2S.prior.K)\n",
    "\n",
    "print(\"sparsity: \", sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3a9266e-b483-43d2-945c-9ef9eeb2a8e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "rmse = my_gp2S.rmse(x_test, y_test)\n",
    "print(rmse)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa600ab-c9b8-4fc2-b848-63a65be129d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "mean = my_gp2S.posterior_mean(x_test)\n",
    "np.save(\"y_predicted\", mean)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7053159-d1ed-42eb-8945-4bd7c3097a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "crps = my_gp2S.crps(x_test, y_test)\n",
    "print(crps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88d06c1c-bdb9-405f-9bc7-21f6b495a6b0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "CAenv",
   "language": "python",
   "name": "caenv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
