{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a7677ba0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import time\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import joblib\n",
    "from src.models.TiTR_solar import TiTR\n",
    "from src.learner import get_model\n",
    "from src.metrics import *\n",
    "from scipy.stats import norm\n",
    "import warnings\n",
    "import os\n",
    "from datetime import datetime, timedelta\n",
    "import model\n",
    "import collections\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e0e3ae6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "def set_seed(seed):\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "222a82be",
   "metadata": {},
   "outputs": [],
   "source": [
    " def create_model(c_in, params):\n",
    "            model =  TiTR(c_in=c_in, context_window=params.context_window, target_window=params.target_window, \n",
    "                          patch_len = params.patch_len, stride=params.stride, padding_patch = params.padding_patch,\n",
    "                          patch_len2 = params.patch_len2, stride2=params.stride2, padding_patch2 = params.padding_patch2,\n",
    "                                n_layers=params.n_layers,\n",
    "                                    n_heads=params.n_heads,\n",
    "                                    d_model=params.d_model,\n",
    "                                    d_ff=params.d_ff,                        \n",
    "                                    dropout=params.dropout,\n",
    "                                    fc_dropout=params.fc_dropout,\n",
    "                                    head_dropout = params.head_dropout,\n",
    "                                    act='gelu',\n",
    "                                    individual = params.individual,\n",
    "                                    revin=params.revin,\n",
    "                                    affine=params.affine,\n",
    "                                    head_type = params.head_type,\n",
    "                                    probablistic=params.probablistic,\n",
    "                                    n_quantiles=params.n_quantiles,\n",
    "                                    masked=params.masked\n",
    "                                    )\n",
    "            return model\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fe0e011f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(file, model, opt=None, with_opt=False, device='cpu', strict=True):\n",
    "    \" load the saved model \"\n",
    "    state = torch.load(file, map_location=device)\n",
    "    if not opt: with_opt=False\n",
    "    model_state = state['model'] if with_opt else state\n",
    "    get_model(model).load_state_dict(model_state, strict=strict)\n",
    "    if with_opt: opt.load_state_dict(state['opt'])\n",
    "    model = model.to(device)\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "029deb7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def pred_func(params,input_data,place,gap,look_back):\n",
    "\n",
    "    model = create_model(input_data.shape[1], params)\n",
    "\n",
    "    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
    "    model_pred = load_model(file = model_path + place+'_L='+str(look_back)+'_gap='+str(gap) +'_'+params.task_type + \".pth\", model = model, device=device)\n",
    "    model_pred.eval()\n",
    "    \n",
    "    # Test Data\n",
    "    t = torch.tensor(input_data).float()\n",
    "    with torch.no_grad():\n",
    "        preds = model_pred.forward(t.to(device))\n",
    "    \n",
    "    return preds\n",
    "\n",
    "def scale_std_func(preds, scale):\n",
    "    for i in [0,1,2,3,4,6,7,8,9,10]:\n",
    "        preds[:,i] = scale * (preds[:,i]-preds[:,5]) + preds[:,5]\n",
    "    return preds\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e60f47b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_prediction(place,gap,look_back,task_type,scale_std,data_path,ground_truth_weather = False,noise_adding = False,seed_list=[0]*1000,positive_error=False):\n",
    "    \n",
    "    df_history = pd.read_csv(data_path+place+'_history.csv')\n",
    "    if not ground_truth_weather:\n",
    "        df_future = pd.read_csv(data_path+place+'_future.csv')\n",
    "    else:\n",
    "        df_future = pd.read_csv(data_path+place+'_future_true.csv')\n",
    "        \n",
    "    ###\n",
    "    df_future_truth = df_future.copy()\n",
    "    ###\n",
    "        \n",
    "    if noise_adding:\n",
    "        df_future = simulate_forecast(df_future.copy(),location,seed_list)\n",
    "        \n",
    "    \n",
    "    params = ModelParams\n",
    "    \n",
    "    if task_type != 'avgL':\n",
    "        new_col = np.zeros(df_future.shape[0])\n",
    "        if \"total_grid\" not in df_future.columns:\n",
    "            df_future.insert(0, 'total_grid', new_col)\n",
    "        df = pd.concat([df_history, df_future]).reset_index()\n",
    "        \n",
    "\n",
    "        df = df.iloc[-(look_back+24):].reset_index()\n",
    "        if df.shape[0] != (look_back + 24):\n",
    "            stop\n",
    "\n",
    "        df['date'] = pd.to_datetime(df['year'].astype(str) + df['dayofyear'].astype(str), format='%Y%j')\n",
    "        df['weekday'] = df['date'].dt.weekday\n",
    "\n",
    "        if task_type == 'SONNET':\n",
    "            data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]    \n",
    "        elif task_type == 'SONNET_base':\n",
    "            data_cont = df[['total_grid','DNI','DHI','temperature','relativehumidity']]\n",
    "        else:\n",
    "            raise AssertionError(\"Task type undefined\")\n",
    "\n",
    "\n",
    "        data_cont = data_cont.values\n",
    "        data_time = df[['weekday','dayofyear', 'timeofday', 'month','year']]\n",
    "        data_time = data_time.values\n",
    "\n",
    "        scaler = joblib.load(model_path+place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+params.task_type+'_scaler.gz')\n",
    "        \n",
    "        data_cont = scaler.transform(data_cont)\n",
    "\n",
    "        if location != \"TX\":\n",
    "            cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "            W = cycl_(data_time[:,0],7)    # week of day\n",
    "            H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "            M = cycl_(data_time[:,3],12)   # month of year\n",
    "            data_time = np.concatenate((W,H,M),0).T\n",
    "            data_context = np.expand_dims(np.concatenate((data_cont,data_time),1), axis=0)\n",
    "        else:\n",
    "            cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "            W = cycl_(data_time[:,0],7)    # week of day\n",
    "            D = cycl_(data_time[:,1],365)  # day of year\n",
    "            H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "            M = cycl_(data_time[:,3],12)   # month of year\n",
    "            Y = torch.tensor(data_time[:,4] - 2022).unsqueeze(0) # year\n",
    "            data_time = np.concatenate((W,D,H,M,Y),0).T\n",
    "            data_context = np.expand_dims(np.concatenate((data_cont,data_time),1), axis=0)\n",
    "\n",
    "        hist_input = data_context[:,:look_back,:]\n",
    "        future_input = data_context[:,look_back:,masked:]\n",
    "\n",
    "        future_input = np.concatenate((np.zeros((future_input.shape[0],future_input.shape[1],masked)),future_input), axis=-1)\n",
    "        input_data = np.concatenate((hist_input, future_input), axis=1)\n",
    "        input_data = np.swapaxes(input_data,1,2)\n",
    "\n",
    "        preds = pred_func(params,input_data,place,gap,look_back)\n",
    "        preds = preds.cpu().numpy()[0,0,:,:]\n",
    "        \n",
    "\n",
    "        if 'c-c' in task_type:\n",
    "            if 'solar' in df_future.columns:           \n",
    "                solar_future = df_future['solar'].values\n",
    "                preds = preds - solar_future[:, None]\n",
    "            else:\n",
    "                raise AssertionError(\"solar not found in future.csv for task type disL-disL\")\n",
    "    \n",
    "    else:\n",
    "        if os.path.exists(model_path+place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+'disL-disL'+'_scaler.gz'):\n",
    "            scaler = joblib.load(model_path+place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+'disL-disL'+'_scaler.gz')\n",
    "            var = scaler.var_\n",
    "            std_dev = np.sqrt(var[0])\n",
    "        elif os.path.exists(model_path+place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+'all-N'+'_scaler.gz'):\n",
    "            scaler = joblib.load(model_path+place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+'all-N'+'_scaler.gz')\n",
    "            var = scaler.var_\n",
    "            std_dev = np.sqrt(var[2])\n",
    "        else:\n",
    "            raise AssertionError(\"cannot use avgL because disaggregated consumption does not exist\")\n",
    "        quantiles = [0.001,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.999]\n",
    "        gaussian_quantiles = norm.ppf(quantiles, loc=0, scale=std_dev)\n",
    "        \n",
    "        df_mean = df_history.groupby(\"timeofday\", sort=False).mean(numeric_only = True)\n",
    "        preds = df_mean.consumption.values.reshape(-1,1) + gaussian_quantiles.reshape(1,-1)\n",
    "        \n",
    "\n",
    "        solar_future = df_future['solar'].values\n",
    "        preds = preds - solar_future[:, None]\n",
    "\n",
    "    \n",
    "    if scale_std != 1.0:\n",
    "        preds = scale_std_func(preds, scale_std)\n",
    "    \n",
    "    df_pred = pd.DataFrame(preds)  \n",
    "    df_pred.columns = [[\"p00\",\"p10\",\"p20\",\"p30\",\"p40\",\"p50\",\"p60\",\"p70\",\"p80\",\"p90\",\"p100\"]]\n",
    "    df_pred.to_csv(place+'_pred.csv')\n",
    "\n",
    "\n",
    "    return np.array(preds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "7286d350",
   "metadata": {},
   "outputs": [],
   "source": [
    "def visualization(df_history,preds,place,params,gap,look_back):\n",
    "    df_ma = df_history.groupby(\"timeofday\", sort=False).mean(numeric_only = True)\n",
    "    plt.figure()\n",
    "    plt.plot(np.mean(preds,axis=-1),label='preds') \n",
    "    plt.plot(preds[:,-1],label='lower bound') \n",
    "    plt.plot(preds[:,0],label='upper bound') \n",
    "    plt.plot(df_ma.total_grid.values,label='average from history')\n",
    "    plt.legend()\n",
    "    plt.savefig(place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+params.task_type+\"_visualize.png\")\n",
    "    \n",
    "def visualization_true(df_true,preds,place,params,gap,look_back): # this is not used in real contest\n",
    "    plt.figure()\n",
    "    plt.plot(np.mean(preds,axis=-1),label='preds') \n",
    "    plt.plot(preds[:,-1],label='lower bound') \n",
    "    plt.plot(preds[:,0],label='upper bound') \n",
    "    plt.plot(df_true.total_grid.values,label='true total grid')\n",
    "    plt.legend()\n",
    "    plt.savefig(place+'_L='+str(look_back)+'_gap='+str(gap)+'_'+params.task_type+\"_visualize_true.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8e614028",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def continuous_ranked_probability_score(obs, fx, fx_prob):\n",
    "    \"\"\"Continuous Ranked Probability Score (CRPS).\n",
    "\n",
    "    .. math::\n",
    "\n",
    "        \\\\text{CRPS} = \\\\frac{1}{n} \\\\sum_{i=1}^n \\\\int_{-\\\\infty}^{\\\\infty}\n",
    "        (F_i(x) - \\\\mathbf{1} \\\\{x \\\\geq y_i \\\\})^2 dx\n",
    "\n",
    "    where :math:`F_i(x)` is the CDF of the forecast at time :math:`i`,\n",
    "    :math:`y_i` is the observation at time :math:`i`, and :math:`\\\\mathbf{1}`\n",
    "    is the indicator function that transforms the observation into a step\n",
    "    function (1 if :math:`x \\\\geq y`, 0 if :math:`x < y`). In other words, the\n",
    "    CRPS measures the difference between the forecast CDF and the empirical CDF\n",
    "    of the observation. The CRPS has the same units as the observation. Lower\n",
    "    CRPS values indicate more accurate forecasts, where a CRPS of 0 indicates a\n",
    "    perfect forecast. [1]_ [2]_ [3]_\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    obs : (n,) array_like\n",
    "        Observations (physical unit).\n",
    "    fx : (n, d) array_like\n",
    "        Forecasts (physical units) of the right-hand-side of a CDF with d\n",
    "        intervals (d >= 2), e.g., fx = [10 MW, 20 MW, 30 MW] is interpreted as\n",
    "        <= 10 MW, <= 20 MW, <= 30 MW.\n",
    "    fx_prob : (n, d) array_like\n",
    "        Probability [%] associated with the forecasts.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    crps : float\n",
    "        The Continuous Ranked Probability Score, with the same units as the\n",
    "        observation.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the forecasts have incorrect dimensions; either a) the forecasts are\n",
    "        for a single sample (n=1) with d CDF intervals but are given as a 1D\n",
    "        array with d values or b) the forecasts are given as 2D arrays (n,d)\n",
    "        but do not contain at least 2 CDF intervals (i.e. d < 2).\n",
    "\n",
    "    Notes\n",
    "    -----\n",
    "    The CRPS can be calculated analytically when the forecast CDF is of a\n",
    "    continuous parametric distribution, e.g., Gaussian distribution. However,\n",
    "    since the Solar Forecast Arbiter makes no assumptions regarding how a\n",
    "    probabilistic forecast was generated, the CRPS is instead calculated using\n",
    "    numerical integration of the discretized forecast CDF. Therefore, the\n",
    "    accuracy of the CRPS calculation is limited by the precision of the\n",
    "    forecast CDF. In practice, this means the forecast CDF should 1) consist of\n",
    "    at least 10 intervals and 2) cover probabilities from 0% to 100%.\n",
    "\n",
    "    References\n",
    "    ----------\n",
    "    .. [1] Matheson and Winkler (1976) \"Scoring rules for continuous\n",
    "           probability distributions.\" Management Science, vol. 22, pp.\n",
    "           1087-1096. doi: 10.1287/mnsc.22.10.1087\n",
    "    .. [2] Hersbach (2000) \"Decomposition of the continuous ranked probability\n",
    "           score for ensemble prediction systems.\" Weather Forecast, vol. 15,\n",
    "           pp. 559-570. doi: 10.1175/1520-0434(2000)015<0559:DOTCRP>2.0.CO;2\n",
    "    .. [3] Wilks (2019) \"Statistical Methods in the Atmospheric Sciences\", 4th\n",
    "           ed. Oxford; Waltham, MA; Academic Press.\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    # match observations to fx shape: (n,) => (n, d)\n",
    "    if np.ndim(fx) < 2:\n",
    "        raise ValueError(\"forecasts must be 2D arrays (expected (n,d), got\"\n",
    "                         f\"{np.shape(fx)})\")\n",
    "    elif np.shape(fx)[1] < 2:\n",
    "        raise ValueError(\"forecasts must have d >= 2 CDF intervals \"\n",
    "                         f\"(expected >= 2, got {np.shape(fx)[1]})\")\n",
    "\n",
    "    n = len(fx)\n",
    "\n",
    "    # extend CDF min to ensure obs within forecast support\n",
    "    # fx.shape = (n, d) ==> (n, d + 1)\n",
    "    fx_min = np.minimum(obs, fx[:, 0])\n",
    "    fx = np.hstack([fx_min[:, np.newaxis], fx])\n",
    "    fx_prob = np.hstack([np.zeros([n, 1]), fx_prob])\n",
    "\n",
    "    # extend CDF max to ensure obs within forecast support\n",
    "    # fx.shape = (n, d + 1) ==> (n, d + 2)\n",
    "    idx = (fx[:, -1] < obs)\n",
    "    fx_max = np.maximum(obs, fx[:, -1])\n",
    "    fx = np.hstack([fx, fx_max[:, np.newaxis]])\n",
    "    fx_prob = np.hstack([fx_prob, np.full([n, 1], 100)])\n",
    "\n",
    "    # indicator function:\n",
    "    # - left of the obs is 0.0\n",
    "    # - obs and right of the obs is 1.0\n",
    "    o = np.where(fx >= obs[:, np.newaxis], 1.0, 0.0)\n",
    "\n",
    "    # correct behavior when obs > max fx:\n",
    "    # - should be 0 over range: max fx < x < obs\n",
    "    o[idx, -1] = 0.0\n",
    "\n",
    "    # forecast probabilities [unitless]\n",
    "    f = fx_prob / 100.0\n",
    "\n",
    "    # integrate along each sample, then average all samples\n",
    "    crps = np.mean(np.trapz((f - o) ** 2, x=fx, axis=1))\n",
    "\n",
    "    return crps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "f7b21f31",
   "metadata": {},
   "outputs": [],
   "source": [
    "def crps(data,path):\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][2:13]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"pred_crps\"] = crps_loss\n",
    "\n",
    "    crps_loss = []\n",
    "    for i in range(len(data)):\n",
    "        truth = np.array(data.iloc[i][\"grid\"]).reshape(-1)\n",
    "        if np.isnan(truth[0]):\n",
    "            crps_loss.append(np.nan)\n",
    "            continue\n",
    "        pred = np.array(data.iloc[i][13:24]).reshape(1,11)\n",
    "        prob = np.array([0,10,20,30,40,50,60,70,80,90,100]).reshape(1,11)\n",
    "        crps_loss.append(continuous_ranked_probability_score(truth, pred, prob))\n",
    "    data[\"ref_crps\"] = crps_loss\n",
    "    \n",
    "    data.to_csv(path, index = False)\n",
    "    \n",
    "    return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a27e6f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "####################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b6bf02f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_evaluation(location, date_strings, model_path, save_to, version, scale_std, day_ahead = 0, future = False, noise_adding = False, seed_list = [0]*1000, positive_error = False):\n",
    "\n",
    "    result = []\n",
    "    for ii,date in enumerate(date_strings):\n",
    "        \n",
    "        data_path = f'../data/historical_data/{version}_eval/{location}_{date}/'\n",
    "        if day_ahead:\n",
    "            data_path = f'../data/evaluation/{day_ahead}_day_ahead/{location}_{date}/'\n",
    "    \n",
    "        history_path = data_path+location+'_history.csv'\n",
    "        if not future:\n",
    "            future_path = data_path+location+'_future.csv'\n",
    "        else:\n",
    "            future_path = data_path+location+'_future_true.csv'\n",
    "            \n",
    "        df_history = pd.read_csv(history_path)\n",
    "        df_future = pd.read_csv(future_path)\n",
    "\n",
    "        hours = int((pd.to_datetime(df_future.localTime.values[0]) - pd.to_datetime(df_history.localTime.values[-1])).total_seconds() / 3600)-1\n",
    "        if hours in [24, 48, 72, 96, 120, 144, 168, 336, 720]:\n",
    "            gap = hours\n",
    "        else:\n",
    "            raise AssertionError(\"gap between history and future is not pre-defined. please check datasets or manually choose the gap\")\n",
    "        \n",
    "        if future:\n",
    "            df_pred = get_prediction(location,gap,look_back,task_type,scale_std,data_path,ground_truth_weather=True)\n",
    "            \n",
    "        else:\n",
    "            seed_list_sub = seed_list[4*ii:4*(ii+1)]\n",
    "            df_pred = get_prediction(location,gap,look_back,task_type,scale_std,data_path, noise_adding = noise_adding, seed_list = seed_list_sub, positive_error=positive_error)\n",
    "        result.append(df_pred)\n",
    "    \n",
    "    df_save = pd.read_csv(save_to)\n",
    "\n",
    "    if not day_ahead or day_ahead in [\"X\", \"XX\"]:\n",
    "        df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[:-2])\n",
    "        df_save.to_csv(save_to, index = False)\n",
    "    else:\n",
    "        if location == \"OR\":\n",
    "            df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[5:-3])\n",
    "            df_save.to_csv(save_to, index = False)\n",
    "        elif location == \"HI\":\n",
    "            df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[6:-5])\n",
    "            df_save.to_csv(save_to, index = False)\n",
    "        elif location == \"GA\" or location == \"TX\":\n",
    "            df_save.iloc[:,2:13] = pd.DataFrame(np.concatenate(result, axis = 0)[7:-3])\n",
    "            df_save.to_csv(save_to, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "52ec9d70",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_error2(irradiance = True, seed = 42, start = 8, end = 50, mean = 30, variance = 100, method = \"uniform\"):\n",
    "    \n",
    "    random.seed(seed)\n",
    "    \n",
    "    error = []\n",
    "    start = start/100 \n",
    "    end = end/100\n",
    "    for _ in range(24):\n",
    "        if method == \"uniform\":\n",
    "            error_rate = random.uniform(start, end)\n",
    "        if positive_error:\n",
    "            if positive_error>0:\n",
    "                error.append(1+error_rate)\n",
    "            elif positive_error<0:\n",
    "                error.append(1-error_rate)\n",
    "        else:\n",
    "            if random.uniform(0,1)<0.5:\n",
    "                error.append(1+error_rate)\n",
    "            else:\n",
    "                error.append(1-error_rate)\n",
    "    \n",
    "    return error\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "89230d72",
   "metadata": {},
   "outputs": [],
   "source": [
    "start_date = datetime.strptime(\"2023-06-27\", \"%Y-%m-%d\")\n",
    "date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(19)]\n",
    "std_summary = pd.DataFrame()\n",
    "\n",
    "for feature in [\"DNI\",\"DHI\",\"GHI\",\"temperature\",\"relativehumidity\"]:\n",
    "    for location in [\"GA\",\"HI\",\"OR\",\"TX\"]:\n",
    "        hour = []\n",
    "        prediction = []\n",
    "        ground_truth = []\n",
    "        for date in date_strings:\n",
    "            forecast = pd.read_csv(f\"{location}_future.csv\") \n",
    "            forecast[\"localTime\"] = pd.to_datetime(forecast[\"localTime\"])\n",
    "            prediction.extend(forecast[feature])\n",
    "\n",
    "            truth = pd.read_csv(f\"{location}_future.csv\") \n",
    "            truth[\"localTime\"] = pd.to_datetime(truth[\"localTime\"])\n",
    "            ground_truth.extend(truth[feature])\n",
    "\n",
    "            hour.extend(truth[\"timeofday\"])\n",
    "        \n",
    "        df = pd.DataFrame()\n",
    "        df[\"hour\"] = hour\n",
    "        df[\"truth\"] = ground_truth\n",
    "        df[\"prediction\"] = prediction\n",
    "        df[\"std(error)\"] = df[\"prediction\"] - df[\"truth\"]\n",
    "\n",
    "        std_summary[\"hour\"] = df.groupby(\"hour\", as_index = False).std()[\"hour\"]\n",
    "        std_summary[f\"{location}_{feature}\"] = df.groupby(\"hour\", as_index = False).std()[\"std(error)\"]\n",
    "\n",
    "\n",
    "def simulate_forecast(data,location,seed_list):\n",
    "    \n",
    "\n",
    "    param_dict = dict()\n",
    "\n",
    "    param_dict[\"GA\"] = [0.633,0.59924,3.32851]\n",
    "    param_dict[\"OR\"] = [0.35,0.39527,2.97331]\n",
    "    param_dict[\"HI\"] = [1.65,0.47095,3.0398]\n",
    "    param_dict[\"TX\"] = [0.21773893,0.87266,3.11363]\n",
    "\n",
    "    if location == \"GA\":\n",
    "        DNI_min, DNI_max = 0, 1040\n",
    "        DHI_min, DHI_max = 0, 551\n",
    "    elif location == \"HI\":\n",
    "        DNI_min, DNI_max = 0, 989\n",
    "        DHI_min, DHI_max = 0, 567\n",
    "    elif location == \"OR\":\n",
    "        DNI_min, DNI_max = 0, 981\n",
    "        DHI_min, DHI_max = 0, 541  \n",
    "    elif location == \"TX\":\n",
    "        DNI_min, DNI_max = 0, 1052\n",
    "        DHI_min, DHI_max = 0, 560\n",
    "        \n",
    "    if location == \"TX\":\n",
    "        tem_min , tem_max = -8, 40\n",
    "    elif location == \"GA\":\n",
    "        tem_min , tem_max = -4, 38\n",
    "    elif location == \"HI\":\n",
    "        tem_min, tem_max = 20, 28\n",
    "    elif location == \"OR\":\n",
    "        tem_min, tem_max = -7, 38\n",
    "\n",
    "    location_dict = dict()\n",
    "    location_dict[\"OR\"] = [45.114559,-123.204903]\n",
    "    location_dict[\"GA\"] = [31.044241,-84.879128]\n",
    "    location_dict[\"HI\"] = [21.446911,-158.188736]\n",
    "    location_dict[\"TX\"] = [29.424122, -98.493629]\n",
    "    \n",
    "    latitude, longitude = location_dict[location]\n",
    "    capacity,beta,gamma = param_dict[location]\n",
    "    \n",
    "    data[\"total_grid\"] = 0\n",
    "    \n",
    "    data[\"DNI\"] = data[\"DNI\"] + generate_error(data, location, \"DNI\", seed_list[0])\n",
    "    data[\"DHI\"] = data[\"DHI\"] + generate_error(data, location, \"DHI\", seed_list[1])\n",
    "    \n",
    "    data[\"DHI\"] = np.clip(data[\"DHI\"], DHI_min, DHI_max)\n",
    "    data[\"DNI\"] = np.clip(data[\"DNI\"], DNI_min, DNI_max)\n",
    "    \n",
    "    \n",
    "    data[\"temperature\"] = data[\"temperature\"] + generate_error(data, location, \"temperature\", seed_list[2])\n",
    "    data[\"relativehumidity\"] = data[\"relativehumidity\"] + generate_error(data, location, \"relativehumidity\", seed_list[3])\n",
    "    \n",
    "    data[\"relativehumidity\"] = np.clip(data[\"relativehumidity\"], 0, 100)\n",
    "    data[\"temperature\"] = np.clip(data[\"temperature\"], tem_min ,tem_max)\n",
    "    #####\n",
    "    data[\"sin_elevation_cal\"] = np.sin(latitude/180 * np.pi) * np.sin(data[\"declination_angle_delta\"]) + np.cos(latitude/180 * np.pi) * np.cos(data[\"declination_angle_delta\"]) * np.cos(data[\"hour_angle_omega\"])\n",
    "    data[\"elevation_alpha\"] = np.arcsin(data[\"sin_elevation_cal\"])\n",
    "    data[\"zenith\"] = np.arccos(data[\"sin_elevation_cal\"])\n",
    "    #####\n",
    "    \n",
    "\n",
    "    df = data[[\"dayofyear\",\"timeofday\",\"minute\",\"month\",\"total_grid\",\"consumption\",\"temperature\",\"zenith\",\"DHI\",\"DNI\",\"GHI\",\"hour_angle_omega\",\"declination_angle_delta\",\"elevation_alpha\",\"azimuth\",\"albedo\",\"timezone\"]]\n",
    "\n",
    "    model_predict = model.PhysicalModel(latitude)\n",
    "    model_predict.capacity = torch.nn.Parameter(torch.tensor(capacity))\n",
    "    model_predict.capacity.requires_grad_(False)\n",
    "    model_predict.beta = torch.nn.Parameter(torch.tensor(beta))\n",
    "    model_predict.beta.requires_grad_(False)\n",
    "    model_predict.gamma = torch.nn.Parameter(torch.tensor(gamma))\n",
    "    model_predict.gamma.requires_grad_(False)\n",
    "    \n",
    "    data[\"solar\"] = list(model_predict(torch.tensor(np.array(df))).detach().numpy())\n",
    "    data[\"solar\"] = np.clip(data[\"solar\"], 0, 100)\n",
    "    \n",
    "    return data\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d6c61f8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_error(data, location, feature, seed):\n",
    "\n",
    "    random.seed(seed)\n",
    "\n",
    "    error = [random.gauss(mu=0, sigma=std * 1) for std in std_summary.iloc[data[\"timeofday\"]][f\"{location}_{feature}\"]]\n",
    "\n",
    "    return error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "fef4fa12",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SONNET\n",
      "      TX     OR     HI     GA    seed\n",
      "0  0.594  0.557  0.179  0.221  2007.0\n"
     ]
    }
   ],
   "source": [
    "location_list =[\"TX\",\"OR\",\"HI\",\"GA\"]\n",
    "\n",
    "day_ahead = 1\n",
    "\n",
    "noise_adding = True\n",
    "positive_error = 0\n",
    "\n",
    "for look_back in [336]:\n",
    "    for ratio in [1]:\n",
    "\n",
    "        def generate_error(data, location, feature, seed):\n",
    "\n",
    "            random.seed(seed)\n",
    "\n",
    "            error = [random.gauss(mu=0, sigma=std * ratio) for std in std_summary.iloc[data[\"timeofday\"]][f\"{location}_{feature}\"]]\n",
    "\n",
    "            return error\n",
    "\n",
    "        for task_type in [\"SONNET\",\"SONNET_base\"]:\n",
    "            summary = pd.DataFrame()\n",
    "            for seed_choice in range(2005,2025):\n",
    "                result = []\n",
    "                for seed in range(2025, 2035):\n",
    "\n",
    "                    random.seed(seed)\n",
    "                    seed_dict = {}\n",
    "                    seed_dict[\"OR\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "                    seed_dict[\"HI\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "                    seed_dict[\"GA\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "                    seed_dict[\"TX\"] = [random.randint(2035,3035) for _ in range(1200)]\n",
    "\n",
    "                    for location in location_list:\n",
    "                        flag = 1\n",
    "                        crpss_dict = dict()\n",
    "                        for scale_std in [1]:\n",
    "\n",
    "                            eval_on_truth = False\n",
    "\n",
    "                            gap = 24\n",
    "                            look_back = look_back\n",
    "                            pred_len = 24\n",
    "                            seed = 41\n",
    "                            visualize = False\n",
    "                            truncate = False\n",
    "                            set_seed(seed)\n",
    "\n",
    "                            look_back = look_back\n",
    "                            pred_len = pred_len\n",
    "                            truncate = truncate\n",
    "\n",
    "                            if task_type in ['SONNET','SONNET_base']:\n",
    "                                masked = 1\n",
    "\n",
    "                            class ModelParams:\n",
    "\n",
    "                                # [dataloader]\n",
    "                                context_window = look_back\n",
    "                                target_window = pred_len\n",
    "                                batch_size = 128\n",
    "                                num_workers = 8\n",
    "                                masked = masked\n",
    "\n",
    "                                # [model params]\n",
    "                                n_layers = 1\n",
    "\n",
    "                                if location == \"GA\":\n",
    "                                    n_heads = 8\n",
    "                                    d_model = 48\n",
    "                                    d_ff = 96\n",
    "                                else:\n",
    "                                    n_heads = 8\n",
    "                                    d_model = 64\n",
    "                                    d_ff = 128\n",
    "\n",
    "\n",
    "                                dropout = 0.2\n",
    "                                fc_dropout = 0.2\n",
    "                                head_dropout = 0\n",
    "\n",
    "                                # [individual heads]\n",
    "                                individual = False\n",
    "\n",
    "                                # [patching]\n",
    "                                patch_len = 8\n",
    "                                stride = 4\n",
    "                                padding_patch = None\n",
    "                                patch_len2 = 1\n",
    "                                stride2 = 1\n",
    "                                padding_patch2 = None\n",
    "                                #patch_num = int((context_window - patch_len)/stride + 2) if padding_patch else int((context_window - patch_len)/stride + 1)\n",
    "\n",
    "                                # [RevIN]\n",
    "                                revin = False\n",
    "                                affine = False         # if False then apply just instance norm\n",
    "\n",
    "                                # [decomposition]\n",
    "                                if \"linear\" in task_type:\n",
    "                                    head_type = \"linear\"\n",
    "                                elif \"mlp\" in task_type:\n",
    "                                    head_type = \"mlp\"\n",
    "                                else:\n",
    "                                    head_type = \"flatten\"\n",
    "\n",
    "                                # [optimization params]\n",
    "                                n_epochs = 200\n",
    "                                lr=1e-3\n",
    "\n",
    "                                # [probablistic setting]\n",
    "                                prediction_type = 'probablistic'\n",
    "                                probablistic = True if prediction_type == 'probablistic' else False\n",
    "                                loss = 'crps'\n",
    "                                n_quantiles = 11\n",
    "\n",
    "                                # [task type]\n",
    "                                task_type = task_type    \n",
    "\n",
    "                                # Start date\n",
    "                            start_date = datetime.strptime(\"2023-06-27\", \"%Y-%m-%d\")\n",
    "                            date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(19)]\n",
    "                            if eval_on_truth:\n",
    "                                date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(16)]\n",
    "\n",
    "                            if day_ahead and day_ahead != \"X\" and day_ahead != \"XX\":\n",
    "                                start_date = datetime.strptime(\"2023-06-18\", \"%Y-%m-%d\")\n",
    "                                date_strings = [(start_date + timedelta(days=i)).strftime(\"%Y-%m-%d\") for i in range(28)]\n",
    "\n",
    "                            model_path = f'./models_and_scalers/{seed_choice}/'\n",
    "\n",
    "                            version = \"None\"\n",
    "                            save_to = f\"../data/evaluation/result/{location}_{day_ahead}_day_ahead.csv\"\n",
    "                            model_evaluation(location, date_strings, model_path, save_to, version, scale_std, day_ahead, noise_adding=noise_adding,seed_list=seed_dict[location],positive_error=positive_error)\n",
    "\n",
    "                            data = pd.read_csv(save_to)\n",
    "                            data = crps(data,save_to)\n",
    "                            data_group = data.groupby(\"date\", as_index = False).mean(numeric_only = True)\n",
    "                            data_group[\"crpss\"] = 1 - data_group[\"pred_crps\"]/data_group[\"ref_crps\"]\n",
    "\n",
    "                            crpss =  1 - np.mean(data_group[\"pred_crps\"])/np.mean(data_group[\"ref_crps\"])\n",
    "                            crpss_dict[str(scale_std)] = np.round(crpss,3)\n",
    "\n",
    "                        result.append(np.round(crpss,3))\n",
    "\n",
    "                result = np.array(result).reshape(-1,4)\n",
    "                result = pd.DataFrame(result)\n",
    "                result[\"seed\"] = seed_choice\n",
    "                result.columns = [\"TX\",\"OR\",\"HI\",\"GA\", \"seed\"]\n",
    "\n",
    "                #######\n",
    "                result_sub = np.array(result.mean(axis = 0)).reshape(1,-1)\n",
    "                result_sub = pd.DataFrame(result_sub)\n",
    "                result_sub.columns = [\"TX\",\"OR\",\"HI\",\"GA\",\"seed\"]\n",
    "                summary = pd.concat([summary,result_sub], axis = 0, ignore_index=True)\n",
    "                #######\n",
    "            print(task_type)\n",
    "            print(summary)\n",
    "\n",
    "#             summary.to_csv(f\"forecast_evaluation/{task_type}_Final_{look_back}lookback.csv\", index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0be01d0d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c4b2e75",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f75a5429",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2e099c9f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
