{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6474a485",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import numpy as np\n",
    "from pathlib import Path\n",
    "import matplotlib.pyplot as plt\n",
    "from simulators.ice_simulator.ice_plots import plot_posterior_nice\n",
    "from simulators.ice_simulator.modelling_utils import regrid\n",
    "import torch\n",
    "import pandas as pd\n",
    "from utils.misc import get_data_dir,get_output_dir,read_pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edbe092b",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "data_dir = get_data_dir()\n",
    "out_dir = get_output_dir()\n",
    "\n",
    "\n",
    "real_layer_idx = 1\n",
    "config_nx_iso = 500\n",
    "\n",
    "\n",
    "# Load data\n",
    "\n",
    "setup_df = pd.read_csv(Path(data_dir / \"sbi_ice\" / \"setup_file.csv\"))\n",
    "x_coords = setup_df[\"x_coord\"].to_numpy()\n",
    "surface = setup_df[\"surface\"].to_numpy()\n",
    "tmb = setup_df[\"tmb\"].to_numpy()\n",
    "base = setup_df[\"base\"].to_numpy()\n",
    "\n",
    "print(\"x_coords\", x_coords.shape)\n",
    "print(\"surface\", surface.shape)\n",
    "layer_bounds = read_pickle(Path(data_dir / \"sbi_ice\" / \"layer_bounds.pkl\"))\n",
    "sim_grid = np.linspace(x_coords[0], x_coords[-1], config_nx_iso)\n",
    "torch_sim_grid = torch.from_numpy(sim_grid).to(device).to(torch.float32)\n",
    "\n",
    "\n",
    "masks = [sim_grid > bound for bound in layer_bounds]\n",
    "masks = np.array(masks)\n",
    "mask = masks[1]\n",
    "\n",
    "layer_mask = torch.from_numpy(mask)\n",
    "layer_sparsity = 1\n",
    "smb_sparsity = 1\n",
    "layer_mask_slice = torch.zeros(config_nx_iso, dtype=bool)  # pyright: ignore\n",
    "layer_mask_slice[::layer_sparsity] = 1\n",
    "\n",
    "layer_all_mask = layer_mask * layer_mask_slice\n",
    "\n",
    "# Load real layer\n",
    "layers_df = pd.read_csv(Path(data_dir / \"sbi_ice\" / \"real_layers.csv\"))\n",
    "n_real_layers = len(layers_df.columns) - 2\n",
    "real_layers = np.zeros(shape=(n_real_layers, 500))\n",
    "# Regrid the real layers to the simulation grid (e.g. real data is defined on a different grid)\n",
    "for i in range(n_real_layers):\n",
    "    layer_depths = regrid(\n",
    "        layers_df[\"x_coord\"],\n",
    "        layers_df[\"layer {}\".format(i + 1)],\n",
    "        sim_grid,\n",
    "        kind=\"linear\",\n",
    "    )\n",
    "    real_layers[i, :] = surface - layer_depths\n",
    "true_layer_unmasked = torch.tensor(real_layers[1]).float()\n",
    "true_layer = torch.tensor(real_layers[real_layer_idx][layer_all_mask]).float()\n",
    "\n",
    "\n",
    "smb_mask_slice = torch.zeros(config_nx_iso, dtype=bool)  # pyright: ignore\n",
    "smb_mask_slice[::smb_sparsity] = 1\n",
    "smb_mask = smb_mask_slice  # infer all smb values\n",
    "smb_x = sim_grid[smb_mask]\n",
    "layer_x = sim_grid[layer_all_mask]\n",
    "\n",
    "\n",
    "print(layer_all_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28affdc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictive_summary_path = out_dir/\"ice_experiment/FNOPE_real_layers_predictive_simulations_summary.pkl\"\n",
    "\n",
    "predictive_summary = read_pickle(predictive_summary_path)\n",
    "print(predictive_summary.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f440452d",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictive_layers = predictive_summary[\"best_layers\"]\n",
    "predictive_bmbs = predictive_summary[\"bmbs\"][:, :, 0]\n",
    "print(predictive_summary[\"bmbs\"].shape)\n",
    "print(tmb.shape)\n",
    "predictive_smbs = -predictive_bmbs + torch.from_numpy(tmb).unsqueeze(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c054f566",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(true_layer.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca0a155d",
   "metadata": {},
   "outputs": [],
   "source": [
    "with plt.rc_context(fname=\"matplotlibrc\"):\n",
    "\n",
    "    plot_posterior_nice(\n",
    "        x=x_coords,\n",
    "        mb_mask=smb_mask,\n",
    "        tmb=tmb,\n",
    "        posterior_smb_samples=predictive_smbs,\n",
    "        LMI_boundary=x_coords[layer_all_mask][0],\n",
    "        layer_mask=layer_all_mask,\n",
    "        posterior_layer_samples=predictive_layers.cpu().numpy()[:, layer_all_mask],\n",
    "        posterior_layer_ages=predictive_summary[\"ages\"],\n",
    "        true_layer=true_layer_unmasked.cpu().numpy(),\n",
    "        shelf_base=base,\n",
    "        shelf_surface=surface,\n",
    "        plot_only_predictive=True,\n",
    "        figsize=(3.5, 1.3),\n",
    "    )\n",
    "\n",
    "    plt.savefig(\n",
    "        \"ice_plots/post_predictives.svg\",\n",
    "        bbox_inches=\"tight\",\n",
    "        dpi=300,\n",
    "    )\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f137aed",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fourier_nets",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
