{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Metrics\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch\n",
    "from matplotlib.transforms import Bbox\n",
    "# import and set up the typeguard\n",
    "from typeguard.importhook import install_import_hook\n",
    "\n",
    "from src.nn import ConditionalGenerativeModel, createGenerativeFCNN, InputTargetDataset, \\\n",
    "    UNet2D, DiscardWindowSizeDim, get_predictions_and_target, createGenerativeGRUNN, DiscardNumberGenerationsInOutput, \\\n",
    "    createGRUNN, createFCNN\n",
    "from src.scoring_rules import EnergyScore, KernelScore, VariogramScore, PatchedScoringRule, estimate_score_chunks\n",
    "from src.utils import load_net, estimate_bandwidth_timeseries, lorenz96_variogram, def_loader_kwargs, \\\n",
    "    weatherbench_variogram_haversine\n",
    "from src.parsers import parser_predict, define_masks, nonlinearities_dict, setup\n",
    "from src.calibration import calibration_error, R2, rmse, plot_metrics_params\n",
    "from src.weatherbench_utils import load_weatherbench_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "<>:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
      "c:\\Users\\arche\\Archer-4th-Year-Diss-shreya\\Archer-4th-Year-Diss-shreya\\GenerativeNetworksScoringRulesProbabilisticForecasting\\src\\utils.py:48: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  net.load_state_dict(torch.load(path, map_location=torch.device(\"cpu\")))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "rnn\n",
      "Test SR network for lorenz model using SignatureKernel scoring rule\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
      "C:\\Users\\arche\\AppData\\Local\\Temp\\ipykernel_58736\\3755770007.py:241: SyntaxWarning: invalid escape sequence '\\p'\n",
      "  string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Input and parameter tensors are not at the same device, found input tensor at cpu and parameter tensor at cuda:0",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[5], line 154\u001b[0m\n\u001b[0;32m    152\u001b[0m         target_data_test \u001b[38;5;241m=\u001b[39m target_data_test\u001b[38;5;241m.\u001b[39mflatten(\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m    153\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m--> 154\u001b[0m         predictions_val \u001b[38;5;241m=\u001b[39m \u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43minput_data_val\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# shape (n_val, ensemble_size, data_size)\u001b[39;00m\n\u001b[0;32m    155\u001b[0m         predictions_test \u001b[38;5;241m=\u001b[39m net(input_data_test)  \u001b[38;5;66;03m# shape (n_test, ensemble_size, data_size)\u001b[39;00m\n\u001b[0;32m    157\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m method \u001b[38;5;241m!=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mregression\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[0;32m    158\u001b[0m     \u001b[38;5;66;03m# --- scoring rules ---\u001b[39;00m\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32mc:\\Users\\arche\\Archer-4th-Year-Diss-shreya\\Archer-4th-Year-Diss-shreya\\GenerativeNetworksScoringRulesProbabilisticForecasting\\src\\nn.py:731\u001b[0m, in \u001b[0;36mConditionalGenerativeModel.forward\u001b[1;34m(self, context, number_generations)\u001b[0m\n\u001b[0;32m    725\u001b[0m z \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdistribution\u001b[38;5;241m.\u001b[39msample(torch\u001b[38;5;241m.\u001b[39mSize([batch_size, number_generations]) \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msize_auxiliary_variable)\u001b[38;5;241m.\u001b[39mto(\n\u001b[0;32m    726\u001b[0m     device\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcuda\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mnext\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparameters())\u001b[38;5;241m.\u001b[39mis_cuda \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m)\u001b[38;5;241m.\u001b[39msqueeze(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m    728\u001b[0m \u001b[38;5;66;03m#Add prediction length here\u001b[39;00m\n\u001b[0;32m    729\u001b[0m \u001b[38;5;66;03m#\u001b[39;00m\n\u001b[1;32m--> 731\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnet\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mz\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32mc:\\Users\\arche\\Archer-4th-Year-Diss-shreya\\Archer-4th-Year-Diss-shreya\\GenerativeNetworksScoringRulesProbabilisticForecasting\\src\\nn.py:343\u001b[0m, in \u001b[0;36mcreateGenerativeGRUNN.<locals>.GenerativeGRUNN.forward\u001b[1;34m(self, context, z)\u001b[0m\n\u001b[0;32m    340\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, context: TensorType[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mwindow_size\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdata_size\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[0;32m    341\u001b[0m             z: TensorType[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumber_generations\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msize_auxiliary_variable\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \\\n\u001b[0;32m    342\u001b[0m         \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m TensorType[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbatch_size\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnumber_generations\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moutput_size\u001b[39m\u001b[38;5;124m\"\u001b[39m]:\n\u001b[1;32m--> 343\u001b[0m     gru_out, _ \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgru\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcontext\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m    345\u001b[0m     \u001b[38;5;66;03m# this has shape [batch_size, window_size, gru_hidden_size]; take only the last temporal element:\u001b[39;00m\n\u001b[0;32m    346\u001b[0m     gru_out \u001b[38;5;241m=\u001b[39m gru_out[:, \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m, :]\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m1\u001b[39m)\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1736\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1734\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)  \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m   1735\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1736\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\module.py:1747\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m   1742\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m   1743\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m   1744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m   1745\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m   1746\u001b[0m         \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1747\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1749\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m   1750\u001b[0m called_always_called_hooks \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mset\u001b[39m()\n",
      "File \u001b[1;32mc:\\Users\\arche\\anaconda3\\envs\\Diss\\Lib\\site-packages\\torch\\nn\\modules\\rnn.py:1392\u001b[0m, in \u001b[0;36mGRU.forward\u001b[1;34m(self, input, hx)\u001b[0m\n\u001b[0;32m   1390\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcheck_forward_args(\u001b[38;5;28minput\u001b[39m, hx, batch_sizes)\n\u001b[0;32m   1391\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m batch_sizes \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m-> 1392\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43m_VF\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgru\u001b[49m\u001b[43m(\u001b[49m\n\u001b[0;32m   1393\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1394\u001b[0m \u001b[43m        \u001b[49m\u001b[43mhx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1395\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_flat_weights\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1396\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbias\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1397\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnum_layers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1398\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdropout\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1399\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1400\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbidirectional\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1401\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_first\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m   1402\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m   1403\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m   1404\u001b[0m     result \u001b[38;5;241m=\u001b[39m _VF\u001b[38;5;241m.\u001b[39mgru(\n\u001b[0;32m   1405\u001b[0m         \u001b[38;5;28minput\u001b[39m,\n\u001b[0;32m   1406\u001b[0m         batch_sizes,\n\u001b[1;32m   (...)\u001b[0m\n\u001b[0;32m   1413\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbidirectional,\n\u001b[0;32m   1414\u001b[0m     )\n",
      "\u001b[1;31mRuntimeError\u001b[0m: Input and parameter tensors are not at the same device, found input tensor at cpu and parameter tensor at cuda:0"
     ]
    }
   ],
   "source": [
    "###############################\n",
    "# HARDCODED CONFIGURATION\n",
    "###############################\n",
    "# Example values, adapt to your needs:\n",
    "model = 'lorenz'\n",
    "method = 'SR'\n",
    "scoring_rule = 'SignatureKernel'\n",
    "kernel = 'RBFtest'  ##??\n",
    "patched = False\n",
    "base_measure = 'normal'\n",
    "root_folder = 'results'         # Where results are stored\n",
    "model_folder = 'nets'           # Subfolder for models\n",
    "datasets_folder = 'results/lorenz/datasets/'\n",
    "nets_folder = 'results/nets/'\n",
    "weatherbench_data_folder = None\n",
    "weatherbench_small = False\n",
    "name_postfix = '_mytrainedmodelkernelrbf' ##Change this\n",
    "\n",
    "unet_noise_method = 'dropout'  # or 'concat', etc., if relevant\n",
    "unet_large = False\n",
    "\n",
    "lr = 0.1\n",
    "lr_c = 0.0\n",
    "batch_size = 10\n",
    "no_early_stop = True\n",
    "critic_steps_every_generator_step = 1\n",
    "\n",
    "save_plots = True\n",
    "cuda = True\n",
    "load_all_data_GPU = False\n",
    "\n",
    "training_ensemble_size = 3\n",
    "prediction_ensemble_size = 3\n",
    "nonlinearity = 'leaky_relu'\n",
    "data_size = 1               # For Lorenz63, typically data_size=1 or 3\n",
    "auxiliary_var_size = 1\n",
    "seed = 0\n",
    "\n",
    "plot_start_timestep = 0\n",
    "plot_end_timestep = 100\n",
    "\n",
    "gamma = None\n",
    "gamma_patched = None\n",
    "patch_size = 16\n",
    "no_RNN = False\n",
    "hidden_size_rnn = 8\n",
    "\n",
    "save_pdf = True\n",
    "\n",
    "save_pdf = True\n",
    "\n",
    "compute_patched = model in [\"lorenz96\", ]\n",
    "\n",
    "model_is_weatherbench = model == \"WeatherBench\"\n",
    "\n",
    "nn_model = \"unet\" if model_is_weatherbench else (\"fcnn\" if no_RNN else \"rnn\")\n",
    "print(nn_model)\n",
    "\n",
    "method_is_gan = False\n",
    "\n",
    "\n",
    "\n",
    "# datasets_folder, nets_folder, data_size, auxiliary_var_size, name_postfix, unet_depths, patch_size, method_is_gan, hidden_size_rnn = \\\n",
    "#     setup(model, root_folder, model_folder, datasets_folder, data_size, method, scoring_rule, kernel, patched,\n",
    "#           patch_size, training_ensemble_size, auxiliary_var_size, critic_steps_every_generator_step, base_measure, lr,\n",
    "#           lr_c, batch_size, no_early_stop, unet_noise_method, unet_large, nn_model, hidden_size_rnn)\n",
    "\n",
    "model_name_for_plot = {\"lorenz\": \"Lorenz63\",\n",
    "                       \"lorenz96\": \"Lorenz96\",\n",
    "                       \"WeatherBench\": \"WeatherBench\"}[model]\n",
    "\n",
    "string = f\"Test {method} network for {model} model\"\n",
    "if not method_is_gan:\n",
    "    string += f\" using {scoring_rule} scoring rule\"\n",
    "print(string)\n",
    "\n",
    "# --- data handling ---\n",
    "if not model_is_weatherbench:\n",
    "    input_data_test = torch.load(datasets_folder + \"test_x.pty\", weights_only=True)\n",
    "    target_data_test = torch.load(datasets_folder + \"test_y.pty\",weights_only=True)\n",
    "    input_data_val = torch.load(datasets_folder + \"val_x.pty\",weights_only=True)\n",
    "    target_data_val = torch.load(datasets_folder + \"val_y.pty\",weights_only=True)\n",
    "\n",
    "    window_size = input_data_test.shape[1]\n",
    "\n",
    "    # create the test loaders; these are unused for the moment.\n",
    "    dataset_val = InputTargetDataset(input_data_val, target_data_val, \"cuda\" if cuda and load_all_data_GPU else \"cpu\")\n",
    "    dataset_test = InputTargetDataset(input_data_test, target_data_test,\n",
    "                                      \"cuda\" if cuda and load_all_data_GPU else \"cpu\")\n",
    "\n",
    "loader_kwargs = def_loader_kwargs(cuda, load_all_data_GPU)\n",
    "\n",
    "# loader_kwargs.update(loader_kwargs_2)  # if you want to add other loader arguments\n",
    "\n",
    "data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=batch_size, shuffle=False, **loader_kwargs)\n",
    "data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False, **loader_kwargs)\n",
    "\n",
    "wrap_net = True\n",
    "# create generative net:\n",
    "if nn_model == \"fcnn\":\n",
    "    input_size = window_size * data_size + auxiliary_var_size\n",
    "    output_size = data_size\n",
    "    hidden_sizes_list = [int(input_size * 1.5), int(input_size * 3), int(input_size * 3),\n",
    "                        int(input_size * 0.75 + output_size * 3), int(output_size * 5)]\n",
    "    inner_net = createGenerativeFCNN(input_size=input_size, output_size=output_size, hidden_sizes=hidden_sizes_list,\n",
    "                                    nonlinearity=nonlinearities_dict[nonlinearity])()\n",
    "elif nn_model == \"rnn\":\n",
    "    output_size = data_size\n",
    "    gru_layers = 1\n",
    "    gru_hidden_size = hidden_size_rnn\n",
    "    inner_net = createGenerativeGRUNN(data_size=data_size, gru_hidden_size=gru_hidden_size,\n",
    "                                    noise_size=auxiliary_var_size,\n",
    "                                    output_size=output_size, hidden_sizes=None, gru_layers=gru_layers,\n",
    "                                    nonlinearity=nonlinearities_dict[nonlinearity])()\n",
    "elif nn_model == \"unet\":\n",
    "    # select the noise method here:\n",
    "    inner_net = UNet2D(in_channels=data_size[0], out_channels=1, noise_method=unet_noise_method,\n",
    "                    number_generations_per_forward_call=prediction_ensemble_size, conv_depths=unet_depths)\n",
    "    if unet_noise_method in [\"sum\", \"concat\"]:\n",
    "        # here we overwrite the auxiliary_var_size above, as there is a precise constraint\n",
    "        downsampling_factor, n_channels = inner_net.calculate_downsampling_factor()\n",
    "        if weatherbench_small:\n",
    "            auxiliary_var_size = torch.Size(\n",
    "                [n_channels, 16 // downsampling_factor, 16 // downsampling_factor])\n",
    "        else:\n",
    "            auxiliary_var_size = torch.Size(\n",
    "                [n_channels, data_size[1] // downsampling_factor, data_size[2] // downsampling_factor])\n",
    "    elif unet_noise_method == \"dropout\":\n",
    "        wrap_net = False  # do not wrap in the conditional generative model\n",
    "if wrap_net:\n",
    "    net = load_net(nets_folder + f\"net{name_postfix}.pth\", ConditionalGenerativeModel, inner_net,\n",
    "                size_auxiliary_variable=auxiliary_var_size, base_measure=base_measure,\n",
    "                number_generations_per_forward_call=prediction_ensemble_size, seed=seed + 1)\n",
    "else:\n",
    "    net = load_net(nets_folder + f\"net{name_postfix}.pth\", DiscardWindowSizeDim, inner_net)\n",
    "\n",
    "if cuda:\n",
    "    net.cuda()\n",
    "\n",
    "# --- predictions ---\n",
    "# predict all the different elements of the test set and create plots.\n",
    "# can directly feed through the whole test set for now; if it does not work well then, I will batch it.\n",
    "with torch.no_grad():\n",
    "    if model_is_weatherbench:\n",
    "        # shape (n_val, ensemble_size, lon, lat, n_fields)\n",
    "        predictions_val, target_data_val = get_predictions_and_target(data_loader_val, net, cuda)\n",
    "        predictions_test, target_data_test = get_predictions_and_target(data_loader_test, net, cuda)\n",
    "        # _map is with the original shape. The following instead is flattened:\n",
    "        predictions_val = predictions_val.flatten(2, -1)\n",
    "        target_data_val = target_data_val.flatten(1, -1)\n",
    "        predictions_test = predictions_test.flatten(2, -1)\n",
    "        target_data_test = target_data_test.flatten(1, -1)\n",
    "    else:\n",
    "        predictions_val = net(input_data_val)  # shape (n_val, ensemble_size, data_size)\n",
    "        predictions_test = net(input_data_test)  # shape (n_test, ensemble_size, data_size)\n",
    "\n",
    "if method != \"regression\":\n",
    "    # --- scoring rules ---\n",
    "    if compute_patched:\n",
    "        # mask for patched SRs:\n",
    "        masks = define_masks[model](data_size=data_size)\n",
    "\n",
    "    if gamma is None:\n",
    "        print(\"Compute gamma...\")\n",
    "        gamma = estimate_bandwidth_timeseries(target_data_val, return_values=[\"median\"])\n",
    "        print(f\"Estimated gamma: {gamma:.4f}\")\n",
    "    if gamma_patched is None and compute_patched:\n",
    "        # determine the gamma using the first patch only. This assumes that the values of the variables\n",
    "        # are roughly the same in the different patches.\n",
    "        gamma_patched = estimate_bandwidth_timeseries(target_data_val[:, masks[0]], return_values=[\"median\"])\n",
    "        print(f\"Estimated gamma patched: {gamma_patched:.4f}\")\n",
    "\n",
    "    # instantiate SRs; each SR takes as input: (net_output, target)\n",
    "    kernel_gaussian_sr = KernelScore(sigma=gamma)\n",
    "    kernel_rat_quad_sr = KernelScore(kernel=\"rational_quadratic\", alpha=gamma ** 2)\n",
    "    energy_sr = EnergyScore()\n",
    "\n",
    "    variogram = None\n",
    "    if model in [\"lorenz96\", ]:\n",
    "        variogram = lorenz96_variogram(data_size)\n",
    "    elif model == \"WeatherBench\":\n",
    "        # variogram = weatherbench_variogram(weatherbench_small=weatherbench_small)\n",
    "        variogram = weatherbench_variogram_haversine(weatherbench_small=weatherbench_small)\n",
    "    if variogram is not None and cuda:\n",
    "        variogram = variogram.cuda()\n",
    "\n",
    "    variogram_sr = VariogramScore(variogram=variogram)\n",
    "\n",
    "    if compute_patched:\n",
    "        # patched SRs:\n",
    "        kernel_gaussian_sr_patched = PatchedScoringRule(KernelScore(sigma=gamma_patched), masks)\n",
    "        kernel_rat_quad_sr_patched = PatchedScoringRule(\n",
    "            KernelScore(kernel=\"rational_quadratic\", alpha=gamma_patched ** 2),\n",
    "            masks)\n",
    "        energy_sr_patched = PatchedScoringRule(energy_sr, masks)\n",
    "\n",
    "    # -- out of sample score --\n",
    "    with torch.no_grad():\n",
    "        string = \"\"\n",
    "        for name, predictions, target in zip([\"VALIDATION\", \"TEST\"], [predictions_val, predictions_test],\n",
    "                                             [target_data_val, target_data_test]):\n",
    "            string += name + \"\\n\"\n",
    "            kernel_gaussian_score = estimate_score_chunks(kernel_gaussian_sr, predictions, target)\n",
    "            kernel_rat_quad_score = estimate_score_chunks(kernel_rat_quad_sr, predictions, target)\n",
    "            energy_score = estimate_score_chunks(energy_sr, predictions, target)\n",
    "            variogram_score = estimate_score_chunks(variogram_sr, predictions, target, chunk_size=8)\n",
    "\n",
    "            string += f\"Whole data scores: \\nEnergy score: {energy_score:.2f}, \" \\\n",
    "                      f\"Gaussian Kernel score {kernel_gaussian_score:.2f},\" \\\n",
    "                      f\" Rational quadratic Kernel score {kernel_rat_quad_score:.2f}, \" \\\n",
    "                      f\"Variogram score {variogram_score:.2f}\\n\"\n",
    "\n",
    "            if compute_patched:\n",
    "                kernel_gaussian_score_patched = estimate_score_chunks(kernel_gaussian_sr_patched, predictions, target)\n",
    "                kernel_rat_quad_score_patched = estimate_score_chunks(kernel_rat_quad_sr_patched, predictions, target)\n",
    "                energy_score_patched = estimate_score_chunks(energy_sr_patched, predictions, target)\n",
    "                string += f\"\\nPatched data scores: \\nEnergy score: {energy_score_patched:.2f}, \" \\\n",
    "                          f\"Gaussian Kernel score {kernel_gaussian_score_patched:.2f},\" \\\n",
    "                          f\" Rational quadratic Kernel score {kernel_rat_quad_score_patched:.2f}\\n\"\n",
    "\n",
    "        print(string)\n",
    "\n",
    "with torch.no_grad():\n",
    "    # -- calibration metrics --\n",
    "    # target_data_test shape (n_test, data_size)\n",
    "    # predictions_test shape (n_test, ensemble_size, data_size)\n",
    "    data_size = predictions_test.shape[-1]\n",
    "    predictions_for_calibration = predictions_test.transpose(1, 0).cpu().detach().numpy()\n",
    "    target_data_test_for_calibration = target_data_test.cpu().detach().numpy()\n",
    "    cal_err_values = calibration_error(predictions_for_calibration, target_data_test_for_calibration)\n",
    "    rmse_values = rmse(predictions_for_calibration, target_data_test_for_calibration)\n",
    "    r2_values = R2(predictions_for_calibration, target_data_test_for_calibration)\n",
    "\n",
    "    string2 = f\"Calibration metrics:\\n\"\n",
    "    for i in range(data_size):\n",
    "        string2 += f\"x{i}: Cal. error {cal_err_values[i]:.4f}, RMSE {rmse_values[i]:.4f}, R2 {r2_values[i]:.4f}\\n\"\n",
    "    string2 += f\"\\nAverage values: Cal. error {cal_err_values.mean():.4f}, RMSE {rmse_values.mean():.4f}, R2 {r2_values.mean():.4f}\\n\"\n",
    "    string2 += f\"\\nStandard deviation: Cal. error {cal_err_values.std():.4f}, RMSE {rmse_values.std():.4f}, R2 {r2_values.std():.4f}\\n\\n\"\n",
    "\n",
    "    string2 += f\"\\nAverage values: Cal. error, RMSE, R2 \\n\"\n",
    "    string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f} $ \\pm$ {cal_err_values.std():.4f} & {rmse_values.mean():.4f}  $ \\pm$ {rmse_values.std():.4f} &  {r2_values.mean():.4f} $ \\pm$ {r2_values.std():.4f} \\\\\\\\ \\n\"\n",
    "    string2 += f\"\\n\\t\\t {cal_err_values.mean():.4f}  & {rmse_values.mean():.4f}  &  {r2_values.mean():.4f}  \\\\\\\\ \\n\"\n",
    "    print(string2)\n",
    "\n",
    "    # -- plots --\n",
    "with torch.no_grad():\n",
    "    if model_is_weatherbench:\n",
    "        # we visualize only the first 8 variables.\n",
    "        variable_list = np.linspace(0, target_data_test.shape[-1] - 1, 8, dtype=int)\n",
    "        predictions_test = predictions_test[:, :, variable_list]\n",
    "        target_data_test = target_data_test[:, variable_list]\n",
    "        predictions_for_calibration = predictions_for_calibration[:, :, variable_list]\n",
    "        target_data_test_for_calibration = target_data_test_for_calibration[:, variable_list]\n",
    "\n",
    "    predictions_test_for_plot = predictions_test.cpu()\n",
    "    target_data_test_for_plot = target_data_test.cpu()\n",
    "    time_vec = torch.arange(len(predictions_test)).cpu()\n",
    "    data_size = predictions_test_for_plot.shape[-1]\n",
    "\n",
    "    if model == \"lorenz\":\n",
    "        var_names = [r\"$y$\"]\n",
    "    elif model == \"WeatherBench\":\n",
    "        # todo write here the correct lon and lat coordinates!\n",
    "        var_names = [r\"$x_{}$\".format(i + 1) for i in range(data_size)]\n",
    "    else:\n",
    "        var_names = [r\"$x_{}$\".format(i + 1) for i in range(data_size)]\n",
    "\n",
    "    # predictions: mean +- std\n",
    "    label_size = 13\n",
    "    if method != \"regression\":\n",
    "        predictions_mean = torch.mean(predictions_test_for_plot, dim=1).detach().numpy()\n",
    "        predictions_std = torch.std(predictions_test_for_plot, dim=1).detach().numpy()\n",
    "\n",
    "        fig, ax = plt.subplots(nrows=data_size, ncols=1, sharex=\"col\", figsize=(6.4, 3) if data_size == 1 else None)\n",
    "        if data_size == 1:\n",
    "            ax = [ax]\n",
    "        for var in range(data_size):\n",
    "            ax[var].plot(time_vec[plot_start_timestep:plot_end_timestep],\n",
    "                         target_data_test_for_plot[plot_start_timestep:plot_end_timestep, var], ls=\"--\",\n",
    "                         color=f\"C{var}\")\n",
    "            ax[var].plot(time_vec[plot_start_timestep:plot_end_timestep],\n",
    "                         predictions_mean[plot_start_timestep:plot_end_timestep, var], ls=\"-\", color=f\"C{var}\")\n",
    "            ax[var].fill_between(\n",
    "                time_vec[plot_start_timestep:plot_end_timestep], alpha=0.3, color=f\"C{var}\",\n",
    "                y1=predictions_mean[plot_start_timestep:plot_end_timestep, var] -\n",
    "                   predictions_std[plot_start_timestep:plot_end_timestep, var],\n",
    "                y2=predictions_mean[plot_start_timestep:plot_end_timestep, var] +\n",
    "                   predictions_std[plot_start_timestep:plot_end_timestep, var])\n",
    "            ax[var].set_ylabel(var_names[var], size=label_size)\n",
    "\n",
    "        ax[-1].set_xlabel(\"Integration time index\")\n",
    "        fig.suptitle(r\"Mean $\\pm$ std, \" + model)\n",
    "        # plt.show()\n",
    "        if save_plots:\n",
    "            plt.savefig(nets_folder + f\"prediction{name_postfix}.png\")\n",
    "        plt.close()\n",
    "\n",
    "    # predictions: median and 99% quantile region\n",
    "    np_predictions = predictions_test_for_plot.detach().numpy()\n",
    "    size = 99\n",
    "    predictions_median = np.median(np_predictions, axis=1)\n",
    "    if method != \"regression\":\n",
    "        predictions_lower = np.percentile(np_predictions, 50 - size / 2, axis=1)\n",
    "        predictions_upper = np.percentile(np_predictions, 50 + size / 2, axis=1)\n",
    "\n",
    "    fig, ax = plt.subplots(nrows=data_size, ncols=1, sharex=\"col\", figsize=(6.4, 3) if data_size == 1 else None)\n",
    "    if data_size == 1:\n",
    "        ax = [ax]\n",
    "    for var in range(data_size):\n",
    "        ax[var].plot(time_vec[plot_start_timestep:plot_end_timestep],\n",
    "                     target_data_test_for_plot[plot_start_timestep:plot_end_timestep, var], ls=\"--\", color=f\"C{var}\",\n",
    "                     label=\"True\")\n",
    "        ax[var].plot(time_vec[plot_start_timestep:plot_end_timestep],\n",
    "                     predictions_median[plot_start_timestep:plot_end_timestep, var], ls=\"-\", color=f\"C{var}\",\n",
    "                     label=\"Median forecast\" if method != \"regression\" else \"Forecast\")\n",
    "        if method != \"regression\":\n",
    "            ax[var].fill_between(\n",
    "                time_vec[plot_start_timestep:plot_end_timestep], alpha=0.3, color=f\"C{var}\",\n",
    "                y1=predictions_lower[plot_start_timestep:plot_end_timestep, var],\n",
    "                y2=predictions_upper[plot_start_timestep:plot_end_timestep, var], label=\"99% credible region\")\n",
    "        ax[var].set_ylabel(var_names[var], size=label_size)\n",
    "        ax[var].tick_params(axis='both', which='major', labelsize=label_size)\n",
    "\n",
    "    if data_size == 1:\n",
    "        ax[0].legend(fontsize=label_size)\n",
    "\n",
    "    ax[-1].set_xlabel(r\"$t$\", size=label_size)\n",
    "    # fig.suptitle(f\"Median and {size}% credible region, \" + model_name_for_plot, size=title_size)\n",
    "    # plt.show()\n",
    "\n",
    "    if save_plots:\n",
    "        # save the metrics in file\n",
    "        text_file = open(nets_folder + f\"test_losses{name_postfix}.txt\", \"w\")\n",
    "        text_file.write(string + \"\\n\")\n",
    "        text_file.write(string2 + \"\\n\")\n",
    "        text_file.close()\n",
    "        # save the plot:\n",
    "\n",
    "        if data_size == 1:\n",
    "            bbox = Bbox(np.array([[0, -0.2], [6.1, 3]]))\n",
    "        else:\n",
    "            bbox = Bbox(np.array([[0, -0.2], [6.0, 4.8]]))\n",
    "        plt.savefig(nets_folder + f\"prediction_median{name_postfix}.\" + (\"pdf\" if save_pdf else \"png\"), dpi=400,\n",
    "                    bbox_inches=bbox)\n",
    "    plt.close()\n",
    "\n",
    "    if not model_is_weatherbench:\n",
    "        # metrics plots\n",
    "        plot_metrics_params(cal_err_values, rmse_values, r2_values,\n",
    "                            filename=nets_folder + f\"metrics{name_postfix}.png\" if save_plots else None)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Diss",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
