{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "496a5517",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from tqdm import tqdm\n",
    "from dataclasses import dataclass\n",
    "import os\n",
    "import argparse\n",
    "from typing import Dict, Optional, Tuple, List\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from scipy.stats import norm\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding\n",
    "import wandb\n",
    "import yaml\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from tqdm import tqdm\n",
    "from dataclasses import dataclass\n",
    "import os\n",
    "import argparse\n",
    "from typing import Dict, Optional, Tuple, List\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from scipy.stats import norm\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding\n",
    "import wandb\n",
    "\n",
    "from torchmetrics import Metric\n",
    "import CRPS.CRPS as pscore  # Assuming `pscore` is the function to compute CRPS\n",
    "from concurrent.futures import ProcessPoolExecutor\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "import sys\n",
    "from typing import List, Dict\n",
    "import os\n",
    "import torch\n",
    "from dataclasses import dataclass, asdict, field\n",
    "from torch_timeseries.nn.embedding import freq_map\n",
    "import argparse\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from torch.optim import *\n",
    "from tqdm import tqdm\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "import time\n",
    "# import multiprocessing\n",
    "import torch.multiprocessing as mp\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch.distributed as dist\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import concurrent.futures\n",
    "from types import SimpleNamespace\n",
    "\n",
    "from dataclasses import asdict, dataclass\n",
    "import datetime\n",
    "import hashlib\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "from typing import Dict, List, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from tqdm import tqdm\n",
    "from torch.nn import MSELoss, L1Loss\n",
    "from torch.optim import *\n",
    "from torch_timeseries.dataset import *\n",
    "from torch_timeseries.scaler import *\n",
    "\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "from torch_timeseries.core import TimeSeriesDataset, BaseIrrelevant, BaseRelevant\n",
    "from torch_timeseries.dataloader import SlidingWindowTS, ETTHLoader, ETTMLoader\n",
    "from torch_timeseries.experiments import ForecastExp\n",
    "from torch_timeseries.utils import asdict_exc\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "from torch_timeseries.dataloader import ETTHLoader\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import r2_score\n",
    "from xgboost import XGBRegressor\n",
    "\n",
    "\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.optim import Adam\n",
    "from tqdm import tqdm\n",
    "from dataclasses import dataclass\n",
    "import os\n",
    "import argparse\n",
    "from typing import Dict, Optional, Tuple, List\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from scipy.stats import norm\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding, TokenEmbedding, TemporalEmbedding, TimeFeatureEmbedding\n",
    "import wandb\n",
    "\n",
    "from torchmetrics import Metric\n",
    "import CRPS.CRPS as pscore  # Assuming `pscore` is the function to compute CRPS\n",
    "from concurrent.futures import ProcessPoolExecutor\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding\n",
    "\n",
    "from dataclasses import dataclass, field\n",
    "import sys\n",
    "from typing import List, Dict\n",
    "import os\n",
    "import torch\n",
    "from dataclasses import dataclass, asdict, field\n",
    "from torch_timeseries.nn.embedding import freq_map\n",
    "import argparse\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from torch.optim import *\n",
    "from tqdm import tqdm\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "import time\n",
    "# import multiprocessing\n",
    "import torch.multiprocessing as mp\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "import yaml\n",
    "import numpy as np\n",
    "import torch.distributed as dist\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import concurrent.futures\n",
    "from types import SimpleNamespace\n",
    "\n",
    "from dataclasses import asdict, dataclass\n",
    "import datetime\n",
    "import hashlib\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "from typing import Dict, List, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from tqdm import tqdm\n",
    "from torch.nn import MSELoss, L1Loss\n",
    "from torch.optim import *\n",
    "from torch_timeseries.dataset import *\n",
    "from torch_timeseries.scaler import *\n",
    "\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "from torch_timeseries.core import TimeSeriesDataset, BaseIrrelevant, BaseRelevant\n",
    "from torch_timeseries.dataloader import SlidingWindowTS, ETTHLoader, ETTMLoader\n",
    "from torch_timeseries.experiments import ForecastExp\n",
    "from torch_timeseries.utils import asdict_exc\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "from torch_timeseries.dataloader import ETTHLoader\n",
    "from torchmetrics import MetricCollection\n",
    "import torch\n",
    "import numpy as np\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import r2_score\n",
    "from xgboost import XGBRegressor\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ace65b4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:3\"\n",
    "seed_idx = 0 # please manually try 0-9. \n",
    "def set_seed(seed):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if device:\n",
    "        torch.cuda.manual_seed_all(seed)\n",
    "        torch.backends.cudnn.deterministic = True\n",
    "        torch.backends.cudnn.benchmark = False\n",
    "\n",
    "set_seed(114 + seed_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd2b53a8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>date</th>\n",
       "      <th>HUFL</th>\n",
       "      <th>HULL</th>\n",
       "      <th>MUFL</th>\n",
       "      <th>MULL</th>\n",
       "      <th>LUFL</th>\n",
       "      <th>LULL</th>\n",
       "      <th>OT</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2016-07-01 00:00:00</td>\n",
       "      <td>5.827</td>\n",
       "      <td>2.009</td>\n",
       "      <td>1.599</td>\n",
       "      <td>0.462</td>\n",
       "      <td>4.203</td>\n",
       "      <td>1.340</td>\n",
       "      <td>30.531000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2016-07-01 01:00:00</td>\n",
       "      <td>5.693</td>\n",
       "      <td>2.076</td>\n",
       "      <td>1.492</td>\n",
       "      <td>0.426</td>\n",
       "      <td>4.142</td>\n",
       "      <td>1.371</td>\n",
       "      <td>27.787001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2016-07-01 02:00:00</td>\n",
       "      <td>5.157</td>\n",
       "      <td>1.741</td>\n",
       "      <td>1.279</td>\n",
       "      <td>0.355</td>\n",
       "      <td>3.777</td>\n",
       "      <td>1.218</td>\n",
       "      <td>27.787001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>2016-07-01 03:00:00</td>\n",
       "      <td>5.090</td>\n",
       "      <td>1.942</td>\n",
       "      <td>1.279</td>\n",
       "      <td>0.391</td>\n",
       "      <td>3.807</td>\n",
       "      <td>1.279</td>\n",
       "      <td>25.044001</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>2016-07-01 04:00:00</td>\n",
       "      <td>5.358</td>\n",
       "      <td>1.942</td>\n",
       "      <td>1.492</td>\n",
       "      <td>0.462</td>\n",
       "      <td>3.868</td>\n",
       "      <td>1.279</td>\n",
       "      <td>21.948000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                  date   HUFL   HULL   MUFL   MULL   LUFL   LULL         OT\n",
       "0  2016-07-01 00:00:00  5.827  2.009  1.599  0.462  4.203  1.340  30.531000\n",
       "1  2016-07-01 01:00:00  5.693  2.076  1.492  0.426  4.142  1.371  27.787001\n",
       "2  2016-07-01 02:00:00  5.157  1.741  1.279  0.355  3.777  1.218  27.787001\n",
       "3  2016-07-01 03:00:00  5.090  1.942  1.279  0.391  3.807  1.279  25.044001\n",
       "4  2016-07-01 04:00:00  5.358  1.942  1.492  0.462  3.868  1.279  21.948000"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "\n",
    "data_path = \"ts_datasets/ETTh1/ETTh1.csv\"\n",
    "assert os.path.exists(data_path), f\"didnt found: {data_path}\"\n",
    "\n",
    "df = pd.read_csv(data_path)\n",
    "display(df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c33b47f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {\n",
    "    \"teacher_force\": False,\n",
    "    \"mixup\": False,\n",
    "\n",
    "    \"matrix_norm_weight\": [(7*192)**0.5 * 1e-1, 1, 0.],\n",
    "    \"fft_weight\": [1, 0.],\n",
    "    \"eign_penalty\": 50,\n",
    "    \"eps_eign_min\": 1e-1,\n",
    "    \"penalty_method\": \"hard\",\n",
    "    'num_training_steps': 20,\n",
    "\n",
    "    \"d_model\": 512,\n",
    "    \"n_heads\": 8,\n",
    "    \"e_layers\": 2,\n",
    "    \"d_layers\": 1,\n",
    "    \"d_ff\": 1024,\n",
    "    \"factor\": 3,\n",
    "    \"dropout\": 0.1,\n",
    "\n",
    "    \"p_hidden_layers\": 2,\n",
    "    \"p_hidden_dims\": [128, 128],\n",
    "\n",
    "    \"windows\": 168,\n",
    "    \"horizon\": 1,\n",
    "    \"pred_len\": 192,\n",
    "    \"label_len\": 168 // 2,\n",
    "    'num_features': 7,\n",
    "\n",
    "    \"batch_size\": 64,\n",
    "    \"num_worker\": 0,\n",
    "    'dataset_type': \"ETTh1\",\n",
    "    'data_path': \"ts_datasets\",\n",
    "    'scaler_type': \"StandardScaler\",\n",
    "\n",
    "    'lr': 1e-4,\n",
    "    'weight_decay': 5e-4,\n",
    "\n",
    "    'window_size': 95,\n",
    "    'pad_mode': 'reflect',\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3703c759",
   "metadata": {},
   "outputs": [],
   "source": [
    "# parameters\n",
    "d_model = config['d_model']\n",
    "n_heads = config['n_heads']\n",
    "e_layers = config['e_layers']\n",
    "d_layers = config['d_layers']\n",
    "d_ff = config['d_ff']\n",
    "factor = config['factor']\n",
    "dropout = config['dropout']\n",
    "p_hidden_layers = config['p_hidden_layers']\n",
    "p_hidden_dims=config['p_hidden_dims']\n",
    "num_features = config['num_features']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ad860619",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_type = config['dataset_type']\n",
    "data_path = config['data_path']\n",
    "\n",
    "DatasetClass = parse_type(dataset_type, globals())\n",
    "dataset = DatasetClass(root=data_path)\n",
    "\n",
    "scaler_type = config['scaler_type']\n",
    "ScalerClass = parse_type(scaler_type, globals())\n",
    "scaler = ScalerClass()\n",
    "\n",
    "windows = config['windows']\n",
    "horizon = config['horizon']\n",
    "pred_len = config['pred_len']\n",
    "batch_size = config['batch_size']\n",
    "num_worker = config['num_worker']\n",
    "label_len= windows // 2\n",
    "\n",
    "\n",
    "dataloader = ETTHLoader(\n",
    "    dataset,\n",
    "    scaler,\n",
    "    window=windows,\n",
    "    horizon=horizon,\n",
    "    steps=pred_len,\n",
    "    shuffle_train=True,\n",
    "    freq=dataset.freq,\n",
    "    batch_size=batch_size,\n",
    "    num_worker=num_worker,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1085f05b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# non stationary transformer\n",
    "class Projector(nn.Module):\n",
    "    '''\n",
    "    MLP to learn the De-stationary factors\n",
    "    '''\n",
    "    def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):\n",
    "        super(Projector, self).__init__()\n",
    "\n",
    "        padding = 1 if torch.__version__ >= '1.5.0' else 2\n",
    "        self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding,\n",
    "                                     padding_mode='circular', bias=False)\n",
    "\n",
    "        layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]\n",
    "        for i in range(hidden_layers - 1):\n",
    "            layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1]), nn.ReLU()]\n",
    "\n",
    "        layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]\n",
    "        self.backbone = nn.Sequential(*layers)\n",
    "\n",
    "\n",
    "    def forward(self, x, stats):\n",
    "        batch_size = x.shape[0]\n",
    "        x = self.series_conv(x)  # B x 1 x D\n",
    "        x = torch.cat([x, stats], dim=1)  # B x 2 x D\n",
    "        x = x.view(batch_size, -1)  # B x 2D\n",
    "        y = self.backbone(x)  # B x output_dim\n",
    "\n",
    "        return y\n",
    "\n",
    "\n",
    "class ns_Transformer(nn.Module):\n",
    "    \"\"\"\n",
    "    Non-stationary Transformer\n",
    "    \"\"\"\n",
    "    def __init__(self, \n",
    "                 pred_len=pred_len,\n",
    "                 seq_len=windows,\n",
    "                 label_len=label_len,\n",
    "                 output_attention=False,\n",
    "                 num_features = num_features,\n",
    "                 enc_in=num_features + int(num_features*(num_features+1)/2),\n",
    "                 d_model=d_model,\n",
    "                 embed='timeF',\n",
    "                 freq=dataloader.dataset.freq,\n",
    "                 dropout=dropout,\n",
    "                 dec_in=num_features + int(num_features*(num_features+1)/2),\n",
    "                 factor=factor,\n",
    "                 n_heads=n_heads,\n",
    "                 d_ff=d_ff,\n",
    "                 e_layers=e_layers,\n",
    "                 d_layers=d_layers,\n",
    "                 c_out=num_features + int(num_features*(num_features+1)/2),\n",
    "                 p_hidden_dims=p_hidden_dims,\n",
    "                 p_hidden_layers=p_hidden_layers,\n",
    "                 activation = nn.SiLU(),\n",
    "                 kernel_size = 3,\n",
    "                 ):\n",
    "        super(ns_Transformer, self).__init__()\n",
    "        self.pred_len = pred_len \n",
    "        self.seq_len = seq_len \n",
    "        self.label_len = label_len \n",
    "        self.output_attention = output_attention \n",
    "        self.num_feature = num_features\n",
    "        self.num_feature_triangle = int(num_features*(num_features+1)/2)\n",
    "\n",
    "        self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq,\n",
    "                                           dropout) \n",
    "        self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq,\n",
    "                                           dropout) \n",
    "\n",
    "        self.encoder = Encoder(\n",
    "            [\n",
    "                EncoderLayer(\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(False, factor, attention_dropout=dropout,\n",
    "                                    output_attention=output_attention), d_model, n_heads),\n",
    "                    d_model,\n",
    "                    d_ff,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation\n",
    "                ) for l in range(e_layers)\n",
    "            ],\n",
    "            norm_layer=torch.nn.LayerNorm(d_model)\n",
    "        )\n",
    "\n",
    "        self.decoder = Decoder(\n",
    "            [\n",
    "                DecoderLayer(\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(True, factor, attention_dropout=dropout, output_attention=False),\n",
    "                        d_model, n_heads),\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(False, factor, attention_dropout=dropout, output_attention=False),\n",
    "                        d_model, n_heads),\n",
    "                    d_model,\n",
    "                    d_ff,\n",
    "                    dropout=dropout,\n",
    "                    activation=activation,\n",
    "                )\n",
    "                for l in range(d_layers)\n",
    "            ],\n",
    "            norm_layer=torch.nn.LayerNorm(d_model),\n",
    "            projection=nn.Linear(d_model, c_out, bias=True)\n",
    "        )\n",
    "\n",
    "        self.tau_learner = Projector(enc_in=enc_in, seq_len=seq_len, hidden_dims=p_hidden_dims,\n",
    "                                     hidden_layers=p_hidden_layers, output_dim=1, kernel_size = kernel_size)\n",
    "        self.delta_learner = Projector(enc_in=enc_in, seq_len=seq_len,\n",
    "                                       hidden_dims=p_hidden_dims, hidden_layers=p_hidden_layers,\n",
    "                                       output_dim=seq_len, kernel_size = kernel_size)\n",
    "        self.future_mixup_layer = nn.Linear(self.pred_len,self.seq_len)\n",
    "\n",
    "    def unpack_cholesky_upper(self, flat_triu):\n",
    "        B, T, _ = flat_triu.shape\n",
    "        D = self.num_feature\n",
    "        U = torch.zeros(B, T, D, D, device=flat_triu.device)\n",
    "        triu_idx = torch.triu_indices(D, D, device=flat_triu.device)\n",
    "        U[:, :, triu_idx[0], triu_idx[1]] = flat_triu\n",
    "        U[:, :, range(D), range(D)] = F.softplus(U[:, :, range(D), range(D)])  # if you want, you can add a small positive number at here. but the numerical stability seems good, so we didn't add\n",
    "        return U\n",
    "\n",
    "    def forward(self, x_enc, x_enc_xxT_trig, x_mark_enc, x_dec, x_mark_dec,\n",
    "                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,\n",
    "                future_mixup_weight = 1, batch_y=None, batch_yyT_trig=None):\n",
    "        \"\"\"\n",
    "        x_enc (Tensor): Encoded input sequence of shape (B, seq_len, enc_in)\n",
    "        x_mark_enc (Tensor): Encoded input time features of shape (B, seq_len, d)\n",
    "        x_dec (Tensor): Decoded input sequence of shape (B, seq_len, dec_in)\n",
    "        x_mark_dec (Tensor): Decoded input time features of shape (B, seq_len, dec_in)\n",
    "        \"\"\"\n",
    "        x_enc = torch.cat([x_enc,x_enc_xxT_trig],dim=-1)\n",
    "        # print(x_enc.shape)\n",
    "        \n",
    "        if (batch_y is not None) and (batch_yyT_trig is not None):\n",
    "            batch_y = torch.cat([batch_y,batch_yyT_trig],dim=-1)\n",
    "            x_enc = future_mixup_weight * x_enc + (1 - future_mixup_weight) * self.future_mixup_layer(batch_y.permute(0, 2, 1)).permute(0, 2, 1)\n",
    "        else:\n",
    "            x_enc = x_enc\n",
    "        x_raw = x_enc.clone().detach()\n",
    "\n",
    "        # Normalization\n",
    "        mean_enc = x_enc.mean(1, keepdim=True).detach()  # B x 1 x E\n",
    "        x_enc = x_enc - mean_enc\n",
    "        std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()  # B x 1 x E\n",
    "        x_enc = x_enc / std_enc\n",
    "\n",
    "        x_dec_new = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_dec[:, -self.pred_len:, :])],\n",
    "                              dim=1).to(x_enc.device).clone()\n",
    "\n",
    "        tau = self.tau_learner(x_raw, std_enc).exp()  # B x S x E, B x 1 x E -> B x 1, positive scalar\n",
    "        delta = self.delta_learner(x_raw, mean_enc)  # B x S x E, B x 1 x E -> B x S\n",
    "\n",
    "\n",
    "        enc_out = self.enc_embedding(x_enc, x_mark_enc)\n",
    "\n",
    "        \n",
    "        enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask, tau=tau, delta=delta)\n",
    "\n",
    "        dec_out = self.dec_embedding(x_dec_new, x_mark_dec)\n",
    "        dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, tau=tau, delta=delta)\n",
    "\n",
    "        # De-normalization\n",
    "        dec_out[:,:,:num_features] = dec_out[:,:,:num_features] * std_enc[:,:,:num_features] + mean_enc[:,:,:num_features]\n",
    "\n",
    "        \n",
    "        miu_pred = dec_out[:,:,:num_features]\n",
    "        U_flat = dec_out[:,:,num_features:]\n",
    "\n",
    "        \n",
    "        U = self.unpack_cholesky_upper(U_flat)   \n",
    "        Sigma_pred = U.transpose(-1, -2) @ U\n",
    "\n",
    "        if self.output_attention:\n",
    "            return miu_pred[:, -self.pred_len:, :], Sigma_pred[:, -self.pred_len:, :, :], attns\n",
    "        else:\n",
    "\n",
    "            return miu_pred[:, -self.pred_len:, :], Sigma_pred[:, -self.pred_len:, :, :], dec_out  # [B, L, D]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "c7e7e436",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_batch_xxT(batch_x):\n",
    "    B, T, D = batch_x.shape\n",
    "    device = batch_x.device\n",
    "    batch_xxT = batch_x.unsqueeze(-1) @ batch_x.unsqueeze(-2) # [B, T, D, D]\n",
    "\n",
    "    triu_indices = torch.triu_indices(D, D)\n",
    "    #  shape [B, T, D*(D+1)//2]\n",
    "    batch_upper_triangular = batch_xxT[:, :, triu_indices[0], triu_indices[1]]\n",
    "    return batch_xxT, batch_upper_triangular\n",
    "\n",
    "\n",
    "def sliding_cov(batch_x, window_size, pad_mode='reflect'):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        batch_x: (B, T, D)\n",
    "        window_size: int\n",
    "        pad_mode: 'reflect' or 'replicate'\n",
    "    Returns:\n",
    "        moving_cov: (B, T, D, D)\n",
    "        moving_cov_tri: (B, T, D * (D + 1) // 2)\n",
    "    \"\"\"\n",
    "    B, T, D = batch_x.shape\n",
    "    pad_left = window_size // 2\n",
    "    pad_right = window_size - 1 - pad_left\n",
    "\n",
    "    # 1. Pad\n",
    "    x = batch_x.permute(0, 2, 1)  # (B, D, T)\n",
    "    x_padded = F.pad(x, (pad_left, pad_right), mode=pad_mode)  # (B, D, T + pad)\n",
    "\n",
    "    # 2. Build Conv1d kernel\n",
    "    avg_kernel = torch.ones(D, 1, window_size, device=batch_x.device) / window_size\n",
    "\n",
    "    # 3. Sliding mean: (B, D, T)\n",
    "    mean_x = F.conv1d(x_padded, avg_kernel, groups=D)  # depthwise conv\n",
    "    mean_x = mean_x.permute(0, 2, 1)  # (B, T, D)\n",
    "\n",
    "    # 4. Compute XX^T for each t: (B, T, D, D)\n",
    "    xxT = batch_x.unsqueeze(3) @ batch_x.unsqueeze(2)  # (B, T, D, D)\n",
    "    xxT_flat = xxT.reshape(B, T, D * D).permute(0, 2, 1)  # (B, D*D, T)\n",
    "\n",
    "    # 5. Pad and Conv1d for mean of XX^T\n",
    "    xxT_padded = F.pad(xxT_flat, (pad_left, pad_right), mode=pad_mode)\n",
    "    avg_kernel_xxT = torch.ones(D * D, 1, window_size, device=batch_x.device) / window_size\n",
    "    mean_xxT = F.conv1d(xxT_padded, avg_kernel_xxT, groups=D * D)  # (B, D*D, T)\n",
    "    mean_xxT = mean_xxT.permute(0, 2, 1).reshape(B, T, D, D)  # (B, T, D, D)\n",
    "\n",
    "    # 6. Covariance: E[XX^T] - E[X]E[X]^T\n",
    "    mean_x_outer = mean_x.unsqueeze(3) @ mean_x.unsqueeze(2)  # (B, T, D, D)\n",
    "    cov = mean_xxT - mean_x_outer  # (B, T, D, D)\n",
    "\n",
    "    # 7. Extract upper triangular part\n",
    "    idx = torch.triu_indices(D, D, offset=0, device=batch_x.device)\n",
    "    cov_tri = cov[:, :, idx[0], idx[1]]  # (B, T, D*(D+1)/2)\n",
    "\n",
    "    return cov, cov_tri\n",
    "\n",
    "\n",
    "def cacf_torch(x, max_lag, dim=(0, 1)):\n",
    "    def get_lower_triangular_indices(n):\n",
    "        return [list(x) for x in torch.tril_indices(n, n)]\n",
    "\n",
    "    ind = get_lower_triangular_indices(x.shape[2])\n",
    "    x = (x - x.mean(dim, keepdims=True)) / x.std(dim, keepdims=True)\n",
    "    x_l = x[..., ind[0]]\n",
    "    x_r = x[..., ind[1]]\n",
    "    cacf_list = list()\n",
    "    for i in range(max_lag):\n",
    "        y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r\n",
    "        cacf_i = torch.mean(y, (1))\n",
    "        cacf_list.append(cacf_i)\n",
    "    cacf = torch.cat(cacf_list, 1)\n",
    "    return cacf.reshape(cacf.shape[0], -1, len(ind[0]))\n",
    "\n",
    "def lower_triangular_to_full_matrix(values, dim):\n",
    "    \"\"\"\n",
    "    values: (B, 1, num_pairs)\n",
    "    dim: D\n",
    "    return: (B, D, D)\n",
    "    \"\"\"\n",
    "    B = values.shape[0]\n",
    "    num_pairs = values.shape[2]\n",
    "    idx = torch.tril_indices(row=dim, col=dim, offset=0, device=values.device)\n",
    "\n",
    "    mat = torch.zeros(B, dim, dim, device=values.device)\n",
    "\n",
    "    mat[:, idx[0], idx[1]] = values.squeeze(1)\n",
    "\n",
    "    mat = mat + mat.transpose(1,2) - torch.diag_embed(torch.diagonal(mat, dim1=1, dim2=2))\n",
    "    return mat\n",
    "\n",
    "\n",
    "def upper_triangular_to_full_matrix(values, dim):\n",
    "\n",
    "    B = values.shape[0]\n",
    "    num_pairs = values.shape[2]\n",
    "    idx = torch.triu_indices(row=dim, col=dim, offset=0, device=values.device)\n",
    "\n",
    "    mat = torch.zeros(B, dim, dim, device=values.device)\n",
    "\n",
    "    mat[:, idx[0], idx[1]] = values.squeeze(1)\n",
    "\n",
    "    mat = mat + mat.transpose(1, 2) - torch.diag_embed(torch.diagonal(mat, dim1=1, dim2=2))\n",
    "\n",
    "    return mat\n",
    "\n",
    "def compute_corr_score(batch_y,batch_y_cw):\n",
    "    batch_y_corr_score = lower_triangular_to_full_matrix(cacf_torch(x=batch_y,max_lag=1),dim = num_features)\n",
    "    batch_y_cw_corr_score = lower_triangular_to_full_matrix(cacf_torch(x=batch_y_cw,max_lag=1),dim = num_features)\n",
    "    return batch_y_corr_score, batch_y_cw_corr_score\n",
    "\n",
    "\n",
    "def average_r2_correlation_metric(X, normalize = True, method = 'linear'):\n",
    "    \"\"\"\n",
    "    X: Tensor of shape (B, T, D)\n",
    "    Returns:\n",
    "        avg_r2: float\n",
    "        r2_scores: list of R² per variable\n",
    "    \"\"\"\n",
    "    B, T, D = X.shape\n",
    "    # reshape to (B*T, D)\n",
    "    data = X.reshape(B*T, D).cpu().numpy()\n",
    "\n",
    "    # standardize each variable\n",
    "    if normalize:\n",
    "        data_mean = data.mean(axis=0, keepdims=True)\n",
    "        data_std = data.std(axis=0, keepdims=True) + 1e-8\n",
    "        data = (data - data_mean) / data_std\n",
    "\n",
    "    r2_scores = []\n",
    "    if method == 'linear':\n",
    "        for target_idx in range(D):\n",
    "            X_other = np.delete(data, target_idx, axis=1)  # shape (N, D-1)\n",
    "            y_target = data[:, target_idx]                # shape (N,)\n",
    "\n",
    "            model = LinearRegression().fit(X_other, y_target)\n",
    "            y_pred = model.predict(X_other)\n",
    "            r2 = r2_score(y_target, y_pred)\n",
    "            r2_scores.append(r2)\n",
    "    elif method == 'xgb':\n",
    "        for target_idx in range(D):\n",
    "            X_other = np.delete(data, target_idx, axis=1)  # (N, D-1)\n",
    "            y_target = data[:, target_idx]                 # (N,)\n",
    "\n",
    "            model = XGBRegressor(objective='reg:squarederror', n_estimators=100, max_depth=3, verbosity=0)\n",
    "            model.fit(X_other, y_target)\n",
    "            y_pred = model.predict(X_other)\n",
    "            r2 = r2_score(y_target, y_pred)\n",
    "            r2_scores.append(r2)\n",
    "\n",
    "    avg_r2 = np.mean(r2_scores)\n",
    "    return avg_r2, r2_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "2ea9b862",
   "metadata": {},
   "outputs": [],
   "source": [
    "def soft_hinge_all_mean(eigvals, eps=1e-3, beta=5.0): # never use it\n",
    "    eps_t = torch.as_tensor(eps, dtype=eigvals.dtype, device=eigvals.device)\n",
    "    gaps  = eps_t - eigvals                 # [B,T,D]\n",
    "    softg = F.softplus(beta * gaps) / beta  \n",
    "    return softg.mean()                     \n",
    "\n",
    "\n",
    "def joint_loss_fn(theta_true, \n",
    "                  theta_outer_true, \n",
    "                  miu_pred, \n",
    "                  sigma_pred, \n",
    "                  batch_xxT_std,\n",
    "                  matrix_norm_weight = [1/3, 1/3, 1/3], \n",
    "                  fft_weight= [1/2, 1/2],\n",
    "                  eign_penalty=0.1,\n",
    "                  eps_eign_min = 1e-3,\n",
    "                  penalty_method = 'hard',\n",
    "                  verbose = False,):\n",
    "    \"\"\"\n",
    "    theta_true:       [B, T, D]\n",
    "    theta_outer_true: [B, T, D, D] (theta * theta^T)\n",
    "    miu_pred:         [B, T, D]\n",
    "    sigma_pred:       [B, T, D, D]\n",
    "    \"\"\"\n",
    "    loss_miu = F.mse_loss(miu_pred, theta_true)\n",
    "    fft_loss_miu = (torch.fft.rfft(miu_pred, dim=1) - torch.fft.rfft(theta_true, dim=1)).abs().mean()  # dont use it\n",
    "\n",
    "    diff = (sigma_pred - theta_outer_true) / batch_xxT_std  # actually we never normalize it. batch_xxT_std is always 1\n",
    "    loss_fro = diff.pow(2).mean()\n",
    "    svals = torch.linalg.svdvals(diff)  # [B, T, D]\n",
    "    loss_svd = svals.mean()\n",
    "    \n",
    "\n",
    "    if eign_penalty > 0:\n",
    "        cov_consistency = sigma_pred\n",
    "        eigvals = torch.linalg.eigvalsh(cov_consistency)\n",
    "        if penalty_method == 'hard':\n",
    "            posdef_penalty = torch.relu(eps_eign_min - eigvals).mean() # the penalty\n",
    "        elif penalty_method == 'soft':\n",
    "            posdef_penalty = soft_hinge_all_mean(eigvals, eps=eps_eign_min, beta=20.0)\n",
    "    else:\n",
    "        posdef_penalty = 0\n",
    "    \n",
    "    if verbose:\n",
    "        if eign_penalty > 0:\n",
    "            print(f'l2 loss:{loss_miu.item()}, f norm loss:{loss_fro.item()}, svd norm loss:{loss_svd.item()}, penalty:{posdef_penalty.item()}')\n",
    "        else:\n",
    "            print(f'l2 loss:{loss_miu.item()}, f norm loss:{loss_fro.item()}, svd norm loss:{loss_svd.item()}, penalty: not used')\n",
    "    fft_loss_cov = (torch.fft.rfft(diff, dim=1) ).abs().mean()  # dont use it\n",
    "    return (fft_weight[0] * loss_miu + fft_weight[1] * fft_loss_miu) \\\n",
    "        + (loss_fro * matrix_norm_weight[0] + loss_svd * matrix_norm_weight[1] +fft_loss_cov * matrix_norm_weight[2]) \\\n",
    "        + eign_penalty * posdef_penalty\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def compute_neural_cov(output_theta_theta_T,output_conditional_mean):\n",
    "    \n",
    "    return output_theta_theta_T.detach()\n",
    "\n",
    "\n",
    "def whiten_sequence(theta, conditional_mean, cov_matrix, eps=1e-3, verbose=False):\n",
    "    \"\"\"\n",
    "    theta: [B, T, D]\n",
    "    conditional_mean: [B, T, D]\n",
    "    cov_matrix: [B, T, D, D]\n",
    "\n",
    "    Returns:\n",
    "        whitened: [B, T, D]\n",
    "        inv_sqrt_all: [B, T, D, D]\n",
    "        sqrt_all: [B, T, D, D]\n",
    "    \"\"\"\n",
    "    B, T, D = theta.shape\n",
    "    residual = theta - conditional_mean  # [B, T, D]\n",
    "    whitened = torch.zeros_like(residual)\n",
    "    inv_sqrt_all = torch.zeros(B, T, D, D, device=theta.device)\n",
    "    sqrt_all = torch.zeros(B, T, D, D, device=theta.device)\n",
    "\n",
    "    for b in range(B):\n",
    "        for t in range(T):\n",
    "            cov = cov_matrix[b, t]  # [D, D]\n",
    "\n",
    "            eigvals, eigvecs = torch.linalg.eigh(cov)\n",
    "\n",
    "            if verbose and (eigvals < eps).any():\n",
    "                print(f\"eigvals before clamp: {eigvals}\")\n",
    "\n",
    "            eigvals_clamped = torch.clamp(eigvals, min=eps)\n",
    "\n",
    "            inv_sqrt = eigvecs @ torch.diag(torch.rsqrt(eigvals_clamped)) @ eigvecs.T\n",
    "            inv_sqrt_all[b, t] = inv_sqrt\n",
    "            whitened[b, t] = inv_sqrt @ residual[b, t]\n",
    "\n",
    "            sqrt = eigvecs @ torch.diag(torch.sqrt(eigvals_clamped)) @ eigvecs.T\n",
    "            sqrt_all[b, t] = sqrt\n",
    "\n",
    "    return whitened, inv_sqrt_all, sqrt_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "406cc7da",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7600955\n"
     ]
    }
   ],
   "source": [
    "set_seed(514 + seed_idx)\n",
    "model_conditional_mean = ns_Transformer().float().to(device)\n",
    "print(sum(p.numel() for p in model_conditional_mean.parameters()))\n",
    "optimizer = torch.optim.AdamW(\n",
    "    model_conditional_mean.parameters(),\n",
    "    lr=config['lr'],\n",
    "    weight_decay=config['weight_decay'] \n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81077f29",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0:Train=1.3337|Val=2.4430|Test=2.0056|whiten_score_y=0.8157|whiten_score_y_cw=0.6427|whiten_score_y_cent=0.7567\n",
      "1:Train=0.8459|Val=2.4596|Test=2.0213|whiten_score_y=0.8157|whiten_score_y_cw=0.6661|whiten_score_y_cent=0.7724\n",
      "2:Train=0.6553|Val=2.3925|Test=2.0181|whiten_score_y=0.8157|whiten_score_y_cw=0.6776|whiten_score_y_cent=0.8014\n",
      "3:Train=0.5428|Val=2.3557|Test=2.0628|whiten_score_y=0.8157|whiten_score_y_cw=0.6940|whiten_score_y_cent=0.8150\n",
      "4:Train=0.4830|Val=2.3821|Test=2.0459|whiten_score_y=0.8157|whiten_score_y_cw=0.6870|whiten_score_y_cent=0.8103\n",
      "5:Train=0.4436|Val=2.4031|Test=2.0716|whiten_score_y=0.8157|whiten_score_y_cw=0.6831|whiten_score_y_cent=0.8176\n",
      "6:Train=0.4168|Val=2.4486|Test=1.9787|whiten_score_y=0.8157|whiten_score_y_cw=0.6743|whiten_score_y_cent=0.8063\n",
      "7:Train=0.3947|Val=2.4639|Test=1.9484|whiten_score_y=0.8157|whiten_score_y_cw=0.6779|whiten_score_y_cent=0.8007\n",
      "8:Train=0.3780|Val=2.4528|Test=1.9355|whiten_score_y=0.8157|whiten_score_y_cw=0.6760|whiten_score_y_cent=0.7949\n",
      "9:Train=0.3645|Val=2.4578|Test=1.9624|whiten_score_y=0.8157|whiten_score_y_cw=0.6562|whiten_score_y_cent=0.7907\n",
      "10:Train=0.3508|Val=2.4228|Test=1.8851|whiten_score_y=0.8157|whiten_score_y_cw=0.6409|whiten_score_y_cent=0.7841\n",
      "11:Train=0.3374|Val=2.3981|Test=1.9316|whiten_score_y=0.8157|whiten_score_y_cw=0.6490|whiten_score_y_cent=0.7786\n",
      "12:Train=0.3256|Val=2.3648|Test=1.9362|whiten_score_y=0.8157|whiten_score_y_cw=0.6533|whiten_score_y_cent=0.7832\n",
      "13:Train=0.3147|Val=2.3346|Test=1.9047|whiten_score_y=0.8157|whiten_score_y_cw=0.6511|whiten_score_y_cent=0.7817\n",
      "14:Train=0.3035|Val=2.2472|Test=1.9070|whiten_score_y=0.8157|whiten_score_y_cw=0.6585|whiten_score_y_cent=0.7836\n",
      "15:Train=0.2936|Val=2.3336|Test=1.9267|whiten_score_y=0.8157|whiten_score_y_cw=0.6393|whiten_score_y_cent=0.7770\n",
      "16:Train=0.2843|Val=2.3350|Test=1.9104|whiten_score_y=0.8157|whiten_score_y_cw=0.6509|whiten_score_y_cent=0.7764\n",
      "17:Train=0.2748|Val=2.3255|Test=1.9161|whiten_score_y=0.8157|whiten_score_y_cw=0.6486|whiten_score_y_cent=0.7758\n",
      "18:Train=0.2673|Val=2.3427|Test=1.9562|whiten_score_y=0.8157|whiten_score_y_cw=0.6402|whiten_score_y_cent=0.7794\n",
      "19:Train=0.2599|Val=2.2808|Test=1.9253|whiten_score_y=0.8157|whiten_score_y_cw=0.6422|whiten_score_y_cent=0.7723\n",
      "\n",
      "Best Val Loss = 2.2472 at Step 14\n"
     ]
    }
   ],
   "source": [
    "val_loss = []\n",
    "num_training_steps = config['num_training_steps']\n",
    "loss_func = nn.MSELoss()\n",
    "teacher_force = config['teacher_force']\n",
    "mixup = config['mixup']\n",
    "matrix_norm_weight = config['matrix_norm_weight']\n",
    "fft_weight= config['fft_weight']\n",
    "eign_penalty = config['eign_penalty']\n",
    "eps_eign_min = config['eps_eign_min']\n",
    "penalty_method = config['penalty_method']\n",
    "window_size = config['window_size']\n",
    "pad_mode = config['pad_mode']\n",
    "\n",
    "best_val_loss = float('inf')\n",
    "best_model_state = None\n",
    "best_step = -1\n",
    "\n",
    "for step in range(num_training_steps):\n",
    "    model_conditional_mean.train()\n",
    "    total_loss = 0\n",
    "###################################################   train   #######################################\n",
    "    for i, (batch_x,\n",
    "            batch_y,\n",
    "            origin_x,\n",
    "            origin_y,\n",
    "            batch_x_mark,\n",
    "            batch_y_mark,\n",
    "            ) in enumerate(dataloader.train_loader):\n",
    "\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "\n",
    "        batch_x = batch_x.to(device).float()\n",
    "        batch_y = batch_y.to(device).float()\n",
    "\n",
    "        batch_x_sliding_cov, batch_x_sliding_cov_trig= sliding_cov(batch_x=batch_x, \n",
    "                                                                   window_size=window_size, \n",
    "                                                                   pad_mode=pad_mode)\n",
    "        batch_y_sliding_cov, batch_y_sliding_cov_trig= sliding_cov(batch_x=batch_y, \n",
    "                                                                   window_size=window_size, \n",
    "                                                                   pad_mode=pad_mode)\n",
    "\n",
    "        batch_xxT, batch_xxT_trig = batch_x_sliding_cov.detach(), batch_x_sliding_cov_trig.detach()\n",
    "        batch_yyT, batch_yyT_trig = batch_y_sliding_cov.detach(), batch_y_sliding_cov_trig.detach()\n",
    "        \n",
    "        batch_x_mark = batch_x_mark.to(device).float()\n",
    "        batch_y_mark = batch_y_mark.to(device).float()\n",
    "\n",
    "        # we nerver use it ↓\n",
    "        # ############# normalize the sigma, this is to keep the scale of cov-loss ###########\n",
    "        # batch_xxT_trig = batch_xxT_trig / train_set_xxT_up_trig_std\n",
    "        # batch_yyT_trig = batch_yyT_trig / train_set_xxT_up_trig_std\n",
    "        # batch_xxT = batch_xxT / train_set_xxT_std\n",
    "        # batch_yyT = batch_yyT / train_set_xxT_std\n",
    "        # ############# normalize the sigma ############################\n",
    "\n",
    "        # === mixup === # it is inspired by timediff. but we also never use it\n",
    "        if mixup:\n",
    "            alpha = 0.2\n",
    "            lam = np.random.beta(alpha, alpha)\n",
    "            idx = torch.randperm(batch_x.size(0))\n",
    "            batch_x = lam * batch_x + (1 - lam) * batch_x[idx]\n",
    "            batch_y = lam * batch_y + (1 - lam) * batch_y[idx]\n",
    "            batch_xxT, batch_xxT_trig = lam * batch_xxT + (1 - lam) * batch_xxT[idx], lam * batch_xxT_trig + (1 - lam) * batch_xxT_trig[idx]\n",
    "            batch_yyT, batch_yyT_trig = lam * batch_yyT + (1 - lam) * batch_yyT[idx], lam * batch_yyT_trig + (1 - lam) * batch_yyT_trig[idx]\n",
    "            batch_x_mark = lam * batch_x_mark + (1 - lam) * batch_x_mark[idx]\n",
    "            batch_y_mark = lam * batch_y_mark + (1 - lam) * batch_y_mark[idx]\n",
    "        # === mixup ===\n",
    "\n",
    "        batch_y_input = torch.concat([batch_x[:, -label_len:, :], batch_y], dim=1)\n",
    "        batch_yyT_input = torch.cat([batch_xxT_trig[:, -label_len:, :], batch_yyT_trig], dim=1)\n",
    "        batch_y_mark_input = torch.concat([batch_x_mark[:, -label_len:, :], batch_y_mark], dim=1)\n",
    "\n",
    "        \n",
    "        dec_inp_label = torch.cat([batch_x[:, -label_len :, :].to(device),batch_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "        dec_inp_pred = torch.zeros(\n",
    "                                    [batch_x.size(0), pred_len, \n",
    "                                     dataset.num_features + int(dataset.num_features*(dataset.num_features+1)/2)]\n",
    "                                ).to(device)\n",
    "        dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)\n",
    "\n",
    "        if teacher_force:\n",
    "            future_mixup_weight = torch.rand((batch_y.shape[0],1,1)).to(device)\n",
    "            miu_pred, sigma_pred, _ = model_conditional_mean(x_enc = batch_x, \n",
    "                                                    x_enc_xxT_trig = batch_xxT_trig, \n",
    "                                                    x_mark_enc = batch_x_mark, \n",
    "                                                    x_dec = dec_inp, \n",
    "                                                    x_mark_dec = batch_y_mark_input,\n",
    "                                                    enc_self_mask=None, \n",
    "                                                    dec_self_mask=None, \n",
    "                                                    dec_enc_mask=None,\n",
    "                                                    future_mixup_weight = future_mixup_weight, \n",
    "                                                    batch_y=batch_y, \n",
    "                                                    batch_yyT_trig=batch_yyT_trig)\n",
    "        else:\n",
    "            future_mixup_weight = torch.rand((batch_y.shape[0],1,1)).to(device)\n",
    "            miu_pred, sigma_pred, _ = model_conditional_mean(x_enc = batch_x, \n",
    "                                                    x_enc_xxT_trig = batch_xxT_trig, \n",
    "                                                    x_mark_enc = batch_x_mark, \n",
    "                                                    x_dec = dec_inp, \n",
    "                                                    x_mark_dec = batch_y_mark_input,\n",
    "                                                    enc_self_mask=None, \n",
    "                                                    dec_self_mask=None, \n",
    "                                                    dec_enc_mask=None,\n",
    "                                                    future_mixup_weight = 1, \n",
    "                                                    batch_y=None, \n",
    "                                                    batch_yyT_trig=None)\n",
    "\n",
    "\n",
    "        loss = joint_loss_fn(theta_true=batch_y, \n",
    "                             theta_outer_true=batch_yyT, \n",
    "                             miu_pred=miu_pred, \n",
    "                             sigma_pred=sigma_pred,\n",
    "                            #  batch_xxT_std = train_set_xxT_std,\n",
    "                             batch_xxT_std = 1,\n",
    "                             matrix_norm_weight = matrix_norm_weight, \n",
    "                             fft_weight= fft_weight,\n",
    "                             eign_penalty=eign_penalty, \n",
    "                             eps_eign_min = eps_eign_min,\n",
    "                             penalty_method = penalty_method,\n",
    "                             verbose = False,)\n",
    "        loss.backward()\n",
    "\n",
    "        torch.nn.utils.clip_grad_norm_(\n",
    "            model_conditional_mean.parameters(), 1.)\n",
    "        optimizer.step()\n",
    "\n",
    "\n",
    "        total_loss += loss.item()\n",
    "    total_loss = total_loss / len(dataloader.train_loader)\n",
    "    \n",
    "\n",
    "###################################################   val   #######################################\n",
    "    with torch.no_grad():\n",
    "        model_conditional_mean.eval()\n",
    "        val_total = 0\n",
    "        for i, (batch_x,\n",
    "            batch_y,\n",
    "            origin_x,\n",
    "            origin_y,\n",
    "            batch_x_mark,\n",
    "            batch_y_mark,\n",
    "            ) in enumerate(dataloader.val_loader):\n",
    "            batch_x = batch_x.to(device).float()\n",
    "            batch_y = batch_y.to(device).float()\n",
    "\n",
    "            batch_x_sliding_cov, batch_x_sliding_cov_trig= sliding_cov(batch_x=batch_x, \n",
    "                                                                    window_size=window_size, \n",
    "                                                                    pad_mode=pad_mode)\n",
    "            batch_y_sliding_cov, batch_y_sliding_cov_trig= sliding_cov(batch_x=batch_y, \n",
    "                                                                    window_size=window_size, \n",
    "                                                                    pad_mode=pad_mode)\n",
    "        \n",
    "            batch_xxT, batch_xxT_trig = batch_x_sliding_cov.detach(), batch_x_sliding_cov_trig.detach()\n",
    "            batch_yyT, batch_yyT_trig = batch_y_sliding_cov.detach(), batch_y_sliding_cov_trig.detach()\n",
    "            \n",
    "            batch_x_mark = batch_x_mark.to(device).float()\n",
    "            batch_y_mark = batch_y_mark.to(device).float()\n",
    "\n",
    "            # ############# normalize the sigma ############################\n",
    "            # batch_xxT_trig = batch_xxT_trig / train_set_xxT_up_trig_std\n",
    "            # batch_yyT_trig = batch_yyT_trig / train_set_xxT_up_trig_std\n",
    "            # batch_xxT = batch_xxT / train_set_xxT_std\n",
    "            # batch_yyT = batch_yyT / train_set_xxT_std\n",
    "            # ############# normalize the sigma ############################\n",
    "\n",
    "\n",
    "            batch_y_input = torch.concat([batch_x[:, -label_len:, :], batch_y], dim=1)\n",
    "            batch_yyT_input = torch.cat([batch_xxT_trig[:, -label_len:, :], batch_yyT_trig], dim=1)\n",
    "            batch_y_mark_input = torch.concat([batch_x_mark[:, -label_len:, :], batch_y_mark], dim=1)\n",
    "\n",
    "            dec_inp_pred = torch.zeros(\n",
    "                [batch_x.size(0), pred_len, dataset.num_features + int(dataset.num_features*(dataset.num_features+1)/2)]\n",
    "            ).to(device)\n",
    "            dec_inp_label = torch.cat([batch_x[:, -label_len :, :].to(device),batch_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "\n",
    "            dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)\n",
    "            \n",
    "            miu_pred, sigma_pred, _ = model_conditional_mean(x_enc = batch_x, \n",
    "                                                    x_enc_xxT_trig = batch_xxT_trig, \n",
    "                                                    x_mark_enc = batch_x_mark, \n",
    "                                                    x_dec = dec_inp, \n",
    "                                                    x_mark_dec = batch_y_mark_input,\n",
    "                                                    enc_self_mask=None, \n",
    "                                                    dec_self_mask=None, \n",
    "                                                    dec_enc_mask=None,\n",
    "                                                    future_mixup_weight = 1, \n",
    "                                                    batch_y=None, \n",
    "                                                    batch_yyT_trig=None)\n",
    "                \n",
    "\n",
    "            loss = joint_loss_fn(theta_true=batch_y, \n",
    "                                theta_outer_true=batch_yyT, \n",
    "                                miu_pred=miu_pred, \n",
    "                                sigma_pred=sigma_pred,\n",
    "                                # batch_xxT_std = train_set_xxT_std,\n",
    "                                batch_xxT_std = 1,\n",
    "                                matrix_norm_weight = matrix_norm_weight, \n",
    "                                fft_weight= fft_weight,\n",
    "                                eign_penalty=eign_penalty, \n",
    "                                eps_eign_min = eps_eign_min,\n",
    "                                penalty_method = penalty_method,\n",
    "                                verbose = False,)\n",
    "            val_total += loss.item()\n",
    "        val_avg = val_total / len(dataloader.val_loader)\n",
    "        val_loss.append(val_avg)\n",
    "\n",
    "###################################################   test   #######################################\n",
    "        # visualize the last batch\n",
    "        test_total = 0\n",
    "        test_total_mean = 0\n",
    "        test_total_sigma = 0\n",
    "        for i, (batch_x,\n",
    "            batch_y,\n",
    "            origin_x,\n",
    "            origin_y,\n",
    "            batch_x_mark,\n",
    "            batch_y_mark,\n",
    "            ) in enumerate(dataloader.test_loader):\n",
    "            batch_x = batch_x.to(device).float()\n",
    "            batch_y = batch_y.to(device).float()\n",
    "            \n",
    "            batch_x_sliding_cov, batch_x_sliding_cov_trig= sliding_cov(batch_x=batch_x, \n",
    "                                                                    window_size=window_size, \n",
    "                                                                    pad_mode=pad_mode)\n",
    "            batch_y_sliding_cov, batch_y_sliding_cov_trig= sliding_cov(batch_x=batch_y, \n",
    "                                                                    window_size=window_size, \n",
    "                                                                    pad_mode=pad_mode)\n",
    "        \n",
    "            batch_xxT, batch_xxT_trig = batch_x_sliding_cov.detach(), batch_x_sliding_cov_trig.detach()\n",
    "            batch_yyT, batch_yyT_trig = batch_y_sliding_cov.detach(), batch_y_sliding_cov_trig.detach()\n",
    "            \n",
    "            batch_x_mark = batch_x_mark.to(device).float()\n",
    "            batch_y_mark = batch_y_mark.to(device).float()\n",
    "\n",
    "            # ############# normalize the sigma ############################\n",
    "            # batch_xxT_trig = batch_xxT_trig / train_set_xxT_up_trig_std\n",
    "            # batch_yyT_trig = batch_yyT_trig / train_set_xxT_up_trig_std\n",
    "            # batch_xxT = batch_xxT / train_set_xxT_std\n",
    "            # batch_yyT = batch_yyT / train_set_xxT_std\n",
    "            # ############# normalize the sigma ############################\n",
    "\n",
    "\n",
    "            batch_y_input = torch.concat([batch_x[:, -label_len:, :], batch_y], dim=1)\n",
    "            batch_yyT_input = torch.cat([batch_xxT_trig[:, -label_len:, :], batch_yyT_trig], dim=1)\n",
    "            batch_y_mark_input = torch.concat([batch_x_mark[:, -label_len:, :], batch_y_mark], dim=1)\n",
    "\n",
    "            dec_inp_pred = torch.zeros(\n",
    "                [batch_x.size(0), pred_len, dataset.num_features + int(dataset.num_features*(dataset.num_features+1)/2)]\n",
    "            ).to(device)\n",
    "            dec_inp_label = torch.cat([batch_x[:, -label_len :, :].to(device),batch_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "\n",
    "            dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)\n",
    "            \n",
    "            miu_pred, sigma_pred, _ = model_conditional_mean(x_enc = batch_x, \n",
    "                                                    x_enc_xxT_trig = batch_xxT_trig, \n",
    "                                                    x_mark_enc = batch_x_mark, \n",
    "                                                    x_dec = dec_inp, \n",
    "                                                    x_mark_dec = batch_y_mark_input,\n",
    "                                                    enc_self_mask=None, \n",
    "                                                    dec_self_mask=None, \n",
    "                                                    dec_enc_mask=None,\n",
    "                                                    future_mixup_weight = 1, \n",
    "                                                    batch_y=None, \n",
    "                                                    batch_yyT_trig=None)\n",
    "                \n",
    "\n",
    "            test_loss = joint_loss_fn(theta_true=batch_y, \n",
    "                                theta_outer_true=batch_yyT, \n",
    "                                miu_pred=miu_pred, \n",
    "                                sigma_pred=sigma_pred,\n",
    "                                # batch_xxT_std = train_set_xxT_std,\n",
    "                                batch_xxT_std = 1,\n",
    "                                matrix_norm_weight = matrix_norm_weight, \n",
    "                                fft_weight= fft_weight,\n",
    "                                eign_penalty=eign_penalty, \n",
    "                                eps_eign_min = eps_eign_min,\n",
    "                                penalty_method = penalty_method,\n",
    "                                verbose = False,)\n",
    "            \n",
    "            # #################### de normalize, then we use the predicted cov, whiten the sequence, visualize ####\n",
    "            # sigma_pred = sigma_pred * train_set_xxT_std\n",
    "            # batch_yyT = batch_yyT * train_set_xxT_std\n",
    "            # batch_yyT_trig = batch_yyT_trig * train_set_xxT_up_trig_std\n",
    "            # #################### de normalize ####################\n",
    "\n",
    "            neural_cov_last_batch = compute_neural_cov(sigma_pred,miu_pred)\n",
    "\n",
    "            y_test_cw, cov_sqrt_inv_test, cov_sqrt_test = whiten_sequence(batch_y, \n",
    "                                conditional_mean=miu_pred.detach(), \n",
    "                                cov_matrix=neural_cov_last_batch.detach(), eps=eps_eign_min, verbose = False)\n",
    "            \n",
    "            # batch_y_corr_score, batch_y_cw_corr_score = compute_corr_score(batch_y,y_test_cw)\n",
    "\n",
    "            whiten_score_y, _ = average_r2_correlation_metric(batch_y, normalize = True, method='linear')\n",
    "            whiten_score_y_cw, _ = average_r2_correlation_metric(y_test_cw, normalize = True, method='linear')\n",
    "            whiten_score_y_centralized, _ = average_r2_correlation_metric(batch_y-miu_pred, normalize = True, method='linear')\n",
    "\n",
    "            \n",
    "            test_total += test_loss.item()\n",
    "            test_total_mean += loss_func(batch_y, miu_pred).item()\n",
    "            test_total_sigma += loss_func(batch_yyT, sigma_pred).item()\n",
    "        test_avg = test_total / len(dataloader.test_loader)\n",
    "        test_mean_avg = test_total_mean / len(dataloader.test_loader)\n",
    "        test_sigma_avg = test_total_sigma / len(dataloader.test_loader)\n",
    "\n",
    "        triu_indices = torch.triu_indices(num_features, num_features)\n",
    "\n",
    "        \n",
    "\n",
    "        if val_avg < best_val_loss:\n",
    "            best_val_loss = val_avg\n",
    "            best_step = step\n",
    "            best_model_state = {k: v.clone() for k, v in model_conditional_mean.state_dict().items()}\n",
    "\n",
    "    print(f\"{step}:Train={total_loss:.4f}|Val={val_avg:.4f}|Test={test_avg:.4f}|whiten_score_y={whiten_score_y:.4f}|whiten_score_y_cw={whiten_score_y_cw:.4f}|whiten_score_y_cent={whiten_score_y_centralized:.4f}\")\n",
    "\n",
    "# ========== load best model ==========\n",
    "if best_model_state is not None:\n",
    "    model_conditional_mean.load_state_dict(best_model_state)\n",
    "    print(f\"\\nBest Val Loss = {best_val_loss:.4f} at Step {best_step}\")\n",
    "else:\n",
    "    print(\"No valid model state was saved.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c73808dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_cw = ETTh1(root='ts_datasets')\n",
    "# === 2.  scaler\n",
    "scaler_type = config['scaler_type']\n",
    "ScalerClass = parse_type(scaler_type, globals())\n",
    "scaler = ScalerClass()\n",
    "\n",
    "# === 3.  dataloader\n",
    "windows = config['windows']\n",
    "horizon = config['horizon']\n",
    "pred_len = config['pred_len']\n",
    "# batch_size = config['batch_size']\n",
    "num_worker = config['num_worker']\n",
    "label_len= windows // 2\n",
    "\n",
    "\n",
    "dataloader_cw = ETTHLoader(\n",
    "    dataset_cw,\n",
    "    scaler,\n",
    "    window=windows,\n",
    "    horizon=horizon,\n",
    "    steps=pred_len,\n",
    "    shuffle_train=True,\n",
    "    freq=dataset_cw.freq,\n",
    "    batch_size=32,\n",
    "    num_worker=num_worker,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "47d444f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def whiten_sequence_fast(theta, conditional_mean, cov_matrix, eps=1e-3, verbose=False):\n",
    "    \"\"\"\n",
    "    theta: [B, T, D]\n",
    "    conditional_mean: [B, T, D]\n",
    "    cov_matrix: [B, T, D, D]\n",
    "\n",
    "    Returns:\n",
    "        whitened: [B, T, D]\n",
    "        inv_sqrt_all: [B, T, D, D]\n",
    "        sqrt_all: [B, T, D, D]\n",
    "    \"\"\"\n",
    "    residual = theta - conditional_mean  # [B, T, D]\n",
    "\n",
    "    # --- batched eigendecomposition ---\n",
    "    eigvals, eigvecs = torch.linalg.eigh(cov_matrix)   # eigvals: [B,T,D], eigvecs: [B,T,D,D]\n",
    "\n",
    "    if verbose and (eigvals < eps).any():\n",
    "        print(\"Some eigenvalues before clamp:\", eigvals[eigvals < eps])\n",
    "\n",
    "    eigvals_clamped = eigvals.clamp_min(eps)  # [B,T,D]\n",
    "\n",
    "    # --- sqrt and inv sqrt of eigenvalues ---\n",
    "    sqrt_vals = eigvals_clamped.sqrt()\n",
    "    inv_sqrt_vals = eigvals_clamped.rsqrt()   # reciprocal sqrt\n",
    "\n",
    "    # [B,T,D,D]\n",
    "    sqrt_diag = torch.diag_embed(sqrt_vals)        # [B,T,D,D]\n",
    "    inv_sqrt_diag = torch.diag_embed(inv_sqrt_vals)\n",
    "\n",
    "    # --- (vecs @ diag @ vecs^T) ---\n",
    "    sqrt_all = eigvecs @ sqrt_diag @ eigvecs.transpose(-2, -1)       # [B,T,D,D]\n",
    "    inv_sqrt_all = eigvecs @ inv_sqrt_diag @ eigvecs.transpose(-2, -1)\n",
    "\n",
    "    # --- whiten residual ---\n",
    "    whitened = torch.einsum(\"btij,btj->bti\", inv_sqrt_all, residual)  # [B,T,D]\n",
    "\n",
    "    return whitened, inv_sqrt_all, sqrt_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ad57772a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0-th batch is whitened\n",
      "1-th batch is whitened\n",
      "2-th batch is whitened\n",
      "3-th batch is whitened\n",
      "4-th batch is whitened\n",
      "5-th batch is whitened\n",
      "6-th batch is whitened\n",
      "7-th batch is whitened\n",
      "8-th batch is whitened\n",
      "9-th batch is whitened\n",
      "10-th batch is whitened\n",
      "11-th batch is whitened\n",
      "12-th batch is whitened\n",
      "13-th batch is whitened\n",
      "14-th batch is whitened\n",
      "15-th batch is whitened\n",
      "16-th batch is whitened\n",
      "17-th batch is whitened\n",
      "18-th batch is whitened\n",
      "19-th batch is whitened\n",
      "20-th batch is whitened\n",
      "21-th batch is whitened\n",
      "22-th batch is whitened\n",
      "23-th batch is whitened\n",
      "24-th batch is whitened\n",
      "25-th batch is whitened\n",
      "26-th batch is whitened\n",
      "27-th batch is whitened\n",
      "28-th batch is whitened\n",
      "29-th batch is whitened\n",
      "30-th batch is whitened\n",
      "31-th batch is whitened\n",
      "32-th batch is whitened\n",
      "33-th batch is whitened\n",
      "34-th batch is whitened\n",
      "35-th batch is whitened\n",
      "36-th batch is whitened\n",
      "37-th batch is whitened\n",
      "38-th batch is whitened\n",
      "39-th batch is whitened\n",
      "40-th batch is whitened\n",
      "41-th batch is whitened\n",
      "42-th batch is whitened\n",
      "43-th batch is whitened\n",
      "44-th batch is whitened\n",
      "45-th batch is whitened\n",
      "46-th batch is whitened\n",
      "47-th batch is whitened\n",
      "48-th batch is whitened\n",
      "49-th batch is whitened\n",
      "50-th batch is whitened\n",
      "51-th batch is whitened\n",
      "52-th batch is whitened\n",
      "53-th batch is whitened\n",
      "54-th batch is whitened\n",
      "55-th batch is whitened\n",
      "56-th batch is whitened\n",
      "57-th batch is whitened\n",
      "58-th batch is whitened\n",
      "59-th batch is whitened\n",
      "60-th batch is whitened\n",
      "61-th batch is whitened\n",
      "62-th batch is whitened\n",
      "63-th batch is whitened\n",
      "64-th batch is whitened\n",
      "65-th batch is whitened\n",
      "66-th batch is whitened\n",
      "67-th batch is whitened\n",
      "68-th batch is whitened\n",
      "69-th batch is whitened\n",
      "70-th batch is whitened\n",
      "71-th batch is whitened\n",
      "72-th batch is whitened\n",
      "73-th batch is whitened\n",
      "74-th batch is whitened\n",
      "75-th batch is whitened\n",
      "76-th batch is whitened\n",
      "77-th batch is whitened\n",
      "78-th batch is whitened\n",
      "79-th batch is whitened\n",
      "80-th batch is whitened\n",
      "81-th batch is whitened\n",
      "82-th batch is whitened\n",
      "83-th batch is whitened\n",
      "84-th batch is whitened\n",
      "85-th batch is whitened\n",
      "86-th batch is whitened\n",
      "87-th batch is whitened\n",
      "88-th batch is whitened\n",
      "89-th batch is whitened\n",
      "90-th batch is whitened\n",
      "91-th batch is whitened\n",
      "92-th batch is whitened\n",
      "93-th batch is whitened\n",
      "94-th batch is whitened\n",
      "95-th batch is whitened\n",
      "96-th batch is whitened\n",
      "97-th batch is whitened\n",
      "98-th batch is whitened\n",
      "99-th batch is whitened\n",
      "100-th batch is whitened\n",
      "101-th batch is whitened\n",
      "102-th batch is whitened\n",
      "103-th batch is whitened\n",
      "104-th batch is whitened\n",
      "105-th batch is whitened\n",
      "106-th batch is whitened\n",
      "107-th batch is whitened\n",
      "108-th batch is whitened\n",
      "109-th batch is whitened\n",
      "110-th batch is whitened\n",
      "111-th batch is whitened\n",
      "112-th batch is whitened\n",
      "113-th batch is whitened\n",
      "114-th batch is whitened\n",
      "115-th batch is whitened\n",
      "116-th batch is whitened\n",
      "117-th batch is whitened\n",
      "118-th batch is whitened\n",
      "119-th batch is whitened\n",
      "120-th batch is whitened\n",
      "121-th batch is whitened\n",
      "122-th batch is whitened\n",
      "123-th batch is whitened\n",
      "124-th batch is whitened\n",
      "125-th batch is whitened\n",
      "126-th batch is whitened\n",
      "127-th batch is whitened\n",
      "128-th batch is whitened\n",
      "129-th batch is whitened\n",
      "130-th batch is whitened\n",
      "131-th batch is whitened\n",
      "132-th batch is whitened\n",
      "133-th batch is whitened\n",
      "134-th batch is whitened\n",
      "135-th batch is whitened\n",
      "136-th batch is whitened\n",
      "137-th batch is whitened\n",
      "138-th batch is whitened\n",
      "139-th batch is whitened\n",
      "140-th batch is whitened\n",
      "141-th batch is whitened\n",
      "142-th batch is whitened\n",
      "143-th batch is whitened\n",
      "144-th batch is whitened\n",
      "145-th batch is whitened\n",
      "146-th batch is whitened\n",
      "147-th batch is whitened\n",
      "148-th batch is whitened\n",
      "149-th batch is whitened\n",
      "150-th batch is whitened\n",
      "151-th batch is whitened\n",
      "152-th batch is whitened\n",
      "153-th batch is whitened\n",
      "154-th batch is whitened\n",
      "155-th batch is whitened\n",
      "156-th batch is whitened\n",
      "157-th batch is whitened\n",
      "158-th batch is whitened\n",
      "159-th batch is whitened\n",
      "160-th batch is whitened\n",
      "161-th batch is whitened\n",
      "162-th batch is whitened\n",
      "163-th batch is whitened\n",
      "164-th batch is whitened\n",
      "165-th batch is whitened\n",
      "166-th batch is whitened\n",
      "167-th batch is whitened\n",
      "168-th batch is whitened\n",
      "169-th batch is whitened\n",
      "170-th batch is whitened\n",
      "171-th batch is whitened\n",
      "172-th batch is whitened\n",
      "173-th batch is whitened\n",
      "174-th batch is whitened\n",
      "175-th batch is whitened\n",
      "176-th batch is whitened\n",
      "177-th batch is whitened\n",
      "178-th batch is whitened\n",
      "179-th batch is whitened\n",
      "180-th batch is whitened\n",
      "181-th batch is whitened\n",
      "182-th batch is whitened\n",
      "183-th batch is whitened\n",
      "184-th batch is whitened\n",
      "185-th batch is whitened\n",
      "186-th batch is whitened\n",
      "187-th batch is whitened\n",
      "188-th batch is whitened\n",
      "189-th batch is whitened\n",
      "190-th batch is whitened\n",
      "191-th batch is whitened\n",
      "192-th batch is whitened\n",
      "193-th batch is whitened\n",
      "194-th batch is whitened\n",
      "195-th batch is whitened\n",
      "196-th batch is whitened\n",
      "197-th batch is whitened\n",
      "198-th batch is whitened\n",
      "199-th batch is whitened\n",
      "200-th batch is whitened\n",
      "201-th batch is whitened\n",
      "202-th batch is whitened\n",
      "203-th batch is whitened\n",
      "204-th batch is whitened\n",
      "205-th batch is whitened\n",
      "206-th batch is whitened\n",
      "207-th batch is whitened\n",
      "208-th batch is whitened\n",
      "209-th batch is whitened\n",
      "210-th batch is whitened\n",
      "211-th batch is whitened\n",
      "212-th batch is whitened\n",
      "213-th batch is whitened\n",
      "214-th batch is whitened\n",
      "215-th batch is whitened\n",
      "216-th batch is whitened\n",
      "217-th batch is whitened\n",
      "218-th batch is whitened\n",
      "219-th batch is whitened\n",
      "220-th batch is whitened\n",
      "221-th batch is whitened\n",
      "222-th batch is whitened\n",
      "223-th batch is whitened\n",
      "224-th batch is whitened\n",
      "225-th batch is whitened\n",
      "226-th batch is whitened\n",
      "227-th batch is whitened\n",
      "228-th batch is whitened\n",
      "229-th batch is whitened\n",
      "230-th batch is whitened\n",
      "231-th batch is whitened\n",
      "232-th batch is whitened\n",
      "233-th batch is whitened\n",
      "234-th batch is whitened\n",
      "235-th batch is whitened\n",
      "236-th batch is whitened\n",
      "237-th batch is whitened\n",
      "238-th batch is whitened\n",
      "239-th batch is whitened\n",
      "240-th batch is whitened\n",
      "241-th batch is whitened\n",
      "242-th batch is whitened\n",
      "243-th batch is whitened\n",
      "244-th batch is whitened\n",
      "245-th batch is whitened\n",
      "246-th batch is whitened\n",
      "247-th batch is whitened\n",
      "248-th batch is whitened\n",
      "249-th batch is whitened\n",
      "250-th batch is whitened\n",
      "251-th batch is whitened\n",
      "252-th batch is whitened\n",
      "253-th batch is whitened\n",
      "254-th batch is whitened\n",
      "255-th batch is whitened\n",
      "256-th batch is whitened\n",
      "257-th batch is whitened\n",
      "258-th batch is whitened\n"
     ]
    }
   ],
   "source": [
    "future_cw_train = {}\n",
    "# with torch.inference_mode():\n",
    "model_conditional_mean.eval()\n",
    "for i, (batch_history,\n",
    "        batch_future,\n",
    "        origin_history,\n",
    "        origin_future,\n",
    "        batch_history_mark,\n",
    "        batch_future_mark,\n",
    "        ) in enumerate(dataloader_cw.train_loader):\n",
    "    print(f'{i}-th batch is whitened')\n",
    "    batch_history=batch_history.to(device).float()\n",
    "    batch_future=batch_future.to(device).float()\n",
    "    origin_history=origin_history.to(device).float()\n",
    "    origin_future=origin_future.to(device).float()\n",
    "    batch_history_mark=batch_history_mark.to(device).float()\n",
    "    batch_future_mark=batch_future_mark.to(device).float()\n",
    "\n",
    "\n",
    "    batch_history_sliding_cov, batch_history_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_history, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "    batch_future_sliding_cov, batch_future_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_future, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "\n",
    "    batch_his_xxT, batch_his_xxT_trig = batch_history_sliding_cov.detach(), batch_history_sliding_cov_trig.detach()\n",
    "    batch_fur_xxT, batch_fur_xxT_trig = batch_future_sliding_cov.detach(), batch_future_sliding_cov_trig.detach()\n",
    "\n",
    "\n",
    "    dec_inp_pred = torch.zeros(\n",
    "        [batch_history.size(0), pred_len, num_features + int(num_features*(num_features+1)/2)]\n",
    "    ).to(device)\n",
    "    dec_inp_label = torch.cat([batch_history[:, -label_len :, :].to(device),batch_his_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "    dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)  # [1, label_len+pred_len, N+n_up_trig]\n",
    "\n",
    "    batch_future_mark_input = torch.concat([batch_history_mark[:, -label_len:, :], batch_future_mark], dim=1)\n",
    "\n",
    "    # print(batch_history.shape, batch_his_xxT_trig.shape, batch_history_mark.shape, dec_inp.shape, batch_x_mark.shape, )\n",
    "    miu_pred, sigma_pred, _ = model_conditional_mean(\n",
    "            x_enc = batch_history,\n",
    "            x_enc_xxT_trig = batch_his_xxT_trig,\n",
    "            x_mark_enc = batch_history_mark.to(device).float(),\n",
    "            x_dec = dec_inp,\n",
    "            x_mark_dec = batch_future_mark_input.to(device).float(),\n",
    "            enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,\n",
    "            future_mixup_weight=1,\n",
    "            batch_y=None, batch_yyT_trig=None\n",
    "        )\n",
    "    neural_cov_last_batch = compute_neural_cov(sigma_pred, miu_pred)\n",
    "    future_cw, cov_sqrt_inv, cov_sqrt = whiten_sequence_fast(batch_future, \n",
    "                                conditional_mean=miu_pred.detach(), \n",
    "                                cov_matrix=neural_cov_last_batch.detach(), eps=eps_eign_min, verbose = False)\n",
    "    future_cw_train[i] = {'batch_history': batch_history.detach(),\n",
    "                            'batch_future': batch_future.detach(),\n",
    "                            'batch_history_mark': batch_history_mark.detach(),\n",
    "                            'batch_future_mark': batch_future_mark.detach(),\n",
    "                            'batch_future_cw': future_cw.detach(),\n",
    "                            'batch_cov_sqrt_inv': cov_sqrt_inv.detach(),\n",
    "                            'cov_sqrt': cov_sqrt.detach(),\n",
    "                            'miu_pred': miu_pred.detach(),\n",
    "                            }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "95b0d130",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0-th batch is whitened\n"
     ]
    }
   ],
   "source": [
    "future_cw_val = {}\n",
    "# with torch.inference_mode():\n",
    "model_conditional_mean.eval()\n",
    "for i, (batch_history,\n",
    "        batch_future,\n",
    "        origin_history,\n",
    "        origin_future,\n",
    "        batch_history_mark,\n",
    "        batch_future_mark,\n",
    "        ) in enumerate(dataloader_cw.val_loader):\n",
    "    print(f'{i}-th batch is whitened')\n",
    "    batch_history=batch_history.to(device).float()\n",
    "    batch_future=batch_future.to(device).float()\n",
    "    origin_history=origin_history.to(device).float()\n",
    "    origin_future=origin_future.to(device).float()\n",
    "    batch_history_mark=batch_history_mark.to(device).float()\n",
    "    batch_future_mark=batch_future_mark.to(device).float()\n",
    "\n",
    "\n",
    "    batch_history_sliding_cov, batch_history_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_history, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "    batch_future_sliding_cov, batch_future_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_future, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "\n",
    "    batch_his_xxT, batch_his_xxT_trig = batch_history_sliding_cov.detach(), batch_history_sliding_cov_trig.detach()\n",
    "    batch_fur_xxT, batch_fur_xxT_trig = batch_future_sliding_cov.detach(), batch_future_sliding_cov_trig.detach()\n",
    "\n",
    "\n",
    "    dec_inp_pred = torch.zeros(\n",
    "        [batch_history.size(0), pred_len, num_features + int(num_features*(num_features+1)/2)]\n",
    "    ).to(device)\n",
    "    dec_inp_label = torch.cat([batch_history[:, -label_len :, :].to(device),batch_his_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "    dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)  # [1, label_len+pred_len, N+n_up_trig]\n",
    "\n",
    "    batch_future_mark_input = torch.concat([batch_history_mark[:, -label_len:, :], batch_future_mark], dim=1)\n",
    "\n",
    "    # print(batch_history.shape, batch_his_xxT_trig.shape, batch_history_mark.shape, dec_inp.shape, batch_x_mark.shape, )\n",
    "    miu_pred, sigma_pred, _ = model_conditional_mean(\n",
    "            x_enc = batch_history,\n",
    "            x_enc_xxT_trig = batch_his_xxT_trig,\n",
    "            x_mark_enc = batch_history_mark.to(device).float(),\n",
    "            x_dec = dec_inp,\n",
    "            x_mark_dec = batch_future_mark_input.to(device).float(),\n",
    "            enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,\n",
    "            future_mixup_weight=1,\n",
    "            batch_y=None, batch_yyT_trig=None\n",
    "        )\n",
    "    neural_cov_last_batch = compute_neural_cov(sigma_pred, miu_pred)\n",
    "    future_cw, cov_sqrt_inv, cov_sqrt = whiten_sequence_fast(batch_future, \n",
    "                                conditional_mean=miu_pred.detach(), \n",
    "                                cov_matrix=neural_cov_last_batch.detach(), eps=eps_eign_min, verbose = False)\n",
    "    future_cw_val[i] = {'batch_history': batch_history.detach(),\n",
    "                            'batch_future': batch_future.detach(),\n",
    "                            'batch_history_mark': batch_history_mark.detach(),\n",
    "                            'batch_future_mark': batch_future_mark.detach(),\n",
    "                            'batch_future_cw': future_cw.detach(),\n",
    "                            'batch_cov_sqrt_inv': cov_sqrt_inv.detach(),\n",
    "                            'cov_sqrt': cov_sqrt.detach(),\n",
    "                            'miu_pred': miu_pred.detach(),\n",
    "                            }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "9ef361d1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0-th batch is whitened\n"
     ]
    }
   ],
   "source": [
    "future_cw_test = {}\n",
    "# with torch.inference_mode():\n",
    "model_conditional_mean.eval()\n",
    "for i, (batch_history,\n",
    "        batch_future,\n",
    "        origin_history,\n",
    "        origin_future,\n",
    "        batch_history_mark,\n",
    "        batch_future_mark,\n",
    "        ) in enumerate(dataloader_cw.test_loader):\n",
    "    print(f'{i}-th batch is whitened')\n",
    "    batch_history=batch_history.to(device).float()\n",
    "    batch_future=batch_future.to(device).float()\n",
    "    origin_history=origin_history.to(device).float()\n",
    "    origin_future=origin_future.to(device).float()\n",
    "    batch_history_mark=batch_history_mark.to(device).float()\n",
    "    batch_future_mark=batch_future_mark.to(device).float()\n",
    "\n",
    "\n",
    "    batch_history_sliding_cov, batch_history_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_history, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "    batch_future_sliding_cov, batch_future_sliding_cov_trig = sliding_cov(\n",
    "        batch_x=batch_future, window_size=window_size, pad_mode=pad_mode\n",
    "    )\n",
    "\n",
    "    batch_his_xxT, batch_his_xxT_trig = batch_history_sliding_cov.detach(), batch_history_sliding_cov_trig.detach()\n",
    "    batch_fur_xxT, batch_fur_xxT_trig = batch_future_sliding_cov.detach(), batch_future_sliding_cov_trig.detach()\n",
    "\n",
    "    dec_inp_pred = torch.zeros(\n",
    "        [batch_history.size(0), pred_len, num_features + int(num_features*(num_features+1)/2)]\n",
    "    ).to(device)\n",
    "    dec_inp_label = torch.cat([batch_history[:, -label_len :, :].to(device),batch_his_xxT_trig[:, -label_len:, :].to(device)],dim=-1)\n",
    "    dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)  # [1, label_len+pred_len, N+n_up_trig]\n",
    "\n",
    "    batch_future_mark_input = torch.concat([batch_history_mark[:, -label_len:, :], batch_future_mark], dim=1)\n",
    "\n",
    "    # print(batch_history.shape, batch_his_xxT_trig.shape, batch_history_mark.shape, dec_inp.shape, batch_x_mark.shape, )\n",
    "    miu_pred, sigma_pred, _ = model_conditional_mean(\n",
    "            x_enc = batch_history,\n",
    "            x_enc_xxT_trig = batch_his_xxT_trig,\n",
    "            x_mark_enc = batch_history_mark.to(device).float(),\n",
    "            x_dec = dec_inp,\n",
    "            x_mark_dec = batch_future_mark_input.to(device).float(),\n",
    "            enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None,\n",
    "            future_mixup_weight=1,\n",
    "            batch_y=None, batch_yyT_trig=None\n",
    "        )\n",
    "    neural_cov_last_batch = compute_neural_cov(sigma_pred, miu_pred)\n",
    "    future_cw, cov_sqrt_inv, cov_sqrt = whiten_sequence_fast(batch_future, \n",
    "                                conditional_mean=miu_pred.detach(), \n",
    "                                cov_matrix=neural_cov_last_batch.detach(), eps=eps_eign_min, verbose = False)\n",
    "    future_cw_test[i] = {'batch_history': batch_history.detach(),\n",
    "                            'batch_future': batch_future.detach(),\n",
    "                            'batch_history_mark': batch_history_mark.detach(),\n",
    "                            'batch_future_mark': batch_future_mark.detach(),\n",
    "                            'batch_future_cw': future_cw.detach(),\n",
    "                            'batch_cov_sqrt_inv': cov_sqrt_inv.detach(),\n",
    "                            'cov_sqrt': cov_sqrt.detach(),\n",
    "                            'miu_pred': miu_pred.detach(),\n",
    "                            }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d1806170",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchmetrics import Metric\n",
    "import CRPS.CRPS as pscore  # Assuming `pscore` is the function to compute CRPS\n",
    "from concurrent.futures import ProcessPoolExecutor\n",
    "\n",
    "class CRPS(Metric):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.add_state(\"total_crps\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "\n",
    "        # self.executor = ProcessPoolExecutor()\n",
    "        \n",
    "\n",
    "    def update(self, pred: torch.Tensor, true: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pred: Tensor of predicted distributions, shape (B, O, N, S).\n",
    "            true: Tensor of true values, shape (B, O, N,).\n",
    "        \"\"\"\n",
    "        def compute_crps(i):\n",
    "            return pscore(pred_np[i], true_np[i]).compute()[0]\n",
    "\n",
    "        pred = pred.view(-1, pred.shape[3])  # Reshape to (B * O * N, S)\n",
    "        true = true.view(-1)  # Reshape to (B * O * N,)\n",
    "        \n",
    "        pred_np = pred.cpu().numpy()\n",
    "        true_np = true.cpu().numpy()\n",
    "\n",
    "        # crps_sum = sum(self.executor.map(compute_crps, range(len(true_np))))\n",
    "        \n",
    "        crps_sum = 0.0\n",
    "        for i in range(len(true_np)):\n",
    "            res = pscore(pred_np[i], true_np[i]).compute()\n",
    "            crps_sum += res[0]\n",
    "\n",
    "        self.total_crps += torch.tensor(crps_sum).to(self.device)\n",
    "        self.total_samples += pred.size(0)\n",
    "\n",
    "    def compute(self):\n",
    "        return self.total_crps / self.total_samples\n",
    "\n",
    "class CRPSSum(Metric):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.add_state(\"total_crps\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, pred: torch.Tensor, true: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pred: Tensor of predicted distributions, shape (N, S).\n",
    "            true: Tensor of true values, shape (N,).\n",
    "        \"\"\"\n",
    "        \n",
    "        pred = pred.sum(dim=2)\n",
    "        true = true.sum(dim=2)\n",
    "        \n",
    "\n",
    "        pred = pred.view(-1, pred.shape[2])  # Reshape to (B * O , S)\n",
    "        true = true.view(-1)  # Reshape to (B * O,)\n",
    "\n",
    "        \n",
    "        pred_np = pred.cpu().numpy()\n",
    "        true_np = true.cpu().numpy()\n",
    "\n",
    "        crps_sum = 0.0\n",
    "        for i in range(len(true_np)):\n",
    "            res = pscore(pred_np[i], true_np[i]).compute()\n",
    "            crps_sum += res[0]\n",
    "\n",
    "        self.total_crps += torch.tensor(crps_sum).to(self.device)\n",
    "        self.total_samples += pred.size(0)\n",
    "\n",
    "    def compute(self):\n",
    "        return self.total_crps / self.total_samples\n",
    "    \n",
    "\n",
    "class PICP(Metric):\n",
    "    def __init__(self, low_percentile: int = 5, high_percentile: int = 95, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.low_percentile = low_percentile\n",
    "        self.high_percentile = high_percentile\n",
    "        self.add_state(\"coverage\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, all_gen_y: torch.Tensor, y_true: torch.Tensor):\n",
    "        # Reshape to (B * O * N, S)\n",
    "        all_gen_y = all_gen_y.view(-1, all_gen_y.shape[3]).cpu()\n",
    "        y_true = y_true.view(-1).cpu()  # Reshape to (B * O * N,)\n",
    "\n",
    "        # Compute the low and high percentiles using torch.quantile\n",
    "        low, high = self.low_percentile, self.high_percentile\n",
    "        CI_y_pred = torch.quantile(all_gen_y, torch.tensor([low / 100.0, high / 100.0]).float(), dim=1)\n",
    "        \n",
    "        # Determine whether the true values are within the prediction intervals\n",
    "        y_in_range = (y_true >= CI_y_pred[0]) & (y_true <= CI_y_pred[1])\n",
    "        \n",
    "        coverage = y_in_range.float().mean()\n",
    "        self.coverage += coverage.to(self.device)\n",
    "        self.total_samples += y_true.size(0)\n",
    "\n",
    "    def compute(self):\n",
    "        return self.coverage / self.total_samples\n",
    "    \n",
    "\n",
    "class ProbMAE(Metric):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.add_state(\"total_mae\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, pred: torch.Tensor, true: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pred: Tensor of predicted distributions, shape (B, O, N, S).\n",
    "            true: Tensor of true values, shape (B, O, N).\n",
    "        \"\"\"\n",
    "        # Compute mean along S-axis\n",
    "        pred_mean = pred.mean(dim=-1)  # Shape: (B, O, N)\n",
    "\n",
    "        # Ensure the true tensor matches the shape\n",
    "        assert true.shape == pred_mean.shape, \"Shapes of true values and pred_mean must match\"\n",
    "\n",
    "        # Compute absolute error\n",
    "        absolute_error = torch.abs(pred_mean - true)\n",
    "\n",
    "        # Sum errors and count total samples\n",
    "        self.total_mae += absolute_error.sum()\n",
    "        self.total_samples += absolute_error.numel()\n",
    "\n",
    "    def compute(self):\n",
    "        # Compute mean absolute error\n",
    "        return self.total_mae / self.total_samples\n",
    "    \n",
    "class ProbMSE(Metric):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.add_state(\"total_mse\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, pred: torch.Tensor, true: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pred: Tensor of predicted distributions, shape (B, O, N, S).\n",
    "            true: Tensor of true values, shape (B, O, N).\n",
    "        \"\"\"\n",
    "        # Compute mean along S-axis\n",
    "        pred_mean = pred.mean(dim=-1)  # Shape: (B, O, N)\n",
    "\n",
    "        # Ensure the true tensor matches the shape\n",
    "        assert true.shape == pred_mean.shape, \"Shapes of true values and pred_mean must match\"\n",
    "\n",
    "        # Compute squared error\n",
    "        squared_error = (pred_mean - true) ** 2\n",
    "\n",
    "        # Sum errors and count total samples\n",
    "        self.total_mse += squared_error.sum()\n",
    "        self.total_samples += squared_error.numel()\n",
    "\n",
    "    def compute(self):\n",
    "        # Compute mean squared error\n",
    "        return self.total_mse / self.total_samples\n",
    "\n",
    "\n",
    "class ProbRMSE(Metric):\n",
    "    def __init__(self, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.add_state(\"total_mse\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "\n",
    "    def update(self, pred: torch.Tensor, true: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Args:\n",
    "            pred: Tensor of predicted distributions, shape (B, O, N, S).\n",
    "            true: Tensor of true values, shape (B, O, N).\n",
    "        \"\"\"\n",
    "        # Compute mean along S-axis\n",
    "        pred_mean = pred.mean(dim=-1)  # Shape: (B, O, N)\n",
    "\n",
    "        # Ensure the true tensor matches the shape\n",
    "        assert true.shape == pred_mean.shape, \"Shapes of true values and pred_mean must match\"\n",
    "\n",
    "        # Compute squared error\n",
    "        squared_error = (pred_mean - true) ** 2\n",
    "\n",
    "        # Sum errors and count total samples\n",
    "        self.total_mse += squared_error.sum()\n",
    "        self.total_samples += squared_error.numel()\n",
    "\n",
    "    def compute(self):\n",
    "        # Compute root mean squared error\n",
    "        return torch.sqrt(self.total_mse / self.total_samples)\n",
    "    \n",
    "\n",
    "class QICE(Metric):\n",
    "    def __init__(self, n_bins: int = 10, dist_sync_on_step=False):\n",
    "        super().__init__(dist_sync_on_step=dist_sync_on_step)\n",
    "        self.n_bins = n_bins\n",
    "        # Add states for each quantile's coverage ratio\n",
    "        self.add_state(\"quantile_bin_counts\", default=torch.zeros(self.n_bins), dist_reduce_fx=\"sum\")\n",
    "        self.add_state(\"total_samples\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n",
    "        \n",
    "    def update(self, preds: torch.Tensor, targets: torch.Tensor):\n",
    "        \"\"\"\n",
    "        Update the metric with the predictions and targets.\n",
    "        Args:\n",
    "            preds: Tensor of shape (N, S) containing generated predictions\n",
    "            targets: Tensor of shape (N, 1) containing ground truth values\n",
    "        \"\"\"\n",
    "        # print(preds[0, :, 0, :], targets[0, :, 0])\n",
    "        \n",
    "        preds = preds.view(-1, preds.size(3))  # Reshape to (B * O * N, S)\n",
    "        targets = targets.view(-1)  # Reshape to (B * O * N,)\n",
    "\n",
    "        preds_np = preds.cpu().numpy()  # Shape (N, S)\n",
    "        targets_np = targets.cpu().numpy().T  # Shape (1, N)\n",
    "        \n",
    "        # Generate quantiles based on the number of bins\n",
    "        quantile_list = np.arange(self.n_bins + 1) * (100 / self.n_bins)\n",
    "        \n",
    "        # Calculate the quantiles for the predicted values\n",
    "        y_pred_quantiles = np.percentile(preds_np, q=quantile_list, axis=1)  # Shape (n_bins+1, N)\n",
    "        \n",
    "        # Calculate which quantile interval the true target belongs to\n",
    "        quantile_membership_array = ((targets_np - y_pred_quantiles) > 0).astype(int)  # Shape (n_bins+1, N)\n",
    "        y_true_quantile_membership = quantile_membership_array.sum(axis=0)  # Shape (N,)\n",
    "        \n",
    "        # Count the number of targets in each bin\n",
    "        y_true_quantile_bin_count = np.array(\n",
    "            [(y_true_quantile_membership == v).sum() for v in np.arange(self.n_bins + 2)]  # Shape (n_bins+2,)\n",
    "        )\n",
    "        print(y_true_quantile_bin_count)\n",
    "        # Combine outliers into the first and last bins\n",
    "        y_true_quantile_bin_count[1] += y_true_quantile_bin_count[0]\n",
    "        y_true_quantile_bin_count[-2] += y_true_quantile_bin_count[-1]\n",
    "        y_true_quantile_bin_count_ = y_true_quantile_bin_count[1:-1]  # Exclude first and last bin\n",
    "        \n",
    "        # Update the quantile bin counts for each update\n",
    "        self.quantile_bin_counts += torch.tensor(y_true_quantile_bin_count_).to(self.device)\n",
    "        self.total_samples += preds.size(0)\n",
    "        \n",
    "    def compute(self):\n",
    "        \"\"\"\n",
    "        Compute the QICE score (geometric mean of coverage ratios).\n",
    "        Returns:\n",
    "            The QICE score as a float.\n",
    "        \"\"\"\n",
    "        # Normalize the counts by the total number of samples\n",
    "        \n",
    "        \n",
    "        y_true_ratio_by_bin = self.quantile_bin_counts.float() / self.total_samples.item()\n",
    "        # print(self.total_samples,self.quantile_bin_counts )\n",
    "        # print(y_true_ratio_by_bin.shape, torch.sum(y_true_ratio_by_bin),  torch.abs(\n",
    "        #     torch.sum(y_true_ratio_by_bin) - 1))\n",
    "        assert torch.abs(\n",
    "            torch.sum(y_true_ratio_by_bin) - 1) < 1e-5, \"Sum of quantile coverage ratios shall be 1!\"\n",
    "        qice_coverage_ratio = torch.abs(torch.ones(self.n_bins) / self.n_bins - y_true_ratio_by_bin).mean()\n",
    "        return qice_coverage_ratio\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "from ts2vec.ts2vec import TS2Vec\n",
    "import scipy\n",
    "def cacf_torch(x, max_lag = 1, dim=(0, 1)):\n",
    "    def get_lower_triangular_indices(n):\n",
    "        return [list(x) for x in torch.tril_indices(n, n)]\n",
    "\n",
    "    ind = get_lower_triangular_indices(x.shape[2])\n",
    "    x = (x - x.mean(dim, keepdims=True)) / x.std(dim, keepdims=True)\n",
    "    x_l = x[..., ind[0]]\n",
    "    x_r = x[..., ind[1]]\n",
    "    cacf_list = list()\n",
    "    for i in range(max_lag):\n",
    "        y = x_l[:, i:] * x_r[:, :-i] if i > 0 else x_l * x_r\n",
    "        cacf_i = torch.mean(y, (1))\n",
    "        cacf_list.append(cacf_i)\n",
    "    cacf = torch.cat(cacf_list, 1)\n",
    "    return cacf.reshape(cacf.shape[0], -1, len(ind[0]))\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def calculate_fid(act1, act2):\n",
    "    # calculate mean and covariance statistics\n",
    "    mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)\n",
    "    mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)\n",
    "    # calculate sum squared difference between means\n",
    "    ssdiff = np.sum((mu1 - mu2)**2.0)\n",
    "    # calculate sqrt of product between cov\n",
    "    covmean = scipy.linalg.sqrtm(sigma1.dot(sigma2))\n",
    "    # check and correct imaginary numbers from sqrt\n",
    "    if np.iscomplexobj(covmean):\n",
    "        covmean = covmean.real\n",
    "    # calculate score\n",
    "    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)\n",
    "    return fid\n",
    "\n",
    "def Context_FID(ori_data, generated_data):\n",
    "    model = TS2Vec(input_dims=ori_data.shape[-1], device=device, batch_size=512, lr=0.001, output_dims=320,\n",
    "                   max_train_length=3000)\n",
    "    model.fit(ori_data, verbose=False)\n",
    "    ori_represenation = model.encode(ori_data, encoding_window='full_series')\n",
    "    gen_represenation = model.encode(generated_data, encoding_window='full_series')\n",
    "    idx = np.random.permutation(ori_data.shape[0])\n",
    "    ori_represenation = ori_represenation[idx]\n",
    "    gen_represenation = gen_represenation[idx]\n",
    "    results = calculate_fid(ori_represenation, gen_represenation)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ab635365",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPS = 10e-8\n",
    "def make_beta_schedule(schedule=\"linear\", num_timesteps=1000, start=1e-5, end=1e-2):\n",
    "    if schedule == \"linear\":\n",
    "        betas = torch.linspace(start, end, num_timesteps)\n",
    "    elif schedule == \"const\":\n",
    "        betas = end * torch.ones(num_timesteps)\n",
    "    elif schedule == \"quad\":\n",
    "        betas = torch.linspace(start ** 0.5, end ** 0.5, num_timesteps) ** 2\n",
    "    elif schedule == \"jsd\":\n",
    "        betas = 1.0 / torch.linspace(num_timesteps, 1, num_timesteps)\n",
    "    elif schedule == \"sigmoid\":\n",
    "        betas = torch.linspace(-6, 6, num_timesteps)\n",
    "        betas = torch.sigmoid(betas) * (end - start) + start\n",
    "    elif schedule == \"cosine\" or schedule == \"cosine_reverse\":\n",
    "        max_beta = 0.999\n",
    "        cosine_s = 0.008\n",
    "        betas = torch.tensor(\n",
    "            [min(1 - (math.cos(((i + 1) / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2) / (\n",
    "                    math.cos((i / num_timesteps + cosine_s) / (1 + cosine_s) * math.pi / 2) ** 2), max_beta) for i in\n",
    "             range(num_timesteps)])\n",
    "        if schedule == \"cosine_reverse\":\n",
    "            betas = betas.flip(0)  # starts at max_beta then decreases fast\n",
    "    elif schedule == \"cosine_anneal\":\n",
    "        betas = torch.tensor(\n",
    "            [start + 0.5 * (end - start) * (1 - math.cos(t / (num_timesteps - 1) * math.pi)) for t in\n",
    "             range(num_timesteps)])\n",
    "    return betas\n",
    "\n",
    "def extract(input, t, x):\n",
    "    shape = x.shape\n",
    "    out = torch.gather(input, 0, t.to(input.device))\n",
    "    reshape = [t.shape[0]] + [1] * (len(shape) - 1)\n",
    "    return out.reshape(*reshape)\n",
    "\n",
    "def q_sample(y, y_0_hat, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise=None):\n",
    "    \"\"\"\n",
    "    Input:\n",
    "    y: original data, y_0.(B, T, N)\n",
    "    y_0_hat: prediction of pre-trained guidance model. (B, T, N)\n",
    "    alphas_bar_sqrt: square root of cumulative product of alphas at timestep t.(T, N)\n",
    "    one_minus_alphas_bar_sqrt: square root of cumulative product of (1 - alphas) at timestep t.(T, N)\n",
    "    t: current timestep, a scalar tensor. (B, 1)\n",
    "    noise: optional noise tensor, if None, will be sampled from standard normal distribution. (B, T, N)\n",
    "    Returns:\n",
    "        y_t: sampled data at timestep t, q(y_t | y_0, x).(B, T, N)\n",
    "    \"\"\"\n",
    "    if noise is None:\n",
    "        noise = torch.randn_like(y).to(y.device)\n",
    "    sqrt_alpha_bar_t = extract(alphas_bar_sqrt, t, y)\n",
    "    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)\n",
    "    # q(y_t | y_0, x)\n",
    "    # NsDiff: y_t = sqrt_alpha_bar_t * y + (1 - sqrt_alpha_bar_t) * y_0_hat + noise\n",
    "    y_t = sqrt_alpha_bar_t * y + (1 - sqrt_alpha_bar_t) * y_0_hat + sqrt_one_minus_alpha_bar_t * noise\n",
    "    return y_t\n",
    "\n",
    "def dict2namespace(config):\n",
    "    namespace = argparse.Namespace()\n",
    "    for key, value in config.items():\n",
    "        if isinstance(value, dict):\n",
    "            new_value = dict2namespace(value)\n",
    "        else:\n",
    "            new_value = value\n",
    "        setattr(namespace, key, new_value)\n",
    "    return namespace\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3f8dbf4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def p_sample(model, x, x_mark, y, y_0_hat, y_T_mean, t, alphas, one_minus_alphas_bar_sqrt):\n",
    "    \"\"\"\n",
    "    Reverse diffusion process sampling -- one time step.\n",
    "\n",
    "    y: sampled y at time step t, y_t.\n",
    "    y_0_hat: prediction of pre-trained guidance model.\n",
    "    y_T_mean: mean of prior distribution at timestep T.\n",
    "    We replace y_0_hat with y_T_mean in the forward process posterior mean computation, emphasizing that \n",
    "        guidance model prediction y_0_hat = f_phi(x) is part of the input to eps_theta network, while \n",
    "        in paper we also choose to set the prior mean at timestep T y_T_mean = f_phi(x).\n",
    "    \"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    z = torch.randn_like(y)  # if t > 1 else torch.zeros_like(y)\n",
    "    t = torch.tensor([t]).to(device)\n",
    "    alpha_t = extract(alphas, t, y)\n",
    "    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)\n",
    "    sqrt_one_minus_alpha_bar_t_m_1 = extract(one_minus_alphas_bar_sqrt, t - 1, y)\n",
    "    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()\n",
    "    sqrt_alpha_bar_t_m_1 = (1 - sqrt_one_minus_alpha_bar_t_m_1.square()).sqrt()\n",
    "    # y_t_m_1 posterior mean component coefficients\n",
    "    gamma_0 = (1 - alpha_t) * sqrt_alpha_bar_t_m_1 / (sqrt_one_minus_alpha_bar_t.square())\n",
    "    gamma_1 = (sqrt_one_minus_alpha_bar_t_m_1.square()) * (alpha_t.sqrt()) / (sqrt_one_minus_alpha_bar_t.square())\n",
    "    gamma_2 = 1 + (sqrt_alpha_bar_t - 1) * (alpha_t.sqrt() + sqrt_alpha_bar_t_m_1) / (\n",
    "        sqrt_one_minus_alpha_bar_t.square())\n",
    "    eps_theta = model(x, x_mark, 0, y, y_0_hat, t).to(device).detach()\n",
    "    # y_0 reparameterization\n",
    "    y_0_reparam = 1 / sqrt_alpha_bar_t * (\n",
    "            y - (1 - sqrt_alpha_bar_t) * y_T_mean - eps_theta * sqrt_one_minus_alpha_bar_t)\n",
    "    # posterior mean\n",
    "    y_t_m_1_hat = gamma_0 * y_0_reparam + gamma_1 * y + gamma_2 * y_T_mean\n",
    "    # posterior variance\n",
    "    beta_t_hat = (sqrt_one_minus_alpha_bar_t_m_1.square()) / (sqrt_one_minus_alpha_bar_t.square()) * (1 - alpha_t)\n",
    "    y_t_m_1 = y_t_m_1_hat.to(device) + beta_t_hat.sqrt().to(device) * z.to(device)\n",
    "    return y_t_m_1\n",
    "\n",
    "def p_sample_t_1to0(model, x, x_mark, y, y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt):\n",
    "    device = next(model.parameters()).device\n",
    "    t = torch.tensor([0]).to(device)  # corresponding to timestep 1 (i.e., t=1 in diffusion models)\n",
    "    sqrt_one_minus_alpha_bar_t = extract(one_minus_alphas_bar_sqrt, t, y)\n",
    "    sqrt_alpha_bar_t = (1 - sqrt_one_minus_alpha_bar_t.square()).sqrt()\n",
    "    eps_theta = model(x, x_mark, 0, y, y_0_hat, t).to(device).detach()\n",
    "    # y_0 reparameterization\n",
    "    y_0_reparam = 1 / sqrt_alpha_bar_t * (\n",
    "            y - (1 - sqrt_alpha_bar_t) * y_T_mean - eps_theta * sqrt_one_minus_alpha_bar_t)\n",
    "    y_t_m_1 = y_0_reparam.to(device)\n",
    "    return y_t_m_1\n",
    "\n",
    "\n",
    "def p_sample_loop(model, x, x_mark, y_0_hat, y_T_mean, n_steps, alphas, one_minus_alphas_bar_sqrt):\n",
    "    device = next(model.parameters()).device\n",
    "    z = torch.randn_like(y_T_mean).to(device)\n",
    "    cur_y = z + y_T_mean  # sample y_T\n",
    "    y_p_seq = [cur_y]\n",
    "    for t in reversed(range(1, n_steps)):  # t from T to 2\n",
    "        y_t = cur_y\n",
    "        cur_y = p_sample(model, x, x_mark, y_t, y_0_hat, y_T_mean, t, alphas, one_minus_alphas_bar_sqrt)  # y_{t-1}\n",
    "        y_p_seq.append(cur_y)\n",
    "    assert len(y_p_seq) == n_steps\n",
    "    y_0 = p_sample_t_1to0(model, x, x_mark, y_p_seq[-1], y_0_hat, y_T_mean, one_minus_alphas_bar_sqrt)\n",
    "    y_p_seq.append(y_0)\n",
    "    return y_p_seq\n",
    "\n",
    "def kld(y1, y2, grid=(-20, 20), num_grid=400):\n",
    "    y1, y2 = y1.numpy().flatten(), y2.numpy().flatten()\n",
    "    p_y1, _ = np.histogram(y1, bins=num_grid, range=[grid[0], grid[1]], density=True)\n",
    "    p_y1 += 1e-7\n",
    "    p_y2, _ = np.histogram(y2, bins=num_grid, range=[grid[0], grid[1]], density=True)\n",
    "    p_y2 += 1e-7\n",
    "    return (p_y1 * np.log(p_y1 / p_y2)).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "de979555",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch_timeseries.nn.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer\n",
    "from torch_timeseries.nn.SelfAttention_Family import DSAttention, AttentionLayer\n",
    "from torch_timeseries.nn.embedding import DataEmbedding\n",
    "\n",
    "\n",
    "class Projector(nn.Module):\n",
    "    '''\n",
    "    MLP to learn the De-stationary factors\n",
    "    '''\n",
    "\n",
    "    def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):\n",
    "        super(Projector, self).__init__()\n",
    "\n",
    "        padding = 1 if torch.__version__ >= '1.5.0' else 2\n",
    "        self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding,\n",
    "                                     padding_mode='circular', bias=False)\n",
    "\n",
    "        layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]\n",
    "        for i in range(hidden_layers - 1):\n",
    "            layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1]), nn.ReLU()]\n",
    "\n",
    "        layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]\n",
    "        self.backbone = nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x, stats):\n",
    "        # x:     B x S x E\n",
    "        # stats: B x 1 x E\n",
    "        # y:     B x O\n",
    "        batch_size = x.shape[0]\n",
    "        x = self.series_conv(x)  # B x 1 x E\n",
    "        x = torch.cat([x, stats], dim=1)  # B x 2 x E\n",
    "        x = x.view(batch_size, -1)  # B x 2E\n",
    "        y = self.backbone(x)  # B x O\n",
    "\n",
    "        return y\n",
    "\n",
    "\n",
    "class Model(nn.Module):\n",
    "    \"\"\"\n",
    "    Non-stationary Transformer\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, configs):\n",
    "        super(Model, self).__init__()\n",
    "        self.pred_len = configs.pred_len\n",
    "        self.seq_len = configs.seq_len\n",
    "        self.label_len = configs.label_len\n",
    "        self.output_attention = configs.output_attention\n",
    "\n",
    "        # Embedding\n",
    "        self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,\n",
    "                                           configs.dropout)\n",
    "        self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,\n",
    "                                           configs.dropout)\n",
    "        # Encoder\n",
    "        self.encoder = Encoder(\n",
    "            [\n",
    "                EncoderLayer(\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(False, configs.factor, attention_dropout=configs.dropout,\n",
    "                                    output_attention=configs.output_attention), configs.d_model, configs.n_heads),\n",
    "                    configs.d_model,\n",
    "                    configs.d_ff,\n",
    "                    dropout=configs.dropout,\n",
    "                    activation=configs.activation\n",
    "                ) for l in range(configs.e_layers)\n",
    "            ],\n",
    "            norm_layer=torch.nn.LayerNorm(configs.d_model)\n",
    "        )\n",
    "        # Decoder\n",
    "        self.decoder = Decoder(\n",
    "            [\n",
    "                DecoderLayer(\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(True, configs.factor, attention_dropout=configs.dropout, output_attention=False),\n",
    "                        configs.d_model, configs.n_heads),\n",
    "                    AttentionLayer(\n",
    "                        DSAttention(False, configs.factor, attention_dropout=configs.dropout, output_attention=False),\n",
    "                        configs.d_model, configs.n_heads),\n",
    "                    configs.d_model,\n",
    "                    configs.d_ff,\n",
    "                    dropout=configs.dropout,\n",
    "                    activation=configs.activation,\n",
    "                )\n",
    "                for l in range(configs.d_layers)\n",
    "            ],\n",
    "            norm_layer=torch.nn.LayerNorm(configs.d_model),\n",
    "            projection=nn.Linear(configs.d_model, configs.c_out, bias=True)\n",
    "        )\n",
    "\n",
    "        self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims,\n",
    "                                     hidden_layers=configs.p_hidden_layers, output_dim=1)\n",
    "        self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len,\n",
    "                                       hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers,\n",
    "                                       output_dim=configs.seq_len)\n",
    "\n",
    "        self.z_mean = nn.Sequential(\n",
    "            nn.Linear(configs.d_model, configs.d_model),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(configs.d_model, configs.d_model)\n",
    "        )\n",
    "        self.z_logvar = nn.Sequential(\n",
    "            nn.Linear(configs.d_model, configs.d_model),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(configs.d_model, configs.d_model)\n",
    "        )\n",
    "\n",
    "        self.z_out = nn.Sequential(\n",
    "            nn.Linear(configs.d_model, configs.d_model),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(configs.d_model, configs.d_model)\n",
    "        )\n",
    "\n",
    "    def KL_loss_normal(self, posterior_mean, posterior_logvar):\n",
    "        KL = -0.5 * torch.mean(1 - posterior_mean ** 2 + posterior_logvar -\n",
    "                               torch.exp(posterior_logvar), dim=1)\n",
    "        return torch.mean(KL)\n",
    "\n",
    "    def reparameterize(self, posterior_mean, posterior_logvar):\n",
    "        posterior_var = posterior_logvar.exp()\n",
    "        # take sample\n",
    "        if self.training:\n",
    "            posterior_mean = posterior_mean.repeat(100, 1, 1, 1)\n",
    "            posterior_var = posterior_var.repeat(100, 1, 1, 1)\n",
    "            eps = torch.zeros_like(posterior_var).normal_()\n",
    "            z = posterior_mean + posterior_var.sqrt() * eps  # reparameterization\n",
    "            z = z.mean(0)\n",
    "        else:\n",
    "            z = posterior_mean\n",
    "        # z = posterior_mean\n",
    "        return z\n",
    "\n",
    "    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,\n",
    "                enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):\n",
    "\n",
    "        x_raw = x_enc.clone().detach()\n",
    "\n",
    "        # Normalization\n",
    "        mean_enc = x_enc.mean(1, keepdim=True).detach()  # B x 1 x E\n",
    "        x_enc = x_enc - mean_enc\n",
    "        std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()  # B x 1 x E\n",
    "        x_enc = x_enc / std_enc\n",
    "        x_dec_new = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_dec[:, -self.pred_len:, :])],\n",
    "                              dim=1).to(x_enc.device).clone()\n",
    "\n",
    "        tau = self.tau_learner(x_raw, std_enc).exp()  # B x S x E, B x 1 x E -> B x 1, positive scalar\n",
    "        delta = self.delta_learner(x_raw, mean_enc)  # B x S x E, B x 1 x E -> B x S\n",
    "\n",
    "        # Model Inference\n",
    "        enc_out = self.enc_embedding(x_enc, x_mark_enc)\n",
    "        enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask, tau=tau, delta=delta)\n",
    "\n",
    "        mean = self.z_mean(enc_out)\n",
    "        logvar = self.z_logvar(enc_out)\n",
    "\n",
    "        z_sample = self.reparameterize(mean, logvar)\n",
    "\n",
    "        # dec_out = self.z_out(torch.cat([z_sample, dec_out], dim=-1))\n",
    "        enc_out = self.z_out(z_sample)\n",
    "\n",
    "        KL_z = self.KL_loss_normal(mean, logvar)\n",
    "\n",
    "        dec_out = self.dec_embedding(x_dec_new, x_mark_dec)\n",
    "        dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask, tau=tau, delta=delta)\n",
    "\n",
    "        # De-normalization\n",
    "        # dec_out = dec_out * std_enc + mean_enc ###################################### please must # it\n",
    "\n",
    "        if self.output_attention:\n",
    "            return dec_out[:, -self.pred_len:, :], attns\n",
    "        else:\n",
    "            return dec_out[:, -self.pred_len:, :], dec_out, KL_z, z_sample  # [B, L, D]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5343846c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConditionalLinear(nn.Module):\n",
    "    def __init__(self, num_in, num_out, n_steps):\n",
    "        super(ConditionalLinear, self).__init__()\n",
    "        self.num_out = num_out\n",
    "        self.lin = nn.Linear(num_in, num_out)\n",
    "        self.embed = nn.Embedding(n_steps, num_out)\n",
    "        self.embed.weight.data.uniform_()\n",
    "\n",
    "    def forward(self, x, t):\n",
    "        # x: (B, T, num_in) or (B, num_in)\n",
    "        # t: (B, ) or (B, 1)\n",
    "        out = self.lin(x)\n",
    "        gamma = self.embed(t)\n",
    "        out = gamma.view(t.size()[0], -1, self.num_out) * out\n",
    "        return out # (B, T, num_out) or (B, num_out)\n",
    "\n",
    "class ConditionalGuidedModel(nn.Module):\n",
    "    def __init__(self, config, MTS_args):\n",
    "        super(ConditionalGuidedModel, self).__init__()\n",
    "        n_steps = config.diffusion.timesteps + 1\n",
    "        self.cat_x = config.model.cat_x \n",
    "        self.cat_y_pred = config.model.cat_y_pred \n",
    "        \n",
    "        data_dim = MTS_args.enc_in * 2 if self.cat_y_pred else MTS_args.enc_in\n",
    "\n",
    "        self.lin1 = ConditionalLinear(data_dim, 128, n_steps)\n",
    "        self.lin2 = ConditionalLinear(128, 128, n_steps)\n",
    "        self.lin3 = ConditionalLinear(128, 128, n_steps)\n",
    "        self.lin4 = nn.Linear(128, MTS_args.enc_in)\n",
    "\n",
    "    def forward(self, x, y_t, y_0_hat, t):\n",
    "        # x/y_t/y_0_hat: (B,T,N)\n",
    "        # t:(B,)\n",
    "        if self.cat_x:\n",
    "            if self.cat_y_pred:\n",
    "                eps_pred = torch.cat((y_t, y_0_hat), dim=-1) \n",
    "            else:\n",
    "                eps_pred = torch.cat((y_t, x), dim=2) \n",
    "        else:\n",
    "            if self.cat_y_pred:\n",
    "                eps_pred = torch.cat((y_t, y_0_hat), dim=2)\n",
    "            else:\n",
    "                eps_pred = y_t\n",
    "        if y_t.device.type == 'mps':\n",
    "            eps_pred = self.lin1(eps_pred, t)\n",
    "            eps_pred = F.softplus(eps_pred.cpu()).to(y_t.device)\n",
    "\n",
    "            eps_pred = self.lin2(eps_pred, t)\n",
    "            eps_pred = F.softplus(eps_pred.cpu()).to(y_t.device)\n",
    "\n",
    "            eps_pred = self.lin3(eps_pred, t)\n",
    "            eps_pred = F.softplus(eps_pred.cpu()).to(y_t.device)\n",
    "\n",
    "        else:\n",
    "            eps_pred = F.softplus(self.lin1(eps_pred, t))\n",
    "            eps_pred = F.softplus(self.lin2(eps_pred, t))\n",
    "            eps_pred = F.softplus(self.lin3(eps_pred, t))\n",
    "        eps_pred = self.lin4(eps_pred)\n",
    "        return eps_pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "675a00f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "class DeterministicFeedForwardNeuralNetwork(nn.Module):\n",
    "\n",
    "    def __init__(self, dim_in, dim_out, hid_layers,\n",
    "                 use_batchnorm=False, negative_slope=0.01, dropout_rate=0):\n",
    "        super(DeterministicFeedForwardNeuralNetwork, self).__init__()\n",
    "        self.dim_in = dim_in  # dimension of nn input\n",
    "        self.dim_out = dim_out  # dimension of nn output\n",
    "        self.hid_layers = hid_layers  # nn hidden layer architecture\n",
    "        self.nn_layers = [self.dim_in] + self.hid_layers  # nn hidden layer architecture, except output layer\n",
    "        self.use_batchnorm = use_batchnorm  # whether apply batch norm\n",
    "        self.negative_slope = negative_slope  # negative slope for LeakyReLU\n",
    "        self.dropout_rate = dropout_rate\n",
    "        layers = self.create_nn_layers()\n",
    "        self.network = nn.Sequential(*layers)\n",
    "\n",
    "    def create_nn_layers(self):\n",
    "        layers = []\n",
    "        for idx in range(len(self.nn_layers) - 1):\n",
    "            layers.append(nn.Linear(self.nn_layers[idx], self.nn_layers[idx + 1]))\n",
    "            if self.use_batchnorm:\n",
    "                layers.append(nn.BatchNorm1d(self.nn_layers[idx + 1]))\n",
    "            layers.append(nn.LeakyReLU(negative_slope=self.negative_slope))\n",
    "            layers.append(nn.Dropout(p=self.dropout_rate))\n",
    "        layers.append(nn.Linear(self.nn_layers[-1], self.dim_out))\n",
    "        return layers\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.network(x)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1be65a27",
   "metadata": {},
   "source": [
    "#### TMDM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "e92c98f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TMDM(nn.Module):\n",
    "    \"\"\"\n",
    "    Vanilla Transformer\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, configs, device):\n",
    "        super(TMDM, self).__init__()\n",
    "\n",
    "        with open(configs.diffusion_config_dir, \"r\") as f:\n",
    "            config = yaml.unsafe_load(f)\n",
    "            diffusion_config = dict2namespace(config)\n",
    "\n",
    "        diffusion_config.diffusion.timesteps = configs.timesteps\n",
    "        \n",
    "        self.args = configs\n",
    "        self.device = device\n",
    "        self.diffusion_config = diffusion_config\n",
    "\n",
    "        self.model_var_type = diffusion_config.model.var_type\n",
    "        self.num_timesteps = diffusion_config.diffusion.timesteps\n",
    "        self.vis_step = diffusion_config.diffusion.vis_step\n",
    "        self.num_figs = diffusion_config.diffusion.num_figs\n",
    "        self.dataset_object = None\n",
    "\n",
    "        betas = make_beta_schedule(schedule=diffusion_config.diffusion.beta_schedule, num_timesteps=self.num_timesteps,\n",
    "                                   start=diffusion_config.diffusion.beta_start, end=diffusion_config.diffusion.beta_end)\n",
    "        betas = self.betas = betas.float().to(self.device)\n",
    "        self.betas_sqrt = torch.sqrt(betas)\n",
    "        alphas = 1.0 - betas\n",
    "        self.alphas = alphas\n",
    "        self.one_minus_betas_sqrt = torch.sqrt(alphas)\n",
    "        alphas_cumprod = alphas.to('cpu').cumprod(dim=0).to(self.device)\n",
    "        self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod)\n",
    "        self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod)\n",
    "        if diffusion_config.diffusion.beta_schedule == \"cosine\":\n",
    "            self.one_minus_alphas_bar_sqrt *= 0.9999  # avoid division by 0 for 1/sqrt(alpha_bar_t) during inference\n",
    "        alphas_cumprod_prev = torch.cat(\n",
    "            [torch.ones(1, device=self.device), alphas_cumprod[:-1]], dim=0\n",
    "        )\n",
    "        self.alphas_cumprod_prev = alphas_cumprod_prev\n",
    "        self.posterior_mean_coeff_1 = (\n",
    "                betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)\n",
    "        )\n",
    "        self.posterior_mean_coeff_2 = (\n",
    "                torch.sqrt(alphas) * (1 - alphas_cumprod_prev) / (1 - alphas_cumprod)\n",
    "        )\n",
    "        posterior_variance = (\n",
    "                betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)\n",
    "        )\n",
    "        self.posterior_variance = posterior_variance\n",
    "        if self.model_var_type == \"fixedlarge\":\n",
    "            self.logvar = betas.log()\n",
    "            # torch.cat(\n",
    "            # [posterior_variance[1:2], betas[1:]], dim=0).log()\n",
    "        elif self.model_var_type == \"fixedsmall\":\n",
    "            self.logvar = posterior_variance.clamp(min=1e-20).log()\n",
    "\n",
    "        self.tau = None  # precision fo test NLL computation\n",
    "\n",
    "        # CATE MLP\n",
    "        self.diffussion_model = ConditionalGuidedModel(diffusion_config, self.args)\n",
    "\n",
    "        self.enc_embedding = DataEmbedding(configs.enc_in, configs.CART_input_x_embed_dim, configs.embed, configs.freq,\n",
    "                                           configs.dropout)\n",
    "\n",
    "\n",
    "    def forward(self, x, x_mark, y, y_t, y_0_hat, t):\n",
    "        enc_out = self.enc_embedding(x, x_mark) #  B, T, d_model\n",
    "        dec_out = self.diffussion_model(enc_out, y_t, y_0_hat, t)\n",
    "\n",
    "        return dec_out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "542ca1eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import asdict, dataclass\n",
    "import datetime\n",
    "import hashlib\n",
    "import json\n",
    "import os\n",
    "import random\n",
    "import time\n",
    "from typing import Dict, List, Type, Union\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from tqdm import tqdm\n",
    "from torch.nn import MSELoss, L1Loss\n",
    "from torch.optim import *\n",
    "from torch_timeseries.dataset import *\n",
    "from torch_timeseries.scaler import *\n",
    "\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "from torch_timeseries.core import TimeSeriesDataset, BaseIrrelevant, BaseRelevant\n",
    "from torch_timeseries.dataloader import SlidingWindowTS, ETTHLoader, ETTMLoader\n",
    "from torch_timeseries.experiments import ForecastExp\n",
    "from torch_timeseries.utils import asdict_exc\n",
    "import torch.multiprocessing as mp\n",
    "\n",
    "\n",
    "def update_metrics(preds, truths, metrics):\n",
    "    \"\"\"Function to update metrics in a separate process.\"\"\"\n",
    "    metrics.update(preds, truths)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class ProbForecastExp(ForecastExp):\n",
    "    loss_func_type : str = 'mse'\n",
    "    epochs : int = 10\n",
    "    \n",
    "    def _init_metrics(self):\n",
    "        self.metrics = MetricCollection(\n",
    "            metrics={\n",
    "                \"crps\": CRPS(),\n",
    "                \"crps_sum\": CRPSSum(),\n",
    "                \"qice\": QICE(),\n",
    "                \"picp\": PICP(),\n",
    "                \"mse\": ProbMSE(),\n",
    "                \"mae\":ProbMAE(),\n",
    "                \"rmse\": ProbRMSE(),\n",
    "            }\n",
    "        )\n",
    "        self.metrics.to(\"cpu\")\n",
    "        ctx = mp.get_context(\"spawn\")  # Options: 'fork', 'spawn', 'forkserver'\n",
    "        self.task_pool = ctx.Pool(processes=32)\n",
    "\n",
    "    def _init_dataset(self):\n",
    "        self.dataset = parse_type(self.dataset_type, globals())(\n",
    "            root=self.data_path\n",
    "        ) \n",
    "\n",
    "    def _train(self):\n",
    "        with torch.enable_grad(), tqdm(total=len(self.train_loader.dataset)) as progress_bar:\n",
    "            self.model.train()\n",
    "            train_loss = []\n",
    "            for i, (\n",
    "                batch_x,\n",
    "                batch_y,\n",
    "                origin_x,\n",
    "                origin_y,\n",
    "                batch_x_date_enc,\n",
    "                batch_y_date_enc,\n",
    "            ) in enumerate(self.train_loader):\n",
    "                origin_y = origin_y.to(self.device).float()\n",
    "                # batch_x = batch_x.to(self.device).float()\n",
    "                batch_x = future_cw_train[i]['batch_history'].to(self.device).float()\n",
    "                # batch_y = batch_y.to(self.device).float()\n",
    "                batch_y = future_cw_train[i]['batch_future_cw'].to(device).float() ################# change here!\n",
    "                batch_x_date_enc = future_cw_train[i]['batch_history_mark'].to(device).float()\n",
    "                batch_y_date_enc = future_cw_train[i]['batch_future_mark'].to(device).float()\n",
    "                self.model_optim.zero_grad()\n",
    "                pred, true = self._process_train_batch(\n",
    "                    batch_x, batch_y, batch_x_date_enc, batch_y_date_enc\n",
    "                )\n",
    "                print(pred.shape, true.shape)\n",
    "                if self.invtrans_loss:\n",
    "                    pred = self.scaler.inverse_transform(pred)\n",
    "                    true = origin_y\n",
    "                loss = self.loss_func(pred, true)\n",
    "                loss.backward()\n",
    "\n",
    "                torch.nn.utils.clip_grad_norm_(\n",
    "                    self.model.parameters(), self.max_grad_norm\n",
    "                )\n",
    "                \n",
    "                progress_bar.update(batch_x.size(0))\n",
    "                \n",
    "                train_loss.append(loss.item())\n",
    "                progress_bar.set_postfix(\n",
    "                    loss=loss.item(),\n",
    "                    lr=self.model_optim.param_groups[0][\"lr\"],\n",
    "                    epoch=self.current_epoch,\n",
    "                    refresh=True,\n",
    "                )\n",
    "                self.model_optim.step()\n",
    "\n",
    "            return train_loss\n",
    "    \n",
    "    def _process_train_batch(\n",
    "        self,\n",
    "        batch_x,\n",
    "        batch_y,\n",
    "        batch_origin_x,\n",
    "        batch_origin_y,\n",
    "        batch_x_date_enc,\n",
    "        batch_y_date_enc,\n",
    "    ):\n",
    "        # inputs:\n",
    "        # batch_x:  (B, T, N)\n",
    "        # batch_y:  (B, Steps,T)\n",
    "        # batch_x_date_enc:  (B, T, N)\n",
    "        # batch_y_date_enc:  (B, T, Steps)\n",
    "\n",
    "        # outputs:\n",
    "        # pred: (B, O, N)\n",
    "        # label:  (B,O,N)\n",
    "        # for single step you should output (B, N)\n",
    "        # for multiple steps you should output (B, O, N)\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    # batch\n",
    "    def _process_val_batch(\n",
    "        self,\n",
    "        batch_x,\n",
    "        batch_origin_x,\n",
    "        batch_x_date_enc,\n",
    "        batch_y_date_enc,\n",
    "    ):\n",
    "        # inputs:\n",
    "        # batch_x:  (B, T, N)\n",
    "        # batch_y:  (B, Steps,T)\n",
    "        # batch_x_date_enc:  (B, T, N)\n",
    "        # batch_y_date_enc:  (B, T, Steps)\n",
    "\n",
    "        # outputs:\n",
    "        # pred: (B, O, N)\n",
    "        # label:  (B,O,N)\n",
    "        # for single step you should output (B, N)\n",
    "        # for multiple steps you should output (B, O, N)\n",
    "        raise NotImplementedError()\n",
    "\n",
    "    def _evaluate(self, dataloader, out_fid = False):\n",
    "        self.model.eval()\n",
    "        self.metrics.reset()\n",
    "        results = []\n",
    "        with tqdm(total=len(dataloader.dataset)) as progress_bar:\n",
    "            for batch_x, batch_y, origin_x, origin_y, batch_x_date_enc, batch_y_date_enc in dataloader:\n",
    "                batch_size = batch_x.size(0)\n",
    "                origin_y = origin_y.to(self.device).float()\n",
    "                batch_x = batch_x.to(self.device).float()\n",
    "                batch_x = future_cw_val[0]['batch_history'].to(self.device).float()\n",
    "                # batch_y = batch_y.to(self.device).float()\n",
    "                batch_y = future_cw_val[0]['batch_future_cw'].to(device).float() ################# change here!\n",
    "                batch_x_date_enc = future_cw_val[0]['batch_history_mark'].to(device).float()\n",
    "                batch_y_date_enc = future_cw_val[0]['batch_future_mark'].to(device).float()\n",
    "                preds, truths = self._process_val_batch(\n",
    "                    batch_x, batch_y, batch_x_date_enc, batch_y_date_enc\n",
    "                )\n",
    "                origin_y = origin_y.to(self.device)\n",
    "                if self.invtrans_loss:\n",
    "                    preds = self.scaler.inverse_transform(preds)\n",
    "                    truths = origin_y\n",
    "                    \n",
    "                self.metrics.update(preds.contiguous().cpu().detach(), truths.contiguous().cpu().detach())\n",
    "                \n",
    "                progress_bar.update(batch_x.shape[0])\n",
    "\n",
    "        result = {name: float(metric.compute()) for name, metric in self.metrics.items()}\n",
    "        return result\n",
    "    \n",
    "    \n",
    "    def _init_data_loader(self, shuffle=True, fast_test=True, fast_val=True):\n",
    "        \n",
    "        self._init_dataset()\n",
    "        \n",
    "        self.scaler = parse_type(self.scaler_type, globals=globals())()\n",
    "        if self.dataset_type[0:3] == \"ETT\":\n",
    "            if self.dataset_type[0:4] == \"ETTh\":\n",
    "                self.dataloader = ETTHLoader(\n",
    "                    self.dataset,\n",
    "                    self.scaler,\n",
    "                    window=self.windows,\n",
    "                    horizon=self.horizon,\n",
    "                    steps=self.pred_len,\n",
    "                    shuffle_train=shuffle,\n",
    "                    freq=self.dataset.freq,\n",
    "                    batch_size=self.batch_size,\n",
    "                    num_worker=self.num_worker,\n",
    "                    fast_test=fast_test,\n",
    "                    fast_val=fast_val,\n",
    "                )\n",
    "            elif  self.dataset_type[0:4] == \"ETTm\":\n",
    "                self.dataloader = ETTMLoader(\n",
    "                    self.dataset,\n",
    "                    self.scaler,\n",
    "                    window=self.windows,\n",
    "                    horizon=self.horizon,\n",
    "                    steps=self.pred_len,\n",
    "                    shuffle_train=shuffle,\n",
    "                    freq=self.dataset.freq,\n",
    "                    batch_size=self.batch_size,\n",
    "                    num_worker=self.num_worker,\n",
    "                    fast_test=fast_test,\n",
    "                    fast_val=fast_val,\n",
    "                )\n",
    "        else:\n",
    "            self.dataloader = SlidingWindowTS(\n",
    "                self.dataset,\n",
    "                self.scaler,\n",
    "                window=self.windows,\n",
    "                horizon=self.horizon,\n",
    "                steps=self.pred_len,\n",
    "                scale_in_train=True,\n",
    "                shuffle_train=shuffle,\n",
    "                freq=self.dataset.freq,\n",
    "                batch_size=self.batch_size,\n",
    "                train_ratio=self.train_ratio,\n",
    "                test_ratio=self.test_ratio,\n",
    "                num_worker=self.num_worker,\n",
    "                fast_test=fast_test,\n",
    "                fast_val=fast_val,\n",
    "            )\n",
    "\n",
    "        self.train_loader, self.val_loader, self.test_loader = (\n",
    "            self.dataloader.train_loader,\n",
    "            self.dataloader.val_loader,\n",
    "            self.dataloader.test_loader,\n",
    "        )\n",
    "        self.train_steps = len(self.train_loader.dataset)\n",
    "        self.val_steps = len(self.val_loader.dataset)\n",
    "        self.test_steps = len(self.test_loader.dataset)\n",
    "\n",
    "        print(f\"train steps: {self.train_steps}\")\n",
    "        print(f\"val steps: {self.val_steps}\")\n",
    "        print(f\"test steps: {self.test_steps}\")\n",
    "        \n",
    "\n",
    "    def _test(self) -> Dict[str, float]:\n",
    "        print(\"Testing .... \")\n",
    "        test_result = self._evaluate(self.test_loader,out_fid = True)\n",
    "\n",
    "        self._run_print(f\"test_results: {test_result}\")\n",
    "        return test_result\n",
    "\n",
    "    def _val(self):\n",
    "        print(\"Validating .... \")\n",
    "        val_result = self._evaluate(self.val_loader)\n",
    "\n",
    "        self._run_print(f\"vali_results: {val_result}\")\n",
    "        return val_result\n",
    "\n",
    "    def _check_run_exist(self, seed: str):\n",
    "        if not os.path.exists(self.run_save_dir):\n",
    "            os.makedirs(self.run_save_dir)\n",
    "            print(f\"Creating running results saving dir: '{self.run_save_dir}'.\")\n",
    "        else:\n",
    "            print(f\"result directory exists: {self.run_save_dir}\")\n",
    "        with open(\n",
    "            os.path.join(self.run_save_dir, \"args.json\"), \"w\", encoding=\"utf-8\"\n",
    "        ) as f:\n",
    "            json.dump(asdict(self), f, ensure_ascii=False, indent=4)\n",
    "\n",
    "        exists = os.path.exists(self.run_checkpoint_filepath)\n",
    "        return exists\n",
    "\n",
    "    def _load_best_model(self):\n",
    "        self.model.load_state_dict(\n",
    "            torch.load(self.best_checkpoint_filepath, map_location=self.device)\n",
    "        )\n",
    "\n",
    "    def _run_print(self, *args, **kwargs):\n",
    "        time = (\n",
    "            \"[\"\n",
    "            + str(datetime.datetime.now() + datetime.timedelta(hours=8))[:19]\n",
    "            + \"] -\"\n",
    "        )\n",
    "        print(*args, **kwargs)\n",
    "        with open(os.path.join(self.run_save_dir, \"output.log\"), \"a+\") as f:\n",
    "            print(time, *args, flush=True, file=f)\n",
    "\n",
    "    def _resume_run(self, seed):\n",
    "        # only train loader rshould be checkedpoint to keep the validation and test consistency\n",
    "        run_checkpoint_filepath = os.path.join(self.run_save_dir, f\"run_checkpoint.pth\")\n",
    "        print(f\"resuming from {run_checkpoint_filepath}\")\n",
    "\n",
    "        check_point = torch.load(run_checkpoint_filepath, map_location=self.device)\n",
    "\n",
    "        self.model.load_state_dict(check_point[\"model\"])\n",
    "        self.model_optim.load_state_dict(check_point[\"optimizer\"])\n",
    "        self.current_epoch = check_point[\"current_epoch\"]\n",
    "\n",
    "        self.early_stopper.set_state(check_point[\"early_stopping\"])\n",
    "\n",
    "    def _use_wandb(self):\n",
    "        return hasattr(self, \"wandb\")\n",
    "\n",
    "    def run(self, seed=42) -> Dict[str, float]:\n",
    "        \n",
    "        if self._use_wandb() and not self._init_wandb(self.project, seed): return {}\n",
    "        \n",
    "        self._setup_run(seed)\n",
    "        if self._check_run_exist(seed):\n",
    "            self._resume_run(seed)\n",
    "\n",
    "        self._run_print(f\"run : {self.current_run} in seed: {seed}\")\n",
    "\n",
    "        parameter_tables, model_parameters_num = count_parameters(self.model)\n",
    "        self._run_print(f\"parameter_tables: {parameter_tables}\")\n",
    "        self._run_print(f\"model parameters: {model_parameters_num}\")\n",
    "\n",
    "        if self._use_wandb():\n",
    "            wandb.run.summary[\"parameters\"] = model_parameters_num\n",
    "\n",
    "        # for resumable reproducibility_\n",
    "        while self.current_epoch < self.epochs:\n",
    "            epoch_start_time = time.time()\n",
    "            if self.early_stopper.early_stop is True:\n",
    "                self._run_print(\n",
    "                    f\"val loss no decreased for patience={self.patience} epochs,  early stopping ....\"\n",
    "                )\n",
    "                break\n",
    "\n",
    "            # for resumable reproducibility\n",
    "            reproducible(seed + self.current_epoch)\n",
    "            train_losses = self._train()\n",
    "            self._run_print(\n",
    "                \"Epoch: {} cost time: {}s\".format(\n",
    "                    self.current_epoch + 1, time.time() - epoch_start_time\n",
    "                )\n",
    "            )\n",
    "            self._run_print(f\"Traininng loss : {np.mean(train_losses)}\")\n",
    "\n",
    "            val_result = self._val()\n",
    "            test_result = self._test()\n",
    "\n",
    "            self.current_epoch = self.current_epoch + 1\n",
    "            self.early_stopper(val_result['crps'], model=self.model)\n",
    "\n",
    "            self._save_run_check_point(seed)\n",
    "\n",
    "            if self._use_wandb():\n",
    "                wandb.log({'training_loss' : np.mean(train_losses)}, step=self.current_epoch)\n",
    "                wandb.log( {f\"val_{k}\": v for k, v in val_result.items()}, step=self.current_epoch)\n",
    "                wandb.log( {f\"test_{k}\": v for k, v in test_result.items()}, step=self.current_epoch)\n",
    "\n",
    "            # self.scheduler.step()\n",
    "\n",
    "        self._load_best_model()\n",
    "        best_test_result = self._test()\n",
    "        if self._use_wandb():\n",
    "            for k, v in best_test_result.items(): wandb.run.summary[f\"best_test_{k}\"] = v \n",
    "        \n",
    "        if self._use_wandb():  wandb.finish()\n",
    "        return best_test_result\n",
    "    \n",
    "    def runs(self, seeds: List[int] = [1, 2, 3, 4, 5]):\n",
    "        results = []\n",
    "        for i, seed in enumerate(seeds):\n",
    "            result = self.run(seed=seed)\n",
    "            results.append(result)\n",
    "\n",
    "        return results\n",
    "\n",
    "    def _save_run_check_point(self, seed):\n",
    "        if not os.path.exists(self.run_save_dir):\n",
    "            os.makedirs(self.run_save_dir)\n",
    "        print(f\"Saving run checkpoint to '{self.run_save_dir}'.\")\n",
    "\n",
    "        self.run_state = {\n",
    "            \"model\": self.model.state_dict(),\n",
    "            \"current_epoch\": self.current_epoch,\n",
    "            \"optimizer\": self.model_optim.state_dict(),\n",
    "            \"rng_state\": torch.get_rng_state(),\n",
    "            \"early_stopping\": self.early_stopper.get_state(),\n",
    "        }\n",
    "\n",
    "        torch.save(self.run_state, f\"{self.run_checkpoint_filepath}\")\n",
    "        print(\"Run state saved ... \")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3876b11d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataclasses import dataclass, field\n",
    "import sys\n",
    "from typing import List, Dict\n",
    "import os\n",
    "import wandb\n",
    "import torch\n",
    "from dataclasses import dataclass, asdict, field\n",
    "from torch_timeseries.nn.embedding import freq_map\n",
    "from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection\n",
    "from torch.optim import *\n",
    "from tqdm import tqdm\n",
    "from torch_timeseries.utils.model_stats import count_parameters\n",
    "from torch_timeseries.utils.reproduce import reproducible\n",
    "import time\n",
    "# import multiprocessing\n",
    "import torch.multiprocessing as mp\n",
    "from torch_timeseries.utils.parse_type import parse_type\n",
    "\n",
    "from torch_timeseries.utils.early_stop import EarlyStopping\n",
    "\n",
    "import numpy as np\n",
    "import torch.distributed as dist\n",
    "import torch\n",
    "from tqdm import tqdm\n",
    "import concurrent.futures\n",
    "from types import SimpleNamespace\n",
    "\n",
    "\n",
    "\n",
    "class TMDMEarlyStopping(EarlyStopping):\n",
    "    def save_checkpoint(self, val_loss, model):\n",
    "        \"\"\"Saves model when validation loss decrease.\"\"\"\n",
    "        if self.verbose:\n",
    "            self.trace_func(\n",
    "                f\"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...\"\n",
    "            )\n",
    "        torch.save(model['model'].state_dict(), os.path.join(self.path, 'model.pth'))\n",
    "        torch.save(model['cond_pred_model'].state_dict(),os.path.join(self.path, 'cond_pred_model.pth'))\n",
    "        self.val_loss_min = val_loss\n",
    "        \n",
    "        \n",
    "        \n",
    "def log_normal(x, mu, var):\n",
    "    \"\"\"Logarithm of normal distribution with mean=mu and variance=var\n",
    "       log(x|μ, σ^2) = loss = -0.5 * Σ log(2π) + log(σ^2) + ((x - μ)/σ)^2\n",
    "\n",
    "    Args:\n",
    "       x: (array) corresponding array containing the input\n",
    "       mu: (array) corresponding array containing the mean\n",
    "       var: (array) corresponding array containing the variance\n",
    "\n",
    "    Returns:\n",
    "       output: (array/float) depending on average parameters the result will be the mean\n",
    "                            of all the sample losses or an array with the losses per sample\n",
    "    \"\"\"\n",
    "    eps = 1e-8\n",
    "    if eps > 0.0:\n",
    "        var = var + eps\n",
    "    return 0.5 * torch.mean(\n",
    "        np.log(2.0 * np.pi) + torch.log(var) + torch.pow(x - mu, 2) / var)\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TMDMParameters:\n",
    "\n",
    "    beta_start: float =  0.0001\n",
    "    beta_end: float =  0.5\n",
    "    d_model: int =  512\n",
    "    n_heads: int =  8\n",
    "    e_layers: int =  2\n",
    "    d_layers: int =  1\n",
    "    d_ff: int =  1024\n",
    "    diffusion_steps :int = 100\n",
    "    moving_avg: int =  25\n",
    "    factor: int =  3\n",
    "    distil: bool =  True\n",
    "    dropout: float =  0.05\n",
    "    activation: str = 'gelu'\n",
    "    k_z: float =  1e-2\n",
    "    k_cond: int =  1\n",
    "    d_z: int =  8\n",
    "    CART_input_x_embed_dim : int= 32\n",
    "    p_hidden_layers : int = 2\n",
    "\n",
    "@dataclass\n",
    "class TMDMForecast(ProbForecastExp, TMDMParameters):\n",
    "    model_type: str = \"TMDM_cw\"\n",
    "    def _init_model(self):\n",
    "        \n",
    "        self.label_len = self.windows // 2\n",
    "        args_dict = {\n",
    "            \"seq_len\": self.windows,\n",
    "            \"device\": self.device,\n",
    "            \"pred_len\": self.pred_len,\n",
    "            \"label_len\": self.label_len,\n",
    "            \"features\" : 'M',\n",
    "            \n",
    "            \"beta_start\": self.beta_start,\n",
    "            \"beta_end\": self.beta_end,\n",
    "\n",
    "            \"enc_in\" : self.dataset.num_features,\n",
    "            \"dec_in\" : self.dataset.num_features,\n",
    "            \"c_out\" : self.dataset.num_features,\n",
    "            \"d_model\" : self.d_model,\n",
    "            \"n_heads\" : self.n_heads,\n",
    "            \"e_layers\" : self.e_layers,\n",
    "            \"d_layers\" : self.d_layers,\n",
    "            \"d_ff\" : self.d_ff,\n",
    "            \"moving_avg\" : self.moving_avg,\n",
    "            \"timesteps\" : self.diffusion_steps,\n",
    "            \"factor\" : self.factor,\n",
    "            \"distil\" : self.distil,\n",
    "            \"embed\" : 'timeF',\n",
    "            \"dropout\" :self.dropout,\n",
    "            \"activation\" :self.activation,\n",
    "            \"output_attention\" : False,\n",
    "            \"do_predict\" :True,\n",
    "            \"k_z\" :self.k_z,\n",
    "            \"k_cond\" :self.k_cond,\n",
    "            \"p_hidden_dims\" : [64, 64],\n",
    "            \"freq\" :self.dataset.freq,\n",
    "            \"CART_input_x_embed_dim\" : self.CART_input_x_embed_dim,\n",
    "            \"p_hidden_layers\" : self.p_hidden_layers,\n",
    "            \"d_z\" :self.d_z,\n",
    "            \"diffusion_config_dir\" : \"NsDiff-main/configs/tmdm.yml\",\n",
    "        }\n",
    "\n",
    "        self.args = SimpleNamespace(**args_dict)\n",
    "        \n",
    "        self.model = TMDM(self.args, self.device).to(self.device)\n",
    "        print(sum(p.numel() for p in self.model.parameters()))\n",
    "        self.cond_pred_model = Model(self.args).float().to(self.device)\n",
    "\n",
    "\n",
    "    def _init_optimizer(self):\n",
    "        self.model_optim = parse_type(self.optm_type, globals=globals())(\n",
    "            [{'params': self.model.parameters()}, {'params': self.cond_pred_model.parameters()}], \n",
    "            lr=self.lr, \n",
    "        )\n",
    "\n",
    "\n",
    "    def _load_best_model(self):\n",
    "        self.model.load_state_dict(\n",
    "            torch.load(os.path.join(self.run_save_dir, 'model.pth'), map_location=self.device)\n",
    "        )\n",
    "        self.cond_pred_mode.load_state_dict(\n",
    "            torch.load(os.path.join(self.run_save_dir, 'cond_pred_model.pth'), map_location=self.device)\n",
    "        )\n",
    "\n",
    "    def _setup_early_stopper(self):\n",
    "        self.best_checkpoint_filepath = os.path.join(\n",
    "            self.run_save_dir, \"model.pth\"\n",
    "        )\n",
    "        self.best_cond_checkpoint_filepath = os.path.join(\n",
    "            self.run_save_dir, \"cond_pred_model.pth\"\n",
    "        )\n",
    "        self.early_stopper = TMDMEarlyStopping(\n",
    "            self.patience, verbose=True, path=self.run_save_dir\n",
    "        )\n",
    "\n",
    "\n",
    "    def _save_run_check_point(self, seed):\n",
    "\n",
    "\n",
    "        if not os.path.exists(self.run_save_dir):\n",
    "            os.makedirs(self.run_save_dir)\n",
    "        \n",
    "        \n",
    "        print(f\"Saving run checkpoint to '{self.run_save_dir}'.\")\n",
    "\n",
    "        self.run_state = {\n",
    "            \"model\": self.model.state_dict(),\n",
    "            \"cond_pred_model\": self.cond_pred_model.state_dict(),\n",
    "            \"current_epoch\": self.current_epoch,\n",
    "            \"optimizer\": self.model_optim.state_dict(),\n",
    "            \"rng_state\": torch.get_rng_state(),\n",
    "            \"early_stopping\": self.early_stopper.get_state(),\n",
    "        }\n",
    "\n",
    "        torch.save(self.run_state, f\"{self.run_checkpoint_filepath}\")\n",
    "        print(\"Run state saved ... \")\n",
    "\n",
    "    def _load_best_model(self):\n",
    "        self.model.load_state_dict(\n",
    "            torch.load(self.best_checkpoint_filepath, map_location=self.device)\n",
    "        )\n",
    "        self.cond_pred_model.load_state_dict(\n",
    "            torch.load(self.best_cond_checkpoint_filepath, map_location=self.device)\n",
    "        )\n",
    "\n",
    "\n",
    "    def _resume_run(self, seed):\n",
    "        # only train loader rshould be checkedpoint to keep the validation and test consistency\n",
    "        check_point = torch.load(self.run_checkpoint_filepath, map_location=self.device)\n",
    "\n",
    "        self.model.load_state_dict(check_point[\"model\"])\n",
    "        self.cond_pred_model.load_state_dict(check_point[\"cond_pred_model\"])\n",
    "        self.model_optim.load_state_dict(check_point[\"optimizer\"])\n",
    "        self.current_epoch = check_point[\"current_epoch\"]\n",
    "\n",
    "        self.early_stopper.set_state(check_point[\"early_stopping\"])\n",
    "\n",
    "    def _train(self):\n",
    "        self.model.train()\n",
    "        self.cond_pred_model.train()\n",
    "\n",
    "        with torch.enable_grad(), tqdm(total=len(self.train_loader.dataset)) as progress_bar:\n",
    "            train_loss = []\n",
    "            for i, (\n",
    "                batch_x,\n",
    "                batch_y,\n",
    "                origin_x,\n",
    "                origin_y,\n",
    "                batch_x_date_enc,\n",
    "                batch_y_date_enc,\n",
    "            ) in enumerate(self.train_loader):\n",
    "                origin_y = origin_y.to(self.device).float()\n",
    "                # batch_x = batch_x.to(self.device).float()\n",
    "                batch_x = future_cw_train[i]['batch_history'].to(self.device).float()\n",
    "                # batch_y = batch_y.to(self.device).float()\n",
    "                batch_y = future_cw_train[i]['batch_future_cw'].to(device).float() ################# change here!\n",
    "                batch_x_date_enc = future_cw_train[i]['batch_history_mark'].to(device).float()\n",
    "                batch_y_date_enc = future_cw_train[i]['batch_future_mark'].to(device).float()\n",
    "                loss = self._process_train_batch(\n",
    "                    batch_x, batch_y, batch_x_date_enc, batch_y_date_enc\n",
    "                )\n",
    "                if self.invtrans_loss:\n",
    "                    pred = self.scaler.inverse_transform(pred)\n",
    "                    true = origin_y\n",
    "                loss.backward()\n",
    "\n",
    "                progress_bar.update(batch_x.size(0))\n",
    "                \n",
    "                train_loss.append(loss.item())\n",
    "                progress_bar.set_postfix(\n",
    "                    loss=loss.item(),\n",
    "                    lr=self.model_optim.param_groups[0][\"lr\"],\n",
    "                    epoch=self.current_epoch,\n",
    "                    refresh=True,\n",
    "                )\n",
    "                self.model_optim.step()\n",
    "                self.model_optim.zero_grad()\n",
    "                \n",
    "\n",
    "        self.model.eval()\n",
    "        self.cond_pred_model.eval()\n",
    "        return train_loss\n",
    "\n",
    "    def _process_train_batch(self, batch_x, batch_y, batch_x_mark, batch_y_mark):\n",
    "        # inputs:\n",
    "        # batch_x: (B, T, N)\n",
    "        # batch_y: (B, O, N)\n",
    "        # ouputs:\n",
    "        # - pred: (B, N)/(B, O, N)\n",
    "        # - label: (B, N)/(B, O, N)\n",
    "        \n",
    "        # Time Diff need the batch_x to be a even number\n",
    "        \n",
    "        \n",
    "        # dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()\n",
    "        # dec_inp = torch.cat([batch_x[:, -self.label_len:, :], dec_inp], dim=1).float().to(self.device)\n",
    "\n",
    "        batch_y = torch.concat([batch_x[:, -self.label_len:, :], batch_y], dim=1)\n",
    "        batch_y_mark = torch.concat([batch_x_mark[:, -self.label_len:, :], batch_y_mark], dim=1)\n",
    "\n",
    "        dec_inp_pred = torch.zeros(\n",
    "            [batch_x.size(0), self.pred_len, self.dataset.num_features]\n",
    "        ).to(self.device)\n",
    "        dec_inp_label = batch_x[:, -self.label_len :, :].to(self.device)\n",
    "\n",
    "        dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)\n",
    "        \n",
    "\n",
    "        n = batch_x.size(0)\n",
    "        t = torch.randint(\n",
    "            low=0, high=self.model.num_timesteps, size=(n // 2 + 1,)\n",
    "        ).to(self.device)\n",
    "        t = torch.cat([t, self.model.num_timesteps - 1 - t], dim=0)[:n]\n",
    "        _, y_0_hat_batch, KL_loss, z_sample = self.cond_pred_model(batch_x, batch_x_mark, dec_inp, batch_y_mark)\n",
    "        loss_vae = log_normal(batch_y, y_0_hat_batch, torch.from_numpy(np.array(1)))\n",
    "\n",
    "        loss_vae_all = loss_vae + self.k_z * KL_loss\n",
    "        # y_0_hat_batch = z_sample\n",
    "\n",
    "        y_T_mean = y_0_hat_batch\n",
    "        e = torch.randn_like(batch_y).to(self.device)\n",
    "\n",
    "        y_t_batch = q_sample(batch_y, y_T_mean, self.model.alphas_bar_sqrt,\n",
    "                                self.model.one_minus_alphas_bar_sqrt, t, noise=e)\n",
    "\n",
    "        output = self.model(batch_x, batch_x_mark, batch_y, y_t_batch, y_0_hat_batch, t)\n",
    "        # loss = (e[:, -self.args.pred_len:, :] - output[:, -self.args.pred_len:, :]).square().mean()\n",
    "        loss = (e - output).square().mean() + self.args.k_cond*loss_vae_all\n",
    "        return loss\n",
    "\n",
    "\n",
    "    def _process_val_batch(self, batch_x, batch_y, batch_x_mark, batch_y_mark):\n",
    "        # inputs:\n",
    "        # batch_x: (B, T, N)\n",
    "        # batch_y: (B, O, N)\n",
    "        # ouputs:\n",
    "        # - pred: (B, N)/(B, O, N)\n",
    "        # - label: (B, N)/(B, O, N)\n",
    "        # - pred: (B, N)/(B, O, N)\n",
    "        # - label: (B, N)/(B, O, N)\n",
    "        b = batch_x.shape[0]\n",
    "        gen_y_by_batch_list = [[] for _ in range(self.diffusion_steps + 1)]\n",
    "        y_se_by_batch_list = [[] for _ in range(self.diffusion_steps + 1)]\n",
    "        minisample = 10\n",
    "        \n",
    "        batch_y = torch.concat([batch_x[:, -self.label_len:, :], batch_y], dim=1)\n",
    "        batch_y_mark = torch.concat([batch_x_mark[:, -self.label_len:, :], batch_y_mark], dim=1)\n",
    "\n",
    "        dec_inp_pred = torch.zeros(\n",
    "            [batch_x.size(0), self.pred_len, self.dataset.num_features]\n",
    "        ).to(self.device)\n",
    "        dec_inp_label = batch_x[:, -self.label_len :, :].to(self.device)\n",
    "        dec_inp = torch.cat([dec_inp_label, dec_inp_pred], dim=1)\n",
    "\n",
    "\n",
    "        def store_gen_y_at_step_t(config, config_diff, idx, y_tile_seq):\n",
    "            \"\"\"\n",
    "            Store generated y from a mini-batch to the array of corresponding time step.\n",
    "            \"\"\"\n",
    "            current_t = self.diffusion_steps - idx\n",
    "            gen_y = y_tile_seq[idx].reshape(b,\n",
    "                                            # int(config_diff.testing.n_z_samples / config_diff.testing.n_z_samples_depart),\n",
    "                                            minisample,\n",
    "                                            (config.label_len + config.pred_len),\n",
    "                                            config.c_out).cpu()\n",
    "            # directly modify the dict value by concat np.array instead of append np.array gen_y to list\n",
    "            # reduces a huge amount of memory consumption\n",
    "            if len(gen_y_by_batch_list[current_t]) == 0:\n",
    "                gen_y_by_batch_list[current_t] = gen_y\n",
    "            else:\n",
    "                gen_y_by_batch_list[current_t] = torch.concat([gen_y_by_batch_list[current_t], gen_y], dim=0)\n",
    "            return gen_y\n",
    "\n",
    "\n",
    "        # dec_inp = torch.zeros_like(batch_y[:, -self.args.pred_len:, :]).float()\n",
    "        # dec_inp = torch.cat([batch_y[:, :self.label_len, :], dec_inp], dim=1).float().to(self.device)\n",
    "\n",
    "        n = batch_x.size(0)\n",
    "        t = torch.randint(\n",
    "            low=0, high=self.model.num_timesteps, size=(n // 2 + 1,)\n",
    "        ).to(self.device)\n",
    "        t = torch.cat([t, self.model.num_timesteps - 1 - t], dim=0)[:n]\n",
    "        \n",
    "        _, y_0_hat_batch, _, z_sample = self.cond_pred_model(batch_x, batch_x_mark, dec_inp,batch_y_mark)\n",
    "        preds = []\n",
    "        for i in range(self.model.diffusion_config.testing.n_z_samples //minisample):\n",
    "            repeat_n = int(minisample)\n",
    "            y_0_hat_tile = y_0_hat_batch.repeat(repeat_n, 1, 1, 1)\n",
    "            y_0_hat_tile = y_0_hat_tile.transpose(0, 1).flatten(0, 1).to(self.device)\n",
    "            y_T_mean_tile = y_0_hat_tile\n",
    "            x_tile = batch_x.repeat(repeat_n, 1, 1, 1)\n",
    "            x_tile = x_tile.transpose(0, 1).flatten(0, 1).to(self.device)\n",
    "\n",
    "            x_mark_tile = batch_x_mark.repeat(repeat_n, 1, 1, 1)\n",
    "            x_mark_tile = x_mark_tile.transpose(0, 1).flatten(0, 1).to(self.device)\n",
    "\n",
    "            gen_y_box = []\n",
    "            for _ in range(self.model.diffusion_config.testing.n_z_samples_depart):\n",
    "                for _ in range(self.model.diffusion_config.testing.n_z_samples_depart):\n",
    "                    y_tile_seq = p_sample_loop(self.model, x_tile, x_mark_tile, y_0_hat_tile, y_T_mean_tile,\n",
    "                                                self.model.num_timesteps,\n",
    "                                                self.model.alphas, self.model.one_minus_alphas_bar_sqrt)\n",
    "                gen_y = store_gen_y_at_step_t(config=self.model.args,\n",
    "                                                config_diff=self.model.diffusion_config,\n",
    "                                                idx=self.model.num_timesteps, y_tile_seq=y_tile_seq)\n",
    "                gen_y_box.append(gen_y)\n",
    "            outputs = torch.concat(gen_y_box, dim=1)\n",
    "\n",
    "            f_dim = -1 if self.args.features == 'MS' else 0\n",
    "            \n",
    "            outputs = outputs[:, :, -self.pred_len:, f_dim:] # B, S, O, N\n",
    "\n",
    "            pred = outputs  # outputs.detach().cpu().numpy()  # .squeeze()\n",
    "\n",
    "            preds.append(pred) # numberof_testbatch,  B, S, O, N\n",
    "            # trues.append(true) # numberof_testbatch, B, T, N\n",
    "        preds = torch.concat(preds, dim=1)\n",
    "        batch_y = batch_y[:, -self.pred_len:, f_dim:].to(self.device) # B, T, N\n",
    "\n",
    "        outs = preds.permute(0, 2, 3, 1)\n",
    "        assert (outs.shape[1], outs.shape[2], outs.shape[3]) == (self.pred_len, self.dataset.num_features, self.model.diffusion_config.testing.n_z_samples)\n",
    "        return outs, batch_y\n",
    "\n",
    "\n",
    "    def run(self, seed=42) -> Dict[str, float]:\n",
    "        \n",
    "        if self._use_wandb() and not self._init_wandb(self.project, seed): return {}\n",
    "        \n",
    "        self._setup_run(seed)\n",
    "        if self._check_run_exist(seed):\n",
    "            self._resume_run(seed)\n",
    "\n",
    "        self._run_print(f\"run : {self.current_run} in seed: {seed}\")\n",
    "\n",
    "        parameter_tables, model_parameters_num = count_parameters(self.model)\n",
    "        self._run_print(f\"parameter_tables: {parameter_tables}\")\n",
    "        self._run_print(f\"model parameters: {model_parameters_num}\")\n",
    "\n",
    "        if self._use_wandb():\n",
    "            wandb.run.summary[\"parameters\"] = model_parameters_num\n",
    "\n",
    "        # for resumable reproducibility_\n",
    "        while self.current_epoch < self.epochs:\n",
    "            epoch_start_time = time.time()\n",
    "            if self.early_stopper.early_stop is True:\n",
    "                self._run_print(\n",
    "                    f\"val loss no decreased for patience={self.patience} epochs,  early stopping ....\"\n",
    "                )\n",
    "                break\n",
    "\n",
    "            # for resumable reproducibility\n",
    "            reproducible(seed + self.current_epoch)\n",
    "            train_losses = self._train()\n",
    "            self._run_print(\n",
    "                \"Epoch: {} cost time: {}s\".format(\n",
    "                    self.current_epoch + 1, time.time() - epoch_start_time\n",
    "                )\n",
    "            )\n",
    "            self._run_print(f\"Traininng loss : {np.mean(train_losses)}\")\n",
    "\n",
    "            val_result = self._val()\n",
    "            test_result = self._test()\n",
    "\n",
    "            self.current_epoch = self.current_epoch + 1\n",
    "            self.early_stopper(val_result['crps'], model={'model':self.model, 'cond_pred_model':self.cond_pred_model})\n",
    "\n",
    "            self._save_run_check_point(seed)\n",
    "\n",
    "            if self._use_wandb():\n",
    "                wandb.log({'training_loss' : np.mean(train_losses)}, step=self.current_epoch)\n",
    "                wandb.log( {f\"val_{k}\": v for k, v in val_result.items()}, step=self.current_epoch)\n",
    "                # wandb.log( {f\"test_{k}\": v for k, v in test_result.items()}, step=self.current_epoch)\n",
    "\n",
    "            # self.scheduler.step()\n",
    "\n",
    "        self._load_best_model()\n",
    "        best_test_result = self._test()\n",
    "        if self._use_wandb():\n",
    "            for k, v in best_test_result.items(): wandb.run.summary[f\"best_test_{k}\"] = v \n",
    "        \n",
    "        if self._use_wandb():  wandb.finish()\n",
    "        return best_test_result\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "349deadd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ExperimentConfig(data_path='ts_datasets', dataset_type='ETTh1', windows=168, pred_len=192, batch_size=32, epochs=10, lr=0.001, device='cuda:3', num_worker=2, patience=3, scaler_type='StandardScaler', optm_type='Adam', train_ratio=0.7, test_ratio=0.2, invtrans_loss=False)\n"
     ]
    }
   ],
   "source": [
    "from dataclasses import dataclass\n",
    "\n",
    "@dataclass\n",
    "class ExperimentConfig:\n",
    "    data_path: str = \"ts_datasets\"\n",
    "    dataset_type: str = \"ETTh1\"\n",
    "    windows: int = 168         \n",
    "    pred_len: int = 192       \n",
    "    batch_size: int = 32\n",
    "    epochs: int = 10          \n",
    "    lr: float = 1e-3\n",
    "    device: str = device\n",
    "    num_worker: int = 2\n",
    "    patience: int = 3\n",
    "    scaler_type: str = \"StandardScaler\"\n",
    "    optm_type: str = \"Adam\"\n",
    "    train_ratio: float = 0.7\n",
    "    test_ratio: float = 0.2\n",
    "    invtrans_loss: bool = False\n",
    "\n",
    "config = ExperimentConfig()\n",
    "print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "88d1774f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train steps: 8281\n",
      "val steps: 8\n",
      "test steps: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yang/anaconda3/envs/nsdiff_env/lib/python3.9/site-packages/torch/__init__.py:1240: UserWarning: torch.set_default_tensor_type() is deprecated as of PyTorch 2.1, please use torch.set_default_dtype() and torch.set_default_device() as alternatives. (Triggered internally at /pytorch/torch/csrc/tensor/python_tensor.cpp:434.)\n",
      "  _C._set_default_tensor_type(t)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "75527\n",
      "result directory exists: ./results/runs/TMDM_cw/ETTh1/w168h1s192/6ddade749fd83388f4725a5a87c21015\n",
      "run : 0 in seed: 42\n",
      "parameter_tables: +------------------------------------------------+------------+\n",
      "|                    Modules                     | Parameters |\n",
      "+------------------------------------------------+------------+\n",
      "|        diffussion_model.lin1.lin.weight        |    1792    |\n",
      "|         diffussion_model.lin1.lin.bias         |    128     |\n",
      "|       diffussion_model.lin1.embed.weight       |   12928    |\n",
      "|        diffussion_model.lin2.lin.weight        |   16384    |\n",
      "|         diffussion_model.lin2.lin.bias         |    128     |\n",
      "|       diffussion_model.lin2.embed.weight       |   12928    |\n",
      "|        diffussion_model.lin3.lin.weight        |   16384    |\n",
      "|         diffussion_model.lin3.lin.bias         |    128     |\n",
      "|       diffussion_model.lin3.embed.weight       |   12928    |\n",
      "|          diffussion_model.lin4.weight          |    896     |\n",
      "|           diffussion_model.lin4.bias           |     7      |\n",
      "| enc_embedding.value_embedding.tokenConv.weight |    672     |\n",
      "|  enc_embedding.value_embedding.tokenConv.bias  |     32     |\n",
      "| enc_embedding.temporal_embedding.embed.weight  |    160     |\n",
      "|  enc_embedding.temporal_embedding.embed.bias   |     32     |\n",
      "+------------------------------------------------+------------+\n",
      "model parameters: 75527\n",
      "val loss no decreased for patience=3 epochs,  early stopping ....\n",
      "Testing .... \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 8/8 [00:05<00:00,  1.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1751 1399  710  663  563  602  548  565  603  749 1240 1359]\n",
      "test_results: {'crps': 1.0201311111450195, 'crps_sum': 3.9066505432128906, 'mae': 1.2204753160476685, 'mse': 3.8023321628570557, 'picp': 5.1156734116375446e-05, 'qice': 0.06693823635578156, 'rmse': 1.949957013130188}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "exp = TMDMForecast(\n",
    "    data_path=config.data_path,\n",
    "    dataset_type=config.dataset_type,\n",
    "    windows=config.windows,\n",
    "    pred_len=config.pred_len,\n",
    "    batch_size=config.batch_size,\n",
    "    epochs=config.epochs,\n",
    "    lr=config.lr,\n",
    "    device=config.device,\n",
    "    num_worker=config.num_worker,\n",
    "    patience=config.patience,\n",
    "    scaler_type=config.scaler_type,\n",
    "    optm_type=config.optm_type,\n",
    "    train_ratio=config.train_ratio,\n",
    "    test_ratio=config.test_ratio,\n",
    "    invtrans_loss=config.invtrans_loss,\n",
    ")\n",
    "result = exp.run(seed=42 + seed_idx) # seed_idx from 0 to 9\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "35872655",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gfbsdbsgbfa bgaf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b17eefc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "set_seed(1920 + seed_idx)\n",
    "exp.model.eval()\n",
    "exp.cond_pred_model.eval()\n",
    "\n",
    "\n",
    "origin_y = origin_y.to(device).float()\n",
    "batch_x = batch_x.to(device).float()\n",
    "batch_x = future_cw_test[0]['batch_history'].to(device).float()\n",
    "batch_y = future_cw_test[0]['batch_future_cw'].to(device).float() ################# change here!\n",
    "batch_x_date_enc = future_cw_test[0]['batch_history_mark'].to(device).float()\n",
    "batch_y_date_enc = future_cw_test[0]['batch_future_mark'].to(device).float()\n",
    "\n",
    "with torch.no_grad():\n",
    "    # batch_x = batch_x.to(exp.device).float()\n",
    "    # batch_x_mark = batch_x_date_enc.to(exp.device).float()\n",
    "    # batch_y = batch_y.to(exp.device).float()\n",
    "    # batch_y_mark = batch_y_date_enc.to(exp.device).float()\n",
    "    origin_y = origin_y.to(device).float()\n",
    "    batch_x = batch_x.to(device).float()\n",
    "    batch_x = future_cw_test[0]['batch_history'].to(device).float()\n",
    "    batch_y_cw = future_cw_test[0]['batch_future_cw'].to(device).float() ################# change here!\n",
    "    batch_y = future_cw_test[0]['batch_future'].to(device).float() ################# change here!\n",
    "    batch_x_mark = future_cw_test[0]['batch_history_mark'].to(device).float()\n",
    "    batch_y_mark = future_cw_test[0]['batch_future_mark'].to(device).float()\n",
    "    preds, truths = exp._process_val_batch(batch_x, batch_y, batch_x_mark, batch_y_mark)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d7002554",
   "metadata": {},
   "outputs": [],
   "source": [
    "generated_batch = []\n",
    "cov_sqrt_test = future_cw_test[0]['cov_sqrt'].to(device)\n",
    "miu_test = future_cw_test[0]['miu_pred'].to(device)\n",
    "for iii in range(preds.shape[-1]):\n",
    "    pred_temp = torch.einsum('btij,btj->bti', cov_sqrt_test.to(device), preds[:,:,:,iii].to(device)) + miu_test.to(device)\n",
    "    generated_batch.append(pred_temp.unsqueeze(-1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1fc2190a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MetricCollection(\n",
       "  (crps): CRPS()\n",
       "  (crps_sum): CRPSSum()\n",
       "  (mae): ProbMAE()\n",
       "  (mse): ProbMSE()\n",
       "  (picp): PICP()\n",
       "  (qice): QICE()\n",
       "  (rmse): ProbRMSE()\n",
       ")"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics = MetricCollection(metrics={\"crps\": CRPS(),\n",
    "                                    \"crps_sum\": CRPSSum(),\n",
    "                                    \"qice\": QICE(),\n",
    "                                    \"picp\": PICP(),\n",
    "                                    \"mse\": ProbMSE(),\n",
    "                                    \"mae\":ProbMAE(),\n",
    "                                    \"rmse\": ProbRMSE(),\n",
    "                                    }\n",
    "                                    )\n",
    "metrics.to(\"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "735d059c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 761 1204  725  699  741  677  699  814  842  973 1701  916]\n",
      "{'crps': 0.4071166515350342, 'crps_sum': 1.73665189743042, 'mae': 0.5315874218940735, 'mse': 0.5648790597915649, 'picp': 6.254027539398521e-05, 'qice': 0.04523066431283951, 'rmse': 0.7515843510627747}\n"
     ]
    }
   ],
   "source": [
    "metrics.reset()\n",
    "metrics.update(torch.cat(generated_batch,dim=-1).float().detach().cpu(), batch_y.detach().cpu())\n",
    "print({name: float(metric.compute()) for name, metric in metrics.items()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f9f1149",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.836060252548963\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor(0.2063), tensor(0.0033))"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_xy = [torch.cat([batch_x,x.squeeze(-1)],dim=1) for x in generated_batch]\n",
    "pred_xy = torch.cat(pred_xy,dim=0)\n",
    "\n",
    "print(Context_FID(ori_data = torch.cat([batch_x,batch_y],dim=1).detach().cpu().numpy(), \n",
    "            generated_data = pred_xy.detach().cpu().numpy()))\n",
    "cacf_list = []\n",
    "for iii in range(len(generated_batch)):\n",
    "    cacf_list.append((cacf_torch(x = torch.cat([batch_x,batch_y],dim=1)) \\\n",
    "                    -cacf_torch(x =  torch.cat([batch_x,generated_batch[iii].squeeze(-1)],dim=1))).abs().mean())\n",
    "torch.tensor(cacf_list).mean(),torch.tensor(cacf_list).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8085350",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "nsdiff_env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
