{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db8a044a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright authors of TSPulse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "67425436-c877-46cc-b604-74fc1ededb61",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import sys\n",
    "import tempfile\n",
    "import warnings\n",
    "import numpy as np\n",
    "import math\n",
    "from types import SimpleNamespace\n",
    "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import OneCycleLR\n",
    "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7ad2ee73-2750-4301-a47c-9295d3753cbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"..\")\n",
    "from models.tspulse import TSPulseForReconstruction\n",
    "from imputation.utils import mse, mask_generate\n",
    "from imputation.datautils.dls import get_dls\n",
    "from imputation.lr_finder import optimal_lr_finder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da9da22d-7f2b-4927-9970-ee64276c0d88",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d5399321-7aef-4991-b139-62a4f2b98154",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda\"\n",
    "CONTEXT_LEN = 512\n",
    "FORECAST_LEN = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "bcd23050-b85f-413d-a563-76437738cddd",
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = \"etth1\"\n",
    "mask_ratio = 0.375\n",
    "mask_type = \"hybrid\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4cf7cd37-5934-4a9a-9dcd-15ca3f62eef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_params_ = {\n",
    "    \"dset\": f\"{DATASET}\",\n",
    "    \"features\": \"M\",\n",
    "    \"data_mount_path\": \"datasets/\",\n",
    "    \"context_points\": CONTEXT_LEN,\n",
    "    \"target_points\": FORECAST_LEN,\n",
    "    \"return_dict\": True,\n",
    "    \"prefix_freq\": False,\n",
    "    \"scale\": True,\n",
    "}\n",
    "data_params = SimpleNamespace(**data_params_)\n",
    "dls = get_dls(data_params)\n",
    "\n",
    "train_dataset = dls.train.dataset\n",
    "valid_dataset = dls.valid.dataset\n",
    "test_dataset = dls.test.dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ffb9ed8d-7f7a-4f0a-8e72-37df57401ca5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at ../../model-binaries/tspulse_hybrid_sign20/tspulse_model were not used when initializing TSPulseForReconstruction: ['decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.feature_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.0.patch_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.feature_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.1.patch_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.feature_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.gating_block.attn_layer.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.gating_block.attn_layer.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.mlp.fc1.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.mlp.fc1.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.mlp.fc2.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.mlp.fc2.weight', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.norm.norm.bias', 'decoder_with_head.fft_softmax_mapping.0.mixers.2.patch_mixer.norm.norm.weight', 'decoder_with_head.fft_softmax_mapping.2.bias', 'decoder_with_head.fft_softmax_mapping.2.weight', 'decoder_with_head.fft_softmax_mapping.3.bias', 'decoder_with_head.fft_softmax_mapping.3.weight']\n",
      "- This IS expected if you are initializing TSPulseForReconstruction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing TSPulseForReconstruction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of TSPulseForReconstruction were not initialized from the model checkpoint at ../../model-binaries/tspulse_hybrid_sign20/tspulse_model and are newly initialized: ['decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc1.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc1.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc2.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc2.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.norm.norm.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.norm.norm.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc1.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc1.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc2.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc2.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.norm.norm.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.norm.norm.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Identity Init in Module:  TSPulseChannelFeatureMixerBlock\n",
      "Init identity weights for channel mixing\n",
      "Try identity init in Gated Attention.\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Identity Init in Module:  TSPulseChannelFeatureMixerBlock\n",
      "Init identity weights for channel mixing\n",
      "Try identity init in Gated Attention.\n"
     ]
    }
   ],
   "source": [
    "model_path = \"../../model-binaries/tspulse_hybrid_sign20/tspulse_model\"\n",
    "\n",
    "model_dict = {\n",
    "    \"mask_ratio\": mask_ratio,\n",
    "    \"mask_type\": mask_type,\n",
    "    \"prediction_length\": 0,\n",
    "    \"fft_time_add_forecasting_pt_loss\": False,\n",
    "    \"enable_fft_prob_loss\": False,\n",
    "    \"fft_time_consistent_masking\": True,\n",
    "    \"fft_original_signal_loss_weight\": 0,\n",
    "    \"loss_apply_mode\": \"mask\",\n",
    "    \"fft_weight\": 0,\n",
    "    \"num_full_patches_for_hybrid_mask\": int((mask_ratio / 0.125) * 4),\n",
    "    \"decoder_mode\": \"mix_channel\",\n",
    "    \"channel_consistent_masking\": False,\n",
    "    \"dropout\": 0,\n",
    "    \"head_dropout\": 0,\n",
    "}\n",
    "\n",
    "model_dict[\"num_input_channels\"] = dls.train.dataset[0][\"past_values\"].shape[-1]\n",
    "model = TSPulseForReconstruction.from_pretrained(model_path, **model_dict).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8981979e-7c42-4f03-b20b-9b5f08493b4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "OUT_DIR = \"tspulse_finetuned_models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "66d1308f-89f1-4c93-a618-d35bad23830f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(\"cuda\").float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "d0b54983-3985-4c2b-b828-a1e1b3fac588",
   "metadata": {},
   "outputs": [],
   "source": [
    "for param in model.parameters():\n",
    "    param.requires_grad = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e05aa2f-bff0-4fa0-b0c2-dfd7ef7dd971",
   "metadata": {},
   "source": [
    "## Finetuning the Full model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "619502da-53e8-4978-bdba-b3448c22fcc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
      "LR Finder: Using GPU:0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder: Suggested learning rate = 0.00020565123083486514\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='45765' max='101700' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 45765/101700 20:46 < 25:23, 36.72 it/s, Epoch 45/100]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.124500</td>\n",
       "      <td>0.147419</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.113900</td>\n",
       "      <td>0.142947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.108500</td>\n",
       "      <td>0.140042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.104700</td>\n",
       "      <td>0.138109</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.100300</td>\n",
       "      <td>0.136230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.094100</td>\n",
       "      <td>0.134347</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.086300</td>\n",
       "      <td>0.137088</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.079400</td>\n",
       "      <td>0.134114</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.074700</td>\n",
       "      <td>0.130689</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.071600</td>\n",
       "      <td>0.128349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>0.068800</td>\n",
       "      <td>0.124605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>0.066700</td>\n",
       "      <td>0.122224</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>0.065200</td>\n",
       "      <td>0.120375</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>0.063400</td>\n",
       "      <td>0.117920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>0.062000</td>\n",
       "      <td>0.118346</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>0.060900</td>\n",
       "      <td>0.116761</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>0.059800</td>\n",
       "      <td>0.114692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.058800</td>\n",
       "      <td>0.114011</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.058000</td>\n",
       "      <td>0.113132</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.057000</td>\n",
       "      <td>0.110611</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.056100</td>\n",
       "      <td>0.111019</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.055500</td>\n",
       "      <td>0.111485</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.054500</td>\n",
       "      <td>0.108878</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.053900</td>\n",
       "      <td>0.105477</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.053400</td>\n",
       "      <td>0.106081</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.052800</td>\n",
       "      <td>0.106087</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.052400</td>\n",
       "      <td>0.104086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.051700</td>\n",
       "      <td>0.104113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.051500</td>\n",
       "      <td>0.104038</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.050900</td>\n",
       "      <td>0.101676</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.050500</td>\n",
       "      <td>0.101702</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.050200</td>\n",
       "      <td>0.101286</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.049800</td>\n",
       "      <td>0.100265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.049300</td>\n",
       "      <td>0.100978</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.049100</td>\n",
       "      <td>0.098621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.048700</td>\n",
       "      <td>0.099243</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.048200</td>\n",
       "      <td>0.099059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.048000</td>\n",
       "      <td>0.098801</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.047800</td>\n",
       "      <td>0.100396</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.047400</td>\n",
       "      <td>0.099082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.047200</td>\n",
       "      <td>0.099863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.046800</td>\n",
       "      <td>0.099860</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.046600</td>\n",
       "      <td>0.098920</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.046400</td>\n",
       "      <td>0.099554</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.046200</td>\n",
       "      <td>0.098583</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=45765, training_loss=0.06296574823793578, metrics={'train_runtime': 1246.5086, 'train_samples_per_second': 652.142, 'train_steps_per_second': 81.588, 'total_flos': 7837779087452160.0, 'train_loss': 0.06296574823793578, 'epoch': 45.0})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_dir = tempfile.mkdtemp()\n",
    "\n",
    "suggested_lr = None\n",
    "\n",
    "train_dict = {\n",
    "    \"overwrite_output_dir\": True,\n",
    "    \"learning_rate\": 0.0001,\n",
    "    \"num_train_epochs\": 100,\n",
    "    \"evaluation_strategy\": \"epoch\",\n",
    "    \"per_device_train_batch_size\": 8,\n",
    "    \"per_device_eval_batch_size\": 8,\n",
    "    \"dataloader_num_workers\": 1,\n",
    "    \"eval_accumulation_steps\": 50,\n",
    "    \"ddp_find_unused_parameters\": False,\n",
    "    \"report_to\": \"tensorboard\",\n",
    "    \"save_strategy\": \"epoch\",\n",
    "    \"logging_strategy\": \"epoch\",\n",
    "    \"save_total_limit\": 1,\n",
    "    \"load_best_model_at_end\": True,\n",
    "    \"metric_for_best_model\": \"eval_loss\",\n",
    "    \"greater_is_better\": False,\n",
    "    \"seed\": 42,\n",
    "}\n",
    "\n",
    "EPOCHS = train_dict[\"num_train_epochs\"]\n",
    "BATCH_SIZE = train_dict[\"per_device_train_batch_size\"]\n",
    "eval_accumulation_steps = train_dict[\"eval_accumulation_steps\"]\n",
    "NUM_WORKERS = 1\n",
    "NUM_GPUS = 1\n",
    "\n",
    "set_seed(42)\n",
    "if suggested_lr is None:\n",
    "    lr, model = optimal_lr_finder(\n",
    "        model,\n",
    "        train_dataset,\n",
    "        batch_size=BATCH_SIZE,\n",
    "    )\n",
    "    suggested_lr = lr\n",
    "\n",
    "finetune_args = TrainingArguments(\n",
    "    output_dir=temp_dir,\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=suggested_lr,\n",
    "    num_train_epochs=EPOCHS,\n",
    "    do_eval=True,\n",
    "    eval_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    per_device_eval_batch_size=BATCH_SIZE,\n",
    "    eval_accumulation_steps=eval_accumulation_steps,\n",
    "    dataloader_num_workers=NUM_WORKERS,\n",
    "    report_to=\"tensorboard\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    logging_dir=os.path.join(OUT_DIR, \"output\"),  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    ")\n",
    "\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=10,  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "# Optimizer and scheduler\n",
    "optimizer = AdamW(model.parameters(), lr=suggested_lr)\n",
    "scheduler = OneCycleLR(\n",
    "    optimizer,\n",
    "    suggested_lr,\n",
    "    epochs=EPOCHS,\n",
    "    steps_per_epoch=math.ceil(len(train_dataset) / (BATCH_SIZE * NUM_GPUS)),\n",
    ")\n",
    "\n",
    "finetune_trainer = Trainer(\n",
    "    model=model,\n",
    "    args=finetune_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    "    optimizers=(optimizer, scheduler),\n",
    ")\n",
    "\n",
    "# Fine tune\n",
    "finetune_trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a2b21701-3562-4900-9db4-12ca8e364e12",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"finetuned_models\", exist_ok=True)\n",
    "path_to_save_model = f\"finetuned_models/finetuned_model_{DATASET}_{mask_ratio}_{mask_type}\"\n",
    "finetune_trainer.save_model(path_to_save_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78c9c5a7-caa7-4fb3-b297-117655db7a79",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "a2280021-9005-46d5-bfae-747d86f1f294",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 46/46 [00:00<00:00, 62.20it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset = etth1  : Mask Type = hybrid  : Mask Ratio = 0.375\n",
      "Mean Squarred Error (MSE)=0.065\n"
     ]
    }
   ],
   "source": [
    "if DATASET in [\"etth1\", \"etth2\", \"ettm1\", \"ettm2\"]:\n",
    "    batch_size = 64\n",
    "else:\n",
    "    batch_size = 4\n",
    "data_params_ = {\n",
    "    \"dset\": f\"{DATASET}\",\n",
    "    \"features\": \"M\",\n",
    "    \"data_mount_path\": \"datasets/\",\n",
    "    \"context_points\": CONTEXT_LEN,\n",
    "    \"target_points\": FORECAST_LEN,\n",
    "    \"return_dict\": True,\n",
    "    \"prefix_freq\": False,\n",
    "    \"scale\": True,\n",
    "}\n",
    "data_params = SimpleNamespace(**data_params_)\n",
    "dls = get_dls(data_params)\n",
    "\n",
    "train_dataset = dls.train.dataset\n",
    "valid_dataset = dls.valid.dataset\n",
    "test_dataset = dls.test.dataset\n",
    "\n",
    "# print(\"shape of a test sample : \", dset_test[0][\"past_values\"].shape)   # l c\n",
    "\n",
    "num_channels = test_dataset[0][\"past_values\"].shape[-1]\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "model_path = path_to_save_model\n",
    "\n",
    "model = TSPulseForReconstruction.from_pretrained(\n",
    "    model_path, fft_time_add_forecasting_pt_loss=False, num_input_channels=num_channels, mask_type=\"user\"\n",
    ").to(device)\n",
    "\n",
    "\n",
    "seed = 42\n",
    "g = torch.Generator(device=device)\n",
    "g.manual_seed(seed)\n",
    "\n",
    "trues, preds, masks = [], [], []\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(test_dataloader):\n",
    "        batch_x = batch[\"past_values\"].to(device)  # b l c\n",
    "\n",
    "        mask = mask_generate(g, batch_x, 8, mask_ratio, mask_type)\n",
    "\n",
    "        output = model(past_values=batch_x, past_observed_mask=~mask)\n",
    "\n",
    "        reconstructed_output = output.reconstruction_outputs\n",
    "\n",
    "        trues.append(batch_x.detach().cpu().numpy())\n",
    "        preds.append(reconstructed_output.detach().cpu().numpy())\n",
    "        masks.append(mask.detach().cpu().numpy())\n",
    "\n",
    "    preds = np.concatenate(preds)\n",
    "    trues = np.concatenate(trues)\n",
    "    masks = np.concatenate(masks)\n",
    "\n",
    "    MSE = mse(y=trues[masks == 1], y_hat=preds[masks == 1], reduction=\"mean\")\n",
    "    print(f\"Dataset = {DATASET}  : Mask Type = {mask_type}  : Mask Ratio = {mask_ratio}\")\n",
    "    print(f\"Mean Squarred Error (MSE)={MSE:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d79010d2-d1b0-490a-9e76-15b66d413863",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
