{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae1(PATH, settings_list, cw,patch_size, model_dim, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae2(PATH, settings_list, cw,patch_size, model_dim, tau_temp, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae3(PATH, settings_list, cw,patch_size, model_dim, tau_temp,tau_inst, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    y = [x for x in y if f'tau_inst{tau_inst}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary(dataset_list, type_='hard'):\n",
    "    for data in dataset_list:\n",
    "        print('='*50)\n",
    "        PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "        PATH2 = f'{data}2{data}'\n",
    "        PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_sep_rev_R/based_model/max'\n",
    "        PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "        settings = os.listdir(PATH)\n",
    "        \n",
    "        for share in [1]:\n",
    "            if share==0:\n",
    "                settings = [x for x in settings if 'no_share' in x]\n",
    "            else:\n",
    "                settings = [x for x in settings if 'no_share' not  in x]\n",
    "            \n",
    "            if type_ == 'hard':\n",
    "                settings = [x for x in settings if 'tau' not in x]\n",
    "            elif type_ == 'soft1':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' not in x]\n",
    "            elif type_ == 'soft2':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' in x]\n",
    "                \n",
    "            if type_ == 'hard':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                for ft_ep in [20,40]:\n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw = cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    try:\n",
    "                                        if (mse_result[0] != 999):\n",
    "                                        #if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                            #print([ps,dim,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                            print([cw,ps,dim,load_ep,ft_ep],np.mean(mse_result).round(3), mse_result)\n",
    "                                    except:\n",
    "                                        pass\n",
    "                \n",
    "            elif type_ == 'soft1':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        print(f'-------- cw={cw},ps={ps} ---------')\n",
    "                        min_val = 999\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                    for ft_ep in [20,40]:\n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if (mse_result[0] != 999):\n",
    "                                                if np.sum(mse_result)<min_val:\n",
    "                                                    print([cw,ps,dim,tau1,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                                    min_val = np.sum(mse_result)\n",
    "                                        except:\n",
    "                                            pass\n",
    "                \n",
    "            elif type_ == 'soft2':\n",
    "                for cw in [336,512,768]:\n",
    "                        \n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for tau2 in [1,3,5]:\n",
    "                                    for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                        for ft_ep in [20,40]:\n",
    "                                            mse_result, mae_result = get_mse_mae3(PATH = PATH,settings_list = settings,\n",
    "                                                                                cw=cw,patch_size=ps, \n",
    "                                                                                model_dim=dim, \n",
    "                                                                                tau_temp=tau1, tau_inst=tau2,\n",
    "                                                                                load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                            try:\n",
    "                                                if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                                    #print([ps,dim,tau1,tau2,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                                    print([cw,ps,dim,tau1,tau2,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                            except:\n",
    "                                                pass\n",
    "                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft1(dataset_list, type_='soft1'):\n",
    "    for cw in [336,512,768]:\n",
    "        for ps in [12,18,24]:\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in [40,60,80,100,120,150]:\n",
    "                        for ft_ep in [20,40]:\n",
    "                            print('='*40)\n",
    "                            print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                min_val = 999\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_sep_rev_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                except:\n",
    "                                    pass\n",
    "                                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_hard(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for load_ep in load_ep_list:\n",
    "                    for ft_ep in ft_ep_list:\n",
    "                        #print('='*40)\n",
    "                        #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                        data_sum = 0\n",
    "                        data_results = []\n",
    "                        for data in dataset_list:\n",
    "                            #print(f'============== {data} ==============')\n",
    "                            PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                            PATH2 = f'{data}2{data}'\n",
    "                            PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_sep_rev_R/based_model/max'\n",
    "                            PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                            try:\n",
    "                                settings = os.listdir(PATH)\n",
    "                                \n",
    "                                for share in [1]:\n",
    "                                    settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                    settings = [x for x in settings if 'tau' not in x]\n",
    "    \n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw=cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    \n",
    "                                    try:\n",
    "                                        if len(mse_result)>0:\n",
    "                                            #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            data_sum += np.sum(mse_result)\n",
    "                                            data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                        \n",
    "                                    except:\n",
    "                                        pass\n",
    "                                    \n",
    "                            except:\n",
    "                                pass\n",
    "                        \n",
    "                        ##for k in data_results:\n",
    "                            #print(k)\n",
    "                            \n",
    "                        if min_val > data_sum:\n",
    "                            min_val = data_sum\n",
    "                            print('-'*50)\n",
    "                            print([dim,load_ep,ft_ep])\n",
    "                            for k in data_results:\n",
    "                                print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft2(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in load_ep_list:\n",
    "                        for ft_ep in ft_ep_list:\n",
    "                            #print('='*40)\n",
    "                            #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            data_sum = 0\n",
    "                            data_results = []\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC2_sep_rev_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    #if 'etth' in data:\n",
    "                                    #    settings = [x for x in settings if 'pretrain150' in x]\n",
    "                                    ##else:\n",
    "                                     #   settings = [x for x in settings if 'pretrain100' in x]\n",
    "                                        \n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        \n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                                data_sum += np.sum(mse_result)\n",
    "                                                data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                        \n",
    "                                except:\n",
    "                                    pass\n",
    "                            \n",
    "                            #for k in data_results:\n",
    "                            #    print(k)\n",
    "                                \n",
    "                            if min_val > data_sum:\n",
    "                                min_val = data_sum\n",
    "                                print('-'*50)\n",
    "                                print([dim,tau1,load_ep,ft_ep])\n",
    "                                for k in data_results:\n",
    "                                    print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 0.433, [0.371023, 0.415114, 0.448655, 0.496418]]\n",
      "--------------------------------------------------\n",
      "[32, 40, 40]\n",
      "['etth1', 0.43, [0.374084, 0.404705, 0.437627, 0.504386]]\n",
      "--------------------------------------------------\n",
      "[32, 60, 40]\n",
      "['etth1', 0.425, [0.373264, 0.407041, 0.434586, 0.484077]]\n"
     ]
    }
   ],
   "source": [
    "#[128, 80, 40]\n",
    "#dataset_list = ['ettm1','ettm2']\n",
    "dataset_list = ['etth1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [18],\n",
    "                \n",
    "                \n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['etth1', 0.445, [0.384156, 0.433304, 0.456439, 0.506423]]\n",
      "['etth2', 0.378, [0.310063, 0.392734, 0.387055, 0.420505]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['etth1', 0.449, [0.390132, 0.470237, 0.486511]]\n",
      "['etth2', 0.373, [0.312299, 0.371213, 0.386502, 0.421732]]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.442, [0.381379, 0.41556, 0.462356, 0.509226]]\n",
      "['etth2', 0.371, [0.303087, 0.373172, 0.390653, 0.415658]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['etth1', 0.449, [0.390132, 0.470237, 0.486511]]\n",
      "['etth2', 0.373, [0.312299, 0.371213, 0.386502, 0.421732]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['etth1', 0.442, [0.381379, 0.41556, 0.462356, 0.509226]]\n",
      "['etth2', 0.371, [0.303087, 0.373172, 0.390653, 0.415658]]\n"
     ]
    }
   ],
   "source": [
    "# 결론\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],\n",
    "                 ps_list = [12],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 60, 40]\n",
      "['ettm1', 0.358, [0.295886, 0.344792, 0.366597, 0.424954]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 60, 40]\n",
      "['ettm1', 0.357, [0.296716, 0.341614, 0.364876, 0.423956]]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 60, 40]\n",
      "['ettm1', 0.357, [0.29344, 0.341223, 0.368417, 0.423297]]\n",
      "['ettm2', 0.253, [0.167097, 0.223627, 0.275526, 0.347306]]\n"
     ]
    }
   ],
   "source": [
    "# [32, 5, 60, 40]\n",
    "#dataset_list = ['ettm1','ettm2']\n",
    "dataset_list = ['ettm1','ettm2']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [24],\n",
    "                load_ep_list = [60],\n",
    "                ft_ep_list = [40]\n",
    "                \n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm1', 0.363, [0.311554, 0.344623, 0.36803, 0.427217]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['ettm1', 0.362, [0.305707, 0.340249, 0.374459, 0.427508]]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm1', 0.358, [0.300058, 0.336193, 0.369499, 0.426053]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['ettm2', 0.266, [0.175785, 0.233851, 0.284664, 0.371678]]\n",
      "============================================================\n",
      "-------- cw=1024,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 100, 10]\n",
      "['ettm2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 10]\n",
      "['ettm2', 0.265, [0.170532, 0.235247, 0.27776, 0.375613]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 60]\n",
      "['ettm2', 0.264, [0.174225, 0.240644, 0.282656, 0.357446]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],\n",
    "                 ps_list = [18],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )\n",
    "\n",
    "dataset_list = ['ettm2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100]\n",
    "                 )\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [1024],\n",
    "                 ps_list = [24],\n",
    "                 load_ep_list= [100],\n",
    "                 ft_ep_list = [10,20,40,60]\n",
    "                 )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 40, 20]\n",
      "['ettm1', 0.359, [0.300196, 0.341279, 0.368797, 0.424531]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 60, 20]\n",
      "['ettm1', 0.357, [0.299221, 0.340648, 0.366037, 0.422851]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 60, 40]\n",
      "['ettm1', 0.357, [0.296716, 0.341614, 0.364876, 0.423956]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 150, 20]\n",
      "['ettm1', 0.299, [0.298981]]\n",
      "--------------------------------------------------\n",
      "[32, 3, 150, 40]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['ettm1','ettm2']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                  cw_list = [768],ps_list = [24])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 0.433, [0.371023, 0.415114, 0.448655, 0.496418]]\n",
      "['etth2', 0.369, [0.297906, 0.369378, 0.379957, 0.430536]]\n",
      "--------------------------------------------------\n",
      "[32, 80, 20]\n",
      "['etth1', 0.431, [0.37483, 0.408849, 0.445778, 0.495186]]\n",
      "['etth2', 0.37, [0.300597, 0.371917, 0.389164, 0.416744]]\n",
      "--------------------------------------------------\n",
      "[32, 150, 20]\n",
      "['etth1', 0.429, [0.378907, 0.415618, 0.436571, 0.482948]]\n",
      "['etth2', 0.372, [0.302057, 0.375161, 0.391676, 0.418842]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],ps_list = [18])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['etth1', 0.422, [0.375807, 0.407021, 0.429339, 0.477402]]\n",
      "['etth2', 0.364, [0.294597, 0.373749, 0.374376, 0.412283]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 60, 20]\n",
      "['etth1', 0.414, [0.369475, 0.40065, 0.415859, 0.469105]]\n",
      "['etth2', 0.37, [0.302094, 0.377081, 0.381576, 0.418018]]\n",
      "--------------------------------------------------\n",
      "[32, 5, 40, 20]\n",
      "['etth1', 0.417, [0.370462, 0.402293, 0.427214, 0.466974]]\n",
      "['etth2', 0.364, [0.301904, 0.368115, 0.379677, 0.407984]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                  cw_list = [512],ps_list = [18])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=336,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=336,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=336,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=512,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=768,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['electricity']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                  load_ep_list = [40,60,80,100]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
