{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "67a05e5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import time\n",
    "import datetime\n",
    "from datetime import timedelta\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "import joblib\n",
    "from scipy.stats import norm\n",
    "import warnings\n",
    "import random\n",
    "import os\n",
    "\n",
    "from datetime import datetime, timedelta\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bfb80175",
   "metadata": {},
   "outputs": [],
   "source": [
    "def continuous_ranked_probability_score(obs, fx, fx_prob):\n",
    "    \"\"\"Continuous Ranked Probability Score (CRPS).\n",
    "\n",
    "    .. math::\n",
    "\n",
    "        \\\\text{CRPS} = \\\\frac{1}{n} \\\\sum_{i=1}^n \\\\int_{-\\\\infty}^{\\\\infty}\n",
    "        (F_i(x) - \\\\mathbf{1} \\\\{x \\\\geq y_i \\\\})^2 dx\n",
    "\n",
    "    where :math:`F_i(x)` is the CDF of the forecast at time :math:`i`,\n",
    "    :math:`y_i` is the observation at time :math:`i`, and :math:`\\\\mathbf{1}`\n",
    "    is the indicator function that transforms the observation into a step\n",
    "    function (1 if :math:`x \\\\geq y`, 0 if :math:`x < y`). In other words, the\n",
    "    CRPS measures the difference between the forecast CDF and the empirical CDF\n",
    "    of the observation. The CRPS has the same units as the observation. Lower\n",
    "    CRPS values indicate more accurate forecasts, where a CRPS of 0 indicates a\n",
    "    perfect forecast. [1]_ [2]_ [3]_\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    obs : (n,) array_like\n",
    "        Observations (physical unit).\n",
    "    fx : (n, d) array_like\n",
    "        Forecasts (physical units) of the right-hand-side of a CDF with d\n",
    "        intervals (d >= 2), e.g., fx = [10 MW, 20 MW, 30 MW] is interpreted as\n",
    "        <= 10 MW, <= 20 MW, <= 30 MW.\n",
    "    fx_prob : (n, d) array_like\n",
    "        Probability [%] associated with the forecasts.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    crps : float\n",
    "        The Continuous Ranked Probability Score, with the same units as the\n",
    "        observation.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the forecasts have incorrect dimensions; either a) the forecasts are\n",
    "        for a single sample (n=1) with d CDF intervals but are given as a 1D\n",
    "        array with d values or b) the forecasts are given as 2D arrays (n,d)\n",
    "        but do not contain at least 2 CDF intervals (i.e. d < 2).\n",
    "\n",
    "    Notes\n",
    "    -----\n",
    "    The CRPS can be calculated analytically when the forecast CDF is of a\n",
    "    continuous parametric distribution, e.g., Gaussian distribution. However,\n",
    "    since the Solar Forecast Arbiter makes no assumptions regarding how a\n",
    "    probabilistic forecast was generated, the CRPS is instead calculated using\n",
    "    numerical integration of the discretized forecast CDF. Therefore, the\n",
    "    accuracy of the CRPS calculation is limited by the precision of the\n",
    "    forecast CDF. In practice, this means the forecast CDF should 1) consist of\n",
    "    at least 10 intervals and 2) cover probabilities from 0% to 100%.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] Matheson and Winkler (1976) \"Scoring rules for continuous\n",
    "           probability distributions.\" Management Science, vol. 22, pp.\n",
    "           1087-1096. doi: 10.1287/mnsc.22.10.1087\n",
    "    .. [2] Hersbach (2000) \"Decomposition of the continuous ranked probability\n",
    "           score for ensemble prediction systems.\" Weather Forecast, vol. 15,\n",
    "           pp. 559-570. doi: 10.1175/1520-0434(2000)015<0559:DOTCRP>2.0.CO;2\n",
    "    .. [3] Wilks (2019) \"Statistical Methods in the Atmospheric Sciences\", 4th\n",
    "           ed. Oxford; Waltham, MA; Academic Press.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    # match observations to fx shape: (n,) => (n, d)\n",
    "    if np.ndim(fx) < 2:\n",
    "        raise ValueError(\"forecasts must be 2D arrays (expected (n,d), got\"\n",
    "                         f\"{np.shape(fx)})\")\n",
    "    elif np.shape(fx)[1] < 2:\n",
    "        raise ValueError(\"forecasts must have d >= 2 CDF intervals \"\n",
    "                         f\"(expected >= 2, got {np.shape(fx)[1]})\")\n",
    "\n",
    "    n = len(fx)\n",
    "\n",
    "    # extend CDF min to ensure obs within forecast support\n",
    "    # fx.shape = (n, d) ==> (n, d + 1)\n",
    "    fx_min = np.minimum(obs, fx[:, 0])\n",
    "    fx = np.hstack([fx_min[:, np.newaxis], fx])\n",
    "    fx_prob = np.hstack([np.zeros([n, 1]), fx_prob])\n",
    "\n",
    "    # extend CDF max to ensure obs within forecast support\n",
    "    # fx.shape = (n, d + 1) ==> (n, d + 2)\n",
    "    idx = (fx[:, -1] < obs)\n",
    "    fx_max = np.maximum(obs, fx[:, -1])\n",
    "    fx = np.hstack([fx, fx_max[:, np.newaxis]])\n",
    "    fx_prob = np.hstack([fx_prob, np.full([n, 1], 100)])\n",
    "\n",
    "    # indicator function:\n",
    "    # - left of the obs is 0.0\n",
    "    # - obs and right of the obs is 1.0\n",
    "    o = np.where(fx >= obs[:, np.newaxis], 1.0, 0.0)\n",
    "\n",
    "    # correct behavior when obs > max fx:\n",
    "    # - should be 0 over range: max fx < x < obs\n",
    "    o[idx, -1] = 0.0\n",
    "\n",
    "    # forecast probabilities [unitless]\n",
    "    f = fx_prob / 100.0\n",
    "\n",
    "    # integrate along each sample, then average all samples\n",
    "    crps = np.mean(np.trapz((f - o) ** 2, x=fx, axis=1))\n",
    "\n",
    "    return crps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c82de078",
   "metadata": {},
   "outputs": [],
   "source": [
    "def crps(data,path):\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][2:13]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"pred_crps\"] = crps_loss\n",
    "\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][13:24]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"ref_crps\"] = crps_loss\n",
    "    \n",
    "    data.to_csv(path, index = False)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0e81f594",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LSTM(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LSTM, self).__init__()\n",
    "\n",
    "        self.lstm = nn.LSTM(\n",
    "            input_size=6,   \n",
    "            hidden_size=60,\n",
    "            num_layers=1, \n",
    "            batch_first=True,\n",
    "#             dropout=0.3\n",
    "        )\n",
    "        self.ann2 = nn.Sequential(\n",
    "            nn.Linear(14,14),\n",
    "            nn.ReLU(),\n",
    "#             nn.Dropout(p=0.3),\n",
    "            nn.Linear(14,1),\n",
    "        )\n",
    "        self.out = nn.Sequential(\n",
    "            nn.Linear(60,11),\n",
    "        )\n",
    "        self.state = None\n",
    "\n",
    "    def forward(self, x, z):\n",
    "\n",
    "        batch = x.shape[0]\n",
    "        tx = torch.reshape(x,(batch,336,-1))\n",
    "        tx, self.state = self.lstm(tx, self.state)\n",
    "        tz = torch.reshape(z,(batch,24,-1))\n",
    "        temp = self.ann2(tz)\n",
    "#         combined = torch.cat((tx[:,-24:,:], temp), dim=2)\n",
    "#         out = self.out(combined)\n",
    "        out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "        out = torch.unsqueeze(out,2)\n",
    "        out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "        out1 = torch.unsqueeze(out1, -1)\n",
    "        out2 = F.softplus(out2)\n",
    "        out = torch.cat((out1,out2),dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n",
    "\n",
    "def load_model(model_path):\n",
    "    model = LSTM()\n",
    "    model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))\n",
    "    model.double()\n",
    "    model.eval()\n",
    "    return model\n",
    "\n",
    "def read_data_cont(df):\n",
    "    data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "    data_cont = data_cont.values\n",
    "    return data_cont\n",
    "def read_data_target(df):\n",
    "    data_target = df[['total_grid']]\n",
    "    data_target = data_target.values\n",
    "    return data_target\n",
    "def read_data_time(df):\n",
    "    data_time = df[['dayofweek','timeofday', 'month','dayofyear','year']]\n",
    "    data_time = data_time.values\n",
    "    return data_time\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "de0defa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred(location, model_year, time_now, scaler_path, seed_list):\n",
    "    \n",
    "    df_history = pd.read_csv(os.path.join(data_path, f\"{location}_history.csv\"))\n",
    "    df_future = pd.read_csv(os.path.join(data_path, f\"{location}_future.csv\"))\n",
    "    \n",
    "    df_future = simulate_forecast(df_future.copy(),location,seed_list)\n",
    "    \n",
    "    if \"total_grid\" not in df_future.columns:\n",
    "        new_col = np.zeros(df_future.shape[0])\n",
    "        df_future.insert(0, 'total_grid', new_col)\n",
    "\n",
    "#     if not os.path.exists(traindata_path):\n",
    "#         print(f\"Error: Directory {traindata_path} does not exist.\")\n",
    "\n",
    "\n",
    "#     df_raw = pd.read_csv(traindata_path)\n",
    "#     rawdata = df_raw.copy()\n",
    "#     df_raw = rawdata[(rawdata[\"year\"] == 2022) | ((rawdata[\"year\"] == 2023) & (rawdata[\"dayofyear\"] <= 134))]   \n",
    "#     rawdata_cont = read_data_cont(df_raw)\n",
    "#     scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "# #     scaler = StandardScaler()\n",
    "#     scaler.fit(rawdata_cont)\n",
    "    \n",
    "\n",
    "    if not os.path.exists(model_path):\n",
    "        print(f\"Error: Model file {model_path} does not exist.\")\n",
    "        \n",
    "    if not os.path.exists(scaler_path):\n",
    "        print(f\"Error: Scaler file {scaler_path} does not exist.\")\n",
    "\n",
    "    model = load_model(model_path)\n",
    "    scaler = joblib.load(scaler_path)\n",
    "\n",
    "    def read_test(df):\n",
    "        test_cont = read_data_cont(df)\n",
    "        test_cont = scaler.transform(test_cont)\n",
    "\n",
    "        test_time = read_data_time(df)\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(test_time[:,0],7)    \n",
    "        H = cycl_(test_time[:,1],24)   \n",
    "        M = cycl_(test_time[:,2]-1,12)  \n",
    "        D = cycl_(test_time[:,3],365)\n",
    "        Y = torch.tensor(test_time[:,4] - 2022).unsqueeze(0)\n",
    "        test_time = np.concatenate((W,H,M,D,Y),0).T\n",
    "\n",
    "        test_context = np.concatenate((test_cont,test_time),1)\n",
    "\n",
    "        return test_context\n",
    "\n",
    "    history_context = read_test(df_history)[-336:,:6]\n",
    "    future_context = read_test(df_future)[:,1:]\n",
    "    \n",
    "    \n",
    "    test_X = (torch.unsqueeze(torch.tensor(history_context, dtype=torch.double),axis=0))\n",
    "    test_Z = (torch.unsqueeze(torch.tensor(future_context, dtype=torch.double),axis=0))\n",
    "\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        state = None\n",
    "        preds = model(test_X,test_Z)\n",
    "        preds = torch.squeeze(preds)\n",
    "       \n",
    "    df_pred = pd.DataFrame(preds)\n",
    "    df_pred.columns = [\"p00\",\"p10\",\"p20\",\"p30\",\"p40\",\"p50\",\"p60\",\"p70\",\"p80\",\"p90\",\"p100\"]\n",
    "    df_pred.index = df_future['localTime']\n",
    "    \n",
    "    return np.array(df_pred)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ee24e049",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-27\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(19)]\n",
    "std_summary = pd.DataFrame()\n",
    "\n",
    "for feature in [\"DNI\",\"DHI\",\"GHI\",\"temperature\",\"relativehumidity\"]:\n",
    "    for location in [\"GA\",\"HI\",\"OR\",\"TX\"]:\n",
    "        hour = []\n",
    "        prediction = []\n",
    "        ground_truth = []\n",
    "        for date in date_strings:\n",
    "            forecast = pd.read_csv(f\"eval/old_eval/{location}_{date}/{location}_future.csv\") \n",
    "            forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
    "            prediction.extend(forecast[feature])\n",
    "\n",
    "            truth = pd.read_csv(f\"eval/1_day_ahead/{location}_{date}/{location}_future.csv\") \n",
    "            truth[\"localTime\"] = pd.to_datetime(truth[\"localTime\"])\n",
    "            ground_truth.extend(truth[feature])\n",
    "\n",
    "            hour.extend(truth[\"timeofday\"])\n",
    "        \n",
    "        df = pd.DataFrame()\n",
    "        df[\"hour\"] = hour\n",
    "        df[\"truth\"] = ground_truth\n",
    "        df[\"prediction\"] = prediction\n",
    "        df[\"std(error)\"] = df[\"prediction\"] - df[\"truth\"]\n",
    "\n",
    "        std_summary[\"hour\"] = df.groupby(\"hour\", as_index = False).std()[\"hour\"]\n",
    "        std_summary[f\"{location}_{feature}\"] = df.groupby(\"hour\", as_index = False).std()[\"std(error)\"]\n",
    "\n",
    "def generate_error(data, location, feature, seed):\n",
    "    \n",
    "    random.seed(seed)\n",
    "    \n",
    "    error = [random.gauss(mu=0, sigma=std * 1) for std in std_summary.iloc[data[\"timeofday\"]][f\"{location}_{feature}\"]]\n",
    "    \n",
    "    return error\n",
    "\n",
    "\n",
    "def simulate_forecast(data,location,seed_list):\n",
    "    \n",
    "\n",
    "    if location == \"GA\":\n",
    "        DNI_min, DNI_max = 0, 1040\n",
    "        DHI_min, DHI_max = 0, 551\n",
    "    elif location == \"HI\":\n",
    "        DNI_min, DNI_max = 0, 989\n",
    "        DHI_min, DHI_max = 0, 567\n",
    "    elif location == \"OR\":\n",
    "        DNI_min, DNI_max = 0, 981\n",
    "        DHI_min, DHI_max = 0, 541 \n",
    "    elif location == \"TX\":\n",
    "        DNI_min, DNI_max = 0, 1052\n",
    "        DHI_min, DHI_max = 0, 560\n",
    "        \n",
    "\n",
    "    data[\"total_grid\"] = 0\n",
    "    \n",
    "    data[\"DNI\"] = data[\"DNI\"] + generate_error(data, location, \"DNI\", seed_list[0])\n",
    "    data[\"DHI\"] = data[\"DHI\"] + generate_error(data, location, \"DHI\", seed_list[1])\n",
    "    \n",
    "    data[\"DHI\"] = np.clip(data[\"DHI\"], DHI_min, DHI_max)\n",
    "    data[\"DNI\"] = np.clip(data[\"DNI\"], DNI_min, DNI_max)\n",
    "    \n",
    "    \n",
    "    data[\"temperature\"] = data[\"temperature\"] + generate_error(data, location, \"temperature\", seed_list[2])\n",
    "    data[\"relativehumidity\"] = data[\"relativehumidity\"] + generate_error(data, location, \"relativehumidity\", seed_list[3])\n",
    "    data[\"relativehumidity\"] = np.clip(data[\"relativehumidity\"], 0, 100)\n",
    "    \n",
    "  \n",
    "    return data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "59ac3fa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-18\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(28)]\n",
    "date_strings = [k[5:] for k in date_strings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "95c64e7b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2005\n",
      "[0.162, 0.311, 0.158, 0.29, 0.174, 0.096, 0.27, 0.202, 0.153, 0.117]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "2006\n",
      "[0.398, 0.402, 0.39, 0.324, 0.432, 0.38, 0.427, 0.445, 0.47, 0.411]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2007\n",
      "[0.505, 0.544, 0.459, 0.527, 0.525, 0.446, 0.527, 0.507, 0.46, 0.473]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "2008\n",
      "[0.45, 0.494, 0.406, 0.488, 0.423, 0.383, 0.491, 0.445, 0.42, 0.388]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "2009\n",
      "[0.408, 0.374, 0.401, 0.344, 0.457, 0.381, 0.377, 0.427, 0.424, 0.428]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "2010\n",
      "[0.41, 0.429, 0.447, 0.375, 0.508, 0.408, 0.458, 0.506, 0.471, 0.441]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "5  0.4453  2010.0\n",
      "2011\n",
      "[0.445, 0.441, 0.396, 0.408, 0.486, 0.412, 0.432, 0.474, 0.459, 0.434]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "5  0.4453  2010.0\n",
      "6  0.4387  2011.0\n",
      "2012\n",
      "[0.364, 0.327, 0.357, 0.251, 0.398, 0.34, 0.301, 0.34, 0.388, 0.357]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "5  0.4453  2010.0\n",
      "6  0.4387  2011.0\n",
      "7  0.3423  2012.0\n",
      "2013\n",
      "[0.203, 0.172, 0.196, 0.007, 0.315, 0.195, 0.12, 0.293, 0.314, 0.206]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "5  0.4453  2010.0\n",
      "6  0.4387  2011.0\n",
      "7  0.3423  2012.0\n",
      "8  0.2021  2013.0\n",
      "2014\n",
      "[0.247, 0.361, 0.197, 0.357, 0.214, 0.196, 0.326, 0.207, 0.174, 0.22]\n",
      "       TX    seed\n",
      "0  0.1933  2005.0\n",
      "1  0.4079  2006.0\n",
      "2  0.4973  2007.0\n",
      "3  0.4388  2008.0\n",
      "4  0.4021  2009.0\n",
      "5  0.4453  2010.0\n",
      "6  0.4387  2011.0\n",
      "7  0.3423  2012.0\n",
      "8  0.2021  2013.0\n",
      "9  0.2499  2014.0\n",
      "2015\n",
      "[0.464, 0.467, 0.447, 0.488, 0.456, 0.42, 0.475, 0.462, 0.427, 0.458]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "2016\n",
      "[0.139, 0.152, 0.121, 0.044, 0.238, 0.199, 0.203, 0.264, 0.297, 0.171]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "2017\n",
      "[0.275, 0.215, 0.321, 0.104, 0.352, 0.265, 0.246, 0.355, 0.326, 0.298]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "2018\n",
      "[0.502, 0.543, 0.448, 0.468, 0.521, 0.436, 0.52, 0.518, 0.513, 0.481]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "2019\n",
      "[0.458, 0.474, 0.429, 0.433, 0.474, 0.425, 0.473, 0.477, 0.44, 0.438]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "2020\n",
      "[0.232, 0.248, 0.269, 0.143, 0.324, 0.255, 0.232, 0.296, 0.356, 0.269]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "15  0.2624  2020.0\n",
      "2021\n",
      "[0.472, 0.471, 0.466, 0.456, 0.512, 0.445, 0.46, 0.486, 0.516, 0.485]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "15  0.2624  2020.0\n",
      "16  0.4769  2021.0\n",
      "2022\n",
      "[0.398, 0.343, 0.391, 0.285, 0.467, 0.385, 0.39, 0.433, 0.44, 0.391]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "15  0.2624  2020.0\n",
      "16  0.4769  2021.0\n",
      "17  0.3923  2022.0\n",
      "2023\n",
      "[0.381, 0.395, 0.388, 0.267, 0.437, 0.378, 0.39, 0.433, 0.452, 0.375]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "15  0.2624  2020.0\n",
      "16  0.4769  2021.0\n",
      "17  0.3923  2022.0\n",
      "18  0.3896  2023.0\n",
      "2024\n",
      "[0.331, 0.298, 0.352, 0.164, 0.447, 0.337, 0.332, 0.446, 0.424, 0.303]\n",
      "        TX    seed\n",
      "0   0.1933  2005.0\n",
      "1   0.4079  2006.0\n",
      "2   0.4973  2007.0\n",
      "3   0.4388  2008.0\n",
      "4   0.4021  2009.0\n",
      "5   0.4453  2010.0\n",
      "6   0.4387  2011.0\n",
      "7   0.3423  2012.0\n",
      "8   0.2021  2013.0\n",
      "9   0.2499  2014.0\n",
      "10  0.4564  2015.0\n",
      "11  0.1828  2016.0\n",
      "12  0.2757  2017.0\n",
      "13  0.4950  2018.0\n",
      "14  0.4521  2019.0\n",
      "15  0.2624  2020.0\n",
      "16  0.4769  2021.0\n",
      "17  0.3923  2022.0\n",
      "18  0.3896  2023.0\n",
      "19  0.3434  2024.0\n"
     ]
    }
   ],
   "source": [
    "summary = pd.DataFrame()\n",
    "\n",
    "for model_year in [str(m) for m in range(2005,2025)]:\n",
    "    print(model_year)\n",
    "    for location in [\"TX\"]:\n",
    "        crpss_result = []\n",
    "        for seed in range(2025,2035): \n",
    "            \n",
    "            random.seed(seed)\n",
    "            seed_list = [random.randint(1, 1000) for _ in range(1200)]\n",
    "\n",
    "            result = []\n",
    "            for time_now in date_strings:\n",
    "                save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "                df_save = pd.read_csv(save_to)\n",
    "\n",
    "                model_path = f'./newest models and scalers/{location}_{model_year}.pt'\n",
    "                data_path = f'./eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "#                 traindata_path = f'./training/{model_year}_seed_60/{location}_0.csv'\n",
    "#                 traindata_path = f'./new_training/{model_year}/{location}_{model_year}.csv'\n",
    "                scaler_path = f'./newest models and scalers/{location}_{model_year}_scaler.gz'\n",
    "\n",
    "                result.append(get_pred(location, model_year, time_now, scaler_path, seed_list))\n",
    "\n",
    "            if location == \"OR\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"HI\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"GA\" or location == \"TX\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "\n",
    "            data = pd.read_csv(f\"{save_to}\")\n",
    "            data = crps(data,f\"{save_to}\")\n",
    "            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "#             print(np.mean(data_group[\"pred_crps\"]),np.mean(data_group[\"ref_crps\"]))\n",
    "            crpss_result.append(np.round(crpss,3))\n",
    "        \n",
    "        print(crpss_result)\n",
    "    crpss_result = np.array(crpss_result).reshape(-1,1)\n",
    "    result_sum = pd.DataFrame(crpss_result)\n",
    "    result_sum[\"seed\"] = int(model_year)\n",
    "    result_sum.columns = [\"TX\",\"seed\"]\n",
    "    \n",
    "    result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "    result_sub = pd.DataFrame(result_sub)\n",
    "    result_sub.columns = [\"TX\",\"seed\"]\n",
    "    summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "    print(summary)\n",
    "\n",
    "summary.to_csv(f\"D:/Jupyter/Research/Competition/ICLR/model/forecast_evaluation/Ablation_LSTM_TX.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "569c7465",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
