{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1ab9f854",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd \n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import pandas as pd\n",
    "import time\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "import os\n",
    "from os import listdir\n",
    "from os.path import isfile, join\n",
    "import joblib\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "import datetime"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "824e15f3",
   "metadata": {},
   "outputs": [],
   "source": [
    " class CRPSLoss(nn.Module):\n",
    "        def __init__(self, quantiles=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1], adjusted=True):\n",
    "            super().__init__()\n",
    "            self.adjusted = adjusted\n",
    "            if self.adjusted:\n",
    "                self.quantiles = torch.tensor([0]+quantiles+[1])\n",
    "            else:\n",
    "                self.quantiles = torch.tensor(quantiles)\n",
    "\n",
    "        def forward(self, preds, target): #preds:[B,N,T,Q] target:[B,N,T]\n",
    "            assert not target.requires_grad\n",
    "            assert preds.size(0) == target.size(0)\n",
    "\n",
    "            # this is adjusted to avoid some extrme values that larger than the max or smaller than the min in quantiles, so a max bound and min bound are provided\n",
    "            if self.adjusted:\n",
    "                B = preds.shape[0]\n",
    "                N = preds.shape[1]\n",
    "                T = preds.shape[2]\n",
    "                max_bound = 100\n",
    "                min_bound = -100\n",
    "                preds = torch.cat((min_bound*torch.ones((B, N, T, 1), device=preds.device), preds), dim=-1)\n",
    "                preds = torch.cat((preds, max_bound*torch.ones((B, N, T, 1), device=preds.device)), dim=-1)\n",
    "\n",
    "            q_i1 = self.quantiles[1:].to(preds.device)\n",
    "            q_i = self.quantiles[:-1].to(preds.device)\n",
    "            X_i1 = preds[:,:,:,1:]\n",
    "            X_i = preds[:,:,:,:-1]\n",
    "            X_t = target.unsqueeze(3).repeat(1, 1, 1, X_i.shape[-1])\n",
    "\n",
    "            # this index indicates which case this interval is in. 0 -> term0, 1 -> term1, 2 -> term2\n",
    "            index = torch.full_like(X_i1, 2)\n",
    "            index[X_t > X_i1] = 0\n",
    "            index[X_t < X_i] = 1\n",
    "            index = F.one_hot(index.to(torch.int64), num_classes=3) #ntqd\n",
    "\n",
    "            # term0: F(x_i+1)>O(x_i+1)\n",
    "            # term1: F(x_i)<O(x_i) \n",
    "            # term2: F(x_i)>O(x_i) and F(x_i+1)<O(x_i+1)\n",
    "            # All the terms are calculated with precise integration function\n",
    "            term0 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, (q_i1**2+q_i1*q_i+q_i**2))\n",
    "            term1 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, ((q_i1-1)**2+(q_i1-1)*(q_i-1)+(q_i-1)**2))\n",
    "            #term2 = torch.einsum('bntq,q->bntq', X_i1-X_t, 1-2*q_i) + torch.einsum('bntq,q->bntq', ((X_i1-X_i)**2-(X_t-X_i)**2)/(X_i1-X_i), -(q_i1-q_i)) + term0\n",
    "            # this is the same as the line below\n",
    "            term2 = torch.einsum('bntq,q->bntq', X_t-X_i, 2*q_i-1) + torch.einsum('bntq,q->bntq', (X_t-X_i)**2/(X_i1-X_i), q_i1-q_i) + term1\n",
    "            terms = torch.stack((term0,term1,term2),dim=-1)\n",
    "\n",
    "            loss = torch.einsum('bntqd,bntqd->bntq', index.to(torch.float), terms.to(torch.float))\n",
    "            return torch.mean(torch.sum(loss,dim=-1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9616773b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, d=64):\n",
    "        super(Attention, self).__init__()\n",
    "        self.query = nn.Linear(32, d)\n",
    "        self.key = nn.Linear(32, d)\n",
    "        self.value = nn.Linear(32, d)\n",
    "        self.softmax = nn.Softmax(dim=1)\n",
    "        self.d = d\n",
    "\n",
    "    def forward(self, future, history):\n",
    "        query = self.query(future)\n",
    "        key = self.key(history)\n",
    "        value = self.value(history)\n",
    "        \n",
    "        attention_scores = torch.matmul(query, key.transpose(-2, -1)) / (self.d ** 0.5)\n",
    "        attention_weights = self.softmax(attention_scores)\n",
    "        attention_output = torch.matmul(attention_weights, value)\n",
    "\n",
    "        return attention_output\n",
    "\n",
    "class MLPAttention(nn.Module):\n",
    "    def __init__(self, dropout=0.2):\n",
    "        super(MLPAttention, self).__init__()\n",
    "        self.future_net = nn.Sequential(\n",
    "            nn.Linear(11, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout) \n",
    "        )\n",
    "        self.history_net = nn.Sequential(\n",
    "            nn.Linear(84, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout) \n",
    "        )\n",
    "        self.attention = Attention(64)\n",
    "        self.combined_net = nn.Sequential(\n",
    "            nn.Linear(64 + 32, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 11)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        future = x[:, :11]\n",
    "        history = x[:, 11:]\n",
    "        future_out = self.future_net(future)\n",
    "        history_out = self.history_net(history)\n",
    "        attention_out = self.attention(future_out, history_out)\n",
    "        combined = torch.cat((attention_out, future_out), dim=1)\n",
    "        out = self.combined_net(combined)\n",
    "\n",
    "        out = out.reshape(-1, 1, 1, 11)\n",
    "        x1, x2 = out[:, :, :, 0], out[:, :, :, 1:]\n",
    "        x1 = torch.unsqueeze(x1, -1)\n",
    "        x2 = F.softplus(x2)\n",
    "        out = torch.cat((x1, x2), dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n",
    "\n",
    "\n",
    "# 定义 MLP 模型\n",
    "class MLPAttentionTX(nn.Module):\n",
    "    def __init__(self, dropout=0.2):\n",
    "        super(MLPAttentionTX, self).__init__()\n",
    "        self.future_net = nn.Sequential(\n",
    "            nn.Linear(14, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout)  \n",
    "        )\n",
    "        self.history_net = nn.Sequential(\n",
    "            nn.Linear(84, 128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "        self.attention = Attention(64)\n",
    "        self.combined_net = nn.Sequential(\n",
    "            nn.Linear(64 + 32, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 11)\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, x):\n",
    "        future = x[:, :14]\n",
    "        history = x[:, 14:]\n",
    "        future_out = self.future_net(future)\n",
    "        history_out = self.history_net(history)\n",
    "        attention_out = self.attention(future_out, history_out)\n",
    "        combined = torch.cat((attention_out, future_out), dim=1)\n",
    "        out = self.combined_net(combined)\n",
    "        \n",
    "        print(out.shape)\n",
    "        out = out.reshape(-1, 1, 1, 11)\n",
    "        print(out.shape)\n",
    "        x1, x2 = out[:, :, :, 0], out[:, :, :, 1:]\n",
    "        print(x1.shape)\n",
    "        print(x2.shape)\n",
    "        stop\n",
    "        x1 = torch.unsqueeze(x1, -1)\n",
    "        x2 = F.softplus(x2)\n",
    "        out = torch.cat((x1, x2), dim=-1)\n",
    "        out = torch.cumsum(out, dim=-1)\n",
    "        \n",
    "        return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "8ee7d6da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_data_cont(task_type, df):\n",
    "    \n",
    "    if task_type == 'MLPAttention':\n",
    "        data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]    \n",
    "        \n",
    "    data_cont = data_cont.values\n",
    "    return data_cont\n",
    "\n",
    "def create_datasets_onechunk(configs, df, scaler):\n",
    "    \n",
    "    location = configs.location\n",
    "    gap = configs.gap\n",
    "    look_back = configs.look_back\n",
    "    pred_len = configs.pred_len\n",
    "    total_len = gap+look_back+pred_len\n",
    "    task_type = configs.task_type\n",
    "        \n",
    "    data_cont = read_data_cont(task_type, df)\n",
    "    df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
    "    data_time = df[['weekday','dayofyear', 'hour', 'month','year']]\n",
    "     \n",
    "    data_time = data_time.values\n",
    "\n",
    "    data_target = df[['total_grid']]\n",
    "    data_target = data_target.values\n",
    "    \n",
    "    data_cont = scaler.transform(data_cont)\n",
    "        \n",
    "    if location != \"TX\":\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,3],12)   # month of year\n",
    "        data_time = np.concatenate((W,H,M),0).T\n",
    "        data_context = np.concatenate((data_cont,data_time),1)\n",
    "    else:\n",
    "        cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        D = cycl_(data_time[:,1],365)  # day of year\n",
    "        H = cycl_(data_time[:,2],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,3],12)   # month of year\n",
    "        Y = torch.tensor(data_time[:,4] - 2022).unsqueeze(0) # year\n",
    "        data_time = np.concatenate((W,D,H,M,Y),0).T\n",
    "        data_context = np.concatenate((data_cont,data_time),1)\n",
    "        \n",
    "    def helper(x_train_hist, x_train_future, target):\n",
    "\n",
    "        x_train = []\n",
    "        y_train = []\n",
    "        \n",
    "        \n",
    "        for k in range(24):\n",
    "            x_train_sub = []\n",
    "            y_train_sub= []\n",
    "            y_train_sub.extend(target[k])\n",
    "            x_train_sub.extend(x_train_future[k])\n",
    "            for d in range(14):\n",
    "                x_train_sub.extend(x_train_hist[d*24 + k,:6])\n",
    "\n",
    "            x_train.append(np.array(x_train_sub).reshape(1,-1))\n",
    "            y_train.append(np.array(y_train_sub))\n",
    "        \n",
    "\n",
    "        x_train, y_train = np.concatenate(x_train), np.concatenate(y_train)\n",
    "        \n",
    "        return x_train, y_train\n",
    "\n",
    "    def Create_dataset(data_context,data_target):\n",
    "        x_train_total = []\n",
    "        y_train_total = []\n",
    "        \n",
    "        for i in range(0, len(data_context)-total_len, 24): \n",
    "            tempx = data_context[i:i+look_back,:]\n",
    "            tempy = data_context[i+look_back+gap:i+total_len,configs.masked:]\n",
    "            tempy_target = data_target[i+look_back+gap:i+total_len,0]\n",
    "            \n",
    "            tempx = np.array(tempx)\n",
    "            tempy = np.array(tempy)\n",
    "            tempy_target = np.array(tempy_target).reshape(-1,1)\n",
    "\n",
    "            x_train, y_train = helper(tempx, tempy, tempy_target)\n",
    "            x_train_total.append(x_train)\n",
    "            y_train_total.append(y_train)\n",
    "        \n",
    "        return np.concatenate(x_train_total), np.concatenate(y_train_total)\n",
    "\n",
    "    x_train, y_train = Create_dataset(data_context,data_target)\n",
    "    return x_train, y_train\n",
    "\n",
    "\n",
    "def create_datasets_location(df, configs, status = 1, valid = 1, use_scaler = 0):\n",
    "    location = configs.location\n",
    "    gap = configs.gap\n",
    "    look_back = configs.look_back\n",
    "    pred_len = configs.pred_len\n",
    "    total_len = gap+look_back+pred_len\n",
    "    task_type = configs.task_type\n",
    "    \n",
    "    data_conts = []\n",
    "    dfs = []\n",
    "    dfs_valid = []\n",
    "\n",
    "    df['date'] = pd.to_datetime(df['year'].astype(str) + df['dayofyear'].astype(str), format='%Y%j')\n",
    "    df['weekday'] = df['date'].dt.weekday\n",
    "    \n",
    "    data = df.copy()\n",
    "    \n",
    "    if valid:\n",
    "        df = data[(data[\"year\"] == 2022) | ((data[\"year\"] == 2023) & (data[\"dayofyear\"] <= 134))]\n",
    "        df_valid = data[(data[\"year\"] == 2023) & (data[\"dayofyear\"] >= 121)]\n",
    "        \n",
    "        for i in set(df_valid[\"valid_group\"]):\n",
    "            df_sub = df_valid[df_valid[\"valid_group\"] == i]\n",
    "            dfs_valid.append(df_sub)\n",
    "    else:\n",
    "        df = data.copy()\n",
    "\n",
    "    for i in set(df[\"train_group\"]):\n",
    "        df_sub = df[df[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(task_type, df_sub)\n",
    "        data_conts.append(data_cont)  \n",
    "    \n",
    "    if use_scaler:\n",
    "        scaler = use_scaler\n",
    "    else:\n",
    "        data_cont_cat = np.concatenate(data_conts, axis=0)\n",
    "        scaler = StandardScaler()\n",
    "        scaler.fit(data_cont_cat)\n",
    "    \n",
    "    os.makedirs(configs.path, exist_ok=True)\n",
    "    joblib.dump(scaler, configs.path+location+'_L='+str(look_back)+'_gap='+str(gap)+'_'+task_type+'_scaler.gz')\n",
    "\n",
    "    if status == 1:\n",
    "        x_train, y_train = [], []\n",
    "        for df in dfs:\n",
    "            if df.shape[0] > total_len:\n",
    "                x_train_sub, y_train_sub= create_datasets_onechunk(configs, df, scaler)\n",
    "            \n",
    "                x_train.append(x_train_sub)\n",
    "                y_train.append(y_train_sub)\n",
    "\n",
    "        x_valid, y_valid= [], []\n",
    "        for df in dfs_valid:\n",
    "            if df.shape[0] > total_len:\n",
    "                x_valid_sub, y_valid_sub= create_datasets_onechunk(configs, df, scaler)\n",
    "\n",
    "                x_valid.append(x_valid_sub)\n",
    "                y_valid.append(y_valid_sub)\n",
    "\n",
    "\n",
    "        input_datas = np.concatenate(x_train, axis=0)\n",
    "        output_datas = np.concatenate(y_train, axis=0)\n",
    "\n",
    "\n",
    "        input_datas_valid = np.concatenate(x_valid, axis=0)\n",
    "        output_datas_valid = np.concatenate(y_valid, axis=0)\n",
    "\n",
    "        return input_datas, output_datas, input_datas_valid, output_datas_valid\n",
    "    elif status == 0 or valid == 0:\n",
    "        \n",
    "        x_train, y_train = [], []\n",
    "        for df in dfs:\n",
    "            if df.shape[0] > total_len:\n",
    "                x_train_sub, y_train_sub= create_datasets_onechunk(configs, df, scaler)\n",
    "\n",
    "                x_train.append(x_train_sub)\n",
    "                y_train.append(y_train_sub)\n",
    "    \n",
    "        input_datas = np.concatenate(x_train, axis=0)\n",
    "        output_datas = np.concatenate(y_train, axis=0)\n",
    "        \n",
    "                        \n",
    "    return input_datas, output_datas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "142582a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "def GA_scaler(df, df2, configs, valid = 1):\n",
    "    \n",
    "    location = configs.location\n",
    "    gap = configs.gap\n",
    "    look_back = configs.look_back\n",
    "    pred_len = configs.pred_len\n",
    "    total_len = gap+look_back+pred_len\n",
    "    task_type = configs.task_type\n",
    "    \n",
    "    data_conts = []\n",
    "    dfs = []\n",
    "    dfs_valid = []\n",
    "\n",
    "    df['date'] = pd.to_datetime(df['year'].astype(str) + df['dayofyear'].astype(str), format='%Y%j')\n",
    "    df['weekday'] = df['date'].dt.weekday\n",
    "    \n",
    "    df2['date'] = pd.to_datetime(df2['year'].astype(str) + df2['dayofyear'].astype(str), format='%Y%j')\n",
    "    df2['weekday'] = df2['date'].dt.weekday\n",
    "    \n",
    "    data = df.copy()\n",
    "    data2 = df2.copy()\n",
    "    \n",
    "    if valid:\n",
    "        df = data[(data[\"year\"] == 2022) | ((data[\"year\"] == 2023) & (data[\"dayofyear\"] <= 134))]\n",
    "        df2 = data2[(data2[\"year\"] == 2022) | ((data2[\"year\"] == 2023) & (data2[\"dayofyear\"] <= 134))]   \n",
    "    else:\n",
    "        df = data.copy()\n",
    "        df2 = data2.copy()\n",
    "\n",
    "    for i in set(df[\"train_group\"]):\n",
    "        df_sub = df[df[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(task_type, df_sub)\n",
    "        data_conts.append(data_cont)  \n",
    "    \n",
    "    for i in set(df2[\"train_group\"]):\n",
    "        df_sub = df2[df2[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(task_type, df_sub)\n",
    "        data_conts.append(data_cont)  \n",
    "    \n",
    "\n",
    "    data_cont_cat = np.concatenate(data_conts, axis=0)\n",
    "    scaler = StandardScaler()\n",
    "    scaler.fit(data_cont_cat)\n",
    "    \n",
    "    return scaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9981de08",
   "metadata": {},
   "outputs": [],
   "source": [
    "def MLP_train(location, data_seed, valid = 1, num_epochs = 20):\n",
    "    \n",
    "    class ModelParams:\n",
    "        context_window = 336\n",
    "        target_window = 24\n",
    "        look_back = 336\n",
    "        pred_len = 24\n",
    "        gap = 24\n",
    "        masked = 1\n",
    "        location = location\n",
    "        task_type = \"MLPAttention\"\n",
    "        path = f'models_and_scalers/{data_seed}/'\n",
    "\n",
    "    params = ModelParams  \n",
    "\n",
    "    path = f\"../../../data/preprocess/Final/{data_seed}/\"\n",
    "\n",
    "    \n",
    "    if valid:\n",
    "        if location != \"GA\":\n",
    "            df = pd.read_csv(path + f\"{location}_{data_seed}.csv\")\n",
    "            df[\"localTime\"] = pd.to_datetime(df[\"localTime\"])\n",
    "            df[\"date\"] = df[\"localTime\"].dt.date\n",
    "            df_sub = df.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude = list(df_sub[df_sub[\"localTime\"]!=72][\"date\"])\n",
    "            df = df[~df[\"date\"].isin(exclude)].reset_index()\n",
    "            x_train, y_train, x_valid, y_valid = create_datasets_location(df, params)\n",
    "\n",
    "        if location == \"GA\":\n",
    "            df = pd.read_csv(path + f\"{location}_{data_seed}.csv\")\n",
    "            df2 = pd.read_csv(path + f\"{location}_{data_seed}_1.csv\")\n",
    "            \n",
    "            df[\"localTime\"] = pd.to_datetime(df[\"localTime\"])\n",
    "            df[\"date\"] = df[\"localTime\"].dt.date\n",
    "            df_sub = df.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude = list(df_sub[df_sub[\"localTime\"]!=72][\"date\"])\n",
    "            df = df[~df[\"date\"].isin(exclude)].reset_index()\n",
    "            ga_scaler = GA_scaler(df, df2, params, valid)\n",
    "            x_train, y_train = create_datasets_location(df, params, status = 0, valid = 0, use_scaler = ga_scaler)\n",
    "\n",
    "            \n",
    "            df2[\"localTime\"] = pd.to_datetime(df2[\"localTime\"])\n",
    "            df2[\"date\"] = df2[\"localTime\"].dt.date\n",
    "            df_sub2 = df2.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude2 = list(df_sub2[df_sub2[\"localTime\"]!=72][\"date\"])\n",
    "            df2 = df2[~df2[\"date\"].isin(exclude2)].reset_index()\n",
    "            \n",
    "            x_train2, y_train2, x_valid, y_valid = create_datasets_location(df2, params, use_scaler = ga_scaler)\n",
    "\n",
    "            x_train = np.concatenate([x_train, x_train2])\n",
    "            y_train = np.concatenate([y_train, y_train2])\n",
    "    else:\n",
    "        if location != \"GA\":\n",
    "            df = pd.read_csv(path + f\"{location}_{data_seed}.csv\")\n",
    "            df[\"localTime\"] = pd.to_datetime(df[\"localTime\"])\n",
    "            df[\"date\"] = df[\"localTime\"].dt.date\n",
    "            df_sub = df.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude = list(df_sub[df_sub[\"localTime\"]!=72][\"date\"])\n",
    "            df = df[~df[\"date\"].isin(exclude)].reset_index()\n",
    "            x_train, y_train = create_datasets_location(df, params, status = 0, valid = 0)\n",
    "\n",
    "        if location == \"GA\":\n",
    "            df = pd.read_csv(path + f\"{location}_{data_seed}.csv\")\n",
    "            df2 = pd.read_csv(path + f\"{location}_{data_seed}_1.csv\")\n",
    "            \n",
    "            df[\"localTime\"] = pd.to_datetime(df[\"localTime\"])\n",
    "            df[\"date\"] = df[\"localTime\"].dt.date\n",
    "            df_sub = df.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude = list(df_sub[df_sub[\"localTime\"]!=72][\"date\"])\n",
    "            df = df[~df[\"date\"].isin(exclude)].reset_index()\n",
    "            \n",
    "            ga_scaler = GA_scaler(df, df2, params, valid)\n",
    "           \n",
    "            x_train, y_train = create_datasets_location(df, params, status = 0, valid = 0, use_scaler = ga_scaler)\n",
    "            \n",
    "            df2[\"localTime\"] = pd.to_datetime(df2[\"localTime\"])\n",
    "            df2[\"date\"] = df2[\"localTime\"].dt.date\n",
    "            df_sub2 = df2.groupby([\"year\", \"date\"], as_index = False).count()\n",
    "            exclude2 = list(df_sub2[df_sub2[\"localTime\"]!=72][\"date\"])\n",
    "            df2 = df2[~df2[\"date\"].isin(exclude2)].reset_index()\n",
    "            \n",
    "            x_train2, y_train2 = create_datasets_location(df2, params, status = 0, valid = 0, use_scaler = ga_scaler)\n",
    "\n",
    "            x_train = np.concatenate([x_train, x_train2])\n",
    "            y_train = np.concatenate([y_train, y_train2])\n",
    "\n",
    "    x_train = torch.tensor(x_train, dtype=torch.float32)\n",
    "    y_train = torch.tensor(y_train, dtype=torch.float32)\n",
    "    train_dataset = TensorDataset(x_train, y_train)\n",
    "    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "    \n",
    "    if valid:\n",
    "        x_valid = torch.tensor(x_valid, dtype=torch.float32)\n",
    "        y_valid = torch.tensor(y_valid, dtype=torch.float32)\n",
    "        valid_dataset = TensorDataset(x_valid, y_valid)\n",
    "        valid_dataloader = DataLoader(valid_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "\n",
    "    if location == \"TX\":\n",
    "        model = MLPAttentionTX()\n",
    "    else:\n",
    "        model = MLPAttention()\n",
    "    \n",
    "    loss_func = CRPSLoss()\n",
    "\n",
    "    if location in [\"OR\",\"HI\"]:\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.005)\n",
    "    elif location in [\"TX\",\"GA\"]:\n",
    "        optimizer = optim.Adam(model.parameters(), lr=0.0025)\n",
    "    \n",
    "    scheduler = CosineAnnealingLR(optimizer, T_max=20)\n",
    "\n",
    "    best_valid_loss = float('inf')\n",
    "    best_epochs = 0\n",
    "    best_model_path = params.path + f'{location}_MLPAttention.pth'\n",
    "    \n",
    "    for epoch in range(num_epochs):\n",
    "        model.train()\n",
    "        train_loss = 0.0\n",
    "        for batch_x, batch_y in train_dataloader:\n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(batch_x)\n",
    "            batch_y = batch_y.unsqueeze(1).unsqueeze(1)\n",
    "            loss = loss_func(outputs, batch_y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            train_loss += loss.item() * batch_x.size(0)\n",
    "        train_loss /= len(train_dataloader.dataset)\n",
    "        \n",
    "        scheduler.step()\n",
    "        \n",
    "        if not valid:\n",
    "            torch.save(model.state_dict(), best_model_path)\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}]')\n",
    "        \n",
    "        if valid:\n",
    "            model.eval()\n",
    "            valid_loss = 0.0\n",
    "            with torch.no_grad():\n",
    "                for batch_x, batch_y in valid_dataloader:\n",
    "                    outputs = model(batch_x)\n",
    "                    batch_y = batch_y.unsqueeze(1).unsqueeze(1)\n",
    "                    loss = loss_func(outputs, batch_y)\n",
    "                    valid_loss += loss.item() * batch_x.size(0)\n",
    "            valid_loss /= len(valid_dataloader.dataset)\n",
    "\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}')\n",
    "\n",
    "            if valid_loss < best_valid_loss:\n",
    "                best_valid_loss = valid_loss\n",
    "                best_epochs = epoch\n",
    "                torch.save(model.state_dict(), best_model_path)\n",
    "    \n",
    "    if valid:\n",
    "        print(f'Best Valid Loss: {best_valid_loss:.4f}')\n",
    "        return best_epochs\n",
    "    else:\n",
    "        return"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "29cead72",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "####################################################################\n",
      "2005\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Train Loss: 0.0973, Valid Loss: 0.0321\n",
      "Epoch [2/20], Train Loss: 0.0369, Valid Loss: 0.0275\n",
      "Epoch [3/20], Train Loss: 0.0320, Valid Loss: 0.0296\n",
      "Epoch [4/20], Train Loss: 0.0298, Valid Loss: 0.0286\n",
      "Epoch [5/20], Train Loss: 0.0284, Valid Loss: 0.0317\n",
      "Epoch [6/20], Train Loss: 0.0281, Valid Loss: 0.0276\n",
      "Epoch [7/20], Train Loss: 0.0270, Valid Loss: 0.0299\n",
      "Epoch [8/20], Train Loss: 0.0262, Valid Loss: 0.0256\n",
      "Epoch [9/20], Train Loss: 0.0259, Valid Loss: 0.0293\n",
      "Epoch [10/20], Train Loss: 0.0254, Valid Loss: 0.0249\n",
      "Epoch [11/20], Train Loss: 0.0247, Valid Loss: 0.0291\n",
      "Epoch [12/20], Train Loss: 0.0242, Valid Loss: 0.0273\n",
      "Epoch [13/20], Train Loss: 0.0239, Valid Loss: 0.0291\n",
      "Epoch [14/20], Train Loss: 0.0236, Valid Loss: 0.0306\n",
      "Epoch [15/20], Train Loss: 0.0232, Valid Loss: 0.0290\n",
      "Epoch [16/20], Train Loss: 0.0228, Valid Loss: 0.0292\n",
      "Epoch [17/20], Train Loss: 0.0225, Valid Loss: 0.0305\n",
      "Epoch [18/20], Train Loss: 0.0224, Valid Loss: 0.0314\n",
      "Epoch [19/20], Train Loss: 0.0223, Valid Loss: 0.0291\n",
      "Epoch [20/20], Train Loss: 0.0222, Valid Loss: 0.0299\n",
      "Best Valid Loss: 0.0249\n",
      "Epoch [1/9, Train Loss: 0.0868]\n",
      "Epoch [2/9, Train Loss: 0.0326]\n",
      "Epoch [3/9, Train Loss: 0.0304]\n",
      "Epoch [4/9, Train Loss: 0.0288]\n",
      "Epoch [5/9, Train Loss: 0.0278]\n",
      "Epoch [6/9, Train Loss: 0.0267]\n",
      "Epoch [7/9, Train Loss: 0.0263]\n",
      "Epoch [8/9, Train Loss: 0.0254]\n",
      "Epoch [9/9, Train Loss: 0.0250]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Train Loss: 0.0562, Valid Loss: 0.0589\n",
      "Epoch [2/20], Train Loss: 0.0270, Valid Loss: 0.0392\n",
      "Epoch [3/20], Train Loss: 0.0243, Valid Loss: 0.0261\n",
      "Epoch [4/20], Train Loss: 0.0223, Valid Loss: 0.0287\n",
      "Epoch [5/20], Train Loss: 0.0214, Valid Loss: 0.0303\n",
      "Epoch [6/20], Train Loss: 0.0212, Valid Loss: 0.0384\n",
      "Epoch [7/20], Train Loss: 0.0207, Valid Loss: 0.0324\n",
      "Epoch [8/20], Train Loss: 0.0201, Valid Loss: 0.0305\n",
      "Epoch [9/20], Train Loss: 0.0197, Valid Loss: 0.0273\n",
      "Epoch [10/20], Train Loss: 0.0194, Valid Loss: 0.0356\n",
      "Epoch [11/20], Train Loss: 0.0191, Valid Loss: 0.0299\n",
      "Epoch [12/20], Train Loss: 0.0186, Valid Loss: 0.0369\n",
      "Epoch [13/20], Train Loss: 0.0182, Valid Loss: 0.0324\n",
      "Epoch [14/20], Train Loss: 0.0180, Valid Loss: 0.0348\n",
      "Epoch [15/20], Train Loss: 0.0177, Valid Loss: 0.0327\n",
      "Epoch [16/20], Train Loss: 0.0175, Valid Loss: 0.0333\n",
      "Epoch [17/20], Train Loss: 0.0173, Valid Loss: 0.0329\n",
      "Epoch [18/20], Train Loss: 0.0170, Valid Loss: 0.0346\n",
      "Epoch [19/20], Train Loss: 0.0170, Valid Loss: 0.0352\n",
      "Epoch [20/20], Train Loss: 0.0169, Valid Loss: 0.0335\n",
      "Best Valid Loss: 0.0261\n",
      "Epoch [1/2, Train Loss: 0.0570]\n",
      "Epoch [2/2, Train Loss: 0.0268]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Train Loss: 0.0559, Valid Loss: 0.0278\n",
      "Epoch [2/20], Train Loss: 0.0301, Valid Loss: 0.0257\n",
      "Epoch [3/20], Train Loss: 0.0273, Valid Loss: 0.0264\n",
      "Epoch [4/20], Train Loss: 0.0258, Valid Loss: 0.0252\n",
      "Epoch [5/20], Train Loss: 0.0251, Valid Loss: 0.0269\n",
      "Epoch [6/20], Train Loss: 0.0245, Valid Loss: 0.0297\n",
      "Epoch [7/20], Train Loss: 0.0240, Valid Loss: 0.0268\n",
      "Epoch [8/20], Train Loss: 0.0236, Valid Loss: 0.0254\n",
      "Epoch [9/20], Train Loss: 0.0232, Valid Loss: 0.0247\n",
      "Epoch [10/20], Train Loss: 0.0226, Valid Loss: 0.0231\n",
      "Epoch [11/20], Train Loss: 0.0224, Valid Loss: 0.0276\n",
      "Epoch [12/20], Train Loss: 0.0218, Valid Loss: 0.0253\n",
      "Epoch [13/20], Train Loss: 0.0215, Valid Loss: 0.0226\n",
      "Epoch [14/20], Train Loss: 0.0212, Valid Loss: 0.0231\n",
      "Epoch [15/20], Train Loss: 0.0209, Valid Loss: 0.0229\n",
      "Epoch [16/20], Train Loss: 0.0207, Valid Loss: 0.0225\n",
      "Epoch [17/20], Train Loss: 0.0204, Valid Loss: 0.0243\n",
      "Epoch [18/20], Train Loss: 0.0202, Valid Loss: 0.0249\n",
      "Epoch [19/20], Train Loss: 0.0201, Valid Loss: 0.0243\n",
      "Epoch [20/20], Train Loss: 0.0200, Valid Loss: 0.0244\n",
      "Best Valid Loss: 0.0225\n",
      "Epoch [1/15, Train Loss: 0.0476]\n",
      "Epoch [2/15, Train Loss: 0.0284]\n",
      "Epoch [3/15, Train Loss: 0.0266]\n",
      "Epoch [4/15, Train Loss: 0.0262]\n",
      "Epoch [5/15, Train Loss: 0.0253]\n",
      "Epoch [6/15, Train Loss: 0.0249]\n",
      "Epoch [7/15, Train Loss: 0.0242]\n",
      "Epoch [8/15, Train Loss: 0.0237]\n",
      "Epoch [9/15, Train Loss: 0.0231]\n",
      "Epoch [10/15, Train Loss: 0.0227]\n",
      "Epoch [11/15, Train Loss: 0.0223]\n",
      "Epoch [12/15, Train Loss: 0.0219]\n",
      "Epoch [13/15, Train Loss: 0.0215]\n",
      "Epoch [14/15, Train Loss: 0.0213]\n",
      "Epoch [15/15, Train Loss: 0.0208]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n",
      "C:\\Users\\scukp\\AppData\\Local\\Temp\\ipykernel_31540\\1323600634.py:19: SettingWithCopyWarning: \n",
      "A value is trying to be set on a copy of a slice from a DataFrame\n",
      "\n",
      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
      "  df.rename(columns = {\"timeofday\": \"hour\"}, inplace = True)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/20], Train Loss: 0.0881, Valid Loss: 0.0555\n",
      "Epoch [2/20], Train Loss: 0.0580, Valid Loss: 0.0668\n",
      "Epoch [3/20], Train Loss: 0.0552, Valid Loss: 0.0622\n",
      "Epoch [4/20], Train Loss: 0.0535, Valid Loss: 0.0567\n",
      "Epoch [5/20], Train Loss: 0.0533, Valid Loss: 0.0603\n",
      "Epoch [6/20], Train Loss: 0.0521, Valid Loss: 0.0815\n",
      "Epoch [7/20], Train Loss: 0.0513, Valid Loss: 0.0762\n",
      "Epoch [8/20], Train Loss: 0.0506, Valid Loss: 0.0605\n",
      "Epoch [9/20], Train Loss: 0.0498, Valid Loss: 0.0710\n",
      "Epoch [10/20], Train Loss: 0.0491, Valid Loss: 0.0563\n",
      "Epoch [11/20], Train Loss: 0.0489, Valid Loss: 0.0595\n",
      "Epoch [12/20], Train Loss: 0.0481, Valid Loss: 0.0628\n",
      "Epoch [13/20], Train Loss: 0.0476, Valid Loss: 0.0683\n",
      "Epoch [14/20], Train Loss: 0.0469, Valid Loss: 0.0692\n",
      "Epoch [15/20], Train Loss: 0.0466, Valid Loss: 0.0661\n",
      "Epoch [16/20], Train Loss: 0.0461, Valid Loss: 0.0634\n",
      "Epoch [17/20], Train Loss: 0.0457, Valid Loss: 0.0588\n",
      "Epoch [18/20], Train Loss: 0.0454, Valid Loss: 0.0634\n",
      "Epoch [19/20], Train Loss: 0.0452, Valid Loss: 0.0629\n",
      "Epoch [20/20], Train Loss: 0.0451, Valid Loss: 0.0652\n",
      "Best Valid Loss: 0.0555\n"
     ]
    }
   ],
   "source": [
    "for data_seed in range(2005, 2006):\n",
    "    print(\"####################################################################\")\n",
    "    print(data_seed)\n",
    "    for location in [\"GA\",\"TX\",\"OR\",\"HI\"]:\n",
    "        best_epochs = MLP_train(location, data_seed)\n",
    "        MLP_train(location, data_seed, valid = 0, num_epochs = best_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76b9a925",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
