{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae1(PATH, settings_list, cw,patch_size, stride, model_dim, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'stride{stride}' in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        \n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        #print(123)\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'lp_ep{ft_epoch}' in x]\n",
    "    #files = [x for x in files if f'_lp_' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_hard(dataset_list, cw_list = [336,512,768,1024],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,50,60,80,100,120,150],\n",
    "                      ft_ep_list = [10,20,40,60,80],\n",
    "                      dim_list = [32,64,128,256],\n",
    "                      lr = 1e-4,\n",
    "                      same_patch = True,\n",
    "                      mask = 0.5,\n",
    "                      stride_size = 999):\n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            if same_patch:\n",
    "                stride = ps\n",
    "            else:\n",
    "                stride = stride_size\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in dim_list:\n",
    "                for load_ep in load_ep_list:\n",
    "                    for ft_ep in ft_ep_list:\n",
    "                        data_sum = 0\n",
    "                        data_results = []\n",
    "                        data_results2 = []\n",
    "                        print('~~~~~~~~~~~'*8)\n",
    "                        for data in dataset_list:\n",
    "                            \n",
    "                            PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                            PATH2 = f'{data}2{data}'\n",
    "                            PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_sep_R/based_model/max'\n",
    "                            PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                            settings = os.listdir(PATH)\n",
    "                            settings = [x for x in settings if f'mask{mask}' in x]\n",
    "                            \n",
    "\n",
    "                            for share in [1]:\n",
    "                                settings = [x for x in settings if 'no_share' not in x]\n",
    "                                settings = [x for x in settings if 'tau' not in x]\n",
    "                                mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                    cw=cw, patch_size=ps,stride=stride, \n",
    "                                                                    model_dim=dim, \n",
    "                                                                    load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                \n",
    "                                try:\n",
    "                                    if len(mse_result)>0:\n",
    "                                        data_sum += np.sum(mse_result)\n",
    "                                        data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                        data_results2.append([data, np.mean(mae_result).round(3), mae_result])\n",
    "                                    \n",
    "                                except:\n",
    "                                    pass\n",
    "                                    \n",
    "                        \n",
    "                        for i,k in enumerate(data_results):\n",
    "                            if k[1]!=999:\n",
    "                                print(cw,ps,[dim,load_ep,ft_ep])\n",
    "                                print(data_results[i])\n",
    "                                print(data_results2[i])\n",
    "                            \n",
    "                        #if min_val > data_sum:\n",
    "                        #    min_val = data_sum\n",
    "                        #    print('-'*50)\n",
    "                        #    print([dim,load_ep,ft_ep])\n",
    "                        #    for k in data_results:\n",
    "                        #        print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "ep_pretrain = 150        \n",
    "load = 100\n",
    "REVERSE = 1\n",
    "\n",
    "patch_len = 12\n",
    "stride = patch_len\n",
    "cp = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 10]\n",
      "['etth1', 0.403, [0.365846, 0.397887, 0.418816, 0.430048]]\n",
      "['etth1', 0.422, [0.391674, 0.413596, 0.428234, 0.453614]]\n",
      "512 12 [128, 100, 10]\n",
      "['etth2', 0.334, [0.269377, 0.33189, 0.352427, 0.383354]]\n",
      "['etth2', 0.382, [0.333125, 0.373433, 0.395061, 0.425904]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 20]\n",
      "['etth1', 0.408, [0.366387, 0.398322, 0.423847, 0.443452]]\n",
      "['etth1', 0.423, [0.39189, 0.413082, 0.429283, 0.459689]]\n",
      "512 12 [128, 100, 20]\n",
      "['etth2', 0.338, [0.273597, 0.337479, 0.355617, 0.383834]]\n",
      "['etth2', 0.383, [0.334609, 0.375743, 0.396659, 0.426349]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 40]\n",
      "['etth1', 0.409, [0.368184, 0.399129, 0.421221, 0.445949]]\n",
      "['etth1', 0.424, [0.393511, 0.412986, 0.428699, 0.461133]]\n",
      "512 12 [128, 100, 40]\n",
      "['etth2', 0.339, [0.276504, 0.337047, 0.355445, 0.386487]]\n",
      "['etth2', 0.384, [0.337228, 0.375653, 0.39652, 0.426843]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 60]\n",
      "['etth1', 0.41, [0.367244, 0.400256, 0.427214, 0.444801]]\n",
      "['etth1', 0.424, [0.393, 0.412888, 0.430421, 0.460626]]\n",
      "512 12 [128, 100, 60]\n",
      "['etth2', 0.339, [0.272681, 0.3398, 0.355123, 0.386613]]\n",
      "['etth2', 0.384, [0.333877, 0.377625, 0.396629, 0.426789]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [128],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 10]\n",
      "['weather', 0.239, [0.167878, 0.213109, 0.256246, 0.319912]]\n",
      "['weather', 0.279, [0.222199, 0.261458, 0.292794, 0.338656]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 20]\n",
      "['weather', 0.239, [0.167716, 0.211861, 0.255975, 0.319517]]\n",
      "['weather', 0.278, [0.222041, 0.259535, 0.292653, 0.337991]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 40]\n",
      "['weather', 0.239, [0.167462, 0.211691, 0.256053, 0.319141]]\n",
      "['weather', 0.278, [0.221253, 0.259131, 0.292767, 0.337779]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 60]\n",
      "['weather', 0.239, [0.167344, 0.211773, 0.256702, 0.319168]]\n",
      "['weather', 0.278, [0.221112, 0.259298, 0.293135, 0.337786]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['weather']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [768],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [128],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [64, 100, 10]\n",
      "['ettm1', 0.357, [0.307326, 0.338512, 0.366165, 0.414497]]\n",
      "['ettm1', 0.379, [0.350017, 0.36853, 0.384788, 0.41139]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [64, 100, 20]\n",
      "['ettm1', 0.357, [0.307496, 0.339571, 0.365736, 0.414419]]\n",
      "['ettm1', 0.379, [0.349867, 0.369538, 0.384507, 0.41166]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [64, 100, 40]\n",
      "['ettm1', 0.356, [0.307231, 0.337379, 0.365075, 0.415021]]\n",
      "['ettm1', 0.378, [0.349597, 0.36816, 0.383948, 0.412007]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [64, 100, 60]\n",
      "['ettm1', 0.356, [0.307614, 0.337959, 0.365202, 0.41479]]\n",
      "['ettm1', 0.379, [0.349712, 0.368482, 0.384063, 0.411953]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [64],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "1024 24 [64, 100, 10]\n",
      "['ettm2', 0.245, [0.162885, 0.213217, 0.262819, 0.339211]]\n",
      "['ettm2', 0.311, [0.255236, 0.290433, 0.324812, 0.373649]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "1024 24 [64, 100, 20]\n",
      "['ettm2', 0.244, [0.160155, 0.213531, 0.262399, 0.34154]]\n",
      "['ettm2', 0.31, [0.252341, 0.289857, 0.323522, 0.376064]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "1024 24 [64, 100, 40]\n",
      "['ettm2', 0.244, [0.160499, 0.213968, 0.263456, 0.338818]]\n",
      "['ettm2', 0.31, [0.252701, 0.290202, 0.324443, 0.37397]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "1024 24 [64, 100, 60]\n",
      "['ettm2', 0.245, [0.160409, 0.212972, 0.26451, 0.343032]]\n",
      "['ettm2', 0.311, [0.25249, 0.289568, 0.325385, 0.375938]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [64],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 10]\n",
      "['traffic', 0.413, [0.413033]]\n",
      "['traffic', 0.296, [0.296312]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 20]\n",
      "['traffic', 0.409, [0.392538, 0.426028]]\n",
      "['traffic', 0.29, [0.277056, 0.30321]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "#768 24 [256, 100, 40]\n",
    "dataset_list = ['traffic']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [768],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [256],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=1024,ps=32 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "1024 32 [256, 100, 20]\n",
      "['electricity', 0.137, [0.136535]]\n",
      "['electricity', 0.235, [0.234999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# FINAL : 1024/32\n",
    "# [256, 100, 60]\n",
    "dataset_list = ['electricity']\n",
    "\n",
    "get_summary_hard(dataset_list, cw_list = [1024],\n",
    "                 ps_list = [32],\n",
    "                 load_ep_list = [100],\n",
    "                 dim_list = [256],\n",
    "                 ft_ep_list = [10,20,40,60])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
