{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae1(PATH, settings_list, cw,patch_size, model_dim, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae2(PATH, settings_list, cw,patch_size, model_dim, tau_temp, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mse_mae3(PATH, settings_list, cw,patch_size, model_dim, tau_temp,tau_inst, load_epoch, ft_epoch):\n",
    "    y = [x for x in settings_list if f'z' not in x]\n",
    "    y = [x for x in y if f'cw{cw}' in x]\n",
    "    y = [x for x in y if f'patch{patch_size}' in x]\n",
    "    y = [x for x in y if f'_D{model_dim}' in x]\n",
    "    y = [x for x in y if f'tau_temp{tau_temp}' in x]\n",
    "    y = [x for x in y if f'tau_inst{tau_inst}' in x]\n",
    "    try:\n",
    "        arch = os.path.join(PATH,y[0])\n",
    "    except:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    #print(arch)\n",
    "    files = os.listdir(arch)\n",
    "    if len(files)==0:\n",
    "        return [999,999,999,999], [999,999,999,999]\n",
    "    files = [x for x in files if 'acc' in x]\n",
    "    files = [x for x in files if f'load_ep{load_epoch}' in x]\n",
    "    files = [x for x in files if f'ft_ep{ft_epoch}' in x]\n",
    "    idx_order = [int(x.split('_')[0].split('tw')[1]) for x in files]\n",
    "    files =  list(np.array(files)[np.argsort(idx_order)])\n",
    "    files = [os.path.join(x) for x in files]\n",
    "    files = [os.path.join(arch, f) for f in files]\n",
    "    mse_result = []\n",
    "    mae_result = []\n",
    "    for f in files:\n",
    "        try:\n",
    "            df = pd.read_csv(f)\n",
    "            mse_result.append(df['mse'][0])\n",
    "            mae_result.append(df['mae'][0])\n",
    "        except:\n",
    "            mse_result.append(999)\n",
    "            mae_result.append(999)\n",
    "    return mse_result, mae_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary(dataset_list, type_='hard'):\n",
    "    for data in dataset_list:\n",
    "        print('='*50)\n",
    "        PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "        PATH2 = f'{data}2{data}'\n",
    "        PATH3 = f'masked_patchtst_sim_half_v3_mean_FC_wo_MTM_R/based_model/max'\n",
    "        PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "        settings = os.listdir(PATH)\n",
    "        \n",
    "        for share in [1]:\n",
    "            if share==0:\n",
    "                settings = [x for x in settings if 'no_share' in x]\n",
    "            else:\n",
    "                settings = [x for x in settings if 'no_share' not  in x]\n",
    "            \n",
    "            if type_ == 'hard':\n",
    "                settings = [x for x in settings if 'tau' not in x]\n",
    "            elif type_ == 'soft1':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' not in x]\n",
    "            elif type_ == 'soft2':\n",
    "                settings = [x for x in settings if 'tau' in x]\n",
    "                settings = [x for x in settings if 'tau_inst' in x]\n",
    "                \n",
    "            if type_ == 'hard':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                for ft_ep in [20,40]:\n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw = cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    try:\n",
    "                                        if (mse_result[0] != 999):\n",
    "                                        #if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                            #print([ps,dim,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                            print([cw,ps,dim,load_ep,ft_ep],np.mean(mse_result).round(3), mse_result)\n",
    "                                    except:\n",
    "                                        pass\n",
    "                \n",
    "            elif type_ == 'soft1':\n",
    "                for cw in [336,512,768]:\n",
    "                    for ps in [12,18,24]:\n",
    "                        print(f'-------- cw={cw},ps={ps} ---------')\n",
    "                        min_val = 999\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                    for ft_ep in [20,40]:\n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if (mse_result[0] != 999):\n",
    "                                                if np.sum(mse_result)<min_val:\n",
    "                                                    print([cw,ps,dim,tau1,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                                    min_val = np.sum(mse_result)\n",
    "                                        except:\n",
    "                                            pass\n",
    "                \n",
    "            elif type_ == 'soft2':\n",
    "                for cw in [336,512,768]:\n",
    "                        \n",
    "                    for ps in [12,18,24]:\n",
    "                        for dim in [32,64,128,256]:\n",
    "                            for tau1 in [1,3,5]:\n",
    "                                for tau2 in [1,3,5]:\n",
    "                                    for load_ep in [20,40,60,80,100,120,150]:\n",
    "                                        for ft_ep in [20,40]:\n",
    "                                            mse_result, mae_result = get_mse_mae3(PATH = PATH,settings_list = settings,\n",
    "                                                                                cw=cw,patch_size=ps, \n",
    "                                                                                model_dim=dim, \n",
    "                                                                                tau_temp=tau1, tau_inst=tau2,\n",
    "                                                                                load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                            try:\n",
    "                                                if (mse_result[0] != 999) & (len(mse_result)==4):\n",
    "                                                    #print([ps,dim,tau1,tau2,load_ep,ft_ep],mse_result, mae_result)\n",
    "                                                    print([cw,ps,dim,tau1,tau2,load_ep,ft_ep],np.mean(mse_result).round(3),mse_result)\n",
    "                                            except:\n",
    "                                                pass\n",
    "                            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft1(dataset_list, type_='soft1'):\n",
    "    for cw in [336,512,768]:\n",
    "        for ps in [12,18,24]:\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in [40,60,80,100,120,150]:\n",
    "                        for ft_ep in [20,40]:\n",
    "                            print('='*40)\n",
    "                            print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                min_val = 999\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC_wo_MTM_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                except:\n",
    "                                    pass\n",
    "                                "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_hard(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for load_ep in load_ep_list:\n",
    "                    for ft_ep in ft_ep_list:\n",
    "                        #print('='*40)\n",
    "                        #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                        data_sum = 0\n",
    "                        data_results = []\n",
    "                        for data in dataset_list:\n",
    "                            #print(f'============== {data} ==============')\n",
    "                            PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                            PATH2 = f'{data}2{data}'\n",
    "                            PATH3 = f'masked_patchtst_sim_half_v3_mean_FC_wo_MTM_R/based_model/max'\n",
    "                            PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                            try:\n",
    "                                settings = os.listdir(PATH)\n",
    "                                \n",
    "                                for share in [1]:\n",
    "                                    settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                    settings = [x for x in settings if 'tau' not in x]\n",
    "    \n",
    "                                    mse_result, mae_result = get_mse_mae1(PATH = PATH, settings_list = settings,\n",
    "                                                                        cw=cw, patch_size=ps, \n",
    "                                                                        model_dim=dim, \n",
    "                                                                        load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                    \n",
    "                                    try:\n",
    "                                        if len(mse_result)>0:\n",
    "                                            #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                            data_sum += np.sum(mse_result)\n",
    "                                            data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                        \n",
    "                                    except:\n",
    "                                        pass\n",
    "                                    \n",
    "                            except:\n",
    "                                pass\n",
    "                        \n",
    "                        #for k in data_results:\n",
    "                        #    print(k)\n",
    "                            \n",
    "                        if min_val > data_sum:\n",
    "                            min_val = data_sum\n",
    "                            print('-'*50)\n",
    "                            print([dim,load_ep,ft_ep])\n",
    "                            for k in data_results:\n",
    "                                print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_summary_soft2(dataset_list, cw_list = [336,512,768],\n",
    "                      ps_list = [12,18,24], \n",
    "                      load_ep_list = [40,60,80,100,120,150],\n",
    "                      ft_ep_list = [20,40],\n",
    "                      type_='soft1'):\n",
    "    \n",
    "    \n",
    "    for cw in cw_list:\n",
    "        for ps in ps_list:\n",
    "            min_val = 999999\n",
    "            print('='*60)\n",
    "            print(f'-------- cw={cw},ps={ps} ---------')\n",
    "            print('='*60)\n",
    "            for dim in [32,64,128,256]:\n",
    "                for tau1 in [1,3,5]:\n",
    "                    for load_ep in load_ep_list:\n",
    "                        for ft_ep in ft_ep_list:\n",
    "                            #print('='*40)\n",
    "                            #print([cw,ps,dim,tau1,load_ep,ft_ep])\n",
    "                            data_sum = 0\n",
    "                            data_results = []\n",
    "                            for data in dataset_list:\n",
    "                                #print(f'============== {data} ==============')\n",
    "                                PATH1 = '/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/'\n",
    "                                PATH2 = f'{data}2{data}'\n",
    "                                PATH3 = f'masked_patchtst_sim_half_v3_mean_FC_wo_MTM_R/based_model/max'\n",
    "                                PATH = os.path.join(PATH1, PATH2, PATH3)\n",
    "                                try:\n",
    "                                    settings = os.listdir(PATH)\n",
    "                                    \n",
    "                                    for share in [1]:\n",
    "                                        settings = [x for x in settings if 'no_share' not  in x]\n",
    "                                        settings = [x for x in settings if 'tau' in x]\n",
    "                                        settings = [x for x in settings if 'tau_inst' not in x]\n",
    "        \n",
    "                                        mse_result, mae_result = get_mse_mae2(PATH = PATH, settings_list = settings,\n",
    "                                                                            cw=cw, patch_size=ps, \n",
    "                                                                            model_dim=dim, tau_temp=tau1,\n",
    "                                                                            load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                                        \n",
    "                                        try:\n",
    "                                            if len(mse_result)>0:\n",
    "                                                #print(data, np.mean(mse_result).round(3), mse_result)\n",
    "                                                data_sum += np.sum(mse_result)\n",
    "                                                data_results.append([data, np.mean(mse_result).round(3), mse_result])\n",
    "                                            \n",
    "                                        except:\n",
    "                                            pass\n",
    "                                        \n",
    "                                except:\n",
    "                                    pass\n",
    "                            \n",
    "                            #for k in data_results:\n",
    "                            #    print(k)\n",
    "                                \n",
    "                            if min_val > data_sum:\n",
    "                                min_val = data_sum\n",
    "                                print('-'*50)\n",
    "                                print([dim,tau1,load_ep,ft_ep])\n",
    "                                for k in data_results:\n",
    "                                    print(k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=336,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=336,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=336,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=512,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 40, 20]\n",
      "['etth1', 0.71, [0.710824, 0.714722, 0.703709, 0.710498]]\n",
      "['etth2', 0.391, [0.372243, 0.383445, 0.38457, 0.424692]]\n",
      "--------------------------------------------------\n",
      "[128, 60, 20]\n",
      "['etth1', 0.71, [0.711737, 0.714334, 0.704519, 0.709207]]\n",
      "['etth2', 0.391, [0.371904, 0.383298, 0.384363, 0.424752]]\n",
      "--------------------------------------------------\n",
      "[128, 60, 40]\n",
      "['etth1', 0.709, [0.709252, 0.714221, 0.702818, 0.710309]]\n",
      "['etth2', 0.391, [0.372603, 0.383211, 0.384292, 0.424534]]\n",
      "--------------------------------------------------\n",
      "[128, 80, 40]\n",
      "['etth1', 0.709, [0.710257, 0.713525, 0.702713, 0.709581]]\n",
      "['etth2', 0.391, [0.372537, 0.382871, 0.38434, 0.424488]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['etth1', 0.709, [0.710351, 0.713935, 0.703633, 0.709209]]\n",
      "['etth2', 0.391, [0.371322, 0.383047, 0.384261, 0.424525]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['etth1', 0.709, [0.708831, 0.712677, 0.703819, 0.709011]]\n",
      "['etth2', 0.391, [0.371751, 0.383062, 0.384116, 0.42472]]\n",
      "--------------------------------------------------\n",
      "[128, 120, 20]\n",
      "['etth1', 0.708, [0.7083, 0.713067, 0.703112, 0.708941]]\n",
      "['etth2', 0.391, [0.371387, 0.383164, 0.384243, 0.424534]]\n",
      "--------------------------------------------------\n",
      "[128, 150, 20]\n",
      "['etth1', 0.709, [0.709727, 0.713602, 0.702901, 0.709994]]\n",
      "['etth2', 0.371, [0.37129]]\n",
      "============================================================\n",
      "-------- cw=512,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=12 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['etth1', 0.71, [0.710529, 0.714454, 0.705124, 0.71031]]\n",
      "['etth2', 0.392, [0.37251, 0.383407, 0.385152, 0.424964]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 60, 20]\n",
      "['etth1', 0.709, [0.707674, 0.7134, 0.706553, 0.709149]]\n",
      "['etth2', 0.391, [0.372611, 0.38334, 0.384953, 0.424919]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 80, 20]\n",
      "['etth1', 0.709, [0.708363, 0.712998, 0.706118, 0.709853]]\n",
      "['etth2', 0.391, [0.372289, 0.383138, 0.384619, 0.424851]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 100, 20]\n",
      "['etth1', 0.709, [0.708625, 0.71346, 0.705128, 0.710125]]\n",
      "['etth2', 0.391, [0.372027, 0.383101, 0.384716, 0.424897]]\n",
      "--------------------------------------------------\n",
      "[64, 3, 60, 20]\n",
      "['etth1', 0.709, [0.708979, 0.712257, 0.703241, 0.709613]]\n",
      "['etth2', 0.392, [0.372258, 0.383545, 0.386096, 0.424552]]\n",
      "--------------------------------------------------\n",
      "[64, 3, 80, 20]\n",
      "['etth1', 0.708, [0.708488, 0.712723, 0.702783, 0.709688]]\n",
      "['etth2', 0.391, [0.372075, 0.383331, 0.38537, 0.424449]]\n",
      "--------------------------------------------------\n",
      "[64, 3, 100, 20]\n",
      "['etth1', 0.709, [0.708353, 0.713538, 0.702823, 0.709908]]\n",
      "['etth2', 0.391, [0.371938, 0.383187, 0.384463, 0.424466]]\n",
      "--------------------------------------------------\n",
      "[64, 3, 100, 40]\n",
      "['etth1', 0.709, [0.708015, 0.713553, 0.703722, 0.70885]]\n",
      "['etth2', 0.391, [0.37224, 0.38309, 0.384253, 0.424885]]\n",
      "--------------------------------------------------\n",
      "[64, 5, 60, 40]\n",
      "['etth1', 0.708, [0.70771, 0.712589, 0.704027, 0.709041]]\n",
      "['etth2', 0.391, [0.372522, 0.383348, 0.384359, 0.424909]]\n",
      "--------------------------------------------------\n",
      "[64, 5, 120, 20]\n",
      "['etth1', 0.708, [0.708704, 0.712466, 0.702497, 0.709785]]\n",
      "['etth2', 0.391, [0.37219, 0.3829, 0.38464, 0.42445]]\n",
      "--------------------------------------------------\n",
      "[128, 1, 100, 40]\n",
      "['etth1', 0.708, [0.708714, 0.712068, 0.702918, 0.708227]]\n",
      "['etth2', 0.391, [0.370644, 0.38303, 0.384832, 0.425429]]\n",
      "--------------------------------------------------\n",
      "[128, 1, 150, 20]\n",
      "['etth1', 0.71, [0.712619, 0.713197, 0.705215, 0.70947]]\n",
      "['etth2', 0.371, [0.371169]]\n",
      "--------------------------------------------------\n",
      "[128, 1, 150, 40]\n",
      "['etth1', 0.709, [0.711049, 0.711887, 0.702829, 0.708364]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                 cw_list = [512],ps_list = [18])\n",
    "                  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=512,ps=18 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['etth1', 999.0, [999, 999, 999, 999]]\n",
      "['etth2', 999.0, [999, 999, 999, 999]]\n",
      "--------------------------------------------------\n",
      "[128, 40, 20]\n",
      "['etth1', 0.71, [0.710824, 0.714722, 0.703709, 0.710498]]\n",
      "['etth2', 0.391, [0.372243, 0.383445, 0.38457, 0.424692]]\n",
      "--------------------------------------------------\n",
      "[128, 60, 20]\n",
      "['etth1', 0.71, [0.711737, 0.714334, 0.704519, 0.709207]]\n",
      "['etth2', 0.391, [0.371904, 0.383298, 0.384363, 0.424752]]\n",
      "--------------------------------------------------\n",
      "[128, 60, 40]\n",
      "['etth1', 0.709, [0.709252, 0.714221, 0.702818, 0.710309]]\n",
      "['etth2', 0.391, [0.372603, 0.383211, 0.384292, 0.424534]]\n",
      "--------------------------------------------------\n",
      "[128, 80, 40]\n",
      "['etth1', 0.709, [0.710257, 0.713525, 0.702713, 0.709581]]\n",
      "['etth2', 0.391, [0.372537, 0.382871, 0.38434, 0.424488]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 20]\n",
      "['etth1', 0.709, [0.710351, 0.713935, 0.703633, 0.709209]]\n",
      "['etth2', 0.391, [0.371322, 0.383047, 0.384261, 0.424525]]\n",
      "--------------------------------------------------\n",
      "[128, 100, 40]\n",
      "['etth1', 0.709, [0.708831, 0.712677, 0.703819, 0.709011]]\n",
      "['etth2', 0.391, [0.371751, 0.383062, 0.384116, 0.42472]]\n",
      "--------------------------------------------------\n",
      "[128, 120, 20]\n",
      "['etth1', 0.708, [0.7083, 0.713067, 0.703112, 0.708941]]\n",
      "['etth2', 0.391, [0.371387, 0.383164, 0.384243, 0.424534]]\n",
      "--------------------------------------------------\n",
      "[128, 150, 20]\n",
      "['etth1', 0.709, [0.709727, 0.713602, 0.702901, 0.709994]]\n",
      "['etth2', 0.371, [0.37129]]\n"
     ]
    }
   ],
   "source": [
    "dataset_list = ['etth1','etth2']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [512],ps_list = [18])\n",
    "                  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 40, 20]\n",
      "['ettm1', 0.7, [0.690315, 0.691535, 0.70281, 0.715614]]\n",
      "--------------------------------------------------\n",
      "[32, 40, 40]\n",
      "['ettm1', 0.698, [0.684605, 0.693846, 0.701362, 0.714091]]\n",
      "--------------------------------------------------\n",
      "[32, 100, 20]\n",
      "['ettm1', 0.694, [0.687236, 0.692099, 0.701211]]\n",
      "--------------------------------------------------\n",
      "[32, 120, 20]\n"
     ]
    }
   ],
   "source": [
    "#dataset_list = ['etth1','etth2']\n",
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_hard(dataset_list,\n",
    "                 cw_list = [768],ps_list = [24])\n",
    "                  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "============================================================\n",
      "-------- cw=768,ps=24 ---------\n",
      "============================================================\n",
      "--------------------------------------------------\n",
      "[32, 1, 40, 20]\n",
      "['ettm1', 0.699, [0.688356, 0.692315, 0.701, 0.71465]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 80, 20]\n",
      "['ettm1', 0.699, [0.687071, 0.692139, 0.70075, 0.71433]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 80, 40]\n",
      "['ettm1', 0.698, [0.686488, 0.69136, 0.700591, 0.714112]]\n",
      "--------------------------------------------------\n",
      "[32, 1, 120, 20]\n"
     ]
    }
   ],
   "source": [
    "#dataset_list = ['etth1','etth2']\n",
    "dataset_list = ['ettm1']\n",
    "\n",
    "get_summary_soft2(dataset_list,\n",
    "                 cw_list = [768],ps_list = [24])\n",
    "                  "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Share ( CI )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 반드시 같아야\n",
    "- ag_list\n",
    "- load_epoch_list\n",
    "- data_struc_list\n",
    "- arch_list\n",
    "- type_list\n",
    "- tau_list"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 달라도 OK\n",
    "- target_dat_list\n",
    "- finetune_epoch_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "acc_list = []\n",
    "acc_total_list = []\n",
    "struc_list = []\n",
    "\n",
    "for ps in [18]:\n",
    "    for dim in [32,64,128,256]:\n",
    "        for tau1 in [1,3,5]:\n",
    "            for load_ep in [20,40,60,80,100,120,150]:\n",
    "                for ft_ep in [20,40]:\n",
    "                    mse_result, mae_result = get_result_SOFT_temp(patch_size=ps, \n",
    "                                                                  model_dim=dim, tau_temp=tau1,\n",
    "                                                                  load_epoch = load_ep, ft_epoch=ft_ep)\n",
    "                \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([0.366243, 0.396837, 0.419621, 0.444178],\n",
       " [0.392938, 0.411791, 0.426909, 0.46106])"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_result_SOFT_temp(patch_size=18, model_dim=32, tau_temp=5.0,\n",
    "                     load_epoch = 120, ft_epoch=40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "SUB_PATH = os.path.join(PATH, settings_share[1])\n",
    "\n",
    "subsettings = os.listdir(SUB_PATH)\n",
    "subsettings = sorted([x for x in subsettings if 'acc.csv' in x])\n",
    "subsettings_96 = [x for x in subsettings if '96' in x]\n",
    "subsettings_192 = [x for x in subsettings if '192' in x]\n",
    "subsettings_336 = [x for x in subsettings if '336' in x]\n",
    "subsettings_720 = [x for x in subsettings if '720' in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    0.402523\n",
       "Name: mse, dtype: float64"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(os.path.join(SUB_PATH,subsettings_192[0]))['mse']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_pretrain_list = ['etth1','etth2','ettm1','ettm2']\n",
    "\n",
    "ep_pretrain = 150\n",
    "\n",
    "tau_list = [1,3,5]\n",
    "\n",
    "m = 0.5\n",
    "\n",
    "patch_len = 12\n",
    "stride = 12\n",
    "\n",
    "p = 0\n",
    "\n",
    "for s in [0,1]:\n",
    "    for ds_pretraion in ds_pretrain_list:\n",
    "        for d_model in [32,64,128,256]:\n",
    "            !python patchtst_pretrain_sim_half_v3_mean_FC.py \\\n",
    "                --device_id {device} \\\n",
    "                --dset_pretrain {ds_pretrain} \\\n",
    "                --n_epochs_pretrain {ep_pretrain} \\\n",
    "                --reverse 1 \\\n",
    "                --mask_ratio {m} \\\n",
    "                --share {s} \\\n",
    "                --d_model {d_model} --patch_len {patch_len} --stride {stride} --permute {p} \n",
    "            \n",
    "            for tau in tau_list:\n",
    "                !python patchtst_pretrain_sim_half_v3_mean_FC.py \\\n",
    "                --device_id {device} \\\n",
    "                --dset_pretrain {ds_pretrain} \\\n",
    "                --n_epochs_pretrain {ep_pretrain} \\\n",
    "                --reverse 1 \\\n",
    "                --mask_ratio {m} \\\n",
    "                --share {s} \\\n",
    "                --tau_temp {tau2} \\\n",
    "                --d_model {d_model} --patch_len {patch_len} --stride {stride} --permute {p} \n",
    "            \n",
    "            for tau1 in tau_list:\n",
    "                for tau2 in tau_list:\n",
    "                    !python patchtst_pretrain_sim_half_v3_mean_FC.py \\\n",
    "                        --device_id {device} \\\n",
    "                        --dset_pretrain {ds_pretrain} \\\n",
    "                        --n_epochs_pretrain {ep_pretrain} \\\n",
    "                        --reverse 1 \\\n",
    "                        --mask_ratio {m} \\\n",
    "                        --share {s} \\\n",
    "                        --tau_inst {tau2} \\\n",
    "                        --tau_temp {tau2} \\\n",
    "                        --d_model {d_model} --patch_len {patch_len} --stride {stride} --permute {p} \n",
    "                \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "# permute (X) + hard CL\n",
    "ds_pretrain = 'etth1'\n",
    "d = 'etth1'\n",
    "m = 0.5\n",
    "\n",
    "for _ in range(3):\n",
    "    for s in [0,1]:\n",
    "        for d_model in [64,128,256]:\n",
    "            for load in [20,40,60,80,100]:\n",
    "                for tp in [96,192,336,720]:\n",
    "                    for ep_ft_head in [10,20,30]:\n",
    "                        ep_ft_entire = ep_ft_head * 2\n",
    "                        !python patchtst_finetune_sim_half_v3_mean_FC.py \\\n",
    "                            --is_finetune 1 \\\n",
    "                            --target_points {tp} \\\n",
    "                            --device_id {device} \\\n",
    "                            --dset_pretrain {d} \\\n",
    "                            --dset_finetune {d} \\\n",
    "                            --n_epochs_finetune_head {ep_ft_head} \\\n",
    "                            --n_epochs_finetune_entire {ep_ft_entire} \\\n",
    "                            --n_epochs_pretrain {ep_pretrain} \\\n",
    "                            --reverse 1\\\n",
    "                            --n_epochs_load {load} \\\n",
    "                            --mask_ratio {m} --share {s}  --d_model {d_model} --permute {p}                    \n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['tw192_ft_ep20_model1_load_ep100_acc.csv',\n",
       " 'tw192_ft_ep20_model1_load_ep20_acc.csv',\n",
       " 'tw192_ft_ep20_model1_load_ep40_acc.csv',\n",
       " 'tw192_ft_ep20_model1_load_ep60_acc.csv',\n",
       " 'tw192_ft_ep20_model1_load_ep80_acc.csv',\n",
       " 'tw192_ft_ep40_model1_load_ep100_acc.csv',\n",
       " 'tw192_ft_ep40_model1_load_ep20_acc.csv',\n",
       " 'tw192_ft_ep40_model1_load_ep40_acc.csv',\n",
       " 'tw192_ft_ep40_model1_load_ep60_acc.csv',\n",
       " 'tw192_ft_ep40_model1_load_ep80_acc.csv',\n",
       " 'tw192_ft_ep60_model1_load_ep100_acc.csv',\n",
       " 'tw192_ft_ep60_model1_load_ep20_acc.csv',\n",
       " 'tw192_ft_ep60_model1_load_ep40_acc.csv',\n",
       " 'tw192_ft_ep60_model1_load_ep60_acc.csv',\n",
       " 'tw192_ft_ep60_model1_load_ep80_acc.csv',\n",
       " 'tw336_ft_ep20_model1_load_ep100_acc.csv',\n",
       " 'tw336_ft_ep20_model1_load_ep20_acc.csv',\n",
       " 'tw336_ft_ep20_model1_load_ep40_acc.csv',\n",
       " 'tw336_ft_ep20_model1_load_ep60_acc.csv',\n",
       " 'tw336_ft_ep20_model1_load_ep80_acc.csv',\n",
       " 'tw336_ft_ep40_model1_load_ep100_acc.csv',\n",
       " 'tw336_ft_ep40_model1_load_ep20_acc.csv',\n",
       " 'tw336_ft_ep40_model1_load_ep40_acc.csv',\n",
       " 'tw336_ft_ep40_model1_load_ep60_acc.csv',\n",
       " 'tw336_ft_ep40_model1_load_ep80_acc.csv',\n",
       " 'tw336_ft_ep60_model1_load_ep100_acc.csv',\n",
       " 'tw336_ft_ep60_model1_load_ep20_acc.csv',\n",
       " 'tw336_ft_ep60_model1_load_ep40_acc.csv',\n",
       " 'tw336_ft_ep60_model1_load_ep60_acc.csv',\n",
       " 'tw336_ft_ep60_model1_load_ep80_acc.csv',\n",
       " 'tw720_ft_ep20_model1_load_ep100_acc.csv',\n",
       " 'tw720_ft_ep20_model1_load_ep20_acc.csv',\n",
       " 'tw720_ft_ep20_model1_load_ep40_acc.csv',\n",
       " 'tw720_ft_ep20_model1_load_ep60_acc.csv',\n",
       " 'tw720_ft_ep20_model1_load_ep80_acc.csv',\n",
       " 'tw720_ft_ep40_model1_load_ep100_acc.csv',\n",
       " 'tw720_ft_ep40_model1_load_ep20_acc.csv',\n",
       " 'tw720_ft_ep40_model1_load_ep40_acc.csv',\n",
       " 'tw720_ft_ep40_model1_load_ep60_acc.csv',\n",
       " 'tw720_ft_ep40_model1_load_ep80_acc.csv',\n",
       " 'tw720_ft_ep60_model1_load_ep100_acc.csv',\n",
       " 'tw720_ft_ep60_model1_load_ep20_acc.csv',\n",
       " 'tw720_ft_ep60_model1_load_ep40_acc.csv',\n",
       " 'tw720_ft_ep60_model1_load_ep60_acc.csv',\n",
       " 'tw720_ft_ep60_model1_load_ep80_acc.csv',\n",
       " 'tw96_ft_ep20_model1_load_ep100_acc.csv',\n",
       " 'tw96_ft_ep20_model1_load_ep20_acc.csv',\n",
       " 'tw96_ft_ep20_model1_load_ep40_acc.csv',\n",
       " 'tw96_ft_ep20_model1_load_ep60_acc.csv',\n",
       " 'tw96_ft_ep20_model1_load_ep80_acc.csv',\n",
       " 'tw96_ft_ep40_model1_load_ep100_acc.csv',\n",
       " 'tw96_ft_ep40_model1_load_ep20_acc.csv',\n",
       " 'tw96_ft_ep40_model1_load_ep40_acc.csv',\n",
       " 'tw96_ft_ep40_model1_load_ep60_acc.csv',\n",
       " 'tw96_ft_ep40_model1_load_ep80_acc.csv',\n",
       " 'tw96_ft_ep60_model1_load_ep100_acc.csv',\n",
       " 'tw96_ft_ep60_model1_load_ep20_acc.csv',\n",
       " 'tw96_ft_ep60_model1_load_ep40_acc.csv',\n",
       " 'tw96_ft_ep60_model1_load_ep60_acc.csv',\n",
       " 'tw96_ft_ep60_model1_load_ep80_acc.csv']"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subsettings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    0.397272\n",
       "Name: mse, dtype: float64"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(os.path.join(SUB_PATH,subsettings_192[0]))['mse']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
