{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42db30e8-5a77-4ff1-bc9b-0607d77d49e0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# import numpy as np\n",
    "# import torch\n",
    "# from torch import nn, optim\n",
    "# from torch.autograd import Variable\n",
    "# import torch.nn.functional as F\n",
    "# import time\n",
    "# import math\n",
    "# import matplotlib.pyplot as plt\n",
    "# from torch.utils.data import Dataset, DataLoader\n",
    "# from sklearn.preprocessing import MinMaxScaler\n",
    "# import csv\n",
    "# import pandas as pd\n",
    "# import os\n",
    "# from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "\n",
    "\n",
    "# base_dir = '/root/data'\n",
    "\n",
    "# start_year = 2010\n",
    "# end_year = 2010\n",
    "\n",
    "# def train_lstm(file_path,year):\n",
    "#     df = pd.read_csv(file_path)\n",
    "#     print(df.columns)\n",
    "    \n",
    "#     device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "#     print(\"Using device:\", device)\n",
    "#     data = df.copy()\n",
    "#     df = data[(data[\"year\"] == 2022) | ((data[\"year\"] == 2023) & (data[\"dayofyear\"] <= 134))]\n",
    "#     df_valid = data[(data[\"year\"] == 2023) & (data[\"dayofyear\"] >= 121)]\n",
    "    \n",
    "\n",
    "#     def read_data_cont(df):\n",
    "#         data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "#         data_cont = data_cont.values\n",
    "#         return data_cont\n",
    "#     def read_data_target(df):\n",
    "#         data_target = df[['total_grid']]\n",
    "#         data_target = data_target.values\n",
    "#         return data_target\n",
    "#     def read_data_time(df):\n",
    "#         data_time = df[['dayofweek','timeofday', 'month']]\n",
    "#         data_time = data_time.values\n",
    "#         return data_time\n",
    "    \n",
    "#     gap = 24\n",
    "#     look_back = 336\n",
    "#     pred_len = 24\n",
    "#     total_len = gap+look_back+pred_len\n",
    "#     scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "#     cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "\n",
    "\n",
    "#     def Create_dataset(df):\n",
    "        \n",
    "#         data_cont = read_data_cont(df)\n",
    "#         data_target = read_data_target(df)\n",
    "#         data_time = read_data_time(df)\n",
    "\n",
    "#         scaler.fit(data_cont)\n",
    "#         data_cont = scaler.transform(data_cont)\n",
    "\n",
    "#         W = cycl_(data_time[:,0],7)    # week of day\n",
    "#         H = cycl_(data_time[:,1],24)   # timeslot of the day\n",
    "#         M = cycl_(data_time[:,2]-1,12)   # month of year\n",
    "#         data_time = np.concatenate((W,H,M),0).T\n",
    "        \n",
    "#         data_context = np.concatenate((data_cont,data_time),1)\n",
    "\n",
    "#         data_X, data_Z, data_Y = [], [], []\n",
    "#         for i in range(len(data_context)-total_len): \n",
    "#             tempx = data_context[i:i+look_back,:].reshape(-1,1)\n",
    "#             tempz = data_context[i+look_back+gap:i+total_len,1:].reshape(-1,1)\n",
    "#             tempy = data_target[i+look_back+gap:i+total_len,0]\n",
    "\n",
    "\n",
    "#             data_X.append(tempx)\n",
    "#             data_Z.append(tempz)\n",
    "#             data_Y.append(tempy)\n",
    "\n",
    "#         data_X = np.array(data_X)\n",
    "#         data_Z = np.array(data_Z)\n",
    "#         data_Y = np.expand_dims(np.array(data_Y),-1)\n",
    "\n",
    "#         return data_X, data_Z, data_Y\n",
    "\n",
    "#     t1 = time.time()\n",
    "#     data_X, data_Z, data_Y = Create_dataset(df)\n",
    "#     t2 = time.time()\n",
    "#     print(t2-t1)\n",
    "#     print(data_X.shape,data_Z.shape)\n",
    "#     data_X = np.concatenate((data_X,data_Z),1)\n",
    "#     print(data_X.shape,data_Y.shape,data_Z.shape)\n",
    "\n",
    "#     t1 = time.time()\n",
    "#     data_X_valid, data_Z_valid, data_Y_valid = Create_dataset(df_valid)\n",
    "#     t2 = time.time()\n",
    "#     print(t2-t1)\n",
    "#     print(data_X_valid.shape)\n",
    "#     data_X_valid = np.concatenate((data_X_valid,data_Z_valid),1)\n",
    "#     print(data_X_valid.shape,data_Y_valid.shape,data_Z_valid.shape)\n",
    "    \n",
    "\n",
    "#     class Train(Dataset):\n",
    "#         def __init__(self, data):\n",
    "#             self.hist, self.future, self.label = data[:,:4032,:].float(), data[:,4032:4032+264,:].float(), data[:,-24:,:].float()\n",
    "\n",
    "#         def __getitem__(self, index):\n",
    "#             return self.hist[index], self.future[index], self.label[index]\n",
    "\n",
    "#         def __len__(self):\n",
    "#             return len(self.hist)\n",
    "\n",
    "#     train_loader = DataLoader(Train(torch.cat((torch.tensor(data_X),torch.tensor(data_Y)),1)), batch_size=64,shuffle=True)\n",
    "#     valid_loader = DataLoader(Train(torch.cat((torch.tensor(data_X_valid),torch.tensor(data_Y_valid)),1)), batch_size=100000,shuffle=False)\n",
    "\n",
    "#     class LSTM(nn.Module):\n",
    "\n",
    "#         def __init__(self):\n",
    "#             super(LSTM, self).__init__()\n",
    "\n",
    "#             self.lstm = nn.LSTM(\n",
    "#                 input_size=12,   \n",
    "#                 hidden_size=360,\n",
    "#                 num_layers=3, \n",
    "#                 batch_first=True,\n",
    "#                 dropout=0.3\n",
    "#             )\n",
    "\n",
    "#             self.ann2 = nn.Sequential(\n",
    "#                 nn.Linear(11,1),\n",
    "#                 # nn.Dropout(0.3)\n",
    "#             )\n",
    "#             self.out = nn.Sequential(\n",
    "#                 nn.Linear(360,11)\n",
    "#             )\n",
    "\n",
    "#         def forward(self, x, z):\n",
    "\n",
    "#             batch = x.shape[0]\n",
    "#             tx = torch.reshape(x,(batch,336,-1))\n",
    "#             tx, _ = self.lstm(tx)\n",
    "#             tz = torch.reshape(z,(batch,24,-1))\n",
    "#             temp = self.ann2(tz)\n",
    "#             out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "#             out = torch.unsqueeze(out,2)\n",
    "#             out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "#             out1 = torch.unsqueeze(out1, -1)\n",
    "#             out2 = F.softplus(out2)\n",
    "#             out = torch.cat((out1,out2),dim=-1)\n",
    "#             out = torch.cumsum(out, dim=-1)\n",
    "            \n",
    "#             return out\n",
    "\n",
    "\n",
    "#     class CRPSLoss(nn.Module):\n",
    "#         def __init__(self, quantiles=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1], adjusted=True):\n",
    "#             super().__init__()\n",
    "#             self.adjusted = adjusted\n",
    "#             if self.adjusted:\n",
    "#                 self.quantiles = torch.tensor([0]+quantiles+[1])\n",
    "#             else:\n",
    "#                 self.quantiles = torch.tensor(quantiles)\n",
    "\n",
    "#         def forward(self, preds, target): #preds:[B,N,T,Q] target:[B,N,T]\n",
    "#             assert not target.requires_grad\n",
    "#             assert preds.size(0) == target.size(0)\n",
    "\n",
    "#             # this is adjusted to avoid some extrme values that larger than the max or smaller than the min in quantiles, so a max bound and min bound are provided\n",
    "#             if self.adjusted:\n",
    "#                 B = preds.shape[0]\n",
    "#                 N = preds.shape[1]\n",
    "#                 T = preds.shape[2]\n",
    "#                 max_bound = 100\n",
    "#                 min_bound = -100\n",
    "#                 preds = torch.cat((min_bound*torch.ones((B, N, T, 1), device=preds.device), preds), dim=-1)\n",
    "#                 preds = torch.cat((preds, max_bound*torch.ones((B, N, T, 1), device=preds.device)), dim=-1)\n",
    "\n",
    "#             q_i1 = self.quantiles[1:].to(preds.device)\n",
    "#             q_i = self.quantiles[:-1].to(preds.device)\n",
    "#             X_i1 = preds[:,:,:,1:]\n",
    "#             X_i = preds[:,:,:,:-1]\n",
    "#             X_t = target.unsqueeze(3).repeat(1, 1, 1, X_i.shape[-1])\n",
    "\n",
    "#             # this index indicates which case this interval is in. 0 -> term0, 1 -> term1, 2 -> term2\n",
    "#             index = torch.full_like(X_i1, 2)\n",
    "#             index[X_t > X_i1] = 0\n",
    "#             index[X_t < X_i] = 1\n",
    "#             index = F.one_hot(index.to(torch.int64), num_classes=3) #ntqd\n",
    "\n",
    "#             # term0: F(x_i+1)>O(x_i+1)\n",
    "#             # term1: F(x_i)<O(x_i) \n",
    "#             # term2: F(x_i)>O(x_i) and F(x_i+1)<O(x_i+1)\n",
    "#             # All the terms are calculated with precise integration function\n",
    "#             term0 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, (q_i1**2+q_i1*q_i+q_i**2))\n",
    "#             term1 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, ((q_i1-1)**2+(q_i1-1)*(q_i-1)+(q_i-1)**2))\n",
    "#             #term2 = torch.einsum('bntq,q->bntq', X_i1-X_t, 1-2*q_i) + torch.einsum('bntq,q->bntq', ((X_i1-X_i)**2-(X_t-X_i)**2)/(X_i1-X_i), -(q_i1-q_i)) + term0\n",
    "#             # this is the same as the line below\n",
    "#             term2 = torch.einsum('bntq,q->bntq', X_t-X_i, 2*q_i-1) + torch.einsum('bntq,q->bntq', (X_t-X_i)**2/(X_i1-X_i), q_i1-q_i) + term1\n",
    "#             terms = torch.stack((term0,term1,term2),dim=-1)\n",
    "\n",
    "#             loss = torch.einsum('bntqd,bntqd->bntq', index.to(torch.float), terms.to(torch.float))\n",
    "#             return torch.mean(torch.sum(loss,dim=-1))\n",
    "\n",
    "#     LR = 0.001\n",
    "#     EPOCH = 100\n",
    "#     loss_func = CRPSLoss()\n",
    "\n",
    "\n",
    "#     lstm = LSTM().to(device)\n",
    "#     optimizer = torch.optim.Adam(lstm.parameters(), lr=LR)\n",
    "#     scheduler = CosineAnnealingLR(optimizer, T_max=20)\n",
    "\n",
    "#     best_loss = 1\n",
    "#     state = None\n",
    "#     Loss = []\n",
    "#     Val = []\n",
    "\n",
    "#     for epoch in range(EPOCH):\n",
    "#         for x, z, y in train_loader:\n",
    "#             x, z, y = x.to(device), z.to(device), y.to(device)\n",
    "\n",
    "#             lstm.train()\n",
    "#             optimizer.zero_grad()\n",
    "#             load_out = lstm(x,z)\n",
    "#             loss = loss_func(load_out,y)\n",
    "#             loss.backward()  \n",
    "#             optimizer.step()\n",
    "\n",
    "#         # scheduler.step()\n",
    "#         lstm.eval()\n",
    "#         with torch.no_grad():\n",
    "#             for tx, tz, ty in valid_loader:\n",
    "#                 tx, tz, ty = tx.to(device), tz.to(device), ty.to(device)\n",
    "#                 load_valout = lstm(tx,tz)\n",
    "#                 val_loss = loss_func(load_valout,ty)\n",
    "#             print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')\n",
    "\n",
    "#         if val_loss.cpu().item() < best_loss:\n",
    "#             best_loss = val_loss.cpu().item()\n",
    "#             torch.save(lstm.state_dict(), f'OR_{year}.pt')\n",
    "#             print(f'New model saved at epoch {epoch+1} with val_loss {best_loss:.4f}')\n",
    "\n",
    "#     print(f\"Trained lstm model from {file_path}\")\n",
    "    \n",
    "# for year in range(start_year, end_year + 1):\n",
    "#     filename = f\"OR_{year}.csv\"\n",
    "#     file_path = os.path.join(base_dir, filename)\n",
    "    \n",
    "#     # Check if the file exists before processing\n",
    "#     if os.path.exists(file_path):\n",
    "#         train_lstm(file_path,year)\n",
    "#     else:\n",
    "#         print(f\"File not found: {file_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "579f4f70-978f-49ac-965f-273c0de305d4",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "import time\n",
    "import math\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import csv\n",
    "import pandas as pd\n",
    "import os\n",
    "import joblib\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "\n",
    "base_dir = '/root/data'\n",
    "root_dir = '/root'\n",
    "scaler_dir = '/root/scaler_OR'\n",
    "\n",
    "start_year = 2005\n",
    "end_year = 2024\n",
    "def train_lstm(file_path,year):\n",
    "    df = pd.read_csv(file_path)\n",
    "    print(df.columns)\n",
    "    \n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(\"Using device:\", device)\n",
    "    data = df.copy()\n",
    "    df = data[(data[\"year\"] == 2022) | ((data[\"year\"] == 2023) & (data[\"dayofyear\"] <= 134))]\n",
    "    df_valid = data[(data[\"year\"] == 2023) & (data[\"dayofyear\"] >= 121)]\n",
    "\n",
    "    def read_data_cont(df):\n",
    "        data_cont = df[['consumption','solar','DNI','DHI','temperature','relativehumidity']]   \n",
    "        data_cont = data_cont.values\n",
    "        return data_cont\n",
    "    def read_data_target(df):\n",
    "        data_target = df[['total_grid']]\n",
    "        data_target = data_target.values\n",
    "        return data_target\n",
    "    def read_data_time(df):\n",
    "        data_time = df[['dayofweek','timeofday', 'month']]\n",
    "        data_time = data_time.values\n",
    "        return data_time\n",
    "    \n",
    "    gap = 24\n",
    "    look_back = 336\n",
    "    pred_len = 24\n",
    "    total_len = gap+look_back+pred_len\n",
    "    scaler = MinMaxScaler(feature_range=(0, 1))\n",
    "    cycl_ = lambda x,period : torch.tensor((np.sin(x / period * 2 * np.pi),np.cos(x / period * 2 * np.pi))).type(torch.float32)\n",
    "\n",
    "\n",
    "    def Create_dataset(df):\n",
    "        \n",
    "        data_cont = read_data_cont(df)\n",
    "        data_target = read_data_target(df)\n",
    "        data_time = read_data_time(df)\n",
    "\n",
    "        data_cont = scaler.transform(data_cont)\n",
    "\n",
    "        W = cycl_(data_time[:,0],7)    # week of day\n",
    "        H = cycl_(data_time[:,1],24)   # timeslot of the day\n",
    "        M = cycl_(data_time[:,2]-1,12)   # month of year\n",
    "        data_time = np.concatenate((W,H,M),0).T\n",
    "        \n",
    "        data_context = np.concatenate((data_cont,data_time),1)\n",
    "\n",
    "        data_V, data_W, data_Y = [], [], []\n",
    "        for i in range(len(data_context)-total_len): \n",
    "            tempv = data_context[i:i+look_back,:6].reshape(-1,1)\n",
    "            tempw = data_context[i+look_back+gap:i+total_len,1:].reshape(-1,1)\n",
    "            tempy = data_target[i+look_back+gap:i+total_len,0]\n",
    "\n",
    "\n",
    "            data_V.append(tempv)\n",
    "            data_W.append(tempw)\n",
    "            data_Y.append(tempy)\n",
    "\n",
    "        data_V = np.array(data_V)\n",
    "        data_W = np.array(data_W)\n",
    "        data_X = np.concatenate((data_V,data_W),1)\n",
    "        data_Y = np.expand_dims(np.array(data_Y),-1)\n",
    "\n",
    "        return data_X, data_Y\n",
    "\n",
    "    dfs = []\n",
    "    dfs_valid = []\n",
    "    data_conts = []\n",
    "    \n",
    "    for i in set(df[\"train_group\"]):\n",
    "        df_sub = df[df[\"train_group\"] == i]\n",
    "        dfs.append(df_sub)\n",
    "        data_cont = read_data_cont(df_sub)\n",
    "        data_conts.append(data_cont)\n",
    "    \n",
    "    data_cont_cat = np.concatenate(data_conts, axis=0)\n",
    "    scaler.fit(data_cont_cat)\n",
    "    joblib.dump(scaler, os.path.join(scaler_dir, 'OR_'+ str(year) + '_scaler.gz'))\n",
    "    \n",
    "    for i in set(df_valid[\"valid_group\"]):\n",
    "        df_sub = df_valid[df_valid[\"valid_group\"] == i]\n",
    "        dfs_valid.append(df_sub)\n",
    "    \n",
    "    t1 = time.time()\n",
    "    train_X, train_Y = [], []\n",
    "    for df in dfs:\n",
    "        data_X, data_Y = Create_dataset(df)\n",
    "        train_X.append(data_X)\n",
    "        train_Y.append(data_Y)\n",
    "    \n",
    "    input_X = np.concatenate(train_X, axis=0)\n",
    "    input_Y = np.concatenate(train_Y, axis=0)\n",
    "    t2 = time.time()\n",
    "    print(t2-t1)\n",
    "    print(input_X.shape,input_Y.shape)\n",
    "    \n",
    "    t1 = time.time()\n",
    "    valid_X, valid_Y = [], []\n",
    "    for df in dfs_valid:\n",
    "        data_X, data_Y = Create_dataset(df)\n",
    "        valid_X.append(data_X)\n",
    "        valid_Y.append(data_Y)\n",
    "        \n",
    "    input_X_valid = np.concatenate(valid_X, axis=0)\n",
    "    input_Y_valid = np.concatenate(valid_Y, axis=0)   \n",
    "    t2 = time.time()\n",
    "    print(t2-t1)\n",
    "    print(input_X_valid.shape,input_Y_valid.shape)\n",
    "\n",
    "    class Train(Dataset):\n",
    "        def __init__(self, data):\n",
    "            self.hist, self.future, self.label = data[:,:2016,:].float(), data[:,2016:2016+264,:].float(), data[:,-24:,:].float()\n",
    "\n",
    "        def __getitem__(self, index):\n",
    "            return self.hist[index], self.future[index], self.label[index]\n",
    "\n",
    "        def __len__(self):\n",
    "            return len(self.hist)\n",
    "\n",
    "    train_loader = DataLoader(Train(torch.cat((torch.tensor(input_X),torch.tensor(input_Y)),1)), batch_size=16,shuffle=True)\n",
    "    valid_loader = DataLoader(Train(torch.cat((torch.tensor(input_X_valid),torch.tensor(input_Y_valid)),1)), batch_size=len(input_X_valid),shuffle=False)\n",
    "\n",
    "    class LSTM(nn.Module):\n",
    "\n",
    "        def __init__(self):\n",
    "            super(LSTM, self).__init__()\n",
    "\n",
    "            self.lstm = nn.LSTM(\n",
    "                input_size=6,   \n",
    "                hidden_size=60,\n",
    "                num_layers=1, \n",
    "                batch_first=True,\n",
    "                # dropout=0.3\n",
    "            )\n",
    "\n",
    "            self.ann2 = nn.Sequential(\n",
    "                nn.Linear(11,11),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(11,1),\n",
    "                # nn.Dropout(0.2)\n",
    "            )\n",
    "            self.out = nn.Sequential(\n",
    "                nn.Linear(60,11),\n",
    "            )\n",
    "\n",
    "        def forward(self, x, z):\n",
    "\n",
    "            batch = x.shape[0]\n",
    "            tx = torch.reshape(x,(batch,336,-1))\n",
    "            tx, _ = self.lstm(tx)\n",
    "            tz = torch.reshape(z,(batch,24,-1))\n",
    "            temp = self.ann2(tz)\n",
    "            out = self.out(tx[:,-24:,:]+temp)\n",
    "\n",
    "            out = torch.unsqueeze(out,2)\n",
    "            out1, out2 = out[:,:,:,0], out[:,:,:,1:]\n",
    "            out1 = torch.unsqueeze(out1, -1)\n",
    "            out2 = F.softplus(out2)\n",
    "            out = torch.cat((out1,out2),dim=-1)\n",
    "            out = torch.cumsum(out, dim=-1)\n",
    "            \n",
    "            return out\n",
    "\n",
    "\n",
    "    class CRPSLoss(nn.Module):\n",
    "        def __init__(self, quantiles=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1], adjusted=True):\n",
    "            super().__init__()\n",
    "            self.adjusted = adjusted\n",
    "            if self.adjusted:\n",
    "                self.quantiles = torch.tensor([0]+quantiles+[1])\n",
    "            else:\n",
    "                self.quantiles = torch.tensor(quantiles)\n",
    "\n",
    "        def forward(self, preds, target): #preds:[B,N,T,Q] target:[B,N,T]\n",
    "            assert not target.requires_grad\n",
    "            assert preds.size(0) == target.size(0)\n",
    "\n",
    "            # this is adjusted to avoid some extrme values that larger than the max or smaller than the min in quantiles, so a max bound and min bound are provided\n",
    "            if self.adjusted:\n",
    "                B = preds.shape[0]\n",
    "                N = preds.shape[1]\n",
    "                T = preds.shape[2]\n",
    "                max_bound = 100\n",
    "                min_bound = -100\n",
    "                preds = torch.cat((min_bound*torch.ones((B, N, T, 1), device=preds.device), preds), dim=-1)\n",
    "                preds = torch.cat((preds, max_bound*torch.ones((B, N, T, 1), device=preds.device)), dim=-1)\n",
    "\n",
    "            q_i1 = self.quantiles[1:].to(preds.device)\n",
    "            q_i = self.quantiles[:-1].to(preds.device)\n",
    "            X_i1 = preds[:,:,:,1:]\n",
    "            X_i = preds[:,:,:,:-1]\n",
    "            X_t = target.unsqueeze(3).repeat(1, 1, 1, X_i.shape[-1])\n",
    "\n",
    "            # this index indicates which case this interval is in. 0 -> term0, 1 -> term1, 2 -> term2\n",
    "            index = torch.full_like(X_i1, 2)\n",
    "            index[X_t > X_i1] = 0\n",
    "            index[X_t < X_i] = 1\n",
    "            index = F.one_hot(index.to(torch.int64), num_classes=3) #ntqd\n",
    "\n",
    "            # term0: F(x_i+1)>O(x_i+1)\n",
    "            # term1: F(x_i)<O(x_i) \n",
    "            # term2: F(x_i)>O(x_i) and F(x_i+1)<O(x_i+1)\n",
    "            # All the terms are calculated with precise integration function\n",
    "            term0 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, (q_i1**2+q_i1*q_i+q_i**2))\n",
    "            term1 = 1/3*torch.einsum('bntq,q->bntq', X_i1-X_i, ((q_i1-1)**2+(q_i1-1)*(q_i-1)+(q_i-1)**2))\n",
    "            #term2 = torch.einsum('bntq,q->bntq', X_i1-X_t, 1-2*q_i) + torch.einsum('bntq,q->bntq', ((X_i1-X_i)**2-(X_t-X_i)**2)/(X_i1-X_i), -(q_i1-q_i)) + term0\n",
    "            # this is the same as the line below\n",
    "            term2 = torch.einsum('bntq,q->bntq', X_t-X_i, 2*q_i-1) + torch.einsum('bntq,q->bntq', (X_t-X_i)**2/(X_i1-X_i), q_i1-q_i) + term1\n",
    "            terms = torch.stack((term0,term1,term2),dim=-1)\n",
    "\n",
    "            loss = torch.einsum('bntqd,bntqd->bntq', index.to(torch.float), terms.to(torch.float))\n",
    "            return torch.mean(torch.sum(loss,dim=-1))\n",
    "\n",
    "    # LR = 0.00045\n",
    "    LR = 0.004\n",
    "    EPOCH = 20\n",
    "    loss_func = CRPSLoss()\n",
    "\n",
    "\n",
    "    lstm = LSTM().to(device)\n",
    "    optimizer = torch.optim.Adam(lstm.parameters(), lr=LR) \n",
    "    scheduler = CosineAnnealingLR(optimizer, T_max=EPOCH)\n",
    "    \n",
    "    best_loss = float('inf')\n",
    "    state = None\n",
    "    Loss = []\n",
    "    Val = []\n",
    "\n",
    "    for epoch in range(EPOCH):\n",
    "        lstm.train()\n",
    "        for x, z, y in train_loader:\n",
    "            x, z, y = x.to(device), z.to(device), y.to(device)\n",
    "            \n",
    "            optimizer.zero_grad()\n",
    "            load_out = lstm(x,z)\n",
    "            loss = loss_func(load_out,y)\n",
    "            loss.backward()  \n",
    "            optimizer.step()\n",
    "            # scheduler.step()\n",
    "\n",
    "        lstm.eval()\n",
    "        with torch.no_grad():\n",
    "            for tx, tz, ty in valid_loader:\n",
    "                tx, tz, ty = tx.to(device), tz.to(device), ty.to(device)\n",
    "                load_valout = lstm(tx,tz)\n",
    "                val_loss = loss_func(load_valout,ty)\n",
    "            print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}')\n",
    "\n",
    "        if val_loss.cpu().item() < best_loss:\n",
    "            best_loss = val_loss.cpu().item()\n",
    "            torch.save(lstm.state_dict(), f'OR_{year}.pt')\n",
    "            print(f'New model saved at epoch {epoch+1} with val_loss {best_loss:.4f}')\n",
    "\n",
    "    print(f\"Trained lstm model from {file_path}\")\n",
    "    \n",
    "for year in range(start_year, end_year + 1):\n",
    "    filename = f\"OR_{year}.csv\"\n",
    "    file_path = os.path.join(base_dir, filename)\n",
    "    \n",
    "    # Check if the file exists before processing\n",
    "    if os.path.exists(file_path):\n",
    "        train_lstm(file_path,year)\n",
    "    else:\n",
    "        print(f\"File not found: {file_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75bfff35-07c2-4316-b981-d689b234eef4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf6ce8c-0215-489e-b390-03c1ad8bb4fe",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c81aa1e-3e16-4687-93bb-9fed58fd841d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9edd5f22-9542-42be-8edc-13b6c98edc92",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "220444c5-0745-4ce8-82f5-b608606696f8",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
