{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "args: Namespace(cls=0, reverse=1, is_soft_instance=0, is_soft_temporal=0, padmask=0, mask_ends=0, permute=0, share=1, loss_fn='mse', loss_weight='uniform', loss_sigma=0, tau_inst=0, tau_temp=0, head_dropout=0.2, dset_pretrain='traffic', context_points=768, target_points=96, batch_size=32, num_workers=0, scaler='standard', features='M', patch_len=24, stride=24, revin=1, d_model=256, mask_ratio=0.5, mask_schedule=0, n_epochs_pretrain=100, lr=0.001, pretrained_model_id=1, model_type='based_model', device_id=6, seed=1, overlap=0.7)\n",
      "traffic\n",
      "New Training\n",
      "---------------------------------------------------------------------------------------------------- \n",
      " ----------------------------------------------------------------------------------------------------\n",
      "Finding the best Learning Rate\n",
      "==================================================\n",
      "Loading DataLoaders\n",
      "==================================================\n",
      "X shape : torch.Size([768, 862])\n",
      "Y shape : torch.Size([96, 862])\n",
      "dls.vars : 862\n",
      "dls.c (=output length): 96\n",
      "dls.len (=input length): 768\n",
      "--------------------------------------------------------------------------------\n",
      "==================================================\n",
      "Loading Models\n",
      "==================================================\n",
      "input TS length: 768\n",
      "patch size: 24\n",
      "number of patches: 32\n",
      "Non-Overlapping patches...\n",
      "--------------------------------------------------------------------------------\n",
      "number of model params 78360\n",
      "suggested_lr 0.0013219411484660286\n",
      "---------------------------------------------------------------------------------------------------- \n",
      " ----------------------------------------------------------------------------------------------------\n",
      "Start Pretraining\n",
      "==================================================\n",
      "Loading DataLoaders\n",
      "==================================================\n",
      "X shape : torch.Size([768, 862])\n",
      "Y shape : torch.Size([96, 862])\n",
      "dls.vars : 862\n",
      "dls.c (=output length): 96\n",
      "dls.len (=input length): 768\n",
      "--------------------------------------------------------------------------------\n",
      "==================================================\n",
      "Loading Models\n",
      "==================================================\n",
      "input TS length: 768\n",
      "patch size: 24\n",
      "number of patches: 32\n",
      "Non-Overlapping patches...\n",
      "--------------------------------------------------------------------------------\n",
      "number of model params 78360\n",
      "          epoch     train_loss     valid_loss           time\n",
      "              0      49.928336      35.189739          02:12\n",
      "              1      30.031951      20.736017          02:13\n",
      "              2      17.816150      11.717252          02:12\n",
      "              3      10.040677       6.252008          02:12\n",
      "              4       5.368891       3.182643          02:12\n",
      "              5       2.767160       1.592461          02:12\n",
      "              6       1.384788       0.773343          02:12\n",
      "              7       0.676664       0.374969          02:14\n",
      "              8       0.337965       0.201726          02:17\n",
      "              9       0.188207       0.124387          02:10\n",
      "             10       0.119517       0.086678          02:11\n",
      "             11       0.083751       0.064808          02:10\n",
      "             12       0.063200       0.051513          03:51\n",
      "             13       0.051381       0.044602          02:10\n",
      "             14       0.044557       0.040420          02:10\n",
      "             15       0.040550       0.038229          02:10\n",
      "             16       0.038455       0.036840          02:10\n",
      "             17       0.037092       0.035801          02:11\n",
      "             18       0.036078       0.036093          02:11\n",
      "             19       0.035479       0.034874          02:11\n",
      "             20       0.034989       0.034517          02:10\n",
      "             21       0.034686       0.034553          02:10\n",
      "             22       0.034465       0.033767          02:09\n",
      "             23       0.033890       0.033462          02:13\n",
      "             24       0.033506       0.033383          02:11\n",
      "             25       0.033204       0.032425          02:11\n",
      "             26       0.032596       0.032967          02:10\n",
      "             27       0.032687       0.032195          02:09\n",
      "             28       0.032273       0.031687          02:12\n",
      "             29       0.032082       0.031467          02:10\n",
      "             30       0.031773       0.031336          02:12\n",
      "             31       0.031484       0.031638          02:10\n",
      "             32       0.031554       0.031080          02:09\n",
      "             33       0.031145       0.030503          04:27\n",
      "             34       0.031195       0.030495          04:41\n",
      "             35       0.030874       0.031024          02:08\n",
      "             36       0.030895       0.031316          02:08\n",
      "             37       0.031147       0.030494          02:11\n"
     ]
    }
   ],
   "source": [
    "#%%capture\n",
    "import os\n",
    "os.chdir('/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised')\n",
    "\n",
    "#ds_pretrain_list = ['etth1','etth2','ettm1','ettm2']\n",
    "dset_pretrain = 'traffic'\n",
    "d = dset_pretrain\n",
    "ep_pretrain = 100\n",
    "\n",
    "device = 6\n",
    "########################################################\n",
    "patch_len = 24\n",
    "stride = patch_len\n",
    "cp = 768\n",
    "np = cp // stride\n",
    "########################################################\n",
    "\n",
    "for lr in [1e-3]:\n",
    "    for d_model in [256]:\n",
    "        !python patchtst_pretrain_sim_half_v3_mean_FC2_sep.py \\\n",
    "            --device_id {device} --dset_pretrain {d} \\\n",
    "            --n_epochs_pretrain {ep_pretrain} --reverse 1 --context_points {cp} \\\n",
    "            --d_model {d_model} --patch_len {patch_len} --stride {stride} --lr {lr} --batch_size 32\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#%%capture\n",
    "import os\n",
    "os.chdir('/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised')\n",
    "\n",
    "#ds_pretrain_list = ['etth1','etth2','ettm1','ettm2']\n",
    "dset_pretrain = 'traffic'\n",
    "d = dset_pretrain\n",
    "ep_pretrain = 100\n",
    "\n",
    "device = 6\n",
    "########################################################\n",
    "patch_len = 18\n",
    "stride = patch_len\n",
    "cp = 768\n",
    "np = cp // stride\n",
    "########################################################\n",
    "\n",
    "for lr in [1e-3]:\n",
    "    for d_model in [128,256]:\n",
    "        !python patchtst_pretrain_sim_half_v3_mean_FC2_sep.py \\\n",
    "            --device_id {device} --dset_pretrain {d} \\\n",
    "            --n_epochs_pretrain {ep_pretrain} --reverse 1 --context_points {cp} \\\n",
    "            --d_model {d_model} --patch_len {patch_len} --stride {stride} --lr {lr} --batch_size 32\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
