{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58e87f88-cd51-451a-ba6d-451bf0cf8e30",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import csv\n",
    "import pandas as pd\n",
    "import os\n",
    "import joblib\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "\n",
    "data_dir = '/root/data_GA'\n",
    "model_dir = '/root'\n",
    "scaler_dir = '/root/scaler_GA'\n",
    "\n",
    "start_year = 2005\n",
    "end_year = 2024\n",
    "def train_lstm(file_path,file_path_sup,year):\n",
    "    df = pd.read_csv(file_path)\n",
    "    df_sup = pd.read_csv(file_path_sup)\n",
    "    print(df.columns)\n",
    "    \n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"Using device:\", device)\n",
    "    data = df.copy()\n",
    "    df = data[(data[\"year\"] == 2022) | ((data[\"year\"] == 2023) & (data[\"dayofyear\"] <= 134))]\n",
    "    df_valid = data[(data[\"year\"] == 2023) & (data[\"dayofyear\"] >= 121)]\n",
    "    \n",
    "    \n",
    "    data_sup = df_sup.copy()\n",
    "    df_sup = data_sup[(data_sup[\"year\"] == 2022) | ((data_sup[\"year\"] == 2023) & (data_sup[\"dayofyear\"] <= 134))]\n",
    "    \n",
    "    print(df.shape,df_valid.shape,df_sup.shape)\n",
    "    \n",
    "    def read_data_cont(df):\n",
    "        data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "        data_cont = data_cont.values\n",
    "        return data_cont\n",
    "    def read_data_target(df):\n",
    "        data_target = df[['total_grid']]\n",
    "        data_target = data_target.values\n",
    "        return data_target\n",
    "    def read_data_time(df):\n",
    "        data_time = df[['dayofweek','timeofday', 'month']]\n",
    "        data_time = data_time.values\n",
    "        return data_time\n",
    "    \n",
    "    gap = 24\n",
    "    look_back = 336\n",
    "    pred_len = 24\n",
    "    total_len = gap+look_back+pred_len\n",
    "    scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "    # scaler = StandardScaler()\n",
    "    cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "\n",
    "\n",
    "    def Create_dataset(df):\n",
    "        \n",
    "        data_cont = read_data_cont(df)\n",
    "        data_target = read_data_target(df)\n",
    "        data_time = read_data_time(df)\n",
    "\n",
    "        data_cont = scaler.transform(data_cont)\n",
    "\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        H = cycl_(data_time[:,1],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,2]-1,12)   # month of year\n",
    "        data_time = np.concatenate((W,H,M),0).T\n",
    "        \n",
    "        data_context = np.concatenate((data_cont,data_time),1)\n",
    "\n",
    "        data_V, data_W, data_Y = [], [], []\n",
    "        for i in range(len(data_context)-total_len): \n",
    "            tempv = data_context[i:i+look_back,:6].reshape(-1,1)\n",
    "            tempw = data_context[i+look_back+gap:i+total_len,1:].reshape(-1,1)\n",
    "            tempy = data_target[i+look_back+gap:i+total_len,0]\n",
    "\n",
    "\n",
    "            data_V.append(tempv)\n",
    "            data_W.append(tempw)\n",
    "            data_Y.append(tempy)\n",
    "\n",
    "        data_V = np.array(data_V)\n",
    "        data_W = np.array(data_W)\n",
    "        data_X = np.concatenate((data_V,data_W),1)\n",
    "        data_Y = np.expand_dims(np.array(data_Y),-1)\n",
    "\n",
    "        return data_X, data_Y\n",
    "\n",
    "    dfs = []\n",
    "    dfs_valid = []\n",
    "    data_conts = []\n",
    "    \n",
    "    for i in set(df[\"train_group\"]):\n",
    "        df_sub = df[df[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(df_sub)\n",
    "        data_conts.append(data_cont)\n",
    "    \n",
    "    for i in set(df_sup[\"train_group\"]):\n",
    "        df_sub = df_sup[df_sup[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(df_sub)\n",
    "        data_conts.append(data_cont)\n",
    "    \n",
    "    data_cont_cat = np.concatenate(data_conts, axis=0)\n",
    "    scaler.fit(data_cont_cat)\n",
    "    joblib.dump(scaler, os.path.join(scaler_dir, 'GA_'+ str(year) + '_scaler.gz'))\n",
    "    \n",
    "    for i in set(df_valid[\"valid_group\"]):\n",
    "        df_sub = df_valid[df_valid[\"valid_group\"] == i]\n",
    "        dfs_valid.append(df_sub)\n",
    "    \n",
    "    \n",
    "    t1 = time.time()\n",
    "    train_X, train_Y = [], []\n",
    "    for df in dfs:\n",
    "        data_X, data_Y = Create_dataset(df)\n",
    "        train_X.append(data_X)\n",
    "        train_Y.append(data_Y)\n",
    "    \n",
    "    input_X = np.concatenate(train_X, axis=0)\n",
    "    input_Y = np.concatenate(train_Y, axis=0)\n",
    "    t2 = time.time()\n",
    "    print(t2-t1)\n",
    "    print(input_X.shape,input_Y.shape)\n",
    "    \n",
    "    t1 = time.time()\n",
    "    valid_X, valid_Y = [], []\n",
    "    for df in dfs_valid:\n",
    "        data_X, data_Y = Create_dataset(df)\n",
    "        valid_X.append(data_X)\n",
    "        valid_Y.append(data_Y)\n",
    "        \n",
    "    input_X_valid = np.concatenate(valid_X, axis=0)\n",
    "    input_Y_valid = np.concatenate(valid_Y, axis=0)   \n",
    "    t2 = time.time()\n",
    "    print(t2-t1)\n",
    "    print(input_X_valid.shape,input_Y_valid.shape)\n",
    "\n",
    "    class Train(Dataset):\n",
    "        def __init__(self, data):\n",
    "            self.hist, self.future, self.label = data[:,:2016,:].float(), data[:,2016:2016+264,:].float(), data[:,-24:,:].float()\n",
    "\n",
    "        def __getitem__(self, index):\n",
    "            return self.hist[index], self.future[index], self.label[index]\n",
    "\n",
    "        def __len__(self):\n",
    "            return len(self.hist)\n",
    "    train_loader = DataLoader(Train(torch.cat((torch.tensor(input_X),torch.tensor(input_Y)),1)), batch_size=32,shuffle=True)\n",
    "    valid_loader = DataLoader(Train(torch.cat((torch.tensor(input_X_valid),torch.tensor(input_Y_valid)),1)), batch_size=len(input_X_valid),shuffle=False)\n",
    "\n",
    "    class LSTM(nn.Module):\n",
    "\n",
    "        def __init__(self):\n",
    "            super(LSTM, self).__init__()\n",
    "\n",
    "            self.lstm = nn.LSTM(\n",
    "                input_size=6,   \n",
    "                hidden_size=60,\n",
    "                num_layers=1, \n",
    "                batch_first=True,\n",
    "                # dropout=0.3\n",
    "            )\n",
    "\n",
    "            self.ann2 = nn.Sequential(\n",
    "                nn.Linear(11,11),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(11,1)\n",
    "            )\n",
    "            self.out = nn.Sequential(\n",
    "                nn.Linear(60,11),\n",
    "            )\n",
    "\n",
    "        def forward(self, x, z):\n",
    "\n",
    "            batch = x.shape[0]\n",
    "            tx = torch.reshape(x,(batch,336,-1))\n",
    "            tx, _ = self.lstm(tx)\n",
    "            tz = torch.reshape(z,(batch,24,-1))\n",
    "            temp = self.ann2(tz)\n",
    "            out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "            out = torch.unsqueeze(out,2)\n",
    "            out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "            out1 = torch.unsqueeze(out1, -1)\n",
    "            out2 = F.softplus(out2)\n",
    "            out = torch.cat((out1,out2),dim=-1)\n",
    "            out = torch.cumsum(out, dim=-1)\n",
    "            \n",
    "            return out\n",
    "\n",
    "\n",
    "    class CRPSLoss(nn.Module):\n",
    "        def __init__(self, quantiles=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1], adjusted=True):\n",
    "            super().__init__()\n",
    "            self.adjusted = adjusted\n",
    "            if self.adjusted:\n",
    "                self.quantiles = torch.tensor([0]+quantiles+[1])\n",
    "            else:\n",
    "                self.quantiles = torch.tensor(quantiles)\n",
    "\n",
    "        def forward(self, preds, target): #preds:[B,N,T,Q] target:[B,N,T]\n",
    "            assert not target.requires_grad\n",
    "            assert preds.size(0) == target.size(0)\n",
    "\n",
    "            # this is adjusted to avoid some extrme values that larger than the max or smaller than the min in quantiles, so a max bound and min bound are provided\n",
    "            if self.adjusted:\n",
    "                B = preds.shape[0]\n",
    "                N = preds.shape[1]\n",
    "                T = preds.shape[2]\n",
    "                max_bound = 100\n",
    "                min_bound = -100\n",
    "                preds = torch.cat((min_bound*torch.ones((B, N, T, 1), device=preds.device), preds), dim=-1)\n",
    "                preds = torch.cat((preds, max_bound*torch.ones((B, N, T, 1), device=preds.device)), dim=-1)\n",
    "\n",
    "            q_i1 = self.quantiles[1:].to(preds.device)\n",
    "            q_i = self.quantiles[:-1].to(preds.device)\n",
    "            X_i1 = preds[:,:,:,1:]\n",
    "            X_i = preds[:,:,:,:-1]\n",
    "            X_t = target.unsqueeze(3).repeat(1, 1, 1, X_i.shape[-1])\n",
    "\n",
    "            # this index indicates which case this interval is in. 0 -> term0, 1 -> term1, 2 -> term2\n",
    "            index = torch.full_like(X_i1, 2)\n",
    "            index[X_t > X_i1] = 0\n",
    "            index[X_t < X_i] = 1\n",
    "            index = F.one_hot(index.to(torch.int64), num_classes=3) #ntqd\n",
    "\n",
    "            # term0: F(x_i+1)>O(x_i+1)\n",
    "            # term1: F(x_i)<O(x_i) \n",
    "            # term2: F(x_i)>O(x_i) and F(x_i+1)<O(x_i+1)\n",
    "            # All the terms are calculated with precise integration function\n",
    "            term0 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, (q_i1**2+q_i1*q_i+q_i**2))\n",
    "            term1 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, ((q_i1-1)**2+(q_i1-1)*(q_i-1)+(q_i-1)**2))\n",
    "            #term2 = torch.einsum('bntq,q->bntq', X_i1-X_t, 1-2*q_i) + torch.einsum('bntq,q->bntq', ((X_i1-X_i)**2-(X_t-X_i)**2)/(X_i1-X_i), -(q_i1-q_i)) + term0\n",
    "            # this is the same as the line below\n",
    "            term2 = torch.einsum('bntq,q->bntq', X_t-X_i, 2*q_i-1) + torch.einsum('bntq,q->bntq', (X_t-X_i)**2/(X_i1-X_i), q_i1-q_i) + term1\n",
    "            terms = torch.stack((term0,term1,term2),dim=-1)\n",
    "\n",
    "            loss = torch.einsum('bntqd,bntqd->bntq', index.to(torch.float), terms.to(torch.float))\n",
    "            return torch.mean(torch.sum(loss,dim=-1))\n",
    "\n",
    "    LR = 0.002\n",
    "    EPOCH = 30\n",
    "    loss_func = CRPSLoss()\n",
    "\n",
    "\n",
    "    lstm = LSTM().to(device)\n",
    "    optimizer = torch.optim.Adam(lstm.parameters(), lr=LR) \n",
    "    # scheduler = CosineAnnealingLR(optimizer, T_max=5)\n",
    "    \n",
    "    best_loss = float('inf')\n",
    "    state = None\n",
    "    Loss = []\n",
    "    Val = []\n",
    "\n",
    "    for epoch in range(EPOCH):\n",
    "        lstm.train()\n",
    "        for x, z, y in train_loader:\n",
    "            x, z, y = x.to(device), z.to(device), y.to(device)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            load_out = lstm(x,z)\n",
    "            loss = loss_func(load_out,y)\n",
    "            loss.backward()  \n",
    "            optimizer.step()\n",
    "            # scheduler.step()\n",
    "\n",
    "        lstm.eval()\n",
    "        with torch.no_grad():\n",
    "            for tx, tz, ty in valid_loader:\n",
    "                tx, tz, ty = tx.to(device), tz.to(device), ty.to(device)\n",
    "                load_valout = lstm(tx,tz)\n",
    "                val_loss = loss_func(load_valout,ty)\n",
    "            print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')\n",
    "\n",
    "        if val_loss.cpu().item() < best_loss:\n",
    "            best_loss = val_loss.cpu().item()\n",
    "            torch.save(lstm.state_dict(), f'GA_{year}.pt')\n",
    "            print(f'New model saved at epoch {epoch+1} with val_loss {best_loss:.4f}')\n",
    "\n",
    "    print(f\"Trained lstm model from {file_path}\")\n",
    "    \n",
    "for year in range(start_year, end_year + 1):\n",
    "    filename = f\"GA_{year}_1.csv\"\n",
    "    filename_sup = f\"GA_{year}.csv\"\n",
    "    file_path = os.path.join(data_dir, filename)\n",
    "    file_path_sup = os.path.join(data_dir, filename_sup)\n",
    "    \n",
    "    # Check if the file exists before processing\n",
    "    if os.path.exists(file_path):\n",
    "        train_lstm(file_path,file_path_sup,year)\n",
    "    else:\n",
    "        print(f\"File not found: {file_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06186319-2b57-4e10-96db-f8f569540c1d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5548fc5-c0c2-4c12-ad84-a560d8e9f09b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
