{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "099a3c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import time\n",
    "import datetime\n",
    "from datetime import timedelta\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "import joblib\n",
    "from scipy.stats import norm\n",
    "import warnings\n",
    "import random\n",
    "import os\n",
    "\n",
    "from datetime import datetime, timedelta\n",
    "warnings.filterwarnings(\"ignore\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "de052b5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def continuous_ranked_probability_score(obs, fx, fx_prob):\n",
    "    \"\"\"Continuous Ranked Probability Score (CRPS).\n",
    "\n",
    "    .. math::\n",
    "\n",
    "        \\\\text{CRPS} = \\\\frac{1}{n} \\\\sum_{i=1}^n \\\\int_{-\\\\infty}^{\\\\infty}\n",
    "        (F_i(x) - \\\\mathbf{1} \\\\{x \\\\geq y_i \\\\})^2 dx\n",
    "\n",
    "    where :math:`F_i(x)` is the CDF of the forecast at time :math:`i`,\n",
    "    :math:`y_i` is the observation at time :math:`i`, and :math:`\\\\mathbf{1}`\n",
    "    is the indicator function that transforms the observation into a step\n",
    "    function (1 if :math:`x \\\\geq y`, 0 if :math:`x < y`). In other words, the\n",
    "    CRPS measures the difference between the forecast CDF and the empirical CDF\n",
    "    of the observation. The CRPS has the same units as the observation. Lower\n",
    "    CRPS values indicate more accurate forecasts, where a CRPS of 0 indicates a\n",
    "    perfect forecast. [1]_ [2]_ [3]_\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    obs : (n,) array_like\n",
    "        Observations (physical unit).\n",
    "    fx : (n, d) array_like\n",
    "        Forecasts (physical units) of the right-hand-side of a CDF with d\n",
    "        intervals (d >= 2), e.g., fx = [10 MW, 20 MW, 30 MW] is interpreted as\n",
    "        <= 10 MW, <= 20 MW, <= 30 MW.\n",
    "    fx_prob : (n, d) array_like\n",
    "        Probability [%] associated with the forecasts.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    crps : float\n",
    "        The Continuous Ranked Probability Score, with the same units as the\n",
    "        observation.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the forecasts have incorrect dimensions; either a) the forecasts are\n",
    "        for a single sample (n=1) with d CDF intervals but are given as a 1D\n",
    "        array with d values or b) the forecasts are given as 2D arrays (n,d)\n",
    "        but do not contain at least 2 CDF intervals (i.e. d < 2).\n",
    "\n",
    "    Notes\n",
    "    -----\n",
    "    The CRPS can be calculated analytically when the forecast CDF is of a\n",
    "    continuous parametric distribution, e.g., Gaussian distribution. However,\n",
    "    since the Solar Forecast Arbiter makes no assumptions regarding how a\n",
    "    probabilistic forecast was generated, the CRPS is instead calculated using\n",
    "    numerical integration of the discretized forecast CDF. Therefore, the\n",
    "    accuracy of the CRPS calculation is limited by the precision of the\n",
    "    forecast CDF. In practice, this means the forecast CDF should 1) consist of\n",
    "    at least 10 intervals and 2) cover probabilities from 0% to 100%.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] Matheson and Winkler (1976) \"Scoring rules for continuous\n",
    "           probability distributions.\" Management Science, vol. 22, pp.\n",
    "           1087-1096. doi: 10.1287/mnsc.22.10.1087\n",
    "    .. [2] Hersbach (2000) \"Decomposition of the continuous ranked probability\n",
    "           score for ensemble prediction systems.\" Weather Forecast, vol. 15,\n",
    "           pp. 559-570. doi: 10.1175/1520-0434(2000)015<0559:DOTCRP>2.0.CO;2\n",
    "    .. [3] Wilks (2019) \"Statistical Methods in the Atmospheric Sciences\", 4th\n",
    "           ed. Oxford; Waltham, MA; Academic Press.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    # match observations to fx shape: (n,) => (n, d)\n",
    "    if np.ndim(fx) < 2:\n",
    "        raise ValueError(\"forecasts must be 2D arrays (expected (n,d), got\"\n",
    "                         f\"{np.shape(fx)})\")\n",
    "    elif np.shape(fx)[1] < 2:\n",
    "        raise ValueError(\"forecasts must have d >= 2 CDF intervals \"\n",
    "                         f\"(expected >= 2, got {np.shape(fx)[1]})\")\n",
    "\n",
    "    n = len(fx)\n",
    "\n",
    "    # extend CDF min to ensure obs within forecast support\n",
    "    # fx.shape = (n, d) ==> (n, d + 1)\n",
    "    fx_min = np.minimum(obs, fx[:, 0])\n",
    "    fx = np.hstack([fx_min[:, np.newaxis], fx])\n",
    "    fx_prob = np.hstack([np.zeros([n, 1]), fx_prob])\n",
    "\n",
    "    # extend CDF max to ensure obs within forecast support\n",
    "    # fx.shape = (n, d + 1) ==> (n, d + 2)\n",
    "    idx = (fx[:, -1] < obs)\n",
    "    fx_max = np.maximum(obs, fx[:, -1])\n",
    "    fx = np.hstack([fx, fx_max[:, np.newaxis]])\n",
    "    fx_prob = np.hstack([fx_prob, np.full([n, 1], 100)])\n",
    "\n",
    "    # indicator function:\n",
    "    # - left of the obs is 0.0\n",
    "    # - obs and right of the obs is 1.0\n",
    "    o = np.where(fx >= obs[:, np.newaxis], 1.0, 0.0)\n",
    "\n",
    "    # correct behavior when obs > max fx:\n",
    "    # - should be 0 over range: max fx < x < obs\n",
    "    o[idx, -1] = 0.0\n",
    "\n",
    "    # forecast probabilities [unitless]\n",
    "    f = fx_prob / 100.0\n",
    "\n",
    "    # integrate along each sample, then average all samples\n",
    "    crps = np.mean(np.trapz((f - o) ** 2, x=fx, axis=1))\n",
    "\n",
    "    return crps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c2e8a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def crps(data,path):\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][2:13]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"pred_crps\"] = crps_loss\n",
    "\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][13:24]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"ref_crps\"] = crps_loss\n",
    "    \n",
    "    data.to_csv(path, index = False)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b459cc24",
   "metadata": {},
   "outputs": [],
   "source": [
    "# class LSTM(nn.Module):\n",
    "\n",
    "#     def __init__(self):\n",
    "#         super(LSTM, self).__init__()\n",
    "\n",
    "#         self.lstm = nn.LSTM(\n",
    "#             input_size=12,   \n",
    "#             hidden_size=240,\n",
    "#             num_layers=1, \n",
    "#             batch_first=True,\n",
    "# #             dropout=0.3\n",
    "#         )\n",
    "#         self.ann2 = nn.Sequential(\n",
    "#             nn.Linear(11,1),\n",
    "# #             nn.ReLU(),\n",
    "# #             nn.Linear(11,1)\n",
    "#         )\n",
    "#         self.out = nn.Sequential(\n",
    "#             nn.Linear(240,11),\n",
    "#         )\n",
    "#         self.state = None\n",
    "\n",
    "#     def forward(self, x, z):\n",
    "\n",
    "#         batch = x.shape[0]\n",
    "#         tx = torch.reshape(x,(batch,336,-1))\n",
    "#         tx, self.state = self.lstm(tx, self.state)\n",
    "#         tz = torch.reshape(z,(batch,24,-1))\n",
    "#         temp = self.ann2(tz)\n",
    "#         out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "#         out = torch.unsqueeze(out,2)\n",
    "#         out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "#         out1 = torch.unsqueeze(out1, -1)\n",
    "#         out2 = F.softplus(out2)\n",
    "#         out = torch.cat((out1,out2),dim=-1)\n",
    "#         out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "#         return out\n",
    "\n",
    "# def load_model(model_path):\n",
    "#     model = LSTM()\n",
    "#     model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))\n",
    "#     model.double()\n",
    "#     model.eval()\n",
    "#     return model\n",
    "\n",
    "# def read_data_cont(df):\n",
    "#     data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "#     data_cont = data_cont.values\n",
    "#     return data_cont\n",
    "# def read_data_target(df):\n",
    "#     data_target = df[['total_grid']]\n",
    "#     data_target = data_target.values\n",
    "#     return data_target\n",
    "# def read_data_time(df):\n",
    "#     data_time = df[['dayofweek','timeofday', 'month']]\n",
    "#     data_time = data_time.values\n",
    "#     return data_time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "01cf61a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "\n",
    "        self.lstm = nn.LSTM(\n",
    "            input_size=6,   \n",
    "            hidden_size=60,\n",
    "            num_layers=1, \n",
    "            batch_first=True,\n",
    "        )\n",
    "        self.ann2 = nn.Sequential(\n",
    "            nn.Linear(11,11),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(11,1)\n",
    "        )\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Linear(60,11),\n",
    "        )\n",
    "        self.state = None\n",
    "\n",
    "    def forward(self, x, z):\n",
    "\n",
    "        batch = x.shape[0]\n",
    "        tx = torch.reshape(x,(batch,336,-1))\n",
    "        tx, self.state = self.lstm(tx, self.state)\n",
    "        tz = torch.reshape(z,(batch,24,-1))\n",
    "        temp = self.ann2(tz)\n",
    "        out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "        out = torch.unsqueeze(out,2)\n",
    "        out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "        out1 = torch.unsqueeze(out1, -1)\n",
    "        out2 = F.softplus(out2)\n",
    "        out = torch.cat((out1,out2),dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n",
    "\n",
    "def load_model(model_path):\n",
    "    model = LSTM()\n",
    "    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))\n",
    "    model.double()\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "def read_data_cont(df):\n",
    "    data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "    data_cont = data_cont.values\n",
    "    return data_cont\n",
    "def read_data_target(df):\n",
    "    data_target = df[['total_grid']]\n",
    "    data_target = data_target.values\n",
    "    return data_target\n",
    "def read_data_time(df):\n",
    "    data_time = df[['dayofweek','timeofday', 'month']]\n",
    "    data_time = data_time.values\n",
    "    return data_time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4a499d75",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred(location, model_year, time_now, scaler_path, seed_list):\n",
    "    \n",
    "    df_history = pd.read_csv(os.path.join(data_path, f\"{location}_history.csv\"))\n",
    "    df_future = pd.read_csv(os.path.join(data_path, f\"{location}_future.csv\"))\n",
    "    \n",
    "    df_future = simulate_forecast(df_future.copy(),location,seed_list)\n",
    "    \n",
    "    if \"total_grid\" not in df_future.columns:\n",
    "        new_col = np.zeros(df_future.shape[0])\n",
    "        df_future.insert(0, 'total_grid', new_col)\n",
    "\n",
    "#     if not os.path.exists(traindata_path):\n",
    "#         print(f\"Error: Directory {traindata_path} does not exist.\")\n",
    "\n",
    "\n",
    "#     df_raw = pd.read_csv(traindata_path)\n",
    "#     rawdata = df_raw.copy()\n",
    "#     df_raw = rawdata[(rawdata[\"year\"] == 2022) | ((rawdata[\"year\"] == 2023) & (rawdata[\"dayofyear\"] <= 134))]   \n",
    "#     rawdata_cont = read_data_cont(df_raw)\n",
    "#     scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "# #     scaler = StandardScaler()\n",
    "#     scaler.fit(rawdata_cont)\n",
    "    \n",
    "\n",
    "    if not os.path.exists(model_path):\n",
    "        print(f\"Error: Model file {model_path} does not exist.\")\n",
    "        \n",
    "    if not os.path.exists(scaler_path):\n",
    "        print(f\"Error: Scaler file {scaler_path} does not exist.\")\n",
    "\n",
    "    model = load_model(model_path)\n",
    "    scaler = joblib.load(scaler_path)\n",
    "\n",
    "    def read_test(df):\n",
    "        test_cont = read_data_cont(df)\n",
    "        test_cont = scaler.transform(test_cont)\n",
    "\n",
    "        test_time = read_data_time(df)\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(test_time[:,0],7)    \n",
    "        H = cycl_(test_time[:,1],24)   \n",
    "        M = cycl_(test_time[:,2]-1,12)  \n",
    "        test_time = np.concatenate((W,H,M),0).T\n",
    "\n",
    "        test_context = np.concatenate((test_cont,test_time),1)\n",
    "\n",
    "        return test_context\n",
    "\n",
    "    history_context = read_test(df_history)[-336:,:6]\n",
    "    future_context = read_test(df_future)[:,1:]\n",
    "    \n",
    "    \n",
    "    test_X = (torch.unsqueeze(torch.tensor(history_context, dtype=torch.double),axis=0))\n",
    "    test_Z = (torch.unsqueeze(torch.tensor(future_context, dtype=torch.double),axis=0))\n",
    "\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        state = None\n",
    "        preds = model(test_X,test_Z)\n",
    "        preds = torch.squeeze(preds)\n",
    "       \n",
    "    df_pred = pd.DataFrame(preds)\n",
    "    df_pred.columns = [\"p00\",\"p10\",\"p20\",\"p30\",\"p40\",\"p50\",\"p60\",\"p70\",\"p80\",\"p90\",\"p100\"]\n",
    "    df_pred.index = df_future['localTime']\n",
    "    \n",
    "    return np.array(df_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "78f23472",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-27\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(19)]\n",
    "std_summary = pd.DataFrame()\n",
    "\n",
    "for feature in [\"DNI\",\"DHI\",\"GHI\",\"temperature\",\"relativehumidity\"]:\n",
    "    for location in [\"GA\",\"HI\",\"OR\",\"TX\"]:\n",
    "        hour = []\n",
    "        prediction = []\n",
    "        ground_truth = []\n",
    "        for date in date_strings:\n",
    "            forecast = pd.read_csv(f\"eval/old_eval/{location}_{date}/{location}_future.csv\") \n",
    "            forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
    "            prediction.extend(forecast[feature])\n",
    "\n",
    "            truth = pd.read_csv(f\"eval/1_day_ahead/{location}_{date}/{location}_future.csv\") \n",
    "            truth[\"localTime\"] = pd.to_datetime(truth[\"localTime\"])\n",
    "            ground_truth.extend(truth[feature])\n",
    "\n",
    "            hour.extend(truth[\"timeofday\"])\n",
    "        \n",
    "        df = pd.DataFrame()\n",
    "        df[\"hour\"] = hour\n",
    "        df[\"truth\"] = ground_truth\n",
    "        df[\"prediction\"] = prediction\n",
    "        df[\"std(error)\"] = df[\"prediction\"] - df[\"truth\"]\n",
    "\n",
    "        std_summary[\"hour\"] = df.groupby(\"hour\", as_index = False).std()[\"hour\"]\n",
    "        std_summary[f\"{location}_{feature}\"] = df.groupby(\"hour\", as_index = False).std()[\"std(error)\"]\n",
    "\n",
    "def generate_error(data, location, feature, seed):\n",
    "    \n",
    "    random.seed(seed)\n",
    "    \n",
    "    error = [random.gauss(mu=0, sigma=std * 1) for std in std_summary.iloc[data[\"timeofday\"]][f\"{location}_{feature}\"]]\n",
    "    \n",
    "    return error\n",
    "\n",
    "\n",
    "def simulate_forecast(data,location,seed_list):\n",
    "    \n",
    "\n",
    "    if location == \"GA\":\n",
    "        DNI_min, DNI_max = 0, 1040\n",
    "        DHI_min, DHI_max = 0, 551\n",
    "    elif location == \"HI\":\n",
    "        DNI_min, DNI_max = 0, 989\n",
    "        DHI_min, DHI_max = 0, 567\n",
    "    elif location == \"OR\":\n",
    "        DNI_min, DNI_max = 0, 981\n",
    "        DHI_min, DHI_max = 0, 541 \n",
    "    elif location == \"TX\":\n",
    "        DNI_min, DNI_max = 0, 1052\n",
    "        DHI_min, DHI_max = 0, 560\n",
    "        \n",
    "\n",
    "    data[\"total_grid\"] = 0\n",
    "    \n",
    "    data[\"DNI\"] = data[\"DNI\"] + generate_error(data, location, \"DNI\", seed_list[0])\n",
    "    data[\"DHI\"] = data[\"DHI\"] + generate_error(data, location, \"DHI\", seed_list[1])\n",
    "    \n",
    "    data[\"DHI\"] = np.clip(data[\"DHI\"], DHI_min, DHI_max)\n",
    "    data[\"DNI\"] = np.clip(data[\"DNI\"], DNI_min, DNI_max)\n",
    "    \n",
    "    \n",
    "    data[\"temperature\"] = data[\"temperature\"] + generate_error(data, location, \"temperature\", seed_list[2])\n",
    "    data[\"relativehumidity\"] = data[\"relativehumidity\"] + generate_error(data, location, \"relativehumidity\", seed_list[3])\n",
    "    data[\"relativehumidity\"] = np.clip(data[\"relativehumidity\"], 0, 100)\n",
    "    \n",
    "  \n",
    "    return data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "be210116",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-18\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(28)]\n",
    "date_strings = [k[5:] for k in date_strings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "67dc914c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2005\n",
      "[0.375, 0.432, 0.424, 0.407, 0.45, 0.438, 0.425, 0.413, 0.429, 0.43]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "2006\n",
      "[0.349, 0.434, 0.394, 0.427, 0.414, 0.411, 0.406, 0.389, 0.408, 0.372]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2007\n",
      "[0.428, 0.473, 0.431, 0.466, 0.461, 0.443, 0.441, 0.401, 0.439, 0.451]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "2008\n",
      "[0.346, 0.407, 0.412, 0.408, 0.445, 0.425, 0.419, 0.398, 0.422, 0.402]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "2009\n",
      "[0.403, 0.457, 0.432, 0.432, 0.481, 0.429, 0.427, 0.409, 0.449, 0.428]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "2010\n",
      "[0.398, 0.467, 0.431, 0.439, 0.476, 0.443, 0.423, 0.404, 0.442, 0.424]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "5  0.4347  2010.0\n",
      "2011\n",
      "[0.412, 0.471, 0.433, 0.452, 0.458, 0.468, 0.456, 0.378, 0.441, 0.435]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "5  0.4347  2010.0\n",
      "6  0.4404  2011.0\n",
      "2012\n",
      "[0.39, 0.435, 0.405, 0.446, 0.457, 0.437, 0.391, 0.416, 0.439, 0.43]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "5  0.4347  2010.0\n",
      "6  0.4404  2011.0\n",
      "7  0.4246  2012.0\n",
      "2013\n",
      "[0.345, 0.414, 0.359, 0.409, 0.395, 0.406, 0.401, 0.361, 0.378, 0.373]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "5  0.4347  2010.0\n",
      "6  0.4404  2011.0\n",
      "7  0.4246  2012.0\n",
      "8  0.3841  2013.0\n",
      "2014\n",
      "[0.406, 0.441, 0.445, 0.466, 0.465, 0.429, 0.418, 0.405, 0.431, 0.432]\n",
      "       OR    seed\n",
      "0  0.4223  2005.0\n",
      "1  0.4004  2006.0\n",
      "2  0.4434  2007.0\n",
      "3  0.4084  2008.0\n",
      "4  0.4347  2009.0\n",
      "5  0.4347  2010.0\n",
      "6  0.4404  2011.0\n",
      "7  0.4246  2012.0\n",
      "8  0.3841  2013.0\n",
      "9  0.4338  2014.0\n",
      "2015\n",
      "[0.389, 0.41, 0.362, 0.391, 0.409, 0.377, 0.371, 0.324, 0.354, 0.359]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "2016\n",
      "[0.371, 0.412, 0.42, 0.39, 0.434, 0.433, 0.364, 0.385, 0.404, 0.419]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "2017\n",
      "[0.413, 0.483, 0.441, 0.448, 0.452, 0.442, 0.456, 0.384, 0.434, 0.411]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "2018\n",
      "[0.374, 0.43, 0.397, 0.403, 0.455, 0.442, 0.337, 0.397, 0.43, 0.422]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "2019\n",
      "[0.398, 0.475, 0.443, 0.451, 0.478, 0.454, 0.453, 0.443, 0.441, 0.42]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "2020\n",
      "[0.348, 0.411, 0.407, 0.365, 0.445, 0.408, 0.346, 0.427, 0.434, 0.441]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "15  0.4032  2020.0\n",
      "2021\n",
      "[0.345, 0.416, 0.396, 0.367, 0.446, 0.39, 0.358, 0.364, 0.427, 0.399]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "15  0.4032  2020.0\n",
      "16  0.3908  2021.0\n",
      "2022\n",
      "[0.383, 0.338, 0.4, 0.391, 0.348, 0.366, 0.36, 0.385, 0.358, 0.387]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "15  0.4032  2020.0\n",
      "16  0.3908  2021.0\n",
      "17  0.3716  2022.0\n",
      "2023\n",
      "[0.42, 0.473, 0.447, 0.462, 0.485, 0.462, 0.435, 0.388, 0.466, 0.449]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "15  0.4032  2020.0\n",
      "16  0.3908  2021.0\n",
      "17  0.3716  2022.0\n",
      "18  0.4487  2023.0\n",
      "2024\n",
      "[0.396, 0.376, 0.368, 0.404, 0.395, 0.392, 0.404, 0.375, 0.383, 0.395]\n",
      "        OR    seed\n",
      "0   0.4223  2005.0\n",
      "1   0.4004  2006.0\n",
      "2   0.4434  2007.0\n",
      "3   0.4084  2008.0\n",
      "4   0.4347  2009.0\n",
      "5   0.4347  2010.0\n",
      "6   0.4404  2011.0\n",
      "7   0.4246  2012.0\n",
      "8   0.3841  2013.0\n",
      "9   0.4338  2014.0\n",
      "10  0.3746  2015.0\n",
      "11  0.4032  2016.0\n",
      "12  0.4364  2017.0\n",
      "13  0.4087  2018.0\n",
      "14  0.4456  2019.0\n",
      "15  0.4032  2020.0\n",
      "16  0.3908  2021.0\n",
      "17  0.3716  2022.0\n",
      "18  0.4487  2023.0\n",
      "19  0.3888  2024.0\n"
     ]
    }
   ],
   "source": [
    "summary = pd.DataFrame()\n",
    "\n",
    "for model_year in [str(m) for m in range(2005,2025)]:\n",
    "    print(model_year)\n",
    "    for location in [\"OR\"]:\n",
    "        crpss_result = []\n",
    "        for seed in range(2025,2035): \n",
    "            \n",
    "            random.seed(seed)\n",
    "            seed_list = [random.randint(1, 1000) for _ in range(1200)]\n",
    "\n",
    "            result = []\n",
    "            for time_now in date_strings:\n",
    "                save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "                df_save = pd.read_csv(save_to)\n",
    "\n",
    "                model_path = f'./newest models and scalers/{location}_{model_year}.pt'\n",
    "                data_path = f'./eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "#                 traindata_path = f'./training/{model_year}_seed_60/{location}_0.csv'\n",
    "#                 traindata_path = f'./new_training/{model_year}/{location}_{model_year}.csv'\n",
    "                scaler_path = f'./newest models and scalers/{location}_{model_year}_scaler.gz'\n",
    "\n",
    "                result.append(get_pred(location, model_year, time_now, scaler_path, seed_list))\n",
    "\n",
    "            if location == \"OR\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"HI\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"GA\" or location == \"TX\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "\n",
    "            data = pd.read_csv(f\"{save_to}\")\n",
    "            data = crps(data,f\"{save_to}\")\n",
    "            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "#             print(np.mean(data_group[\"pred_crps\"]),np.mean(data_group[\"ref_crps\"]))\n",
    "            crpss_result.append(np.round(crpss,3))\n",
    "        \n",
    "        print(crpss_result)\n",
    "    crpss_result = np.array(crpss_result).reshape(-1,1)\n",
    "    result_sum = pd.DataFrame(crpss_result)\n",
    "    result_sum[\"seed\"] = int(model_year)\n",
    "    result_sum.columns = [\"OR\",\"seed\"]\n",
    "    \n",
    "    result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "    result_sub = pd.DataFrame(result_sub)\n",
    "    result_sub.columns = [\"OR\",\"seed\"]\n",
    "    summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "    print(summary)\n",
    "\n",
    "summary.to_csv(f\"D:/Jupyter/Research/Competition/ICLR/model/forecast_evaluation/Ablation_LSTM_OR.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "a1280cd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2005\n",
      "[0.114, 0.189, 0.244, 0.187, 0.2, 0.175, 0.168, 0.221, 0.175, 0.168]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "2006\n",
      "[0.244, 0.3, 0.303, 0.283, 0.29, 0.279, 0.287, 0.313, 0.276, 0.29]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2007\n",
      "[0.122, 0.151, 0.241, 0.175, 0.194, 0.144, 0.185, 0.19, 0.11, 0.186]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "2008\n",
      "[0.156, 0.188, 0.227, 0.206, 0.193, 0.181, 0.16, 0.214, 0.144, 0.147]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "2009\n",
      "[0.189, 0.24, 0.262, 0.251, 0.252, 0.216, 0.233, 0.26, 0.222, 0.226]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "2010\n",
      "[0.058, 0.139, 0.177, 0.128, 0.166, 0.095, 0.143, 0.196, 0.108, 0.101]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "5  0.1311  2010.0\n",
      "2011\n",
      "[0.145, 0.206, 0.214, 0.143, 0.201, 0.172, 0.151, 0.217, 0.162, 0.171]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "5  0.1311  2010.0\n",
      "6  0.1782  2011.0\n",
      "2012\n",
      "[0.191, 0.224, 0.267, 0.264, 0.238, 0.226, 0.218, 0.237, 0.197, 0.245]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "5  0.1311  2010.0\n",
      "6  0.1782  2011.0\n",
      "7  0.2307  2012.0\n",
      "2013\n",
      "[0.124, 0.207, 0.239, 0.201, 0.214, 0.163, 0.188, 0.203, 0.153, 0.174]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "5  0.1311  2010.0\n",
      "6  0.1782  2011.0\n",
      "7  0.2307  2012.0\n",
      "8  0.1866  2013.0\n",
      "2014\n",
      "[0.184, 0.268, 0.254, 0.229, 0.241, 0.215, 0.234, 0.238, 0.245, 0.234]\n",
      "       GA    seed\n",
      "0  0.1841  2005.0\n",
      "1  0.2865  2006.0\n",
      "2  0.1698  2007.0\n",
      "3  0.1816  2008.0\n",
      "4  0.2351  2009.0\n",
      "5  0.1311  2010.0\n",
      "6  0.1782  2011.0\n",
      "7  0.2307  2012.0\n",
      "8  0.1866  2013.0\n",
      "9  0.2342  2014.0\n",
      "2015\n",
      "[0.158, 0.226, 0.255, 0.204, 0.261, 0.171, 0.243, 0.228, 0.175, 0.21]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "2016\n",
      "[0.221, 0.27, 0.258, 0.26, 0.265, 0.246, 0.223, 0.263, 0.242, 0.253]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "2017\n",
      "[0.142, 0.202, 0.189, 0.184, 0.202, 0.167, 0.207, 0.182, 0.188, 0.161]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "2018\n",
      "[0.252, 0.292, 0.29, 0.268, 0.272, 0.261, 0.307, 0.298, 0.253, 0.289]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "2019\n",
      "[0.094, 0.162, 0.197, 0.125, 0.133, 0.13, 0.113, 0.138, 0.103, 0.126]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "2020\n",
      "[0.178, 0.235, 0.24, 0.233, 0.246, 0.211, 0.216, 0.231, 0.215, 0.205]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "15  0.2210  2020.0\n",
      "2021\n",
      "[0.147, 0.112, 0.189, 0.176, 0.201, 0.162, 0.107, 0.264, 0.114, 0.135]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "15  0.2210  2020.0\n",
      "16  0.1607  2021.0\n",
      "2022\n",
      "[0.065, 0.163, 0.167, 0.151, 0.162, 0.117, 0.138, 0.136, 0.076, 0.119]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "15  0.2210  2020.0\n",
      "16  0.1607  2021.0\n",
      "17  0.1294  2022.0\n",
      "2023\n",
      "[0.188, 0.2, 0.274, 0.222, 0.215, 0.216, 0.204, 0.237, 0.143, 0.216]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "15  0.2210  2020.0\n",
      "16  0.1607  2021.0\n",
      "17  0.1294  2022.0\n",
      "18  0.2115  2023.0\n",
      "2024\n",
      "[0.108, 0.184, 0.217, 0.221, 0.205, 0.177, 0.183, 0.178, 0.132, 0.157]\n",
      "        GA    seed\n",
      "0   0.1841  2005.0\n",
      "1   0.2865  2006.0\n",
      "2   0.1698  2007.0\n",
      "3   0.1816  2008.0\n",
      "4   0.2351  2009.0\n",
      "5   0.1311  2010.0\n",
      "6   0.1782  2011.0\n",
      "7   0.2307  2012.0\n",
      "8   0.1866  2013.0\n",
      "9   0.2342  2014.0\n",
      "10  0.2131  2015.0\n",
      "11  0.2501  2016.0\n",
      "12  0.1824  2017.0\n",
      "13  0.2782  2018.0\n",
      "14  0.1321  2019.0\n",
      "15  0.2210  2020.0\n",
      "16  0.1607  2021.0\n",
      "17  0.1294  2022.0\n",
      "18  0.2115  2023.0\n",
      "19  0.1762  2024.0\n"
     ]
    }
   ],
   "source": [
    "summary = pd.DataFrame()\n",
    "\n",
    "for model_year in [str(m) for m in range(2005,2025)]:\n",
    "    print(model_year)\n",
    "    for location in [\"GA\"]:\n",
    "        crpss_result = []\n",
    "        for seed in range(2025,2035): \n",
    "            \n",
    "            random.seed(seed)\n",
    "            seed_list = [random.randint(1, 1000) for _ in range(1200)]\n",
    "\n",
    "            result = []\n",
    "            for time_now in date_strings:\n",
    "                save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "                df_save = pd.read_csv(save_to)\n",
    "\n",
    "                model_path = f'./newest models and scalers/{location}_{model_year}.pt'\n",
    "                data_path = f'./eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "#                 traindata_path = f'./training/{model_year}_seed_60/{location}_0.csv'\n",
    "#                 traindata_path = f'./new_training/{model_year}/{location}_{model_year}.csv'\n",
    "                scaler_path = f'./newest models and scalers/{location}_{model_year}_scaler.gz'\n",
    "\n",
    "                result.append(get_pred(location, model_year, time_now, scaler_path, seed_list))\n",
    "\n",
    "            if location == \"OR\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"HI\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"GA\" or location == \"TX\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "\n",
    "            data = pd.read_csv(f\"{save_to}\")\n",
    "            data = crps(data,f\"{save_to}\")\n",
    "            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "#             print(np.mean(data_group[\"pred_crps\"]),np.mean(data_group[\"ref_crps\"]))\n",
    "            crpss_result.append(np.round(crpss,3))\n",
    "        \n",
    "        print(crpss_result)\n",
    "    crpss_result = np.array(crpss_result).reshape(-1,1)\n",
    "    result_sum = pd.DataFrame(crpss_result)\n",
    "    result_sum[\"seed\"] = int(model_year)\n",
    "    result_sum.columns = [\"GA\",\"seed\"]\n",
    "    \n",
    "    result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "    result_sub = pd.DataFrame(result_sub)\n",
    "    result_sub.columns = [\"GA\",\"seed\"]\n",
    "    summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "    print(summary)\n",
    "\n",
    "summary.to_csv(f\"D:/Jupyter/Research/Competition/ICLR/model/forecast_evaluation/Ablation_LSTM_GA.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "13889c03",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2005\n",
      "[0.059, -0.032, 0.086, 0.013, 0.076, 0.045, 0.045, 0.08, 0.054, 0.049]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "2006\n",
      "[0.016, 0.045, 0.123, 0.07, 0.107, 0.064, 0.098, 0.095, 0.156, 0.083]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2007\n",
      "[0.048, -0.027, 0.099, 0.017, 0.093, 0.035, 0.012, 0.04, 0.091, 0.08]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "2008\n",
      "[0.036, -0.033, 0.091, 0.0, 0.087, 0.043, 0.068, 0.078, 0.112, 0.045]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "2009\n",
      "[0.039, 0.09, 0.144, 0.106, 0.13, 0.12, 0.12, 0.153, 0.153, 0.096]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "2010\n",
      "[0.133, 0.052, 0.159, 0.106, 0.149, 0.12, 0.134, 0.172, 0.148, 0.147]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "5  0.1320  2010.0\n",
      "2011\n",
      "[0.079, 0.043, 0.108, 0.049, 0.096, 0.089, 0.068, 0.131, 0.103, 0.092]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "5  0.1320  2010.0\n",
      "6  0.0858  2011.0\n",
      "2012\n",
      "[0.07, -0.018, 0.082, 0.017, 0.086, 0.066, 0.086, 0.125, 0.072, 0.066]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "5  0.1320  2010.0\n",
      "6  0.0858  2011.0\n",
      "7  0.0652  2012.0\n",
      "2013\n",
      "[0.035, -0.017, 0.053, 0.004, 0.07, 0.027, 0.044, 0.117, 0.065, 0.071]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "5  0.1320  2010.0\n",
      "6  0.0858  2011.0\n",
      "7  0.0652  2012.0\n",
      "8  0.0469  2013.0\n",
      "2014\n",
      "[0.078, -0.012, 0.092, 0.025, 0.117, 0.082, 0.06, 0.089, 0.115, 0.042]\n",
      "       HI    seed\n",
      "0  0.0475  2005.0\n",
      "1  0.0857  2006.0\n",
      "2  0.0488  2007.0\n",
      "3  0.0527  2008.0\n",
      "4  0.1151  2009.0\n",
      "5  0.1320  2010.0\n",
      "6  0.0858  2011.0\n",
      "7  0.0652  2012.0\n",
      "8  0.0469  2013.0\n",
      "9  0.0688  2014.0\n",
      "2015\n",
      "[0.019, 0.058, 0.09, 0.067, 0.107, 0.054, 0.092, 0.105, 0.13, 0.064]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "2016\n",
      "[0.004, -0.117, 0.108, 0.058, 0.059, 0.028, 0.018, -0.0, -0.002, 0.039]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "2017\n",
      "[0.073, -0.089, 0.043, -0.071, 0.067, 0.037, 0.014, 0.116, 0.006, -0.012]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "2018\n",
      "[0.056, -0.052, 0.041, -0.045, 0.068, 0.022, 0.024, 0.101, 0.059, 0.028]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "2019\n",
      "[0.087, 0.069, 0.143, 0.084, 0.139, 0.107, 0.148, 0.163, 0.153, 0.146]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "2020\n",
      "[0.071, 0.028, 0.115, 0.049, 0.117, 0.085, 0.096, 0.15, 0.132, 0.102]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "15  0.0945  2020.0\n",
      "2021\n",
      "[0.096, -0.022, 0.107, 0.095, 0.126, 0.09, 0.079, 0.096, 0.094, 0.087]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "15  0.0945  2020.0\n",
      "16  0.0848  2021.0\n",
      "2022\n",
      "[0.076, -0.052, 0.096, 0.03, 0.089, 0.065, 0.053, 0.091, 0.06, 0.072]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "15  0.0945  2020.0\n",
      "16  0.0848  2021.0\n",
      "17  0.0580  2022.0\n",
      "2023\n",
      "[0.071, 0.043, 0.129, 0.095, 0.132, 0.088, 0.109, 0.135, 0.116, 0.12]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "15  0.0945  2020.0\n",
      "16  0.0848  2021.0\n",
      "17  0.0580  2022.0\n",
      "18  0.1038  2023.0\n",
      "2024\n",
      "[0.108, 0.045, 0.147, 0.095, 0.15, 0.115, 0.123, 0.14, 0.14, 0.113]\n",
      "        HI    seed\n",
      "0   0.0475  2005.0\n",
      "1   0.0857  2006.0\n",
      "2   0.0488  2007.0\n",
      "3   0.0527  2008.0\n",
      "4   0.1151  2009.0\n",
      "5   0.1320  2010.0\n",
      "6   0.0858  2011.0\n",
      "7   0.0652  2012.0\n",
      "8   0.0469  2013.0\n",
      "9   0.0688  2014.0\n",
      "10  0.0786  2015.0\n",
      "11  0.0195  2016.0\n",
      "12  0.0184  2017.0\n",
      "13  0.0302  2018.0\n",
      "14  0.1239  2019.0\n",
      "15  0.0945  2020.0\n",
      "16  0.0848  2021.0\n",
      "17  0.0580  2022.0\n",
      "18  0.1038  2023.0\n",
      "19  0.1176  2024.0\n"
     ]
    }
   ],
   "source": [
    "summary = pd.DataFrame()\n",
    "\n",
    "for model_year in [str(m) for m in range(2005,2025)]:\n",
    "    print(model_year)\n",
    "    for location in [\"HI\"]:\n",
    "        crpss_result = []\n",
    "        for seed in range(2025,2035): \n",
    "            \n",
    "            random.seed(seed)\n",
    "            seed_list = [random.randint(1, 1000) for _ in range(1200)]\n",
    "\n",
    "            result = []\n",
    "            for time_now in date_strings:\n",
    "                save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "                df_save = pd.read_csv(save_to)\n",
    "\n",
    "                model_path = f'./newest models and scalers/{location}_{model_year}.pt'\n",
    "                data_path = f'./eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "#                 traindata_path = f'./training/{model_year}_seed_60/{location}_0.csv'\n",
    "#                 traindata_path = f'./new_training/{model_year}/{location}_{model_year}.csv'\n",
    "                scaler_path = f'./newest models and scalers/{location}_{model_year}_scaler.gz'\n",
    "\n",
    "                result.append(get_pred(location, model_year, time_now, scaler_path, seed_list))\n",
    "\n",
    "            if location == \"OR\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"HI\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"GA\" or location == \"TX\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "\n",
    "            data = pd.read_csv(f\"{save_to}\")\n",
    "            data = crps(data,f\"{save_to}\")\n",
    "            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "#             print(np.mean(data_group[\"pred_crps\"]),np.mean(data_group[\"ref_crps\"]))\n",
    "            crpss_result.append(np.round(crpss,3))\n",
    "        \n",
    "        print(crpss_result)\n",
    "    crpss_result = np.array(crpss_result).reshape(-1,1)\n",
    "    result_sum = pd.DataFrame(crpss_result)\n",
    "    result_sum[\"seed\"] = int(model_year)\n",
    "    result_sum.columns = [\"HI\",\"seed\"]\n",
    "    \n",
    "    result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "    result_sub = pd.DataFrame(result_sub)\n",
    "    result_sub.columns = [\"HI\",\"seed\"]\n",
    "    summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "    print(summary)\n",
    "\n",
    "summary.to_csv(f\"D:/Jupyter/Research/Competition/ICLR/model/forecast_evaluation/Ablation_LSTM_HI.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ea7541",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e7e49b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "358bb8cf",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "9e8c4e1e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2005\n",
      "[0.375, 0.432, 0.424, 0.407, 0.45, 0.438, 0.425, 0.413, 0.429, 0.43]\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "cannot reshape array of size 10 into shape (3)",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[1;32mIn[10], line 46\u001b[0m\n\u001b[0;32m     42\u001b[0m     crpss_result\u001b[38;5;241m.\u001b[39mappend(np\u001b[38;5;241m.\u001b[39mround(crpss,\u001b[38;5;241m3\u001b[39m))\n\u001b[0;32m     44\u001b[0m \u001b[38;5;28mprint\u001b[39m(crpss_result)\n\u001b[1;32m---> 46\u001b[0m crpss_result \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39marray(crpss_result)\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,\u001b[38;5;241m3\u001b[39m)\n\u001b[0;32m     47\u001b[0m result_sum \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame(crpss_result)\n\u001b[0;32m     48\u001b[0m result_sum[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mseed\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mint\u001b[39m(model_year)\n",
      "\u001b[1;31mValueError\u001b[0m: cannot reshape array of size 10 into shape (3)"
     ]
    }
   ],
   "source": [
    "# summary = pd.DataFrame()\n",
    "\n",
    "# for model_year in [str(m) for m in range(2005,2025)]:\n",
    "#     print(model_year)\n",
    "#     for location in [\"OR\",\"GA\",\"HI\"]:\n",
    "#         crpss_result = []\n",
    "#         for seed in range(2025,2035): \n",
    "            \n",
    "#             random.seed(seed)\n",
    "#             seed_list = [random.randint(1, 1000) for _ in range(1200)]\n",
    "\n",
    "#             result = []\n",
    "#             for time_now in date_strings:\n",
    "#                 save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "#                 df_save = pd.read_csv(save_to)\n",
    "\n",
    "#                 model_path = f'./newest models and scalers/{location}_{model_year}.pt'\n",
    "# #                 data_path = f'./testing/{location}_2023-{time_now}/'\n",
    "#                 data_path = f'./eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "# #                 traindata_path = f'./training/{model_year}_seed_60/{location}_0.csv'\n",
    "# #                 traindata_path = f'./new_training/{model_year}/{location}_{model_year}.csv'\n",
    "#                 scaler_path = f'./newest models and scalers/{location}_{model_year}_scaler.gz'\n",
    "\n",
    "#                 result.append(get_pred(location, model_year, time_now, scaler_path, seed_list))\n",
    "\n",
    "#             if location == \"OR\":\n",
    "#                 df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "#                 df_save.to_csv(save_to, index = False)\n",
    "#             elif location == \"HI\":\n",
    "#                 df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "#                 df_save.to_csv(save_to, index = False)\n",
    "#             elif location == \"GA\" or location == \"TX\":\n",
    "#                 df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "#                 df_save.to_csv(save_to, index = False)\n",
    "\n",
    "#             data = pd.read_csv(f\"{save_to}\")\n",
    "#             data = crps(data,f\"{save_to}\")\n",
    "#             data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "#             data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "#             crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "#             crpss_result.append(np.round(crpss,3))\n",
    "        \n",
    "#         print(crpss_result)\n",
    "\n",
    "#         crpss_result = np.array(crpss_result).reshape(-1,3)\n",
    "#         result_sum = pd.DataFrame(crpss_result)\n",
    "#         result_sum[\"seed\"] = int(model_year)\n",
    "#         result_sum.columns = [\"OR\",\"GA\",\"HI\",\"seed\"]\n",
    "\n",
    "#         result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "#         result_sub = pd.DataFrame(result_sub)\n",
    "#         result_sub.columns = [\"OR\",\"GA\",\"HI\",\"seed\"]\n",
    "#         summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "#         print(summary)\n",
    "\n",
    "# summary.to_csv(f\"D:/Jupyter/Research/Competition/ICLR/model/forecast_evaluation/Ablation_LSTM_other.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84386a20",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
