{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn import Linear\n",
    "from torch_geometric.nn import GCNConv,GATConv,SAGEConv\n",
    "from torch_geometric.data import Data,Dataset\n",
    "from torch_geometric.loader  import DataLoader\n",
    "from torch_geometric.nn import MessagePassing\n",
    "from torch_geometric.utils import add_self_loops, degree\n",
    "import numpy as np\n",
    "from scipy.sparse import coo_matrix\n",
    "import pandas as pd\n",
    "import tqdm\n",
    "import random\n",
    "# from  torch.utils.data import TensorDataset,DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(12345)\n",
    "torch.cuda.manual_seed(12345)\n",
    "np.random.seed(12345)\n",
    "random.seed(12345)\n",
    "torch.backends.cudnn.deterministic = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 64\n",
    "time_len = 12\n",
    "learning_rate = 1e-2\n",
    "EPOCH = 5\n",
    "\n",
    "batch_first=True\n",
    "bidirectional=False \n",
    "input_size= 1 #经纬度\n",
    "hidden_size = 64\n",
    "num_layers=2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 工具包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#归一化\n",
    "class StandardScaler():\n",
    "    \"\"\"\n",
    "    Standard the input\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, mean, std):\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def transform(self, data):\n",
    "        return (data - self.mean) / self.std\n",
    "\n",
    "    def inverse_transform(self, data):\n",
    "        return (data * self.std) + self.mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#评价指标\n",
    "def masked_mse(preds, labels, null_val=np.nan):\n",
    "    if np.isnan(null_val):\n",
    "        mask = ~torch.isnan(labels)\n",
    "    else:\n",
    "        mask = (labels!=null_val)\n",
    "    mask = mask.float()\n",
    "    mask /= torch.mean((mask))\n",
    "    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)\n",
    "    loss = (preds-labels)**2\n",
    "    loss = loss * mask\n",
    "    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)\n",
    "    return torch.mean(loss)\n",
    "\n",
    "def masked_rmse(preds, labels, null_val=np.nan):\n",
    "    return torch.sqrt(masked_mse(preds=preds, labels=labels, null_val=null_val))\n",
    "\n",
    "\n",
    "def masked_mae(preds, labels, null_val=np.nan):\n",
    "    if np.isnan(null_val):\n",
    "        mask = ~torch.isnan(labels)\n",
    "    else:\n",
    "        mask = (labels!=null_val)\n",
    "    mask = mask.float()\n",
    "    mask /=  torch.mean((mask))\n",
    "    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)\n",
    "    loss = torch.abs(preds-labels)\n",
    "    loss = loss * mask\n",
    "    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)\n",
    "    return torch.mean(loss)\n",
    "\n",
    "\n",
    "def masked_mape(preds, labels, null_val=np.nan):\n",
    "    if np.isnan(null_val):\n",
    "        mask = ~torch.isnan(labels)\n",
    "    else:\n",
    "        mask = (labels!=null_val)\n",
    "    \n",
    "    mask = mask.float()\n",
    "    mask /=  torch.mean((mask))\n",
    "    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)\n",
    "    loss = torch.abs(preds-labels)/labels\n",
    "    loss = loss * mask\n",
    "    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)\n",
    "    return torch.mean(loss)\n",
    "\n",
    "\n",
    "def metric(pred, real):\n",
    "    mae = masked_mae(pred,real,0.0).item()\n",
    "    mape = masked_mape(pred,real,0.0).item()\n",
    "    rmse = masked_rmse(pred,real,0.0).item()\n",
    "    return mae,mape,rmse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test(test_dataloader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 流量数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n",
      "(207, 34272)\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "## 输入维度统一为 空间在前 时间在后\n",
    "traffic = np.load('/home/hqh/DataSetFile/metrla/STdata.npy').transpose()\n",
    "#traffic = np.load('/home/hqh/GWN/data/metr-la.h5').transpose()\n",
    "print(type(traffic))\n",
    "print(traffic.shape)\n",
    "print(np.isnan(traffic).sum())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 12]) torch.Size([34259, 207])\n"
     ]
    }
   ],
   "source": [
    "x = []\n",
    "y = []\n",
    "for i in range(1,len(traffic[0])-time_len):\n",
    "    x.append(traffic[:,i:i+time_len])#From [0.time_len) to [-1-time_len,-1) #4314,6\n",
    "    y.append(traffic[:,i+time_len])#From time_len to -1 # 4314,1\n",
    "x = torch.FloatTensor(x)\n",
    "y = torch.FloatTensor(y)\n",
    "print(x.shape,y.shape)# batch 空间 时间"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ND_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "ND_t = torch.load('/home/hqh/DataSetFile/metrla/ND_t.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 207, 15]) torch.Size([34259, 207])\n"
     ]
    }
   ],
   "source": [
    "x=torch.cat([x,ND_t],dim=-1)\n",
    "print(x.shape,y.shape)# batch 空间 时间"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = TensorDataset(x,y)\n",
    "# dataloader = DataLoader(dataset=data_set,batch_size=BATCH_SIZE,shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# features, targets = next(iter(dataloader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 邻接矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'numpy.ndarray'>\n"
     ]
    }
   ],
   "source": [
    "A=np.load('/home/hqh/DataSetFile/metrla/adj01.npy')\n",
    "print(type(A))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[1. 0. 0. ... 0. 0. 0.]\n",
      " [0. 1. 1. ... 0. 0. 0.]\n",
      " [0. 1. 1. ... 0. 0. 0.]\n",
      " ...\n",
      " [0. 0. 0. ... 1. 0. 0.]\n",
      " [0. 0. 0. ... 0. 1. 0.]\n",
      " [0. 0. 0. ... 0. 0. 1.]]\n",
      "1722 / 42849\n"
     ]
    }
   ],
   "source": [
    "print(A)\n",
    "print((A>0).sum(),'/',len(A)*len(A))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[  0,   0,   0,  ..., 206, 206, 206],\n",
      "        [  0,  13,  37,  ..., 187, 198, 206]])\n"
     ]
    }
   ],
   "source": [
    "adj = coo_matrix(A)\n",
    "values = adj.data  \n",
    "indices = np.vstack((adj.row, adj.col))  # 我们真正需要的coo形式\n",
    "adj = torch.LongTensor(indices)  # PyG框架需要的coo形式\n",
    "print(adj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### M 正向 负向"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "M_pos=torch.load('/home/hqh/DataSetFile/metrla/M_pos.pt')\n",
    "M_neg=torch.load('/home/hqh/DataSetFile/metrla/M_neg.pt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([34259, 1722, 2])\n"
     ]
    }
   ],
   "source": [
    "M=torch.cat((M_pos.unsqueeze(-1),M_neg.unsqueeze(-1)),dim=-1)#因为想存储在pyg的dataset中，所以unsqueeze后concat起来相当于2d的edge_attr\n",
    "print(M.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 构图dataloader:train/valid/test 分别是7 1 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def construct_dataloader(x,y,M,loader_type):\n",
    "    data_list = [] \n",
    "    for i in tqdm.tqdm(range(len(x))):\n",
    "        data = Data(x=x[i], edge_index=adj,y=y[i],edge_attr=M[i])\n",
    "        data_list.append(data)\n",
    "    if loader_type=='train':\n",
    "        dataloader = DataLoader(data_list, batch_size=BATCH_SIZE, shuffle=True)# 按照对角拼接起来，形成一个大图\n",
    "    else:\n",
    "        dataloader = DataLoader(data_list, batch_size=BATCH_SIZE, shuffle=False)# 按照对角拼接起来，形成一个大图\n",
    "    return dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████| 23981/23981 [00:00<00:00, 40017.83it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 3426/3426 [00:00<00:00, 66474.93it/s]\n",
      "100%|████████████████████████████████████████████████████████████████████████████| 6852/6852 [00:00<00:00, 29390.82it/s]\n"
     ]
    }
   ],
   "source": [
    "num_samples = len(x)\n",
    "train_ratio = 0.8\n",
    "test_ratio = 0.1\n",
    "\n",
    "num_test = round(num_samples * 0.2)\n",
    "num_train = round(num_samples * 0.7)\n",
    "num_val = num_samples - num_test - num_train\n",
    "x_train, y_train,M_train = x[:num_train], y[:num_train],M[:num_train]\n",
    "x_val, y_val,M_val =  x[num_train: num_train + num_val],y[num_train: num_train + num_val],M[num_train: num_train + num_val]\n",
    "x_test, y_test,M_test = x[-num_test:], y[-num_test:],M[-num_test:]\n",
    "\n",
    "\n",
    "scaler = StandardScaler(mean=x_train.mean(), std=x_train.std())\n",
    "x_train = scaler.transform(x_train)\n",
    "x_val = scaler.transform(x_val)\n",
    "x_test = scaler.transform(x_test)\n",
    "\n",
    "train_dataloader = construct_dataloader(x_train,y_train,M_train,'train')\n",
    "val_dataloader = construct_dataloader(x_val,y_val,M_val,'valid')\n",
    "test_dataloader = construct_dataloader(x_test,y_test,M_test,'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23981\n",
      "3426\n",
      "6852\n"
     ]
    }
   ],
   "source": [
    "print(train_dataloader.dataset.__len__())\n",
    "print(val_dataloader.dataset.__len__())\n",
    "print(test_dataloader.dataset.__len__())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型&优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "# class GCN(torch.nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super(GCN, self).__init__()\n",
    "#         self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.conv1 = GCNConv(hidden_size, 64)\n",
    "#         self.lstm2 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.mlp = Linear(hidden_size,1)\n",
    "\n",
    "#     def forward(self, batch):\n",
    "#         x, edge_index=batch.x,batch.edge_index\n",
    "#         out,(hn,cn) = self.lstm1(x.unsqueeze(-1))\n",
    "#         out = out.relu()\n",
    "#         out = self.conv1(out[:,-1,:], edge_index)\n",
    "#         out = out.relu()\n",
    "#         out,(hn,cn) = self.lstm2(x.unsqueeze(-1))\n",
    "#         out = out[:,-1,:].relu()\n",
    "#         out = self.mlp(out)\n",
    "#         return out\n",
    "    \n",
    "# class GAT(torch.nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super(GAT, self).__init__()\n",
    "#         self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.conv1 = GATConv(hidden_size, 64)\n",
    "#         self.lstm2 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.mlp = Linear(hidden_size,1)\n",
    "\n",
    "#     def forward(self, batch):\n",
    "#         x, edge_index=batch.x,batch.edge_index\n",
    "#         out,(hn,cn) = self.lstm1(x.unsqueeze(-1))\n",
    "#         out = out.relu()\n",
    "#         out = self.conv1(out[:,-1,:], edge_index)\n",
    "#         out = out.relu()\n",
    "#         out,(hn,cn) = self.lstm2(x.unsqueeze(-1))\n",
    "#         out = out[:,-1,:].relu()\n",
    "#         out = self.mlp(out)\n",
    "#         return out\n",
    "    \n",
    "# class GraphSAGE(torch.nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super(GraphSAGE, self).__init__()\n",
    "#         self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.conv1 = SAGEConv(hidden_size, 64)\n",
    "#         self.lstm2 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.mlp = Linear(hidden_size,1)\n",
    "\n",
    "#     def forward(self, batch):\n",
    "#         x, edge_index=batch.x,batch.edge_index\n",
    "#         out,(hn,cn) = self.lstm1(x.unsqueeze(-1))\n",
    "#         out = out.relu()\n",
    "#         out = self.conv1(out[:,-1,:], edge_index)\n",
    "#         out = out.relu()\n",
    "#         out,(hn,cn) = self.lstm2(x.unsqueeze(-1))\n",
    "#         out = out[:,-1,:].relu()\n",
    "#         out = self.mlp(out)\n",
    "#         return out\n",
    "    \n",
    "# class GraphSAGE(torch.nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super(GraphSAGE, self).__init__()\n",
    "#         self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.conv1 = SAGEConv(hidden_size, 64)\n",
    "#         self.lstm2 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.mlp = Linear(hidden_size,1)\n",
    "\n",
    "#     def forward(self, batch):\n",
    "#         x, edge_index=batch.x,batch.edge_index\n",
    "#         out,(hn,cn) = self.lstm1(x.unsqueeze(-1))\n",
    "#         out = out.relu()\n",
    "#         out = self.conv1(out[:,-1,:], edge_index)\n",
    "#         out = out.relu()\n",
    "#         out,(hn,cn) = self.lstm2(x.unsqueeze(-1))\n",
    "#         out = out[:,-1,:].relu()\n",
    "#         out = self.mlp(out)\n",
    "#         return out\n",
    "\n",
    "#         outneg = self.convneg (x=x, edge_index=edge_index,edge_weight=(edge_attr[:,0]+1))\n",
    "#         outpos = self.convpos (x=x, edge_index=edge_index,edge_weight=(edge_attr[:,1]+1))\n",
    "\n",
    "#         out = torch.cat((outpos,outneg),dim=-1)    \n",
    "\n",
    "# class WWW2023(torch.nn.Module):\n",
    "#     def __init__(self):\n",
    "#         super(WWW2023, self).__init__()\n",
    "#         self.lstm1 = torch.nn.LSTM(num_layers = 1,batch_first=batch_first,\n",
    "#                                bidirectional=bidirectional,input_size=input_size,hidden_size = hidden_size)\n",
    "#         self.convneg = GCNConv(12, 32)\n",
    "#         self.convpos = GCNConv(12, 32)\n",
    "# #         self.mlp = Linear(hidden_size,1)\n",
    "#         self.mlp = Linear(64,1)\n",
    "\n",
    "#     def forward(self, batch):\n",
    "#         x, edge_index,edge_attr=batch.x,batch.edge_index,batch.edge_attr\n",
    "#         outneg = self.convneg (x=x, edge_index=edge_index,edge_weight=(edge_attr[:,0]+1))\n",
    "#         outpos = self.convpos (x=x, edge_index=edge_index,edge_weight=(edge_attr[:,1]+1))\n",
    "#         out = torch.cat((outpos,outneg),dim=-1)\n",
    "#         out = out.relu()\n",
    "#         out,(hn,cn) = self.lstm1(out.unsqueeze(-1))\n",
    "#         out = out[:,-1,:].relu()\n",
    "\n",
    "#         out = self.mlp(out)\n",
    "#         return out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "# enable_padding 否 ：减少(k-1)*d 是：不减少\n",
    "class CausalConv1d(torch.nn.Conv1d):\n",
    "    def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True):\n",
    "        if enable_padding == True:\n",
    "            self.__padding = (kernel_size - 1) * dilation\n",
    "        else:\n",
    "            self.__padding = 0\n",
    "        super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, groups=groups, bias=bias)\n",
    "\n",
    "    def forward(self, input):\n",
    "        result = super(CausalConv1d, self).forward(input)\n",
    "        if self.__padding != 0:\n",
    "            return result[: , : , : -self.__padding]\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class First_Layer(torch.nn.Module):\n",
    "    # 使其剪掉2个维度的time 12->10\n",
    "    def __init__(self,in_channels=1,out_channels=64):\n",
    "        super(First_Layer, self).__init__()\n",
    "        self.conv = CausalConv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=3)\n",
    "    def forward(self, x):\n",
    "        # x:bat*n  channels time\n",
    "        out = self.conv(x)\n",
    "        x=x.relu()\n",
    "        return out\n",
    "class Out_Layer(torch.nn.Module):\n",
    "    # 使其剪掉1个维度的time 2->1\n",
    "    def __init__(self,in_channels=64,out_channels=1):\n",
    "        super(Out_Layer, self).__init__()\n",
    "        self.conv = torch.nn.Conv1d(in_channels=in_channels,out_channels=in_channels*2,kernel_size=2,padding=0)\n",
    "        self.mlp=Linear(in_channels*2,out_channels)\n",
    "    def forward(self, x):\n",
    "        # x:bat*n  channels time\n",
    "        out=self.conv(x).squeeze(-1).relu()\n",
    "        out=self.mlp(out)\n",
    "        return out  \n",
    "class Time_Conv(torch.nn.Module):\n",
    "    ### GTU tanh(x) · sigmoid(x)\n",
    "    ### time -2\n",
    "    def __init__(self,in_channels=16 ,out_channels=64,kernel_size=3):\n",
    "        super(Time_Conv,self).__init__()\n",
    "        self.conv1 = CausalConv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size)\n",
    "        self.conv2 = CausalConv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size)\n",
    "    def forward(self,x):\n",
    "        out1=self.conv1(x).tanh()\n",
    "        out2=self.conv2(x).sigmoid()\n",
    "        out=out1*out2\n",
    "        return out\n",
    "        \n",
    "class ST_Block(torch.nn.Module):\n",
    "    def __init__(self,tlen_spconv,in_channels=64,hidden_channels=16,out_channels=64,gnn_type='GCN'):\n",
    "        #第一个timeconv -2 第二个timeconv-2\n",
    "        self.gnn_type=gnn_type\n",
    "        super(ST_Block, self).__init__()\n",
    "        self.timeconv1 = Time_Conv(in_channels=in_channels,out_channels=hidden_channels,kernel_size=3)\n",
    "        if gnn_type=='GCN':\n",
    "            self.spatialconv = GCNConv(in_channels=hidden_channels*tlen_spconv, out_channels=hidden_channels*tlen_spconv)\n",
    "        elif gnn_type=='GAT':\n",
    "            self.spatialconv = GATConv(hidden_channels*tlen_spconv,  hidden_size/8,heads=8)\n",
    "        elif gnn_type=='WWW2023':\n",
    "            self.spatialconv_pos = GCNConv(in_channels=hidden_channels*tlen_spconv, out_channels=hidden_channels*tlen_spconv)\n",
    "            self.spatialconv_neg = GCNConv(in_channels=hidden_channels*tlen_spconv, out_channels=hidden_channels*tlen_spconv)\n",
    "            self.wwwmlp = Linear(hidden_channels*2,hidden_channels)\n",
    "            self.k = 3\n",
    "            self.psimlp = torch.nn.Sequential(Linear(3,64),torch.nn.ReLU(),Linear(64,self.k))\n",
    "        self.timeconv2 = Time_Conv(in_channels=hidden_channels,out_channels=out_channels,kernel_size=3)\n",
    "        \n",
    "    def forward(self, x,edge_index,edge_attr,ND_t):\n",
    "        # x:bat*n  channels time\n",
    "        out1 = self.timeconv1(x)\n",
    "        out1 = self.timeconv1(x)\n",
    "        d0,d1,d2=out1.shape[0],out1.shape[1],out1.shape[2]\n",
    "        if self.gnn_type=='WWW2023':\n",
    "            # b*n channels timelen     \n",
    "            psi = self.psimlp(ND_t).unsqueeze(-1).repeat_interleave(64,dim=-1).unsqueeze(-1).repeat_interleave(d2,dim=-1)\n",
    "            out=[out1]\n",
    "            for i in range(self.k):\n",
    "                out[i]=out[i].reshape(d0,-1)\n",
    "#                 print(out[i].shape)\n",
    "                out_pos  =self.spatialconv_pos(edge_index=edge_index,x=out[i],edge_weight=(edge_attr[:,0]+1)).reshape(d0,d1,d2).relu()\n",
    "                out.append(out_pos)\n",
    "                if i == 0:\n",
    "#                     print(psi.shape,out2.shape)\n",
    "                    out_pos_psi = psi[:,i]*out_pos\n",
    "                else:\n",
    "                    out_pos_psi += psi[:,i]*out_pos\n",
    "            out_neg  =self.spatialconv_neg(edge_index=edge_index,x=out[0],edge_weight=(edge_attr[:,1]+1)).reshape(d0,d1,d2).relu()\n",
    "            out2 = torch.cat([out_pos_psi,out_neg],dim=1)#通道维度叠加\n",
    "            out2 = self.wwwmlp(out2.transpose(1,2)).transpose(1,2).relu()\n",
    "                       \n",
    "        else:\n",
    "            out1 = out1.reshape(d0,-1)\n",
    "            out2   =self.spatialconv(edge_index=edge_index,x=out1).reshape(d0,d1,d2).relu()\n",
    "        out3 = self.timeconv2(out2 )\n",
    "        return out3\n",
    "\n",
    "class WWW2023(torch.nn.Module):\n",
    "    def __init__(self,in_channels=64,hidden_channels=16,out_channels=64,gnn_type='GCN'):\n",
    "        super(WWW2023, self).__init__()\n",
    "        self.firstlayer = First_Layer(in_channels=1,out_channels=in_channels)#t-2 t-4 t-4 t-1\n",
    "        self.stblock1 = ST_Block(tlen_spconv=8,in_channels=in_channels,hidden_channels=hidden_channels,\n",
    "                                 out_channels=out_channels,gnn_type=gnn_type)\n",
    "        self.stblock2 = ST_Block(tlen_spconv=4,in_channels=in_channels,hidden_channels=hidden_channels,\n",
    "                                 out_channels=out_channels,gnn_type=gnn_type)  \n",
    "        self.outlayer = Out_Layer(in_channels=out_channels,out_channels=1)\n",
    "\n",
    "    def forward(self, batch):\n",
    "        x,ND_t, edge_index,edge_attr=batch.x[:,:time_len],batch.x[:,time_len:],batch.edge_index,batch.edge_attr\n",
    "        out = self.firstlayer(x.unsqueeze(1))\n",
    "        out=self.stblock1(x=out,edge_index=edge_index,edge_attr=edge_attr,ND_t=ND_t)\n",
    "        out=self.stblock2(x=out,edge_index=edge_index,edge_attr=edge_attr,ND_t=ND_t)\n",
    "        out=self.outlayer(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WWW2023(\n",
      "  (firstlayer): First_Layer(\n",
      "    (conv): CausalConv1d(1, 64, kernel_size=(3,), stride=(1,))\n",
      "  )\n",
      "  (stblock1): ST_Block(\n",
      "    (timeconv1): Time_Conv(\n",
      "      (conv1): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (conv2): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "    )\n",
      "    (spatialconv_pos): GCNConv(512, 512)\n",
      "    (spatialconv_neg): GCNConv(512, 512)\n",
      "    (wwwmlp): Linear(in_features=128, out_features=64, bias=True)\n",
      "    (psimlp): Sequential(\n",
      "      (0): Linear(in_features=3, out_features=64, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=64, out_features=3, bias=True)\n",
      "    )\n",
      "    (timeconv2): Time_Conv(\n",
      "      (conv1): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (conv2): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "    )\n",
      "  )\n",
      "  (stblock2): ST_Block(\n",
      "    (timeconv1): Time_Conv(\n",
      "      (conv1): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (conv2): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "    )\n",
      "    (spatialconv_pos): GCNConv(256, 256)\n",
      "    (spatialconv_neg): GCNConv(256, 256)\n",
      "    (wwwmlp): Linear(in_features=128, out_features=64, bias=True)\n",
      "    (psimlp): Sequential(\n",
      "      (0): Linear(in_features=3, out_features=64, bias=True)\n",
      "      (1): ReLU()\n",
      "      (2): Linear(in_features=64, out_features=3, bias=True)\n",
      "    )\n",
      "    (timeconv2): Time_Conv(\n",
      "      (conv1): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "      (conv2): CausalConv1d(64, 64, kernel_size=(3,), stride=(1,))\n",
      "    )\n",
      "  )\n",
      "  (outlayer): Out_Layer(\n",
      "    (conv): Conv1d(64, 128, kernel_size=(2,), stride=(1,))\n",
      "    (mlp): Linear(in_features=128, out_features=1, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "model = WWW2023(in_channels=64,hidden_channels=64,out_channels=64,gnn_type='WWW2023')\n",
    "print(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer =  torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-3)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:1\"\n",
    "model = model.to(device)\n",
    "loss_fn = torch.nn.L1Loss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_one_epoch(dataloader,optimizer,scheduler):\n",
    "    model.train()\n",
    "    total_loss = 0 \n",
    "    total_batch = 0\n",
    "    for batch in tqdm.tqdm(dataloader):\n",
    "        batch = batch.to(device)\n",
    "        out = model(batch)\n",
    "        predict = scaler.inverse_transform(out)\n",
    "        loss = loss_fn(batch.y,predict.reshape(-1))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        total_loss+=loss.detach().cpu()\n",
    "        total_batch+=1\n",
    "    scheduler.step()\n",
    "    return total_loss/total_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def valid(dataloader):\n",
    "    model.eval()\n",
    "    total_loss = 0 \n",
    "    total_batch = 0\n",
    "    with torch.no_grad():\n",
    "        for batch in dataloader:\n",
    "            batch = batch.to(device)\n",
    "            out = model(batch)\n",
    "            predict = scaler.inverse_transform(out)\n",
    "            loss = loss_fn(batch.y,predict.reshape(-1))\n",
    "            total_loss+=loss.detach().cpu()\n",
    "            total_batch+=1\n",
    "    return total_loss/total_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(dataloader):\n",
    "    model.eval()\n",
    "    total_loss = 0 \n",
    "    total_batch = 0\n",
    "    pred_total = []\n",
    "    label_total = []\n",
    "    with torch.no_grad():\n",
    "        for batch in dataloader:\n",
    "            batch = batch.to(device)\n",
    "            out = model(batch)\n",
    "            predict = scaler.inverse_transform(out).reshape(-1)\n",
    "            pred_total.append(predict)\n",
    "            label_total.append(batch.y)\n",
    "    pred_total = torch.cat(pred_total,dim=0)\n",
    "    label_total = torch.cat(label_total,dim=0)\n",
    "    metrics = metric(pred_total,label_total)\n",
    "    log = 'Evaluate best model on test data, Test MAE: {:.4f}, Test MAPE: {:.4f}, Test RMSE: {:.4f}'\n",
    "    print(log.format( metrics[0], metrics[1], metrics[2]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.15it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1-th epoch | train_loss=6.587403774261475 | valid_loss=4.454043865203857 | best_valid_loss=4.454043865203857\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2-th epoch | train_loss=4.165935039520264 | valid_loss=3.7061917781829834 | best_valid_loss=3.7061917781829834\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3-th epoch | train_loss=3.7590579986572266 | valid_loss=3.3841521739959717 | best_valid_loss=3.3841521739959717\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.19it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4-th epoch | train_loss=3.5064468383789062 | valid_loss=3.243170738220215 | best_valid_loss=3.243170738220215\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "5-th epoch | train_loss=3.423447370529175 | valid_loss=3.1559369564056396 | best_valid_loss=3.1559369564056396\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "6-th epoch | train_loss=3.3015687465667725 | valid_loss=3.17211651802063 | best_valid_loss=3.1559369564056396\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7-th epoch | train_loss=3.180893898010254 | valid_loss=3.0025744438171387 | best_valid_loss=3.0025744438171387\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8-th epoch | train_loss=3.093505859375 | valid_loss=2.992825746536255 | best_valid_loss=2.992825746536255\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9-th epoch | train_loss=3.065075635910034 | valid_loss=2.941331148147583 | best_valid_loss=2.941331148147583\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10-th epoch | train_loss=3.0558300018310547 | valid_loss=2.8944852352142334 | best_valid_loss=2.8944852352142334\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11-th epoch | train_loss=2.9132025241851807 | valid_loss=2.798243999481201 | best_valid_loss=2.798243999481201\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "12-th epoch | train_loss=2.9189908504486084 | valid_loss=2.8401050567626953 | best_valid_loss=2.798243999481201\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13-th epoch | train_loss=2.9319117069244385 | valid_loss=2.9353597164154053 | best_valid_loss=2.798243999481201\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "14-th epoch | train_loss=2.8031253814697266 | valid_loss=2.6675703525543213 | best_valid_loss=2.6675703525543213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "15-th epoch | train_loss=2.83851957321167 | valid_loss=2.7119412422180176 | best_valid_loss=2.6675703525543213\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "16-th epoch | train_loss=2.735198974609375 | valid_loss=2.5450968742370605 | best_valid_loss=2.5450968742370605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "17-th epoch | train_loss=2.778977632522583 | valid_loss=3.047562599182129 | best_valid_loss=2.5450968742370605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "18-th epoch | train_loss=2.9332809448242188 | valid_loss=2.868563652038574 | best_valid_loss=2.5450968742370605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19-th epoch | train_loss=2.910163402557373 | valid_loss=2.73077654838562 | best_valid_loss=2.5450968742370605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "20-th epoch | train_loss=2.8255786895751953 | valid_loss=2.957097053527832 | best_valid_loss=2.5450968742370605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "21-th epoch | train_loss=2.7406363487243652 | valid_loss=2.5086824893951416 | best_valid_loss=2.5086824893951416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "22-th epoch | train_loss=2.6796677112579346 | valid_loss=2.669632911682129 | best_valid_loss=2.5086824893951416\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23-th epoch | train_loss=2.690859794616699 | valid_loss=2.4165332317352295 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "24-th epoch | train_loss=2.7854270935058594 | valid_loss=2.643892526626587 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25-th epoch | train_loss=2.65273118019104 | valid_loss=2.499760866165161 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "26-th epoch | train_loss=2.5873968601226807 | valid_loss=2.5249576568603516 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "27-th epoch | train_loss=2.5016374588012695 | valid_loss=2.521864891052246 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "28-th epoch | train_loss=2.526870012283325 | valid_loss=2.4695591926574707 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "29-th epoch | train_loss=2.9197938442230225 | valid_loss=2.661842107772827 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "30-th epoch | train_loss=2.555002450942993 | valid_loss=2.486650228500366 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "31-th epoch | train_loss=2.6769771575927734 | valid_loss=2.4545445442199707 | best_valid_loss=2.4165332317352295\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "32-th epoch | train_loss=2.628239870071411 | valid_loss=2.371797800064087 | best_valid_loss=2.371797800064087\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "33-th epoch | train_loss=2.6393637657165527 | valid_loss=2.433464288711548 | best_valid_loss=2.371797800064087\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "34-th epoch | train_loss=2.4444870948791504 | valid_loss=2.2975409030914307 | best_valid_loss=2.2975409030914307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "35-th epoch | train_loss=2.391197919845581 | valid_loss=2.3694217205047607 | best_valid_loss=2.2975409030914307\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "36-th epoch | train_loss=2.3863255977630615 | valid_loss=2.2401931285858154 | best_valid_loss=2.2401931285858154\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "37-th epoch | train_loss=2.3519864082336426 | valid_loss=2.181816339492798 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "38-th epoch | train_loss=2.5498695373535156 | valid_loss=4.59777307510376 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39-th epoch | train_loss=2.870999336242676 | valid_loss=2.3704466819763184 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "40-th epoch | train_loss=2.4043850898742676 | valid_loss=2.247041702270508 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "41-th epoch | train_loss=2.3361377716064453 | valid_loss=2.283573627471924 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "42-th epoch | train_loss=2.4012773036956787 | valid_loss=2.2408063411712646 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "43-th epoch | train_loss=2.2800636291503906 | valid_loss=2.224782705307007 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "44-th epoch | train_loss=2.345763683319092 | valid_loss=2.4968621730804443 | best_valid_loss=2.181816339492798\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45-th epoch | train_loss=2.419935703277588 | valid_loss=2.1799070835113525 | best_valid_loss=2.1799070835113525\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.33it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "46-th epoch | train_loss=2.2402846813201904 | valid_loss=2.1698083877563477 | best_valid_loss=2.1698083877563477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "47-th epoch | train_loss=2.2978355884552 | valid_loss=2.176867723464966 | best_valid_loss=2.1698083877563477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "48-th epoch | train_loss=2.268221616744995 | valid_loss=2.281744956970215 | best_valid_loss=2.1698083877563477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "49-th epoch | train_loss=2.2110421657562256 | valid_loss=2.2807223796844482 | best_valid_loss=2.1698083877563477\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "50-th epoch | train_loss=2.217005729675293 | valid_loss=2.1161532402038574 | best_valid_loss=2.1161532402038574\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "51-th epoch | train_loss=2.158503532409668 | valid_loss=2.0922021865844727 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "52-th epoch | train_loss=2.228959321975708 | valid_loss=2.2327635288238525 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "53-th epoch | train_loss=2.43325138092041 | valid_loss=2.489725351333618 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "54-th epoch | train_loss=2.641669750213623 | valid_loss=2.3241662979125977 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "55-th epoch | train_loss=2.493788957595825 | valid_loss=2.310985565185547 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "56-th epoch | train_loss=2.283374071121216 | valid_loss=2.199505090713501 | best_valid_loss=2.0922021865844727\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "57-th epoch | train_loss=2.209789276123047 | valid_loss=2.087540864944458 | best_valid_loss=2.087540864944458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "58-th epoch | train_loss=2.1594839096069336 | valid_loss=2.111748218536377 | best_valid_loss=2.087540864944458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "59-th epoch | train_loss=2.1372032165527344 | valid_loss=2.1208035945892334 | best_valid_loss=2.087540864944458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60-th epoch | train_loss=2.296907901763916 | valid_loss=2.3246853351593018 | best_valid_loss=2.087540864944458\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "61-th epoch | train_loss=2.2030091285705566 | valid_loss=2.0484683513641357 | best_valid_loss=2.0484683513641357\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.30it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "62-th epoch | train_loss=2.1240217685699463 | valid_loss=2.13633131980896 | best_valid_loss=2.0484683513641357\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "63-th epoch | train_loss=2.190052032470703 | valid_loss=2.120518922805786 | best_valid_loss=2.0484683513641357\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "64-th epoch | train_loss=2.1157541275024414 | valid_loss=2.045600652694702 | best_valid_loss=2.045600652694702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "65-th epoch | train_loss=2.1933624744415283 | valid_loss=2.2638890743255615 | best_valid_loss=2.045600652694702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "66-th epoch | train_loss=2.386270761489868 | valid_loss=2.093568801879883 | best_valid_loss=2.045600652694702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "67-th epoch | train_loss=2.1606335639953613 | valid_loss=2.0510294437408447 | best_valid_loss=2.045600652694702\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "68-th epoch | train_loss=2.1409554481506348 | valid_loss=2.040952444076538 | best_valid_loss=2.040952444076538\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.29it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "69-th epoch | train_loss=2.1200215816497803 | valid_loss=2.0125372409820557 | best_valid_loss=2.0125372409820557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "70-th epoch | train_loss=2.1016573905944824 | valid_loss=2.0322184562683105 | best_valid_loss=2.0125372409820557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "71-th epoch | train_loss=2.104844808578491 | valid_loss=2.019066572189331 | best_valid_loss=2.0125372409820557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "72-th epoch | train_loss=2.040015697479248 | valid_loss=2.160895347595215 | best_valid_loss=2.0125372409820557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "73-th epoch | train_loss=2.0823593139648438 | valid_loss=2.1638386249542236 | best_valid_loss=2.0125372409820557\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.28it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "74-th epoch | train_loss=2.064944267272949 | valid_loss=1.9887280464172363 | best_valid_loss=1.9887280464172363\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75-th epoch | train_loss=2.04988431930542 | valid_loss=1.989915370941162 | best_valid_loss=1.9887280464172363\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "76-th epoch | train_loss=2.0450072288513184 | valid_loss=2.145530939102173 | best_valid_loss=1.9887280464172363\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "77-th epoch | train_loss=2.1173110008239746 | valid_loss=1.983356237411499 | best_valid_loss=1.983356237411499\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "78-th epoch | train_loss=2.0121712684631348 | valid_loss=1.975450873374939 | best_valid_loss=1.975450873374939\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "79-th epoch | train_loss=2.006862163543701 | valid_loss=1.9482325315475464 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "80-th epoch | train_loss=2.0811591148376465 | valid_loss=2.05861234664917 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.27it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "81-th epoch | train_loss=2.3689308166503906 | valid_loss=2.0628647804260254 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "82-th epoch | train_loss=2.0683982372283936 | valid_loss=2.0020406246185303 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "83-th epoch | train_loss=2.061028242111206 | valid_loss=2.092353582382202 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "84-th epoch | train_loss=2.0418221950531006 | valid_loss=1.9639782905578613 | best_valid_loss=1.9482325315475464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "85-th epoch | train_loss=1.9915306568145752 | valid_loss=1.9434294700622559 | best_valid_loss=1.9434294700622559\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "86-th epoch | train_loss=2.095334053039551 | valid_loss=2.1435775756835938 | best_valid_loss=1.9434294700622559\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "87-th epoch | train_loss=2.136335611343384 | valid_loss=2.0337533950805664 | best_valid_loss=1.9434294700622559\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "88-th epoch | train_loss=1.9971386194229126 | valid_loss=1.928186297416687 | best_valid_loss=1.928186297416687\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "89-th epoch | train_loss=1.9796664714813232 | valid_loss=1.9309693574905396 | best_valid_loss=1.928186297416687\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "90-th epoch | train_loss=1.9742441177368164 | valid_loss=1.9409186840057373 | best_valid_loss=1.928186297416687\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "91-th epoch | train_loss=1.9808274507522583 | valid_loss=1.9836572408676147 | best_valid_loss=1.928186297416687\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "92-th epoch | train_loss=1.9586164951324463 | valid_loss=1.9097955226898193 | best_valid_loss=1.9097955226898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "93-th epoch | train_loss=1.9941598176956177 | valid_loss=1.9140194654464722 | best_valid_loss=1.9097955226898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "94-th epoch | train_loss=1.9704875946044922 | valid_loss=1.9427833557128906 | best_valid_loss=1.9097955226898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "95-th epoch | train_loss=1.995579719543457 | valid_loss=1.9346563816070557 | best_valid_loss=1.9097955226898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "96-th epoch | train_loss=2.0146970748901367 | valid_loss=1.922863245010376 | best_valid_loss=1.9097955226898193\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "97-th epoch | train_loss=1.9434672594070435 | valid_loss=1.9080612659454346 | best_valid_loss=1.9080612659454346\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "98-th epoch | train_loss=1.9433554410934448 | valid_loss=1.9692524671554565 | best_valid_loss=1.9080612659454346\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "99-th epoch | train_loss=1.951716423034668 | valid_loss=1.9014830589294434 | best_valid_loss=1.9014830589294434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100-th epoch | train_loss=1.9397987127304077 | valid_loss=1.9225741624832153 | best_valid_loss=1.9014830589294434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "101-th epoch | train_loss=1.9143660068511963 | valid_loss=1.8836767673492432 | best_valid_loss=1.8836767673492432\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "102-th epoch | train_loss=1.9767038822174072 | valid_loss=1.930023431777954 | best_valid_loss=1.8836767673492432\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "103-th epoch | train_loss=1.9304578304290771 | valid_loss=1.8816907405853271 | best_valid_loss=1.8816907405853271\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "104-th epoch | train_loss=1.9048681259155273 | valid_loss=1.86787748336792 | best_valid_loss=1.86787748336792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "105-th epoch | train_loss=1.9048810005187988 | valid_loss=1.8710594177246094 | best_valid_loss=1.86787748336792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "106-th epoch | train_loss=1.9538757801055908 | valid_loss=1.93082594871521 | best_valid_loss=1.86787748336792\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "107-th epoch | train_loss=1.9254205226898193 | valid_loss=1.8616522550582886 | best_valid_loss=1.8616522550582886\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "108-th epoch | train_loss=1.8994799852371216 | valid_loss=1.861077070236206 | best_valid_loss=1.861077070236206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "109-th epoch | train_loss=1.8914321660995483 | valid_loss=1.8783040046691895 | best_valid_loss=1.861077070236206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "110-th epoch | train_loss=1.8887871503829956 | valid_loss=1.8706456422805786 | best_valid_loss=1.861077070236206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "111-th epoch | train_loss=1.949856162071228 | valid_loss=1.9353631734848022 | best_valid_loss=1.861077070236206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "112-th epoch | train_loss=1.9213489294052124 | valid_loss=1.9614577293395996 | best_valid_loss=1.861077070236206\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "113-th epoch | train_loss=1.898008942604065 | valid_loss=1.8547742366790771 | best_valid_loss=1.8547742366790771\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "114-th epoch | train_loss=1.8709912300109863 | valid_loss=1.8433927297592163 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "115-th epoch | train_loss=1.879082202911377 | valid_loss=1.8582333326339722 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "116-th epoch | train_loss=1.881123661994934 | valid_loss=1.8629189729690552 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "117-th epoch | train_loss=1.8824235200881958 | valid_loss=1.849517583847046 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "118-th epoch | train_loss=1.9049509763717651 | valid_loss=2.0335545539855957 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "119-th epoch | train_loss=1.9582881927490234 | valid_loss=1.8606247901916504 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "120-th epoch | train_loss=1.8770939111709595 | valid_loss=1.8643860816955566 | best_valid_loss=1.8433927297592163\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "121-th epoch | train_loss=1.8560751676559448 | valid_loss=1.8387689590454102 | best_valid_loss=1.8387689590454102\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "122-th epoch | train_loss=1.8514890670776367 | valid_loss=2.127256155014038 | best_valid_loss=1.8387689590454102\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "123-th epoch | train_loss=1.8789328336715698 | valid_loss=1.822907567024231 | best_valid_loss=1.822907567024231\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "124-th epoch | train_loss=1.8457096815109253 | valid_loss=1.8388139009475708 | best_valid_loss=1.822907567024231\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "125-th epoch | train_loss=1.8410207033157349 | valid_loss=1.8240540027618408 | best_valid_loss=1.822907567024231\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "126-th epoch | train_loss=1.8370598554611206 | valid_loss=1.8219224214553833 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "127-th epoch | train_loss=1.8408398628234863 | valid_loss=1.8654423952102661 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "128-th epoch | train_loss=1.8402010202407837 | valid_loss=1.8267693519592285 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "129-th epoch | train_loss=1.8325128555297852 | valid_loss=1.851879358291626 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "130-th epoch | train_loss=1.828574538230896 | valid_loss=1.8263825178146362 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "131-th epoch | train_loss=1.8204188346862793 | valid_loss=1.852317214012146 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "132-th epoch | train_loss=1.824099063873291 | valid_loss=1.8552498817443848 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "133-th epoch | train_loss=1.8364150524139404 | valid_loss=1.8505200147628784 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "134-th epoch | train_loss=1.9409568309783936 | valid_loss=1.8588286638259888 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "135-th epoch | train_loss=1.8573386669158936 | valid_loss=1.857934594154358 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.26it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "136-th epoch | train_loss=1.8411750793457031 | valid_loss=1.8360353708267212 | best_valid_loss=1.8219224214553833\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.25it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "137-th epoch | train_loss=1.823671817779541 | valid_loss=1.815527319908142 | best_valid_loss=1.815527319908142\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.21it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "138-th epoch | train_loss=1.8176014423370361 | valid_loss=1.8068883419036865 | best_valid_loss=1.8068883419036865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.24it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "139-th epoch | train_loss=1.8109700679779053 | valid_loss=1.798858404159546 | best_valid_loss=1.798858404159546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.22it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "140-th epoch | train_loss=1.809069037437439 | valid_loss=1.812983512878418 | best_valid_loss=1.798858404159546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "141-th epoch | train_loss=1.8425747156143188 | valid_loss=1.9425275325775146 | best_valid_loss=1.798858404159546\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:51<00:00,  7.23it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "142-th epoch | train_loss=1.8207316398620605 | valid_loss=1.7986478805541992 | best_valid_loss=1.7986478805541992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "143-th epoch | train_loss=1.8084644079208374 | valid_loss=1.8736532926559448 | best_valid_loss=1.7986478805541992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.12it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "144-th epoch | train_loss=1.8246017694473267 | valid_loss=1.9492993354797363 | best_valid_loss=1.7986478805541992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "145-th epoch | train_loss=1.8142247200012207 | valid_loss=1.816421389579773 | best_valid_loss=1.7986478805541992\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "146-th epoch | train_loss=1.7968599796295166 | valid_loss=1.7984442710876465 | best_valid_loss=1.7984442710876465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "147-th epoch | train_loss=1.803816556930542 | valid_loss=1.8003578186035156 | best_valid_loss=1.7984442710876465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "148-th epoch | train_loss=1.8129031658172607 | valid_loss=1.8361223936080933 | best_valid_loss=1.7984442710876465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.18it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "149-th epoch | train_loss=1.8112093210220337 | valid_loss=1.8135336637496948 | best_valid_loss=1.7984442710876465\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████| 375/375 [00:52<00:00,  7.16it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "150-th epoch | train_loss=1.809723138809204 | valid_loss=1.804538607597351 | best_valid_loss=1.7984442710876465\n",
      "Evaluate best model on test data, Test MAE: 2.0633, Test MAPE: 0.0507, Test RMSE: 3.6912\n"
     ]
    }
   ],
   "source": [
    "min_valid_loss = 1e9\n",
    "MODEL_PATH='./model/best_valid_wwwk=3.pt'\n",
    "for epoch in range(150):\n",
    "    train_loss = train_one_epoch(train_dataloader,optimizer,scheduler)\n",
    "    valid_loss = valid(val_dataloader)\n",
    "    if min_valid_loss>valid_loss:\n",
    "        torch.save(model,MODEL_PATH)\n",
    "        min_valid_loss = valid_loss\n",
    "    print('{}-th epoch | train_loss={} | valid_loss={} | best_valid_loss={}'\n",
    "          .format(epoch+1,train_loss,valid_loss,min_valid_loss))\n",
    "model = torch.load(MODEL_PATH)\n",
    "test(test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Evaluate best model on test data, Test MAE: 2.0633, Test MAPE: 0.0507, Test RMSE: 3.6912\n"
     ]
    }
   ],
   "source": [
    "model = torch.load(MODEL_PATH)\n",
    "test(test_dataloader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "# GReTo\n",
    "# with 错的M\n",
    "# 100-th epoch | train_loss=2.380671977996826 | valid_loss=2.3669283390045166 | best_valid_loss=2.335008144378662\n",
    "# Evaluate best model on test data, Test MAE: 2.6420, Test MAPE: 0.0630, Test RMSE: 5.2794\n",
    "\n",
    "# with 对的M\n",
    "# 100-th epoch | train_loss=2.1417477130889893 | valid_loss=2.127549648284912 | best_valid_loss=2.118959665298462\n",
    "# Evaluate best model on test data, Test MAE: 2.4024, Test MAPE: 0.0594, Test RMSE: 4.2679\n",
    "\n",
    "# 100-th epoch | train_loss=2.0668952465057373 | valid_loss=1.9732215404510498 | best_valid_loss=1.9636491537094116\n",
    "# Evaluate best model on test data, Test MAE: 2.2462, Test MAPE: 0.0548, Test RMSE: 3.9049\n",
    "\n",
    "# 3 10 0.9\n",
    "# 100-th epoch | train_loss=1.9602874517440796 | valid_loss=1.9380944967269897 | best_valid_loss=1.9380944967269897\n",
    "# Evaluate best model on test data, Test MAE: 2.2124, Test MAPE: 0.0543, Test RMSE: 3.8694\n",
    "\n",
    "# 3 10 0.9 150\n",
    "# 150-th epoch | train_loss=1.809723138809204 | valid_loss=1.804538607597351 | best_valid_loss=1.7984442710876465\n",
    "# Evaluate best model on test data, Test MAE: 2.0633, Test MAPE: 0.0507, Test RMSE: 3.6912"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Batch(x=[13248, 15], edge_index=[2, 110208], edge_attr=[110208, 2], y=[13248], batch=[13248], ptr=[65])\n",
      "torch.Size([110208])\n"
     ]
    }
   ],
   "source": [
    "batch = next(iter(train_dataloader))\n",
    "print(batch)\n",
    "print(batch.edge_attr[:,0].shape)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
