{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5486203",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright authors of TSPulse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "67425436-c877-46cc-b604-74fc1ededb61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import sys\n",
    "import warnings\n",
    "import numpy as np\n",
    "from types import SimpleNamespace\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7ad2ee73-2750-4301-a47c-9295d3753cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"..\")\n",
    "from models.tspulse import TSPulseForReconstruction\n",
    "from imputation.utils import mse, mask_generate\n",
    "from imputation.datautils.dls import get_dls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d5399321-7aef-4991-b139-62a4f2b98154",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
    "CONTEXT_LEN = 512\n",
    "FORECAST_LEN = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "39704542-6fb4-4d86-8f83-707982a9eefc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# inference on etth1\n",
    "dset = [\"etth1\"]  # \"etth1\", \"etth2\", \"ettm1\", \"ettm2\", \"weather\", \"electricity\"\n",
    "m_r = [0.125, 0.25, 0.375, 0.5]\n",
    "m_t = [\"block\", \"hybrid\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a2280021-9005-46d5-bfae-747d86f1f294",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:02<00:00, 20.07it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = block  : Mask Ratio = 0.125\n",
      "Mean Squarred Error (MSE)=0.209\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 64.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = block  : Mask Ratio = 0.25\n",
      "Mean Squarred Error (MSE)=0.225\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 68.17it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = block  : Mask Ratio = 0.375\n",
      "Mean Squarred Error (MSE)=0.246\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 65.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = block  : Mask Ratio = 0.5\n",
      "Mean Squarred Error (MSE)=0.272\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:01<00:00, 44.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = hybrid  : Mask Ratio = 0.125\n",
      "Mean Squarred Error (MSE)=0.146\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 65.72it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = hybrid  : Mask Ratio = 0.25\n",
      "Mean Squarred Error (MSE)=0.155\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 67.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = hybrid  : Mask Ratio = 0.375\n",
      "Mean Squarred Error (MSE)=0.168\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 67.82it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = hybrid  : Mask Ratio = 0.5\n",
      "Mean Squarred Error (MSE)=0.183\n"
     ]
    }
   ],
   "source": [
    "for mask_type in m_t:\n",
    "    for DATASET in dset:\n",
    "        for mask_ratio in m_r:\n",
    "            if DATASET in [\"etth1\", \"etth2\", \"ettm1\", \"ettm2\"]:\n",
    "                batch_size = 64\n",
    "            else:\n",
    "                batch_size = 4\n",
    "            data_params_ = {\n",
    "                \"dset\": f\"{DATASET}\",\n",
    "                \"features\": \"M\",\n",
    "                \"data_mount_path\": \"datasets/\",\n",
    "                \"context_points\": CONTEXT_LEN,\n",
    "                \"target_points\": FORECAST_LEN,\n",
    "                \"return_dict\": True,\n",
    "                \"prefix_freq\": False,\n",
    "                \"scale\": True,\n",
    "            }\n",
    "            data_params = SimpleNamespace(**data_params_)\n",
    "            dls = get_dls(data_params)\n",
    "\n",
    "            dset_train = dls.train.dataset\n",
    "            dset_val = dls.valid.dataset\n",
    "            dset_test = dls.test.dataset\n",
    "\n",
    "            # print(\"shape of a test sample : \", dset_test[0][\"past_values\"].shape)   # l c\n",
    "\n",
    "            num_channels = dset_test[0][\"past_values\"].shape[-1]\n",
    "            test_dataloader = DataLoader(dset_test, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "            model_path = \"../../model-binaries/tspulse_hybrid_sign20/tspulse_model\"\n",
    "\n",
    "            model = TSPulseForReconstruction.from_pretrained(\n",
    "                model_path, fft_time_add_forecasting_pt_loss=False, num_input_channels=num_channels, mask_type=\"user\"\n",
    "            ).to(device)\n",
    "\n",
    "            seed = 42\n",
    "            g = torch.Generator(device=device)\n",
    "            g.manual_seed(seed)\n",
    "\n",
    "            trues, preds, masks = [], [], []\n",
    "            with torch.no_grad():\n",
    "                for batch in tqdm(test_dataloader):\n",
    "                    batch_x = batch[\"past_values\"].to(device)  # b l c\n",
    "\n",
    "                    mask = mask_generate(g, batch_x, 8, mask_ratio, mask_type)\n",
    "\n",
    "                    output = model(past_values=batch_x, past_observed_mask=~mask)\n",
    "\n",
    "                    reconstructed_output = output.reconstruction_outputs\n",
    "\n",
    "                    trues.append(batch_x.detach().cpu().numpy())\n",
    "                    preds.append(reconstructed_output.detach().cpu().numpy())\n",
    "                    masks.append(mask.detach().cpu().numpy())\n",
    "\n",
    "                preds = np.concatenate(preds)\n",
    "                trues = np.concatenate(trues)\n",
    "                masks = np.concatenate(masks)\n",
    "\n",
    "                MSE = mse(y=trues[masks == 1], y_hat=preds[masks == 1], reduction=\"mean\")\n",
    "                print(f\"Dataset = {DATASET}  : Mask Type = {mask_type}  : Mask Ratio = {mask_ratio}\")\n",
    "                print(f\"Mean Squarred Error (MSE)={MSE:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d31f252-6b2b-4573-ae8c-7b29503efe00",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
