{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae1(PATH, settings_list, cw,patch_size, model_dim, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae2(PATH, settings_list, cw,patch_size, model_dim, tau_temp, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae3(PATH, settings_list, cw,patch_size, model_dim, tau_temp,tau_inst, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    y = [x for x in y if f'tau_inst{tau_inst}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary(dataset_list, type_='hard'):\n",
    "    for data in dataset_list:\n",
    "        print('='*50)\n",
    "        PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "        PATH2 = f'{data}2{data}'\n",
    "        PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'\n",
    "        PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "        settings = os.listdir(PATH)\n",
    "        \n",
    "        for share in [1]:\n",
    "            if share==0:\n",
    "                settings = [x for x in settings if 'no_share' in x]\n",
    "            else:\n",
    "                settings = [x for x in settings if 'no_share' not  in x]\n",
    "            \n",
    "            if type_ == 'hard':\n",
    "                settings = [x for x in settings if 'tau' not in x]\n",
    "            elif type_ == 'soft1':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' not in x]\n",
    "            elif type_ == 'soft2':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' in x]\n",
    "                \n",
    "            if type_ == 'hard':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                for ft_ep in [20,40]:\n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw = cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    try:\n",
    "                                        if (mse_result[0] != 999):\n",
    "                                        #if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                            #print([ps,dim,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                            print([cw,ps,dim,load_ep,ft_ep],np.mean(mse_result).round(3), mse_result)\n",
    "                                    except:\n",
    "                                        pass\n",
    "                \n",
    "            elif type_ == 'soft1':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        print(f'-------- cw={cw},ps={ps} ---------')\n",
    "                        min_val = 999\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                    for ft_ep in [20,40]:\n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if (mse_result[0] != 999):\n",
    "                                                if np.sum(mse_result)<min_val:\n",
    "                                                    print([cw,ps,dim,tau1,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                                    min_val = np.sum(mse_result)\n",
    "                                        except:\n",
    "                                            pass\n",
    "                \n",
    "            elif type_ == 'soft2':\n",
    "                for cw in [336,512,768]:\n",
    "                        \n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for tau2 in [1,3,5]:\n",
    "                                    for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                        for ft_ep in [20,40]:\n",
    "                                            mse_result, mae_result = get_mse_mae3(PATH = PATH,settings_list = settings,\n",
    "                                                                                cw=cw,patch_size=ps, \n",
    "                                                                                model_dim=dim, \n",
    "                                                                                tau_temp=tau1, tau_inst=tau2,\n",
    "                                                                                load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                            try:\n",
    "                                                if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                                    #print([ps,dim,tau1,tau2,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                                    print([cw,ps,dim,tau1,tau2,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                            except:\n",
    "                                                pass\n",
    "                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft1(dataset_list, type_='soft1'):\n",
    "    for cw in [336,512,768]:\n",
    "        for ps in [12,18,24]:\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in [40,60,80,100,120,150]:\n",
    "                        for ft_ep in [20,40]:\n",
    "                            print('='*40)\n",
    "                            print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                min_val = 999\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                except:\n",
    "                                    pass\n",
    "                                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_hard(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for load_ep in load_ep_list:\n",
    "                    for ft_ep in ft_ep_list:\n",
    "                        #print('='*40)\n",
    "                        #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                        data_sum = 0\n",
    "                        data_results = []\n",
    "                        for data in dataset_list:\n",
    "                            #print(f'============== {data} ==============')\n",
    "                            PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                            PATH2 = f'{data}2{data}'\n",
    "                            PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'\n",
    "                            PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                            try:\n",
    "                                settings = os.listdir(PATH)\n",
    "                                \n",
    "                                for share in [1]:\n",
    "                                    settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                    settings = [x for x in settings if 'tau' not in x]\n",
    "    \n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw=cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    \n",
    "                                    try:\n",
    "                                        if len(mse_result)>0:\n",
    "                                            #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            data_sum += np.sum(mse_result)\n",
    "                                            data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                        \n",
    "                                    except:\n",
    "                                        pass\n",
    "                                    \n",
    "                            except:\n",
    "                                pass\n",
    "                        \n",
    "                        #for k in data_results:\n",
    "                            #print(k)\n",
    "                            \n",
    "                        if min_val > data_sum:\n",
    "                            min_val = data_sum\n",
    "                            print('-'*50)\n",
    "                            print([dim,load_ep,ft_ep])\n",
    "                            for k in data_results:\n",
    "                                print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft2(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in load_ep_list:\n",
    "                        for ft_ep in ft_ep_list:\n",
    "                            #print('='*40)\n",
    "                            #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            data_sum = 0\n",
    "                            data_results = []\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        \n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                                data_sum += np.sum(mse_result)\n",
    "                                                data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                        \n",
    "                                except:\n",
    "                                    pass\n",
    "                            \n",
    "                            #for k in data_results:\n",
    "                            #    print(k)\n",
    "                                \n",
    "                            if min_val > data_sum:\n",
    "                                min_val = data_sum\n",
    "                                print('-'*50)\n",
    "                                print([dim,tau1,load_ep,ft_ep])\n",
    "                                for k in data_results:\n",
    "                                    print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['etth1', 0.415, [0.371676, 0.402808, 0.422699, 0.462367]]\n",
      "['etth2', 0.36, [0.289003, 0.354964, 0.382737, 0.412405]]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.417, [0.37242, 0.409006, 0.422224, 0.464799]]\n",
      "['etth2', 0.352, [0.306646, 0.362051, 0.387208]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.417, [0.37242, 0.409006, 0.422224, 0.464799]]\n",
      "['etth2', 0.366, [0.306646, 0.362051, 0.387208, 0.408051]]\n"
     ]
    }
   ],
   "source": [
    "# 결론\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 80, 20]\n",
      "['ettm1', 0.352, [0.290961, 0.337703, 0.359835, 0.418035]]\n",
      "['ettm2', 0.249, [0.160919, 0.215947, 0.267152, 0.351188]]\n"
     ]
    }
   ],
   "source": [
    "# [32, 80, 20]\n",
    "dataset_list = ['ettm1','ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],ps_list = [24],\n",
    "                      load_ep_list = [80],\n",
    "                      ft_ep_list = [20]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm1', 0.357, [0.307685, 0.340585, 0.364624, 0.415912]]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm1', 0.356, [0.302219, 0.332107, 0.363056, 0.426204]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm2', 0.258, [0.167318, 0.224509, 0.274439, 0.364615]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['ettm2', 0.254, [0.166652, 0.225389, 0.274167, 0.351233]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm2', 0.256, [0.170782, 0.225859, 0.274265, 0.353619]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['ettm2', 0.254, [0.166652, 0.225389, 0.274167, 0.351233]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )\n",
    "\n",
    "dataset_list = ['ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 100, 40]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 100, 40]\n",
      "['ettm1', 0.355, [0.297847, 0.342519, 0.364561, 0.415406]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 100, 40]\n",
      "['ettm1', 0.353, [0.296363, 0.336456, 0.365984, 0.412797]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[64, 1, 100, 40]\n",
      "['ettm1', 0.35, [0.292794, 0.327668, 0.361679, 0.418451]]\n",
      "['ettm2', 0.252, [0.163636, 0.219648, 0.272633, 0.350361]]\n"
     ]
    }
   ],
   "source": [
    "# [64, 1, 100, 40]\n",
    "dataset_list = ['ettm1','ettm2']\n",
    "#dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                 cw_list = [768],ps_list = [24],\n",
    "                      load_ep_list = [100],\n",
    "                      ft_ep_list = [40]\n",
    "                      \n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def rename_files_with_string(directory, old_string, new_string):\n",
    "    for root, dirs, files in os.walk(directory):\n",
    "        for file in files:\n",
    "            if old_string in file:\n",
    "                old_path = os.path.join(root, file)\n",
    "                new_file = file.replace(old_string, new_string)\n",
    "                new_path = os.path.join(root, new_file)\n",
    "                os.rename(old_path, new_path)\n",
    "\n",
    "# Example usage\n",
    "directory = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/ettm2/masked_patchtst_sim_half_v3_mean_FC_wo_MTM_R'\n",
    "old_string = '-pretrain150_'\n",
    "new_string = '-pretrain100_'\n",
    "\n",
    "rename_files_with_string(directory, old_string, new_string)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=336,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=336,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=336,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 0.357, [0.300547, 0.334937, 0.369806, 0.423647]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 40]\n",
      "['ettm1', 0.355, [0.296704, 0.34282, 0.367656, 0.414131]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 60, 20]\n",
      "['ettm1', 0.355, [0.296817, 0.336059, 0.365133, 0.421573]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 60, 40]\n",
      "['ettm1', 0.355, [0.296622, 0.344318, 0.366487, 0.411628]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 120, 20]\n",
      "['ettm1', 0.354, [0.299303, 0.339006, 0.365856, 0.412505]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 40, 20]\n",
      "['ettm1', 0.353, [0.297734, 0.336297, 0.363208, 0.415379]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 60, 20]\n",
      "['ettm1', 0.353, [0.296598, 0.332272, 0.366775, 0.416034]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 60, 40]\n",
      "['ettm1', 0.351, [0.295075, 0.333829, 0.361125, 0.415158]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 150, 20]\n",
      "['ettm1', 0.372, [0.334196, 0.369228, 0.413884]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 150, 40]\n",
      "['ettm1', 0.357, [0.295548, 0.363236, 0.413197]]\n",
      "--------------------------------------------------\n",
      "[64, 5, 60, 40]\n",
      "['ettm1', 0.331, [0.29516, 0.337533, 0.359445]]\n",
      "--------------------------------------------------\n",
      "[128, 1, 60, 20]\n",
      "['ettm1', 0.304, [0.303713]]\n",
      "--------------------------------------------------\n",
      "[128, 1, 60, 40]\n",
      "============================================================\n",
      "-------- cw=512,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 40, 20]\n",
      "['ettm1', 0.352, [0.297086, 0.330612, 0.359805, 0.419866]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 150, 20]\n",
      "['ettm1', 0.297, [0.297163]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1','ettm2']\n",
    "\n",
    "get_summary_soft2(dataset_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "[512, 18, 32, 40, 20] 0.41 [0.36716, 0.400839, 0.419742, 0.451334]\n",
      "[512, 18, 32, 40, 40] 0.41 [0.365256, 0.404765, 0.423866, 0.447744]\n",
      "[512, 18, 32, 60, 20] 0.408 [0.364521, 0.399924, 0.417852, 0.449267]\n",
      "[512, 18, 32, 60, 40] 0.408 [0.365142, 0.400457, 0.422431, 0.445234]\n",
      "[512, 18, 32, 80, 20] 0.407 [0.366638, 0.399722, 0.417352, 0.444421]\n",
      "[512, 18, 32, 80, 40] 0.408 [0.364119, 0.39938, 0.421672, 0.445064]\n",
      "[512, 18, 32, 100, 20] 0.407 [0.365624, 0.3985, 0.419463, 0.445174]\n",
      "[512, 18, 32, 100, 40] 0.408 [0.363248, 0.40262, 0.418684, 0.446275]\n",
      "[512, 18, 32, 120, 20] 0.407 [0.365166, 0.397805, 0.418563, 0.444608]\n",
      "[512, 18, 32, 120, 40] 0.406 [0.364134, 0.398247, 0.41764, 0.445882]\n",
      "[512, 18, 32, 150, 20] 0.405 [0.364929, 0.397642, 0.414816, 0.444084]\n",
      "[512, 18, 32, 150, 40] 0.407 [0.364156, 0.39803, 0.418594, 0.445248]\n",
      "[512, 18, 64, 40, 20] 0.412 [0.368275, 0.399434, 0.424424, 0.45396]\n",
      "[512, 18, 64, 40, 40] 0.412 [0.368419, 0.409365, 0.418575, 0.453586]\n",
      "[512, 18, 64, 60, 20] 0.411 [0.366687, 0.397166, 0.423451, 0.455559]\n",
      "[512, 18, 64, 60, 40] 0.41 [0.366962, 0.403178, 0.416336, 0.453823]\n",
      "[512, 18, 64, 80, 20] 0.411 [0.3665, 0.404335, 0.418353, 0.455637]\n",
      "[512, 18, 64, 80, 40] 0.41 [0.366746, 0.405954, 0.414197, 0.451939]\n",
      "[512, 18, 64, 100, 20] 0.411 [0.36648, 0.401578, 0.418372, 0.457807]\n",
      "[512, 18, 64, 100, 40] 0.41 [0.366296, 0.407742, 0.412816, 0.451667]\n",
      "[512, 18, 64, 120, 20] 0.409 [0.366343, 0.401489, 0.417366, 0.449675]\n",
      "[512, 18, 64, 120, 40] 0.409 [0.365942, 0.406257, 0.412446, 0.451846]\n",
      "[512, 18, 64, 150, 20] 0.409 [0.366281, 0.402884, 0.417249, 0.449525]\n",
      "[512, 18, 64, 150, 40] 0.412 [0.365891, 0.406463, 0.422453, 0.452025]\n",
      "[512, 18, 128, 40, 20] 0.412 [0.36914, 0.400416, 0.415054, 0.4621]\n",
      "[512, 18, 128, 40, 40] 0.411 [0.368823, 0.398967, 0.422588, 0.452105]\n",
      "[512, 18, 128, 60, 20] 0.414 [0.370851, 0.401395, 0.416747, 0.467263]\n",
      "[512, 18, 128, 60, 40] 0.413 [0.373546, 0.3984, 0.422675, 0.456285]\n",
      "[512, 18, 128, 80, 20] 0.414 [0.372355, 0.40083, 0.412523, 0.472278]\n",
      "[512, 18, 128, 80, 40] 0.412 [0.370145, 0.397906, 0.420961, 0.458535]\n",
      "[512, 18, 128, 100, 20] 0.41 [0.372216, 0.401057, 0.409604, 0.458331]\n",
      "[512, 18, 128, 100, 40] 0.412 [0.367023, 0.397235, 0.423726, 0.458359]\n",
      "[512, 18, 128, 120, 20] 0.411 [0.371405, 0.400702, 0.409772, 0.460861]\n",
      "[512, 18, 128, 120, 40] 0.414 [0.367194, 0.397166, 0.424684, 0.465298]\n",
      "[512, 18, 128, 150, 20] 0.415 [0.371853, 0.400567, 0.427491, 0.461657]\n",
      "[512, 18, 128, 150, 40] 0.414 [0.367122, 0.396958, 0.423582, 0.466506]\n",
      "[512, 18, 256, 40, 20] 0.412 [0.370983, 0.39937, 0.418833, 0.456986]\n",
      "[512, 18, 256, 40, 40] 0.414 [0.374174, 0.399499, 0.418287, 0.463048]\n",
      "[512, 18, 256, 60, 20] 0.414 [0.372177, 0.402571, 0.417825, 0.463324]\n",
      "[512, 18, 256, 60, 40] 0.415 [0.37472, 0.40046, 0.416746, 0.469481]\n",
      "[512, 18, 256, 80, 20] 0.415 [0.374697, 0.402429, 0.420527, 0.462715]\n",
      "[512, 18, 256, 80, 40] 0.416 [0.374351, 0.400888, 0.417747, 0.472434]\n",
      "[512, 18, 256, 100, 20] 0.414 [0.376091, 0.401026, 0.417656, 0.460857]\n",
      "[512, 18, 256, 100, 40] 0.418 [0.379625, 0.400297, 0.418651, 0.473488]\n",
      "[512, 18, 256, 120, 20] 0.417 [0.386451, 0.400838, 0.416831, 0.464736]\n",
      "[512, 18, 256, 120, 40] 0.419 [0.378929, 0.400177, 0.418511, 0.477794]\n",
      "[512, 18, 256, 150, 20] 0.415 [0.385565, 0.400495, 0.416636, 0.458867]\n",
      "[512, 18, 256, 150, 40] 0.419 [0.377895, 0.400107, 0.418529, 0.480106]\n",
      "==================================================\n",
      "[512, 18, 32, 40, 20] 0.352 [0.281341, 0.359763, 0.368792, 0.397942]\n",
      "[512, 18, 32, 40, 40] 0.354 [0.28953, 0.355631, 0.372666, 0.397381]\n",
      "[512, 18, 32, 60, 20] 0.354 [0.286745, 0.354237, 0.375135, 0.40094]\n",
      "[512, 18, 32, 60, 40] 0.358 [0.288666, 0.368767, 0.374586, 0.401067]\n",
      "[512, 18, 32, 80, 20] 0.359 [0.294002, 0.366626, 0.373912, 0.402481]\n",
      "[512, 18, 32, 80, 40] 0.357 [0.286547, 0.365962, 0.372917, 0.402056]\n",
      "[512, 18, 32, 100, 20] 0.36 [0.29426, 0.366428, 0.374475, 0.402951]\n",
      "[512, 18, 32, 100, 40] 0.358 [0.290499, 0.365357, 0.374137, 0.402438]\n",
      "[512, 18, 32, 120, 20] 0.362 [0.294682, 0.366446, 0.38395, 0.400985]\n",
      "[512, 18, 32, 120, 40] 0.357 [0.287073, 0.364907, 0.37405, 0.400737]\n",
      "[512, 18, 32, 150, 20] 0.362 [0.294598, 0.366566, 0.38369, 0.401205]\n",
      "[512, 18, 32, 150, 40] 0.357 [0.287154, 0.364771, 0.373962, 0.400918]\n",
      "[512, 18, 64, 40, 20] 0.353 [0.290296, 0.348747, 0.364994, 0.408905]\n",
      "[512, 18, 64, 40, 40] 0.357 [0.310899, 0.346009, 0.367209, 0.404536]\n",
      "[512, 18, 64, 60, 20] 0.358 [0.286714, 0.356659, 0.373527, 0.414986]\n",
      "[512, 18, 64, 60, 40] 0.357 [0.287713, 0.355426, 0.375521, 0.410677]\n",
      "[512, 18, 64, 80, 20] 0.357 [0.284233, 0.354639, 0.376343, 0.413147]\n",
      "[512, 18, 64, 80, 40] 0.358 [0.284382, 0.356498, 0.375638, 0.414562]\n",
      "[512, 18, 64, 100, 20] 0.357 [0.283456, 0.354527, 0.381937, 0.40958]\n",
      "[512, 18, 64, 100, 40] 0.356 [0.285319, 0.354294, 0.37539, 0.407639]\n",
      "[512, 18, 64, 120, 20] 0.359 [0.286537, 0.354564, 0.382724, 0.412665]\n",
      "[512, 18, 64, 120, 40] 0.355 [0.284121, 0.353678, 0.375565, 0.407715]\n",
      "[512, 18, 64, 150, 20] 0.359 [0.286519, 0.354579, 0.383409, 0.412723]\n",
      "[512, 18, 64, 150, 40] 0.355 [0.284014, 0.353548, 0.375293, 0.407793]\n",
      "[512, 18, 128, 40, 20] 0.357 [0.289745, 0.355865, 0.370478, 0.410089]\n",
      "[512, 18, 128, 40, 40] 0.356 [0.291935, 0.351424, 0.372379, 0.408173]\n",
      "[512, 18, 128, 60, 20] 0.365 [0.291763, 0.363237, 0.38387, 0.420106]\n",
      "[512, 18, 128, 60, 40] 0.364 [0.294414, 0.365269, 0.388001, 0.408472]\n",
      "[512, 18, 128, 80, 20] 0.365 [0.291173, 0.361365, 0.382828, 0.42405]\n",
      "[512, 18, 128, 80, 40] 0.364 [0.29788, 0.361734, 0.38602, 0.411334]\n",
      "[512, 18, 128, 100, 20] 0.364 [0.290246, 0.360501, 0.380491, 0.423749]\n",
      "[512, 18, 128, 100, 40] 0.369 [0.296136, 0.366039, 0.386843, 0.428176]\n",
      "[512, 18, 128, 120, 20] 0.365 [0.290031, 0.360093, 0.384869, 0.423542]\n",
      "[512, 18, 128, 120, 40] 0.368 [0.29862, 0.362491, 0.383585, 0.427385]\n",
      "[512, 18, 128, 150, 20] 0.365 [0.289915, 0.361047, 0.385081, 0.423605]\n",
      "[512, 18, 128, 150, 40] 0.367 [0.291817, 0.366968, 0.383454, 0.426762]\n",
      "[512, 18, 256, 40, 20] 0.355 [0.28285, 0.357526, 0.368882, 0.409885]\n",
      "[512, 18, 256, 40, 40] 0.354 [0.284503, 0.348006, 0.379697, 0.40226]\n",
      "[512, 18, 256, 60, 20] 0.362 [0.293841, 0.361845, 0.377539, 0.414338]\n",
      "[512, 18, 256, 60, 40] 0.373 [0.297524, 0.370843, 0.384187, 0.437947]\n",
      "[512, 18, 256, 80, 20] 0.36 [0.297691, 0.358605, 0.375266, 0.408666]\n",
      "[512, 18, 256, 80, 40] 0.369 [0.309174, 0.368745, 0.385184, 0.411181]\n",
      "[512, 18, 256, 100, 20] 0.361 [0.298074, 0.357657, 0.375477, 0.414258]\n",
      "[512, 18, 256, 100, 40] 0.368 [0.309439, 0.37073, 0.383889, 0.407469]\n",
      "[512, 18, 256, 120, 20] 0.36 [0.297835, 0.359099, 0.373822, 0.408659]\n",
      "[512, 18, 256, 120, 40] 0.364 [0.310855, 0.360077, 0.374454, 0.410327]\n",
      "[512, 18, 256, 150, 20] 0.36 [0.297712, 0.359539, 0.375084, 0.408666]\n",
      "[512, 18, 256, 150, 40] 0.364 [0.310699, 0.359743, 0.37561, 0.410512]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "#dataset_list = ['traffic']\n",
    "get_summary(dataset_list, type_='hard')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "[512, 18, 32, 40, 20] 0.352 [0.281341, 0.359763, 0.368792, 0.397942]\n",
      "[512, 18, 32, 40, 40] 0.354 [0.28953, 0.355631, 0.372666, 0.397381]\n",
      "[512, 18, 32, 60, 20] 0.354 [0.286745, 0.354237, 0.375135, 0.40094]\n",
      "[512, 18, 32, 60, 40] 0.358 [0.288666, 0.368767, 0.374586, 0.401067]\n",
      "[512, 18, 32, 80, 20] 0.359 [0.294002, 0.366626, 0.373912, 0.402481]\n",
      "[512, 18, 32, 80, 40] 0.357 [0.286547, 0.365962, 0.372917, 0.402056]\n",
      "[512, 18, 32, 100, 20] 0.36 [0.29426, 0.366428, 0.374475, 0.402951]\n",
      "[512, 18, 32, 100, 40] 0.358 [0.290499, 0.365357, 0.374137, 0.402438]\n",
      "[512, 18, 32, 120, 20] 0.362 [0.294682, 0.366446, 0.38395, 0.400985]\n",
      "[512, 18, 32, 120, 40] 0.357 [0.287073, 0.364907, 0.37405, 0.400737]\n",
      "[512, 18, 32, 150, 20] 0.362 [0.294598, 0.366566, 0.38369, 0.401205]\n",
      "[512, 18, 32, 150, 40] 0.357 [0.287154, 0.364771, 0.373962, 0.400918]\n",
      "[512, 18, 64, 40, 20] 0.353 [0.290296, 0.348747, 0.364994, 0.408905]\n",
      "[512, 18, 64, 40, 40] 0.357 [0.310899, 0.346009, 0.367209, 0.404536]\n",
      "[512, 18, 64, 60, 20] 0.358 [0.286714, 0.356659, 0.373527, 0.414986]\n",
      "[512, 18, 64, 60, 40] 0.357 [0.287713, 0.355426, 0.375521, 0.410677]\n",
      "[512, 18, 64, 80, 20] 0.357 [0.284233, 0.354639, 0.376343, 0.413147]\n",
      "[512, 18, 64, 80, 40] 0.358 [0.284382, 0.356498, 0.375638, 0.414562]\n",
      "[512, 18, 64, 100, 20] 0.357 [0.283456, 0.354527, 0.381937, 0.40958]\n",
      "[512, 18, 64, 100, 40] 0.356 [0.285319, 0.354294, 0.37539, 0.407639]\n",
      "[512, 18, 64, 120, 20] 0.359 [0.286537, 0.354564, 0.382724, 0.412665]\n",
      "[512, 18, 64, 120, 40] 0.355 [0.284121, 0.353678, 0.375565, 0.407715]\n",
      "[512, 18, 64, 150, 20] 0.359 [0.286519, 0.354579, 0.383409, 0.412723]\n",
      "[512, 18, 64, 150, 40] 0.355 [0.284014, 0.353548, 0.375293, 0.407793]\n",
      "[512, 18, 128, 40, 20] 0.357 [0.289745, 0.355865, 0.370478, 0.410089]\n",
      "[512, 18, 128, 40, 40] 0.356 [0.291935, 0.351424, 0.372379, 0.408173]\n",
      "[512, 18, 128, 60, 20] 0.365 [0.291763, 0.363237, 0.38387, 0.420106]\n",
      "[512, 18, 128, 60, 40] 0.364 [0.294414, 0.365269, 0.388001, 0.408472]\n",
      "[512, 18, 128, 80, 20] 0.365 [0.291173, 0.361365, 0.382828, 0.42405]\n",
      "[512, 18, 128, 80, 40] 0.364 [0.29788, 0.361734, 0.38602, 0.411334]\n",
      "[512, 18, 128, 100, 20] 0.364 [0.290246, 0.360501, 0.380491, 0.423749]\n",
      "[512, 18, 128, 100, 40] 0.369 [0.296136, 0.366039, 0.386843, 0.428176]\n",
      "[512, 18, 128, 120, 20] 0.365 [0.290031, 0.360093, 0.384869, 0.423542]\n",
      "[512, 18, 128, 120, 40] 0.368 [0.29862, 0.362491, 0.383585, 0.427385]\n",
      "[512, 18, 128, 150, 20] 0.365 [0.289915, 0.361047, 0.385081, 0.423605]\n",
      "[512, 18, 128, 150, 40] 0.367 [0.291817, 0.366968, 0.383454, 0.426762]\n",
      "[512, 18, 256, 40, 20] 0.355 [0.28285, 0.357526, 0.368882, 0.409885]\n",
      "[512, 18, 256, 40, 40] 0.354 [0.284503, 0.348006, 0.379697, 0.40226]\n",
      "[512, 18, 256, 60, 20] 0.362 [0.293841, 0.361845, 0.377539, 0.414338]\n",
      "[512, 18, 256, 60, 40] 0.373 [0.297524, 0.370843, 0.384187, 0.437947]\n",
      "[512, 18, 256, 80, 20] 0.36 [0.297691, 0.358605, 0.375266, 0.408666]\n",
      "[512, 18, 256, 80, 40] 0.369 [0.309174, 0.368745, 0.385184, 0.411181]\n",
      "[512, 18, 256, 100, 20] 0.361 [0.298074, 0.357657, 0.375477, 0.414258]\n",
      "[512, 18, 256, 100, 40] 0.368 [0.309439, 0.37073, 0.383889, 0.407469]\n",
      "[512, 18, 256, 120, 20] 0.36 [0.297835, 0.359099, 0.373822, 0.408659]\n",
      "[512, 18, 256, 120, 40] 0.364 [0.310855, 0.360077, 0.374454, 0.410327]\n",
      "[512, 18, 256, 150, 20] 0.36 [0.297712, 0.359539, 0.375084, 0.408666]\n",
      "[512, 18, 256, 150, 40] 0.364 [0.310699, 0.359743, 0.37561, 0.410512]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth2']\n",
    "#dataset_list = ['traffic']\n",
    "get_summary(dataset_list, type_='hard')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n"
     ]
    },
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/traffic2traffic/masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[34], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m#dataset_list = ['etth2']\u001b[39;00m\n\u001b[1;32m      2\u001b[0m dataset_list \u001b[39m=\u001b[39m [\u001b[39m'\u001b[39m\u001b[39mtraffic\u001b[39m\u001b[39m'\u001b[39m]\n\u001b[0;32m----> 3\u001b[0m get_summary(dataset_list, type_\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mhard\u001b[39;49m\u001b[39m'\u001b[39;49m)\n",
      "Cell \u001b[0;32mIn[28], line 8\u001b[0m, in \u001b[0;36mget_summary\u001b[0;34m(dataset_list, type_)\u001b[0m\n\u001b[1;32m      6\u001b[0m PATH3 \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mmasked_patchtst_sim_half_v3_mean_FC2_R/based_model/max\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m      7\u001b[0m PATH \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(PATH1, PATH2, PATH3)\n\u001b[0;32m----> 8\u001b[0m settings \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39;49mlistdir(PATH)\n\u001b[1;32m     10\u001b[0m \u001b[39mfor\u001b[39;00m share \u001b[39min\u001b[39;00m [\u001b[39m1\u001b[39m]:\n\u001b[1;32m     11\u001b[0m     \u001b[39mif\u001b[39;00m share\u001b[39m==\u001b[39m\u001b[39m0\u001b[39m:\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/traffic2traffic/masked_patchtst_sim_half_v3_mean_FC2_R/based_model/max'"
     ]
    }
   ],
   "source": [
    "#dataset_list = ['etth2']\n",
    "dataset_list = ['traffic']\n",
    "get_summary(dataset_list, type_='hard')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
