{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae1(PATH, settings_list, cw,patch_size, stride, model_dim, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'stride{stride}' in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        \n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        #print(123)\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_hard(dataset_list, cw_list = [336,512,768,1024],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,50,60,80,100,120,150],\n",
    "                      ft_ep_list = [10,20,40,60,80],\n",
    "                      dim_list = [32,64,128,256],\n",
    "                      lr = 1e-4,\n",
    "                      same_patch = True,\n",
    "                      mask = 0.5,\n",
    "                      stride_size = 999):\n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            if same_patch:\n",
    "                stride = ps\n",
    "            else:\n",
    "                stride = stride_size\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in dim_list:\n",
    "                for load_ep in load_ep_list:\n",
    "                    for ft_ep in ft_ep_list:\n",
    "                        data_sum = 0\n",
    "                        data_results = []\n",
    "                        data_results2 = []\n",
    "                        print('~~~~~~~~~~~'*8)\n",
    "                        for data in dataset_list:\n",
    "                            \n",
    "                            PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                            PATH2 = f'{data}2{data}'\n",
    "                            PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_wo_CL_R/based_model/max'\n",
    "                            PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                            settings = os.listdir(PATH)\n",
    "                            settings = [x for x in settings if f'mask{mask}' in x]\n",
    "                            \n",
    "\n",
    "                            for share in [1]:\n",
    "                                settings = [x for x in settings if 'no_share' not in x]\n",
    "                                settings = [x for x in settings if 'tau' not in x]\n",
    "                                mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                    cw=cw, patch_size=ps,stride=stride, \n",
    "                                                                    model_dim=dim, \n",
    "                                                                    load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                \n",
    "                                try:\n",
    "                                    if len(mse_result)>0:\n",
    "                                        data_sum += np.sum(mse_result)\n",
    "                                        data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                        data_results2.append([data, np.mean(mae_result).round(3), mae_result])\n",
    "                                    \n",
    "                                except:\n",
    "                                    pass\n",
    "                                    \n",
    "                        \n",
    "                        for i,k in enumerate(data_results):\n",
    "                            if k[1]!=999:\n",
    "                                print(cw,ps,[dim,load_ep,ft_ep])\n",
    "                                print(data_results[i])\n",
    "                                print(data_results2[i])\n",
    "                            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [32, 100, 10]\n",
      "['ettm1', 0.354, [0.301691, 0.336123, 0.364024, 0.414356]]\n",
      "['ettm1', 0.381, [0.352913, 0.370975, 0.388074, 0.411698]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [32, 100, 20]\n",
      "['ettm1', 0.352, [0.296155, 0.335201, 0.365374, 0.413146]]\n",
      "['ettm1', 0.38, [0.349286, 0.370443, 0.389778, 0.410525]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [32, 100, 40]\n",
      "['ettm1', 0.354, [0.293562, 0.339131, 0.364947, 0.41769]]\n",
      "['ettm1', 0.383, [0.348349, 0.374079, 0.391725, 0.417212]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 10]\n",
      "['ettm1', 0.359, [0.311333, 0.339484, 0.367952, 0.416292]]\n",
      "['ettm1', 0.382, [0.356124, 0.371284, 0.387207, 0.413989]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 20]\n",
      "['ettm1', 0.359, [0.30436, 0.349011, 0.365377, 0.415677]]\n",
      "['ettm1', 0.382, [0.352649, 0.37767, 0.383957, 0.412384]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 40]\n",
      "['ettm1', 0.356, [0.301438, 0.341813, 0.365344, 0.416967]]\n",
      "['ettm1', 0.381, [0.351786, 0.373026, 0.384096, 0.413711]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 20]\n",
      "['ettm1', 0.355, [0.303594, 0.336865, 0.365956, 0.415331]]\n",
      "['ettm1', 0.379, [0.351341, 0.367194, 0.384285, 0.412081]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 40]\n",
      "['ettm1', 0.357, [0.305057, 0.343502, 0.366031, 0.414423]]\n",
      "['ettm1', 0.381, [0.353223, 0.373724, 0.384691, 0.411572]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "#dataset_list = ['ettm2']\n",
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                    ps_list = [24],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 10]\n",
      "['etth2', 0.336, [0.268658, 0.331473, 0.361416, 0.383505]]\n",
      "['etth2', 0.384, [0.333991, 0.375249, 0.401015, 0.427193]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 20]\n",
      "['etth2', 0.349, [0.27367, 0.351924, 0.362329, 0.409405]]\n",
      "['etth2', 0.393, [0.337379, 0.38576, 0.403889, 0.446943]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 40]\n",
      "['etth2', 0.342, [0.270805, 0.329655, 0.360081, 0.405818]]\n",
      "['etth2', 0.389, [0.334696, 0.374513, 0.401256, 0.444147]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 60]\n",
      "['etth2', 0.335, [0.268891, 0.329424, 0.356182, 0.383998]]\n",
      "['etth2', 0.383, [0.333947, 0.374491, 0.39878, 0.42649]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['etth2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                    ps_list = [12],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['etth2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                    ps_list = [18],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['etth2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                    ps_list = [24],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 10]\n",
      "['etth1', 0.409, [0.367004, 0.399953, 0.425716, 0.443288]]\n",
      "['etth1', 0.423, [0.392642, 0.412829, 0.429453, 0.458915]]\n",
      "512 12 [128, 100, 10]\n",
      "['etth2', 0.336, [0.268658, 0.331473, 0.361416, 0.383505]]\n",
      "['etth2', 0.384, [0.333991, 0.375249, 0.401015, 0.427193]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.409, [0.367004, 0.399953, 0.425716, 0.443288]]\n",
      "['etth2', 0.336, [0.268658, 0.331473, 0.361416, 0.383505]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 20]\n",
      "['etth1', 0.412, [0.369833, 0.401135, 0.426986, 0.450556]]\n",
      "['etth1', 0.426, [0.395199, 0.415043, 0.430277, 0.46351]]\n",
      "512 12 [128, 100, 20]\n",
      "['etth2', 0.349, [0.27367, 0.351924, 0.362329, 0.409405]]\n",
      "['etth2', 0.393, [0.337379, 0.38576, 0.403889, 0.446943]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 40]\n",
      "['etth1', 0.41, [0.367628, 0.398997, 0.424761, 0.448718]]\n",
      "['etth1', 0.424, [0.392802, 0.411693, 0.428869, 0.46172]]\n",
      "512 12 [128, 100, 40]\n",
      "['etth2', 0.342, [0.270805, 0.329655, 0.360081, 0.405818]]\n",
      "['etth2', 0.389, [0.334696, 0.374513, 0.401256, 0.444147]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 12 [128, 100, 60]\n",
      "['etth1', 0.413, [0.368143, 0.40163, 0.431294, 0.452574]]\n",
      "['etth1', 0.426, [0.393255, 0.413448, 0.432959, 0.464493]]\n",
      "512 12 [128, 100, 60]\n",
      "['etth2', 0.335, [0.268891, 0.329424, 0.356182, 0.383998]]\n",
      "['etth2', 0.383, [0.333947, 0.374491, 0.39878, 0.42649]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[128, 100, 80]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['etth1','etth2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                    ps_list = [12],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[64, 100, 10]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 18 [64, 100, 20]\n",
      "['etth1', 0.41, [0.365218, 0.400995, 0.423332, 0.450563]]\n",
      "['etth1', 0.424, [0.391431, 0.413148, 0.427653, 0.462756]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 18 [64, 100, 40]\n",
      "['etth1', 0.409, [0.365598, 0.402412, 0.423375, 0.445605]]\n",
      "['etth1', 0.423, [0.391206, 0.414954, 0.427069, 0.46031]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 18 [128, 100, 20]\n",
      "['etth1', 0.409, [0.36793, 0.399297, 0.427911, 0.441933]]\n",
      "['etth1', 0.423, [0.393089, 0.411932, 0.430118, 0.458442]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "512 18 [128, 100, 40]\n",
      "['etth1', 0.41, [0.366477, 0.402587, 0.427737, 0.442802]]\n",
      "['etth1', 0.424, [0.391574, 0.414365, 0.430256, 0.45862]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['etth1','etth2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                    ps_list = [18],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=12 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [128, 100, 10]\n",
      "['ettm1', 0.355, [0.305161, 0.336787, 0.364896, 0.413482]]\n",
      "['ettm1', 0.379, [0.352607, 0.367616, 0.383543, 0.410637]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm1', 0.355, [0.305161, 0.336787, 0.364896, 0.413482]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [128, 100, 20]\n",
      "['ettm1', 0.356, [0.305438, 0.340933, 0.365049, 0.414066]]\n",
      "['ettm1', 0.38, [0.353531, 0.372629, 0.384395, 0.410978]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [128, 100, 40]\n",
      "['ettm1', 0.357, [0.301221, 0.341196, 0.370752, 0.416189]]\n",
      "['ettm1', 0.382, [0.353436, 0.372335, 0.389673, 0.413392]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 18 [128, 100, 60]\n",
      "['ettm1', 0.357, [0.303194, 0.340484, 0.369642, 0.41494]]\n",
      "['ettm1', 0.382, [0.353322, 0.372772, 0.391113, 0.411955]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[128, 100, 80]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [32, 100, 20]\n",
      "['ettm1', 0.352, [0.296155, 0.335201, 0.365374, 0.413146]]\n",
      "['ettm1', 0.38, [0.349286, 0.370443, 0.389778, 0.410525]]\n",
      "768 24 [32, 100, 20]\n",
      "['ettm2', 0.247, [0.16001, 0.215011, 0.266477, 0.346457]]\n",
      "['ettm2', 0.311, [0.250721, 0.290827, 0.324929, 0.377749]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [32, 100, 40]\n",
      "['ettm1', 0.354, [0.293562, 0.339131, 0.364947, 0.41769]]\n",
      "['ettm1', 0.383, [0.348349, 0.374079, 0.391725, 0.417212]]\n",
      "768 24 [32, 100, 40]\n",
      "['ettm2', 0.247, [0.160448, 0.214805, 0.266347, 0.34611]]\n",
      "['ettm2', 0.311, [0.251357, 0.290989, 0.324536, 0.377417]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [64, 100, 20]\n",
      "['ettm2', 0.249, [0.160826, 0.214884, 0.266371, 0.354104]]\n",
      "['ettm2', 0.311, [0.251565, 0.28992, 0.32492, 0.377353]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [64, 100, 40]\n",
      "['ettm2', 0.247, [0.160623, 0.214752, 0.265673, 0.347104]]\n",
      "['ettm2', 0.31, [0.251437, 0.289801, 0.322871, 0.377378]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 20]\n",
      "['ettm1', 0.359, [0.30436, 0.349011, 0.365377, 0.415677]]\n",
      "['ettm1', 0.382, [0.352649, 0.37767, 0.383957, 0.412384]]\n",
      "768 24 [128, 100, 20]\n",
      "['ettm2', 0.248, [0.161379, 0.214941, 0.265864, 0.348564]]\n",
      "['ettm2', 0.311, [0.252051, 0.289749, 0.324662, 0.378198]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [128, 100, 40]\n",
      "['ettm1', 0.356, [0.301438, 0.341813, 0.365344, 0.416967]]\n",
      "['ettm1', 0.381, [0.351786, 0.373026, 0.384096, 0.413711]]\n",
      "768 24 [128, 100, 40]\n",
      "['ettm2', 0.248, [0.160688, 0.214998, 0.266646, 0.348311]]\n",
      "['ettm2', 0.311, [0.250993, 0.290018, 0.324472, 0.37688]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 20]\n",
      "['ettm1', 0.355, [0.303594, 0.336865, 0.365956, 0.415331]]\n",
      "['ettm1', 0.379, [0.351341, 0.367194, 0.384285, 0.412081]]\n",
      "768 24 [256, 100, 20]\n",
      "['ettm2', 0.247, [0.160855, 0.215184, 0.266942, 0.345996]]\n",
      "['ettm2', 0.311, [0.251564, 0.2901, 0.325314, 0.377305]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "768 24 [256, 100, 40]\n",
      "['ettm1', 0.357, [0.305057, 0.343502, 0.366031, 0.414423]]\n",
      "['ettm1', 0.381, [0.353223, 0.373724, 0.384691, 0.411572]]\n",
      "768 24 [256, 100, 40]\n",
      "['ettm2', 0.249, [0.160786, 0.21506, 0.266249, 0.354045]]\n",
      "['ettm2', 0.311, [0.251145, 0.290124, 0.324012, 0.378902]]\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n",
      "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n"
     ]
    }
   ],
   "source": [
    "# 64-60-20\n",
    "\n",
    "\n",
    "\n",
    "dataset_list = ['ettm1','ettm2']\n",
    "#dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                    ps_list = [12,18,24],\n",
    "                    load_ep_list = [100]\n",
    "                    )\n",
    " #for d_model in [64,128]:\n",
    " #                       for load in [60,80,100,120,150]:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['etth1', 0.412, [0.369833, 0.401135, 0.426986, 0.450556]]\n",
      "['etth2', 0.349, [0.27367, 0.351924, 0.362329, 0.409405]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['etth1', 0.41, [0.367628, 0.398997, 0.424761, 0.448718]]\n",
      "['etth2', 0.342, [0.270805, 0.329655, 0.360081, 0.405818]]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.409, [0.367004, 0.399953, 0.425716, 0.443288]]\n",
      "['etth2', 0.336, [0.268658, 0.331473, 0.361416, 0.383505]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.409, [0.367004, 0.399953, 0.425716, 0.443288]]\n",
      "['etth2', 0.336, [0.268658, 0.331473, 0.361416, 0.383505]]\n"
     ]
    }
   ],
   "source": [
    "# 결론\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['ettm1', 0.353, [0.296881, 0.335654, 0.366059, 0.414436]]\n",
      "['ettm2', 0.247, [0.15993, 0.214827, 0.266328, 0.346441]]\n",
      "--------------------------------------------------\n",
      "[32, 40, 40]\n",
      "['ettm1', 0.353, [0.294341, 0.338516, 0.364529, 0.413831]]\n",
      "['ettm2', 0.247, [0.1604, 0.21458, 0.266214, 0.346066]]\n",
      "--------------------------------------------------\n",
      "[32, 80, 20]\n",
      "['ettm1', 0.353, [0.296547, 0.335128, 0.364217, 0.414499]]\n",
      "['ettm2', 0.247, [0.159995, 0.215045, 0.266283, 0.346428]]\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm1', 0.352, [0.296155, 0.335201, 0.365374, 0.413146]]\n",
      "['ettm2', 0.247, [0.16001, 0.215011, 0.266477, 0.346457]]\n",
      "--------------------------------------------------\n",
      "[32, 120, 20]\n",
      "['ettm1', 0.352, [0.296041, 0.334922, 0.365546, 0.413186]]\n",
      "['ettm2', 0.247, [0.159988, 0.21503, 0.266379, 0.346454]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1','ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                    ps_list = [24])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm1', 0.356, [0.305438, 0.340933, 0.365049, 0.414066]]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 10]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm1', 0.355, [0.305161, 0.336787, 0.364896, 0.413482]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 20]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm2', 0.245, [0.160869, 0.214557, 0.262767, 0.343502]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['ettm2', 0.244, [0.160637, 0.216156, 0.262441, 0.338674]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[64, 100, 10]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm2', 0.244, [0.160251, 0.21331, 0.26441, 0.337841]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )\n",
    "\n",
    "dataset_list = ['ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
