{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Baseline to handle the regression models within the spatio-temporal settings\n",
    "\"\"\"\n",
    "\n",
    "import os\n",
    "os.environ[\"CURL_CA_BUNDLE\"]=\"\" \n",
    "# os.environ[\"REQUESTS_CA_BUNDLE\"]=\"\"\n",
    "import sys\n",
    "sys.path.extend([\"../\", \"./\"])\n",
    "import random\n",
    "import glob\n",
    "import warnings\n",
    "import pickle\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import dgl\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "from dgl.dataloading import DataLoader, NeighborSampler, MultiLayerFullNeighborSampler\n",
    "\n",
    "\n",
    "from data_utils import split_dataset, load_LargeST\n",
    "from utils import LargeSTLossWrapper\n",
    "from model import MLP, SAGE\n",
    "from data_utils import PlainLoader\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "os.environ[\"CURL_CA_BUNDLE\"] = \"\"\n",
    "DEVICE = 'cuda:6'\n",
    "\n",
    "%load_ext line_profiler\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def disp_metrics(g, preds, labels):\n",
    "    loss_evaluator = LargeSTLossWrapper(g.scale_stats)\n",
    "    all_res = loss_evaluator(torch.concat(preds), torch.concat(labels))\n",
    "    mae, mape, rmse = all_res[0]\n",
    "    print(f\"all time metrics-- mae: {mae:.2f} | mape: {mape:.2f} | rmse: {rmse:.2f}\")\n",
    "    h_mae, h_mape, h_rmse = loss_evaluator.all_horizon(preds, labels)\n",
    "    for i in range(10):\n",
    "        print(f\"Tick {i} metrics-- mae: {h_mae[i]:.2f} | mape: {h_mape[i]:.2f} | rmse: {h_rmse[i]:.2f}\")\n",
    "\n",
    "\n",
    "data_root_path = \"../dataset/LargeST\"\n",
    "data_save_path_prefix = os.path.join(data_root_path, f\"dgl_day_{100}_seq_{12}\")\n",
    "\n",
    "# Load training data\n",
    "g = load_LargeST(data_save_path_prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GraphSAGE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Utility functions\n",
    "def layerwise_infer(device, graph, nid, model, batch_size, is_sample=False, sample_size=2):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        pred = model.inference(graph, device, batch_size, is_sample=is_sample, sample_size=sample_size) # pred in buffer_device\n",
    "        pred = pred[nid]\n",
    "        label = graph.ndata['label'][nid].to(pred.device)\n",
    "        return torch_f1(pred, label)\n",
    "\n",
    "def train(device, g, model, train_conf, is_sample=False):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = g.train_idx.to(device)\n",
    "    val_idx = g.val_idx.to(device)\n",
    "    test_idx = g.test_idx.to(device)\n",
    "\n",
    "    if is_sample:\n",
    "        sampler = NeighborSampler([train_conf[\"train_neighbor_size\"]] * model.GNN_layer,  # fanout for [layer-0, layer-1, layer-2]\n",
    "                                  prefetch_node_feats=['feat'],\n",
    "                                  prefetch_labels=['label'])\n",
    "    else:\n",
    "        sampler = MultiLayerFullNeighborSampler(model.GNN_layer,\n",
    "                                                prefetch_node_feats=['feat'],\n",
    "                                                prefetch_labels=['label'])\n",
    "    use_uva = False\n",
    "    train_dataloader = DataLoader(g, train_idx, sampler, device=device,\n",
    "                                  batch_size=train_conf[\"batch_size\"], shuffle=True,\n",
    "                                  drop_last=False, num_workers=0,\n",
    "                                  use_uva=use_uva)\n",
    "\n",
    "    opt = torch.optim.Adam(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf[\"weight_decay\"])\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=train_conf['step_size'], gamma=train_conf['gamma'])\n",
    "    \n",
    "    best_state, best_val, best_epoch = None, [100000]*3, 0\n",
    "    loss_evaluator = LargeSTLossWrapper(g.scale_stats)\n",
    "    \n",
    "    # wrapper = tqdm(range(train_conf[\"epoch\"]))\n",
    "    wrapper = range(train_conf[\"epoch\"])\n",
    "    for epoch in wrapper:\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):\n",
    "            x = blocks[0].srcdata['feat']\n",
    "            y = blocks[-1].dstdata['label']\n",
    "            y_hat = model(blocks, x)\n",
    "\n",
    "            _, ori_loss, rescale_loss = loss_evaluator(y_hat, y)\n",
    "            loss = rescale_loss[0]\n",
    "            # loss = ori_loss[0]\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        cur_lr = scheduler.get_last_lr()[0]\n",
    "        scheduler.step()\n",
    "        val_preds, val_labels = collect_pred_labels(g, val_idx, model, device)\n",
    "        val_metrics, _, _ = loss_evaluator(val_preds, val_labels)\n",
    "\n",
    "        # test_preds, test_labels = collect_pred_labels(g, test_idx, model, device)\n",
    "        # test_metrics, _, _ = loss_evaluator(test_preds, test_labels)\n",
    "\n",
    "        if val_metrics[0] < best_val[0]:\n",
    "            best_val = val_metrics\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            # print(f\"Epoch {epoch:03d} val mae {val_metrics[0]:.4f} | test mae {test_metrics[0]:.4f} | cur lr: {cur_lr:e}\")\n",
    "            print(f\"Epoch {epoch:03d} val mae {val_metrics[0]:.4f} | cur lr: {cur_lr:e}\")\n",
    "            \n",
    "    print(\"Epoch {:03d} hist best mae {:.4f} \"\n",
    "           .format(best_epoch, best_val[0]))\n",
    "    print(best_val)\n",
    "    \n",
    "    return best_state\n",
    "\n",
    "def collect_pred_labels(g, nid, model, device,\n",
    "                        batch_size = 512,\n",
    "                        feature_key='feat', label_key='label'):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        pred = model.inference(g, device, batch_size, is_sample=False) # pred in buffer_device\n",
    "        pred = pred[nid]\n",
    "        label = g.ndata[label_key][nid].to(pred.device)\n",
    "\n",
    "    return pred, label\n",
    "\n",
    "def SAGE_get_all_pred(model, g, device):\n",
    "    preds, labels = [], []\n",
    "    test_idx = g.test_idx.to(device)\n",
    "    for day in range(g.num_of_days):\n",
    "        test_preds, test_labels = collect_pred_labels(g, test_idx, model, device,\n",
    "                                                      feature_key=f'feat_{day}',\n",
    "                                                      label_key=f'label_{day}')\n",
    "\n",
    "        preds.append(test_preds)\n",
    "        labels.append(test_labels)\n",
    "    return preds, labels\n",
    "\n",
    "\n",
    "###########################\n",
    "\n",
    "train_conf = {\n",
    "    \"batch_size\": 10000,\n",
    "    \"epoch\": 500,\n",
    "    \"lr\": 0.1,\n",
    "    \"train_neighbor_size\": 10,\n",
    "    \"hidden_layer\": 1,\n",
    "    \"hidden_size\": 64,\n",
    "    \"weight_decay\": 0.0000,\n",
    "    \"dropout\": 0.1,\n",
    "    'step_size': 100,\n",
    "    'gamma': 1,\n",
    "    'batch_norm': False,\n",
    "    'bias': False\n",
    "}\n",
    "name = \"LargeST\"\n",
    "\n",
    "print(\"*\"*21)\n",
    "print(name)\n",
    "\n",
    "model_folder = os.path.join(\"../result\", name , \"SAGE\")\n",
    "if not os.path.exists(model_folder):\n",
    "    os.makedirs(model_folder)\n",
    "\n",
    "in_size = g.ndata['feat'].shape[1]\n",
    "out_size = 1\n",
    "hidden_size = train_conf[\"hidden_size\"]\n",
    "\n",
    "model = SAGE(in_size, hidden_size, out_size,\n",
    "             GNN_layer=train_conf[\"hidden_layer\"],\n",
    "             dropout=train_conf[\"dropout\"],\n",
    "             is_batch_norm=train_conf['batch_norm'],\n",
    "             bias=train_conf['bias']).to(DEVICE)\n",
    "\n",
    "g = g.to(DEVICE)\n",
    "best_model_state = train(DEVICE, g, model, train_conf)\n",
    "\n",
    "with open(os.path.join(model_folder, \"state_dict\"), \"wb\") as f:\n",
    "    f.write(best_model_state)\n",
    "\n",
    "model.load_state_dict(pickle.loads(best_model_state))\n",
    "\n",
    "# Test the results for all horizons\n",
    "\n",
    "print(\"*\"*21)\n",
    "print(\"All horizon list\")\n",
    "\n",
    "preds, labels = SAGE_get_all_pred(model, g, DEVICE)\n",
    "disp_metrics(g, preds, labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g.adj()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(device, g, model, train_conf):\n",
    "    # create sampler & dataloader\n",
    "    train_idx = g.train_idx.to(device)\n",
    "    val_idx = g.val_idx.to(device)\n",
    "    test_idx = g.test_idx.to(device)\n",
    "    features = g.ndata['feat'].to(device)\n",
    "    labels = g.ndata['label'].to(device)\n",
    "    \n",
    "    train_dataloader = PlainLoader(features, labels, train_conf[\"batch_size\"], train_idx)\n",
    "    val_dataloader = PlainLoader(features, labels, train_conf[\"batch_size\"], val_idx)\n",
    "    test_dataloader = PlainLoader(features, labels, train_conf[\"batch_size\"], test_idx)\n",
    "\n",
    "\n",
    "    opt = torch.optim.SGD(model.parameters(), lr=train_conf[\"lr\"], weight_decay=train_conf['weight_decay'])\n",
    "    scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=train_conf['step_size'], gamma=train_conf['gamma'])\n",
    "\n",
    "    best_state, best_val, best_epoch = None, [100000]*3, 0\n",
    "    loss_evaluator = LargeSTLossWrapper(g.scale_stats)\n",
    "\n",
    "    # wrapper = tqdm(range(train_conf[\"epoch\"]))\n",
    "    wrapper = range(train_conf[\"epoch\"])\n",
    "    for epoch in wrapper:\n",
    "        model.train()\n",
    "        total_loss = 0\n",
    "        for it, (x, y) in enumerate(train_dataloader):\n",
    "            y_hat = model(x)\n",
    "            _, ori_loss, rescale_loss = loss_evaluator(y_hat, y)\n",
    "            loss = rescale_loss[0]\n",
    "            # loss = ori_loss[0]\n",
    "            opt.zero_grad()\n",
    "            loss.backward()\n",
    "            opt.step()\n",
    "            total_loss += loss.item()\n",
    "\n",
    "        cur_lr = scheduler.get_last_lr()[0]\n",
    "        scheduler.step()\n",
    "        val_preds, val_labels = collect_pred_labels(model, val_dataloader)\n",
    "        val_metrics, _, _ = loss_evaluator(val_preds, val_labels)\n",
    "\n",
    "        test_preds, test_labels = collect_pred_labels(model, test_dataloader)\n",
    "        test_metrics, _, _ = loss_evaluator(test_preds, test_labels)\n",
    "\n",
    "        if val_metrics[0] < best_val[0]:\n",
    "            best_val = val_metrics\n",
    "            best_state = pickle.dumps(model.state_dict())\n",
    "            best_epoch = epoch\n",
    "\n",
    "        if epoch % 10 == 0:\n",
    "            print(f\"Epoch {epoch:03d} val mae {val_metrics[0]:.4f} | test mae {test_metrics[0]:.4f} | cur lr: {cur_lr:e}\")\n",
    "            # print(f\"weight: {model.state_dict()['layers.0.weight'].item()} Grad: {model.layers[0].weight.grad.item()}\")\n",
    "\n",
    "    print(\"Epoch {:03d} hist best mae {:.4f} \"\n",
    "           .format(best_epoch, best_val[0]))\n",
    "    print(best_val)\n",
    "    \n",
    "    return best_state\n",
    "\n",
    "def collect_pred_labels(model, dataloader):\n",
    "    model.eval()\n",
    "    ys = []\n",
    "    y_hats = []\n",
    "    for it, (x, y) in enumerate(dataloader):\n",
    "        with torch.no_grad():\n",
    "            ys.append(y)\n",
    "            y_hats.append(model(x).reshape(-1))\n",
    "    return torch.concat(y_hats), torch.concat(ys)\n",
    "\n",
    "def MLP_get_all_pred(model, g, device):\n",
    "    preds, labels = [], []\n",
    "    test_idx = g.test_idx.to(device)\n",
    "    for day in range(g.num_of_days):\n",
    "        cur_features = g.ndata[f'feat_{day}'].to(device)\n",
    "        cur_labels = g.ndata[f'label_{day}'].to(device)    \n",
    "        test_dataloader = PlainLoader(cur_features, cur_labels, 512, test_idx)\n",
    "        test_preds, test_labels = collect_pred_labels(model, test_dataloader)\n",
    "        preds.append(test_preds)\n",
    "        labels.append(test_labels)\n",
    "    return preds, labels\n",
    "\n",
    "\n",
    "###########################\n",
    "\n",
    "train_conf = {\n",
    "    'batch_size': 10000,\n",
    "    'epoch': 500,\n",
    "    'lr': 0.1,\n",
    "    'hidden_layer': 1,\n",
    "    'IS_ONE_LAYER': False,\n",
    "    'hidden_size': 64,\n",
    "    'weight_decay': 0.0000,\n",
    "    'dropout': 0.1,\n",
    "    'step_size': 100,\n",
    "    'gamma': 1,\n",
    "    'batch_norm': False,\n",
    "    'bias': False\n",
    "}\n",
    "name = \"LargeST\"\n",
    "\n",
    "print(\"*\"*21)\n",
    "print(name)\n",
    "\n",
    "model_folder = os.path.join(\"../result\", name , \"MLP\")\n",
    "if not os.path.exists(model_folder):\n",
    "    os.makedirs(model_folder)\n",
    "\n",
    "in_size =g.ndata['feat'].shape[1]\n",
    "out_size = 1\n",
    "\n",
    "hidden_size = None if train_conf[\"IS_ONE_LAYER\"] else [train_conf[\"hidden_size\"] ] * train_conf['hidden_layer']\n",
    "\n",
    "model = MLP(in_size, hidden_size, out_size, \n",
    "            dropout=train_conf[\"dropout\"], dtype=g.ndata['feat'].dtype,\n",
    "            is_batch_norm=train_conf['batch_norm'],\n",
    "            bias=train_conf['bias']).to(DEVICE)\n",
    "\n",
    "best_model_state = train(DEVICE, g, model, train_conf)\n",
    "\n",
    "with open(os.path.join(model_folder, \"state_dict\"), \"wb\") as f:\n",
    "    f.write(best_model_state)\n",
    "\n",
    "model.load_state_dict(pickle.loads(best_model_state))\n",
    "\n",
    "\n",
    "# Test the results for all horizons\n",
    "\n",
    "print(\"*\"*21)\n",
    "print(\"All horizon list\")\n",
    "\n",
    "preds, labels = MLP_get_all_pred(model, g, DEVICE)\n",
    "disp_metrics(g, preds, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# heuristic last evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hl(g):\n",
    "    preds, labels = [], []\n",
    "    for day in range(g.num_of_days):\n",
    "        cur_label = g.ndata[f\"label_{day}\"][g.test_idx]\n",
    "        labels.append(cur_label)\n",
    "        cur_preds = g.ndata[f\"feat_{day}\"][g.test_idx][:, -1]\n",
    "        preds.append(cur_preds)\n",
    "    return preds, labels\n",
    "\n",
    "preds, labels = hl(g)\n",
    "disp_metrics(g, preds, labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data generation process"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Following https://github.com/liuxu77/LargeST\n",
    "\"\"\"\n",
    "\n",
    "# deal with the sequences and features\n",
    "class StandardScaler():\n",
    "    def __init__(self, mean, std):\n",
    "        self.mean = mean\n",
    "        self.std = std\n",
    "\n",
    "    def transform(self, data):\n",
    "        return (data - self.mean) / self.std\n",
    "\n",
    "    def inverse_transform(self, data):\n",
    "        return (data * self.std) + self.mean\n",
    "\n",
    "\n",
    "def generate_data_and_idx(df, x_offsets, y_offsets, add_time_of_day, add_day_of_week):\n",
    "    num_samples, num_nodes = df.shape\n",
    "    data = np.expand_dims(df.values, axis=-1)\n",
    "    \n",
    "    feature_list = [data]\n",
    "    if add_time_of_day:\n",
    "        time_ind = (df.index.values - df.index.values.astype('datetime64[D]')) / np.timedelta64(1, 'D')\n",
    "        time_of_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0))\n",
    "        feature_list.append(time_of_day)\n",
    "    if add_day_of_week:\n",
    "        dow = df.index.dayofweek\n",
    "        dow_tiled = np.tile(dow, [1, num_nodes, 1]).transpose((2, 1, 0))\n",
    "        day_of_week = dow_tiled / 7\n",
    "        feature_list.append(day_of_week)\n",
    "\n",
    "    data = np.concatenate(feature_list, axis=-1)\n",
    "    \n",
    "    min_t = abs(min(x_offsets))\n",
    "    max_t = abs(num_samples - abs(max(y_offsets)))  # Exclusive\n",
    "    print('idx min & max:', min_t, max_t)\n",
    "    idx = np.arange(min_t, max_t, 1)\n",
    "    return data, idx\n",
    "\n",
    "\n",
    "def generate_preprocessed(input_df,\n",
    "                          seq_length_x,\n",
    "                          seq_length_y,\n",
    "                          is_add_time_of_day,\n",
    "                          is_add_day_of_week):\n",
    "    df = input_df\n",
    "\n",
    "    x_offsets = np.arange(-(seq_length_x - 1), 1, 1)\n",
    "    y_offsets = np.arange(1, (seq_length_y + 1), 1)\n",
    "\n",
    "    data, idx = generate_data_and_idx(df, x_offsets, y_offsets, is_add_time_of_day, is_add_day_of_week)\n",
    "    print('final data shape:', data.shape, 'idx shape:', idx.shape)\n",
    "\n",
    "    # normalize\n",
    "    x_train = data[:ticks_per_day, :, 0] # use the data from the first day to train, the remaining for test\n",
    "    x_stats = [x_train.mean(), x_train.std()]\n",
    "    scaler = StandardScaler(mean=x_stats[0], std=x_stats[1])\n",
    "    data[..., 0] = scaler.transform(data[..., 0])\n",
    "\n",
    "    return data, x_stats\n",
    "\n",
    "year = '2019'\n",
    "num_of_days = 100\n",
    "ticks_per_day = 96\n",
    "sequence_length = 2\n",
    "\n",
    "# Data resampling and cleaning\n",
    "path_raw_trace = os.path.join(data_root_path, \"ca_his_raw_\"+year+\".h5\")\n",
    "path_cleaned_trace = os.path.join(data_root_path, \"ca_his_\"+year+\".h5\")\n",
    "\n",
    "ca_his = pd.read_hdf(path_raw_trace)\n",
    "ca_his = ca_his.resample('15T').mean().round(0)\n",
    "ca_his = ca_his.fillna(0)\n",
    "print('check null value number', ca_his.isnull().any().sum())\n",
    "ca_his\n",
    "\n",
    "ca_his.to_hdf(path_cleaned_trace, key='t', mode='w')\n",
    "ca_his.shape\n",
    "\n",
    "path_cleaned_trace_n_day = os.path.join(data_root_path,\n",
    "                                        f\"ca_his_{year}_day{num_of_days}.h5\")\n",
    "ca_his.iloc[:ticks_per_day*num_of_days, :].to_hdf(path_cleaned_trace_n_day, key='t', mode='w')\n",
    "\n",
    "# deal with the sequences and features\n",
    "cur_df = pd.read_hdf(path_cleaned_trace_n_day)\n",
    "normed_vals, scale_stats = generate_preprocessed(cur_df, 1, 1, False, False)\n",
    "normed_vals = np.squeeze(normed_vals, axis=-1)\n",
    "\n",
    "for sequence_length in [1, 2, 5, 12, 95]:\n",
    "    # slice the data into desired feature label pair\n",
    "    num_nodes = normed_vals.shape[1]\n",
    "    all_traces = normed_vals.reshape([num_of_days, ticks_per_day, num_nodes]).transpose(0, 2, 1)\n",
    "    torch_feat = torch.from_numpy(all_traces[:, :, -sequence_length-1:-1].transpose(1, 0, 2)).float()\n",
    "    torch_label = torch.from_numpy(all_traces[:, :, -1].transpose(1, 0)).float()\n",
    "    # construct the dgl graph and save\n",
    "\n",
    "    # deal with the adj\n",
    "    def load_adj_from_numpy(numpy_file):\n",
    "        return np.load(numpy_file)\n",
    "\n",
    "    path_np_adj = os.path.join(data_root_path, \"ca_rn_adj.npy\")\n",
    "    adj = load_adj_from_numpy(path_np_adj)\n",
    "    coo_adj = sp.csr_matrix(adj).tocoo()\n",
    "    num_nodes = adj.shape[0]\n",
    "\n",
    "    graph = dgl.graph((coo_adj.row, coo_adj.col), num_nodes=num_nodes)\n",
    "    graph.scale_stats = scale_stats\n",
    "    graph.edata['weight'] = torch.from_numpy(coo_adj.data)\n",
    "    graph.num_of_days = num_of_days\n",
    "    graph.ndata['all_label'] = torch_label\n",
    "    graph.ndata['all_feat'] = torch_feat\n",
    "    graph.ndata['label'] = torch_label[:, 0]\n",
    "    graph.ndata['feat'] = torch_feat[:, 0, :]\n",
    "\n",
    "    for i in range(num_of_days):\n",
    "        graph.ndata[f'label_{i}'] = torch_label[:, i]\n",
    "        graph.ndata[f'feat_{i}'] = torch_feat[:, i, :]\n",
    "\n",
    "    # create the split index\n",
    "    split_dataset(graph, seed=666)\n",
    "\n",
    "    # save the final results\n",
    "    graph_save_path = os.path.join(data_root_path, f\"dgl_day_{num_of_days}_seq_{sequence_length}.bin\")\n",
    "    dgl.save_graphs(graph_save_path, graph)\n",
    "    graph_stats_save_path = os.path.join(data_root_path, f\"dgl_day_{num_of_days}_seq_{sequence_length}_stats.pkl\")\n",
    "    with open(graph_stats_save_path, 'wb') as fout:\n",
    "        pickle.dump([graph.num_of_days, graph.scale_stats, graph.train_idx, graph.val_idx, graph.test_idx], fout)\n",
    "\n",
    "    # test = dgl.load_graphs(graph_save_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gnn",
   "language": "python",
   "name": "gnn"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
