{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "19a887d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import time\n",
    "import datetime\n",
    "from datetime import timedelta\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import joblib\n",
    "from scipy.stats import norm\n",
    "import warnings\n",
    "import random\n",
    "import os\n",
    "\n",
    "from datetime import datetime, timedelta\n",
    "\n",
    "from sklearn.exceptions import InconsistentVersionWarning\n",
    "\n",
    "warnings.filterwarnings(\"ignore\", category=InconsistentVersionWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "c6ee0cca",
   "metadata": {},
   "outputs": [],
   "source": [
    "def continuous_ranked_probability_score(obs, fx, fx_prob):\n",
    "    \"\"\"Continuous Ranked Probability Score (CRPS).\n",
    "\n",
    "    .. math::\n",
    "\n",
    "        \\\\text{CRPS} = \\\\frac{1}{n} \\\\sum_{i=1}^n \\\\int_{-\\\\infty}^{\\\\infty}\n",
    "        (F_i(x) - \\\\mathbf{1} \\\\{x \\\\geq y_i \\\\})^2 dx\n",
    "\n",
    "    where :math:`F_i(x)` is the CDF of the forecast at time :math:`i`,\n",
    "    :math:`y_i` is the observation at time :math:`i`, and :math:`\\\\mathbf{1}`\n",
    "    is the indicator function that transforms the observation into a step\n",
    "    function (1 if :math:`x \\\\geq y`, 0 if :math:`x < y`). In other words, the\n",
    "    CRPS measures the difference between the forecast CDF and the empirical CDF\n",
    "    of the observation. The CRPS has the same units as the observation. Lower\n",
    "    CRPS values indicate more accurate forecasts, where a CRPS of 0 indicates a\n",
    "    perfect forecast. [1]_ [2]_ [3]_\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    obs : (n,) array_like\n",
    "        Observations (physical unit).\n",
    "    fx : (n, d) array_like\n",
    "        Forecasts (physical units) of the right-hand-side of a CDF with d\n",
    "        intervals (d >= 2), e.g., fx = [10 MW, 20 MW, 30 MW] is interpreted as\n",
    "        <= 10 MW, <= 20 MW, <= 30 MW.\n",
    "    fx_prob : (n, d) array_like\n",
    "        Probability [%] associated with the forecasts.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    crps : float\n",
    "        The Continuous Ranked Probability Score, with the same units as the\n",
    "        observation.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the forecasts have incorrect dimensions; either a) the forecasts are\n",
    "        for a single sample (n=1) with d CDF intervals but are given as a 1D\n",
    "        array with d values or b) the forecasts are given as 2D arrays (n,d)\n",
    "        but do not contain at least 2 CDF intervals (i.e. d < 2).\n",
    "\n",
    "    Notes\n",
    "    -----\n",
    "    The CRPS can be calculated analytically when the forecast CDF is of a\n",
    "    continuous parametric distribution, e.g., Gaussian distribution. However,\n",
    "    since the Solar Forecast Arbiter makes no assumptions regarding how a\n",
    "    probabilistic forecast was generated, the CRPS is instead calculated using\n",
    "    numerical integration of the discretized forecast CDF. Therefore, the\n",
    "    accuracy of the CRPS calculation is limited by the precision of the\n",
    "    forecast CDF. In practice, this means the forecast CDF should 1) consist of\n",
    "    at least 10 intervals and 2) cover probabilities from 0% to 100%.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] Matheson and Winkler (1976) \"Scoring rules for continuous\n",
    "           probability distributions.\" Management Science, vol. 22, pp.\n",
    "           1087-1096. doi: 10.1287/mnsc.22.10.1087\n",
    "    .. [2] Hersbach (2000) \"Decomposition of the continuous ranked probability\n",
    "           score for ensemble prediction systems.\" Weather Forecast, vol. 15,\n",
    "           pp. 559-570. doi: 10.1175/1520-0434(2000)015<0559:DOTCRP>2.0.CO;2\n",
    "    .. [3] Wilks (2019) \"Statistical Methods in the Atmospheric Sciences\", 4th\n",
    "           ed. Oxford; Waltham, MA; Academic Press.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    # match observations to fx shape: (n,) => (n, d)\n",
    "    if np.ndim(fx) < 2:\n",
    "        raise ValueError(\"forecasts must be 2D arrays (expected (n,d), got\"\n",
    "                         f\"{np.shape(fx)})\")\n",
    "    elif np.shape(fx)[1] < 2:\n",
    "        raise ValueError(\"forecasts must have d >= 2 CDF intervals \"\n",
    "                         f\"(expected >= 2, got {np.shape(fx)[1]})\")\n",
    "\n",
    "    n = len(fx)\n",
    "\n",
    "    # extend CDF min to ensure obs within forecast support\n",
    "    # fx.shape = (n, d) ==> (n, d + 1)\n",
    "    fx_min = np.minimum(obs, fx[:, 0])\n",
    "    fx = np.hstack([fx_min[:, np.newaxis], fx])\n",
    "    fx_prob = np.hstack([np.zeros([n, 1]), fx_prob])\n",
    "\n",
    "    # extend CDF max to ensure obs within forecast support\n",
    "    # fx.shape = (n, d + 1) ==> (n, d + 2)\n",
    "    idx = (fx[:, -1] < obs)\n",
    "    fx_max = np.maximum(obs, fx[:, -1])\n",
    "    fx = np.hstack([fx, fx_max[:, np.newaxis]])\n",
    "    fx_prob = np.hstack([fx_prob, np.full([n, 1], 100)])\n",
    "\n",
    "    # indicator function:\n",
    "    # - left of the obs is 0.0\n",
    "    # - obs and right of the obs is 1.0\n",
    "    o = np.where(fx >= obs[:, np.newaxis], 1.0, 0.0)\n",
    "\n",
    "    # correct behavior when obs > max fx:\n",
    "    # - should be 0 over range: max fx < x < obs\n",
    "    o[idx, -1] = 0.0\n",
    "\n",
    "    # forecast probabilities [unitless]\n",
    "    f = fx_prob / 100.0\n",
    "\n",
    "    # integrate along each sample, then average all samples\n",
    "    crps = np.mean(np.trapz((f - o) ** 2, x=fx, axis=1))\n",
    "\n",
    "    return crps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "1c792326",
   "metadata": {},
   "outputs": [],
   "source": [
    "def crps(data,path):\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][2:13]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"pred_crps\"] = crps_loss\n",
    "\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][13:24]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"ref_crps\"] = crps_loss\n",
    "    \n",
    "    data.to_csv(path, index = False)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b541b091",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Attention(nn.Module):\n",
    "    def __init__(self, d=64):\n",
    "        super(Attention, self).__init__()\n",
    "        self.query = nn.Linear(32, d)\n",
    "        self.key = nn.Linear(32, d)\n",
    "        self.value = nn.Linear(32, d)\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        self.d = d\n",
    "\n",
    "    def forward(self, future, history):\n",
    "        query = self.query(future)\n",
    "        key = self.key(history)\n",
    "        value = self.value(history)\n",
    "        \n",
    "\n",
    "        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d ** 0.5)\n",
    "        attention_weights = self.softmax(attention_scores)\n",
    "        attention_output = torch.matmul(attention_weights, value)\n",
    "\n",
    "        return attention_output\n",
    "\n",
    "\n",
    "class MLPAttention(nn.Module):\n",
    "    def __init__(self, dropout=0.2):\n",
    "        super(MLPAttention, self).__init__()\n",
    "        self.future_net = nn.Sequential(\n",
    "            nn.Linear(11, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout) \n",
    "        )\n",
    "        self.history_net = nn.Sequential(\n",
    "            nn.Linear(84, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout) \n",
    "        )\n",
    "        self.attention = Attention(64)\n",
    "        self.combined_net = nn.Sequential(\n",
    "            nn.Linear(64 + 32, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 11)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        future = x[:, :11]\n",
    "        history = x[:, 11:]\n",
    "        future_out = self.future_net(future)\n",
    "        history_out = self.history_net(history)\n",
    "        attention_out = self.attention(future_out, history_out)\n",
    "        combined = torch.cat((attention_out, future_out), dim=1)\n",
    "        out = self.combined_net(combined)\n",
    "\n",
    "        out = out.reshape(-1, 1, 1, 11)\n",
    "        x1, x2 = out[:, :, :, 0], out[:, :, :, 1:]\n",
    "        x1 = torch.unsqueeze(x1, -1)\n",
    "        x2 = F.softplus(x2)\n",
    "        out = torch.cat((x1, x2), dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n",
    "\n",
    "\n",
    "class MLPAttentionTX(nn.Module):\n",
    "    def __init__(self, dropout=0.2):\n",
    "        super(MLPAttentionTX, self).__init__()\n",
    "        self.future_net = nn.Sequential(\n",
    "            nn.Linear(14, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout)  \n",
    "        )\n",
    "        self.history_net = nn.Sequential(\n",
    "            nn.Linear(84, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout)  \n",
    "        )\n",
    "        self.attention = Attention(64)\n",
    "        self.combined_net = nn.Sequential(\n",
    "            nn.Linear(64 + 32, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 11)\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        future = x[:, :14]\n",
    "        history = x[:, 14:]\n",
    "        future_out = self.future_net(future)\n",
    "        history_out = self.history_net(history)\n",
    "        attention_out = self.attention(future_out, history_out)\n",
    "        combined = torch.cat((attention_out, future_out), dim=1)\n",
    "        out = self.combined_net(combined)\n",
    "\n",
    "        out = out.reshape(-1, 1, 1, 11)\n",
    "        x1, x2 = out[:, :, :, 0], out[:, :, :, 1:]\n",
    "        x1 = torch.unsqueeze(x1, -1)\n",
    "        x2 = F.softplus(x2)\n",
    "        out = torch.cat((x1, x2), dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4eddf251",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_26116\\1977948528.py:12: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_26116\\1977948528.py:12: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_26116\\1977948528.py:12: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_26116\\1977948528.py:12: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_26116\\1977948528.py:12: UserWarning: Could not infer format, so each element will be parsed individually, falling back to `dateutil`. To ensure parsing is consistent and as-expected, please specify a format.\n",
      "  forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n"
     ]
    }
   ],
   "source": [
    "start_date = datetime.strptime(\"2023-06-27\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(19)]\n",
    "std_summary = pd.DataFrame()\n",
    "\n",
    "for feature in [\"DNI\",\"DHI\",\"GHI\",\"temperature\",\"relativehumidity\"]:\n",
    "    for location in [\"GA\",\"HI\",\"OR\",\"TX\"]:\n",
    "        hour = []\n",
    "        prediction = []\n",
    "        ground_truth = []\n",
    "        for date in date_strings:\n",
    "            forecast = pd.read_csv(f\"eval/old_eval/{location}_{date}/{location}_future.csv\") \n",
    "            forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
    "            prediction.extend(forecast[feature])\n",
    "\n",
    "            truth = pd.read_csv(f\"eval/1_day_ahead/{location}_{date}/{location}_future.csv\") \n",
    "            truth[\"localTime\"] = pd.to_datetime(truth[\"localTime\"])\n",
    "            ground_truth.extend(truth[feature])\n",
    "\n",
    "            hour.extend(truth[\"timeofday\"])\n",
    "        \n",
    "        df = pd.DataFrame()\n",
    "        df[\"hour\"] = hour\n",
    "        df[\"truth\"] = ground_truth\n",
    "        df[\"prediction\"] = prediction\n",
    "        df[\"std(error)\"] = df[\"prediction\"] - df[\"truth\"]\n",
    "\n",
    "        std_summary[\"hour\"] = df.groupby(\"hour\", as_index = False).std()[\"hour\"]\n",
    "        std_summary[f\"{location}_{feature}\"] = df.groupby(\"hour\", as_index = False).std()[\"std(error)\"]\n",
    "\n",
    "def generate_error(data, location, feature, seed):\n",
    "    \n",
    "    random.seed(seed)\n",
    "    \n",
    "    error = [random.gauss(mu=0, sigma=std * 1) for std in std_summary.iloc[data[\"timeofday\"]][f\"{location}_{feature}\"]]\n",
    "    \n",
    "    return error\n",
    "\n",
    "\n",
    "def simulate_forecast(data,location,seed_list):\n",
    "    \n",
    "\n",
    "    if location == \"GA\":\n",
    "        DNI_min, DNI_max = 0, 1040\n",
    "        DHI_min, DHI_max = 0, 551\n",
    "    elif location == \"HI\":\n",
    "        DNI_min, DNI_max = 0, 989\n",
    "        DHI_min, DHI_max = 0, 567\n",
    "    elif location == \"OR\":\n",
    "        DNI_min, DNI_max = 0, 981\n",
    "        DHI_min, DHI_max = 0, 541 \n",
    "    elif location == \"TX\":\n",
    "        DNI_min, DNI_max = 0, 1052\n",
    "        DHI_min, DHI_max = 0, 560\n",
    "\n",
    "    if location == \"TX\":\n",
    "        tem_min , tem_max = -8, 40\n",
    "    elif location == \"GA\":\n",
    "        tem_min , tem_max = -4, 38\n",
    "    elif location == \"HI\":\n",
    "        tem_min, tem_max = 20, 28\n",
    "    elif location == \"OR\":\n",
    "        tem_min, tem_max = -7, 38\n",
    "        \n",
    "\n",
    "    data[\"total_grid\"] = 0\n",
    "    \n",
    "    data[\"DNI\"] = data[\"DNI\"] + generate_error(data, location, \"DNI\", seed_list[0])\n",
    "    data[\"DHI\"] = data[\"DHI\"] + generate_error(data, location, \"DHI\", seed_list[1])\n",
    "    \n",
    "    data[\"DHI\"] = np.clip(data[\"DHI\"], DHI_min, DHI_max)\n",
    "    data[\"DNI\"] = np.clip(data[\"DNI\"], DNI_min, DNI_max)\n",
    "    \n",
    "    \n",
    "    data[\"temperature\"] = data[\"temperature\"] + generate_error(data, location, \"temperature\", seed_list[2])\n",
    "    data[\"relativehumidity\"] = data[\"relativehumidity\"] + generate_error(data, location, \"relativehumidity\", seed_list[3])\n",
    "    data[\"relativehumidity\"] = np.clip(data[\"relativehumidity\"], 0, 100)\n",
    "    data[\"temperature\"] = np.clip(data[\"temperature\"], tem_min ,tem_max)\n",
    "    \n",
    "  \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ed15103c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_pred(location, data_path, date, data_seed, seed_list):\n",
    "    \n",
    "    look_back = 336\n",
    "    gap = 24\n",
    "    masked = 1\n",
    "    task_type = 'MLPAttention'\n",
    "    model_path = f'models_and_scalers/{data_seed}/'\n",
    "\n",
    "    df_history = pd.read_csv(os.path.join(data_path, f\"{location}_history.csv\"))\n",
    "    df_future = pd.read_csv(os.path.join(data_path, f\"{location}_future.csv\"))\n",
    "\n",
    "    ###\n",
    "    df_future = simulate_forecast(df_future.copy(),location,seed_list)\n",
    "    ###\n",
    "\n",
    "    if \"total_grid\" not in df_future.columns:\n",
    "        new_col = np.zeros(df_future.shape[0])\n",
    "        df_future.insert(0, 'total_grid', new_col)\n",
    "\n",
    "\n",
    "    df_history = df_history.iloc[-336:]\n",
    "    df_history.index = range(len(df_history))\n",
    "    tod = df_history[\"timeofday\"]\n",
    "    if tod[0] != 1:\n",
    "        stop\n",
    "\n",
    "    if \"total_grid\" not in df_future.columns:\n",
    "        df_future[\"total_grid\"] = 0\n",
    "\n",
    "    df = pd.concat([df_history, df_future])\n",
    "\n",
    "    df['date'] = pd.to_datetime(df['year'].astype(str) + df['dayofyear'].astype(str), format='%Y%j')\n",
    "    df['weekday'] = df['date'].dt.weekday\n",
    "\n",
    "\n",
    "    data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]  \n",
    "\n",
    "    data_cont = data_cont.values\n",
    "    data_time = df[['weekday','dayofyear', 'timeofday', 'month','year']]\n",
    "    data_time = data_time.values\n",
    "\n",
    "    scaler = joblib.load(model_path+location+'_L='+str(look_back)+'_gap='+str(gap)+'_'+task_type+'_scaler.gz')\n",
    "    data_cont = scaler.transform(data_cont)\n",
    "\n",
    "\n",
    "    if location != \"TX\":\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,3],12)   # month of year\n",
    "        data_time = np.concatenate((W,H,M),0).T\n",
    "        data_context = np.expand_dims(np.concatenate((data_cont,data_time),1), axis=0)\n",
    "    else:\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        D = cycl_(data_time[:,1],365)  # day of year\n",
    "        H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,3],12)   # month of year\n",
    "        Y = torch.tensor(data_time[:,4] - 2022).unsqueeze(0) # year\n",
    "        data_time = np.concatenate((W,D,H,M,Y),0).T\n",
    "        data_context = np.expand_dims(np.concatenate((data_cont,data_time),1), axis=0)\n",
    "\n",
    "    hist_input = data_context[0,:look_back,:]\n",
    "    future_input = data_context[0,look_back:,masked:]\n",
    "\n",
    "    def helper_test(x_train_hist, x_train_future):\n",
    "\n",
    "            x_train = []\n",
    "            for k in range(24):\n",
    "                x_train_sub = []\n",
    "                x_train_sub.extend(x_train_future[k])\n",
    "                for d in range(14):\n",
    "                    x_train_sub.extend(x_train_hist[d*24 + k,:6])\n",
    "                x_train.append(np.array(x_train_sub).reshape(1,-1))\n",
    "            x_test = np.concatenate(x_train)\n",
    "\n",
    "            return x_test\n",
    "\n",
    "    x_test = helper_test(hist_input, future_input)\n",
    "    \n",
    "    if location != \"TX\":\n",
    "        model = MLPAttention()\n",
    "    else:\n",
    "        model = MLPAttentionTX()\n",
    "    \n",
    "    model_path = model_path + f'{location}_MLPAttention.pth'\n",
    "    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))\n",
    "    \n",
    "    x_test = torch.tensor(x_test, dtype=torch.float32)\n",
    "    pred = torch.squeeze(model(x_test)).detach().numpy()\n",
    "    \n",
    "    return pred\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "f52f4cfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-18\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(28)]\n",
    "date_strings = [k[5:] for k in date_strings]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2cfbb325",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "       TX      GA      OR     HI    seed\n",
      "0  0.2505  0.1715  0.3835 -0.033  2005.0\n"
     ]
    }
   ],
   "source": [
    "summary = pd.DataFrame()\n",
    "for model_year in [str(m) for m in range(2005,2025)]:\n",
    "    crpss_result = []\n",
    "    for seed in range(2025,2035): \n",
    "        for location in [\"TX\",\"GA\",\"OR\",\"HI\"]:\n",
    "            \n",
    "            random.seed(seed)\n",
    "            \n",
    "            seed_dict = {}\n",
    "            seed_dict[\"OR\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "            seed_dict[\"HI\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "            seed_dict[\"GA\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "            seed_dict[\"TX\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "            \n",
    "            seed_list = seed_dict[location]\n",
    "\n",
    "            result = []\n",
    "            for time_now in date_strings:\n",
    "                save_to = f\"eval/{location}_1_day_ahead.csv\"\n",
    "                df_save = pd.read_csv(save_to)\n",
    "                \n",
    "                data_path = f'eval/1_day_ahead/{location}_2023-{time_now}/'\n",
    "                \n",
    "                result.append(get_pred(location, data_path, time_now, model_year,seed_list))\n",
    "\n",
    "            if location == \"OR\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"HI\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "            elif location == \"GA\" or location == \"TX\":\n",
    "                df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "                df_save.to_csv(save_to, index = False)\n",
    "\n",
    "            data = pd.read_csv(f\"{save_to}\")\n",
    "            data = crps(data,f\"{save_to}\")\n",
    "            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "            crpss_result.append(np.round(crpss,3))\n",
    "            \n",
    "    crpss_result = np.array(crpss_result).reshape(-1,4)\n",
    "    result_sum = pd.DataFrame(crpss_result)\n",
    "    result_sum[\"seed\"] = int(model_year)\n",
    "    result_sum.columns = [\"TX\",\"GA\",\"OR\",\"HI\",\"seed\"]\n",
    "    \n",
    "    result_sub = np.array(result_sum.mean(axis = 0)).reshape(1,-1)\n",
    "    result_sub = pd.DataFrame(result_sub)\n",
    "    result_sub.columns = [\"TX\",\"GA\",\"OR\",\"HI\",\"seed\"]\n",
    "    summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "    \n",
    "    if model_year == \"2015\":\n",
    "        break\n",
    "\n",
    "summary.to_csv(f\"../../forecast_evaluation/Ablation_MLP_NEW.csv\", index = False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
