{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a1ff22e8-98da-477f-86dc-29eca6225baf",
   "metadata": {},
   "source": [
    "# gp2Scale Climate Test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3473f81d-4935-4e9c-ab5c-7679e2f29cd8",
   "metadata": {},
   "source": [
    "Make sure you\n",
    "\n",
    "- make a new environment\n",
    "\n",
    "- activate it\n",
    "\n",
    "- pip install ipykernel\n",
    "\n",
    "- python3 -m ipykernel install --user --name env --display-name MyEnvironment\n",
    "\n",
    "- pip install everything_else\n",
    "\n",
    "- make sure the notebook uses the right kernel\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af1e0909-65dd-4c1f-8f13-e6b5aed09fc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from gpcam import GPOptimizer\n",
    "import matplotlib.pyplot as plt\n",
    "import random\n",
    "import climate_kernel\n",
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45f0be8d-8625-4974-9b03-afbc83b0049f",
   "metadata": {},
   "source": [
    "run this in the terminal on Perlmutter\n",
    "\n",
    "#./launch-dask-module.sh"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f6ad26f-6d0b-4282-a380-485780a559c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import dask\n",
    "from dask.distributed import Client\n",
    "import os\n",
    "import time\n",
    "\n",
    "scheduler_file = os.path.join(os.environ[\"SCRATCH\"], \"scheduler_file3d.json\")\n",
    "\n",
    "dask.config.config[\"distributed\"][\"dashboard\"][\"link\"] = \"{JUPYTERHUB_SERVICE_PREFIX}proxy/{host}:{port}/status\" \n",
    "\n",
    "while True:\n",
    "    time.sleep(2)\n",
    "    if os.path.isfile(scheduler_file):\n",
    "        print(\"file found\")\n",
    "        time.sleep(2)\n",
    "        client = Client(scheduler_file=scheduler_file)\n",
    "        break\n",
    "print(\"waiting for workers\")\n",
    "number_of_workers = 256\n",
    "client.wait_for_workers(number_of_workers)\n",
    "print(client)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39c501bf-bd24-4e5a-ab22-deea98f32503",
   "metadata": {},
   "outputs": [],
   "source": [
    "client"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b09030f-713e-4da0-a372-74f9966c1247",
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train = np.genfromtxt(\"./data/x_train_3dclimate.csv\", delimiter=\" \")\n",
    "y_train = np.genfromtxt(\"./data/y_train_3dclimate.csv\", delimiter=\" \")\n",
    "\n",
    "x_test = np.genfromtxt(\"./data/x_test_3dclimate.csv\", delimiter=\" \")\n",
    "y_test = np.genfromtxt(\"./data/y_test_3dclimate.csv\", delimiter=\" \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2755913",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"train shape: \", x_train.shape)\n",
    "print(\"test shape:  \", x_test.shape)\n",
    "print(np.min(x_train[:,0]), np.max(x_train[:,0]))\n",
    "print(np.min(x_train[:,1]), np.max(x_train[:,1]))\n",
    "print(np.min(x_train[:,2]), np.max(x_train[:,2]))\n",
    "print(\" \")\n",
    "print(np.min(y_train), np.max(y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "750b0c05-e775-4ee8-924c-07911276fe4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def T(D, tb, n,b):\n",
    "    return (D**2 * tb)/(2. * n * b**2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81b5c8d3-cd6c-4234-a2b3-9e9187741bc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "T(len(x_train), 2.2, number_of_workers, 15000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a669f550-e7c7-4dba-8d75-4793c0f859ea",
   "metadata": {},
   "source": [
    "### Wendland, no bumps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1ca48fd-444c-4aee-856e-a3317c2add73",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps_bounds = np.zeros((14,2))\n",
    "init_hyperparameters = np.zeros((14))\n",
    "\n",
    "\n",
    "hps_bounds[0:12] = np.array([-20,20])\n",
    "hps_bounds[12:14] = np.array([ ####new version\n",
    "                    [1.,10], #time length scale\n",
    "                    [1.,4.] #noise  \n",
    "                    ])\n",
    "\n",
    "\n",
    "init_hyperparameters = np.random.uniform(low = hps_bounds[:,0],\n",
    "                                         high= hps_bounds[:,1],\n",
    "                                         size = len(hps_bounds))\n",
    "init_hyperparameters[0:12] = 0.\n",
    "\n",
    "##This will need a large-RAM host node to complete. \n",
    "init_hyperparameters = np.array([ 4.88359018, -0.58821421,  1.41143743,  4.14469273, -0.96621457,\n",
    "        2.6177534 ,  2.25651645, -0.04128646, -3.05101759,  2.16218129,\n",
    "        0.20730839, -1.17847209,  5.8970784 ,  1.98090409]) #Found through training\n",
    "\n",
    "\n",
    "\n",
    "print(\"starting hyperparameters: \")\n",
    "print(init_hyperparameters)\n",
    "\n",
    "\n",
    "from loguru import logger\n",
    "logger.enable(\"fvgp\")\n",
    "\n",
    "my_gp = GPOptimizer(x_train,y_train,\n",
    "        init_hyperparameters = init_hyperparameters,\n",
    "        compute_device='cpu',\n",
    "        gp_kernel_function=climate_kernel.kernel,\n",
    "        gp_noise_function=climate_kernel.my_noise,\n",
    "        gp2Scale = True,\n",
    "        gp2Scale_dask_client = client,\n",
    "        gp2Scale_batch_size=15000,\n",
    "        gp2Scale_linalg_mode = \"sparseMINRES\",\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fdff1510-321c-4cd8-b9e5-ae6943c8422b",
   "metadata": {},
   "outputs": [],
   "source": [
    "st = time.time()\n",
    "print(\"Likelihood: \", my_gp.log_likelihood())\n",
    "print(\"exec time: \",time.time() - st)\n",
    "\n",
    "sparsity = float(my_gp.prior.K.nnz) / float(my_gp.prior.K.shape[0]**2) #would be 0 for full sparsity aka all zeros\n",
    "\n",
    "print(\"sparsity: \", sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae32825c-cf68-4d51-b054-56b5cb57f6ce",
   "metadata": {},
   "outputs": [],
   "source": [
    "def in_bounds(v,bounds):\n",
    "    if any(v<bounds[:,0]) or any(v>bounds[:,1]):\n",
    "        ia = list(np.where(v > bounds[:,1])[0])\n",
    "        ib = list(np.where(v < bounds[:,0])[0])\n",
    "        return False, ia + ib\n",
    "    return True, None\n",
    "\n",
    "\n",
    "def prior_function(theta,args):\n",
    "    bounds = args[\"bounds\"]\n",
    "    d = in_bounds(theta, bounds)\n",
    "    \n",
    "    l1 = climate_kernel.Lambda(torch.from_numpy(x1), lamda1_hps, torch.from_numpy(elev), torch.from_numpy(dist)).numpy()\n",
    "    l2 = climate_kernel.Lambda(torch.from_numpy(x1), lamda2_hps, torch.from_numpy(elev), torch.from_numpy(dist)).numpy()\n",
    "\n",
    "    if any(l1/l2 > 100.) or any(l1/l2 < 0.01):\n",
    "        print(\"PRIOR 0 --- L ratio\", flush = True)\n",
    "        return -np.inf\n",
    "\n",
    "    if any(l1 > 100.) or any(l2 > 100.):\n",
    "        print(\"PRIOR 0 --- L too large\", flush = True)\n",
    "        return -np.inf\n",
    "\n",
    "\n",
    "    \n",
    "    if d[0]:\n",
    "        prior = 0. #+ #np.sum(np.log(pis[one_ampl_ind])) + np.sum(np.log(1.-pis[zero_ampl_ind]))\n",
    "        #print(\"PRIOR=\", prior, theta,flush = True)\n",
    "        return prior\n",
    "    else:\n",
    "        print(\"                    PRIOR eval out of bounds\", d[1], theta[d[1]], flush = True)\n",
    "        return -np.inf\n",
    "\n",
    "def proposal_distribution_normal(x0, hps, obj):\n",
    "    cov = obj.prop_args[\"prop_Sigma\"]\n",
    "    #print(cov)\n",
    "    proposal_hps = np.zeros((len(x0)))\n",
    "    proposal_hps = np.random.multivariate_normal(\n",
    "        mean = x0, cov = cov, size = 1).reshape(len(x0))\n",
    "    return proposal_hps\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad340ebd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from gpcam.gpMCMC import gpMCMC\n",
    "\n",
    "def func(hps, args):\n",
    "    np.save(\"last_hps_backup\", hps)\n",
    "    result = my_gp.log_likelihood(hyperparameters=hps)\n",
    "    print(\"f(x): \", result, \" @ \", hps)\n",
    "    return result\n",
    "\n",
    "def write_results(obj): np.save(\"current_trace\", obj.trace)\n",
    "\n",
    "from gpcam.gpMCMC import ProposalDistribution\n",
    "lengthscale_ind = [i for i in range(0,14)]\n",
    "\n",
    "\n",
    "#initial proposal Sigma\n",
    "axis_std_lengthscale = (hps_bounds[lengthscale_ind, 1] - hps_bounds[lengthscale_ind,0])/100.\n",
    "init_s_ls = np.diag(axis_std_lengthscale**2)\n",
    "\n",
    "\n",
    "#normal proposal distr. for core, Wendland, bump positions, and radii\n",
    "pd1 = ProposalDistribution(lengthscale_ind, proposal_dist = proposal_distribution_normal,\n",
    "                        init_prop_Sigma = init_s_ls, adapt_callable=\"normal\", K=10, ID = \"core\")\n",
    "\n",
    "my_mcmc = gpMCMC(func, prior_function, [pd1], args={\"bounds\":hps_bounds})\n",
    "print(init_hyperparameters)\n",
    "mcmc_result = my_mcmc.run_mcmc(x0=init_hyperparameters, info=True, n_updates=1000, run_in_every_iteration=write_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a48356d-d97d-4b9e-ac87-c811142da05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy import sparse\n",
    "hps1 = np.median(mcmc_result[\"x\"][981:], axis=0)print(\"FINAL HYPERPARAMETERS\")\n",
    "\n",
    "\n",
    "my_gp2S.set_hyperparameters(hps1)\n",
    "print(my_gp2S.log_likelihood(my_gp2S.get_hyperparameters()))\n",
    "np.save(\"full_gp2ScaleHPS\", my_gp2S.get_hyperparameters())\n",
    "#np.save(\"mcmc_result\", mcmc_result)\n",
    "\n",
    "\n",
    "sparsity = float(my_gp2S.prior.K.nnz) / float(my_gp2S.prior.K.shape[0]**2) #would be 0 for full sparsity i.e. all zeros\n",
    "sparse.save_npz(\"sparse_matrix1Mill\",my_gp2S.prior.K)\n",
    "\n",
    "print(\"sparsity: \", sparsity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18b659f4-73db-4cdf-8fba-1f0989049020",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "\n",
    "list_rmse = []\n",
    "for i in range(0,len(x_test), 10):\n",
    "    print(i)\n",
    "    rmse = my_gp.rmse(x_test[i:i+10], y_test[i:i+10])\n",
    "    print(rmse)\n",
    "    list_rmse.append(rmse)\n",
    "    print(\"RMSE: \", np.mean(list_rmse))\n",
    "print(\"++++++++++++++++++++++++++\")\n",
    "print(\"|FINAL RMSE: \", np.mean(list_rmse))\n",
    "print(\"++++++++++++++++++++++++++\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b6365bf6-8368-4a43-a668-7ca24a9ae9ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This might take a long time\n",
    "crps_list = []\n",
    "step = 2\n",
    "for i in range(0,len(x_test), step):\n",
    "    print(i)\n",
    "    print(\"calculating \", i)\n",
    "    crps_list.append(my_gp2S.crps(x_test[i:i+step], y_test[i:i+step]))\n",
    "    print(\"RMSE: \", np.mean(crps_list))\n",
    "print(\"++++++++++++++++++++++++++\")\n",
    "print(\"|FINAL CRPS: \", np.mean(crps_list))\n",
    "print(\"++++++++++++++++++++++++++\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e1139ef8-a891-4801-a74c-783a76c3011e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.savetxt(\"gp2SCaleResultsTopo_nobumps.csv\", res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3903c3f3-e618-4e81-a1e6-8fd4797b98d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "hps = my_gp2S.hyperparameters\n",
    "print(hps)\n",
    "\n",
    "a = plt.scatter(x_train[:,0], x_train[:,1], c = y_train)\n",
    "plt.colorbar(a)\n",
    "plt.show()\n",
    "f = topo_kernelGPU.RBF_func_lambda(torch.from_numpy(x_train), torch.from_numpy(hps[0:2]))\n",
    "a = plt.scatter(x_train[:,0], x_train[:,1], c = f)\n",
    "plt.title(\"length scales 1\")\n",
    "plt.colorbar(a)\n",
    "plt.show()\n",
    "\n",
    "f = topo_kernelGPU.RBF_func_lambda(torch.from_numpy(x_train), torch.from_numpy(hps[2:4]))\n",
    "a = plt.scatter(x_train[:,0], x_train[:,1], c = f)\n",
    "plt.title(\"length scales 2\")\n",
    "plt.colorbar(a)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "f = topo_kernelGPU.RBF_func_gamma(torch.from_numpy(x_train), torch.from_numpy(hps[4:5]))\n",
    "a = plt.scatter(x_train[:,0], x_train[:,1], c = f)\n",
    "plt.title(\"angles\")\n",
    "plt.colorbar(a)\n",
    "plt.show()\n",
    "\n",
    "\n",
    "f = topo_kernelGPU.RBF_func_sigma(torch.from_numpy(x_train), torch.from_numpy(hps[5:7]))\n",
    "a = plt.scatter(x_train[:,0], x_train[:,1], c = f)\n",
    "plt.title(\"signal standard deviation\")\n",
    "plt.colorbar(a)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "3dfvgp",
   "language": "python",
   "name": "3dfvgp"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
