{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "setting = 'etth22etth1'\n",
    "PATH = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{setting}/masked_patchtst_sim_half_v3_mean_FC2_sep_R/based_model/max'\n",
    "PATH = os.path.join(PATH,os.listdir(PATH)[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "results =  os.listdir(PATH)\n",
    "results = sorted(results)\n",
    "results =  [os.path.join(PATH,x) for x in results if 'acc' in x]\n",
    "results1 = [x for x in results if '96' in x]\n",
    "results2 = [x for x in results if '192' in x]\n",
    "results3 = [x for x in results if '336' in x]\n",
    "results4 = [x for x in results if '720' in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(setting):\n",
    "    PATH = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{setting}/masked_patchtst_sim_half_v3_mean_FC2_sep_R/based_model/max'\n",
    "    PATH = os.path.join(PATH,os.listdir(PATH)[0])\n",
    "    results =  os.listdir(PATH)\n",
    "    results = sorted(results)\n",
    "    results =  [os.path.join(PATH,x) for x in results if ('acc' in x) & ('_lp_' in x)]\n",
    "\n",
    "    results1 = [x for x in results if 'tw96' in x]\n",
    "    results2 = [x for x in results if 'tw192' in x]\n",
    "    results3 = [x for x in results if 'tw336' in x]\n",
    "    results4 = [x for x in results if 'tw720' in x]\n",
    "    #print(len(results1))\n",
    "    for i in range(4):\n",
    "        df1 = pd.read_csv(results1[i])\n",
    "        df2 = pd.read_csv(results2[i])\n",
    "        df3 = pd.read_csv(results3[i])\n",
    "        df4 = pd.read_csv(results4[i])\n",
    "        print(df1.values[0])\n",
    "        print(df2.values[0])\n",
    "        print(df3.values[0])\n",
    "        print(df4.values[0])\n",
    "\n",
    "        print('-----------')\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(setting):\n",
    "    PATH1 = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{setting}/masked_patchtst_sim_half_v3_mean_FC2_sep_R/based_model/max'\n",
    "    PATH2 = os.listdir(PATH1)\n",
    "    for P in PATH2:\n",
    "        try:\n",
    "            PATH = os.path.join(PATH1,P)\n",
    "            results =  os.listdir(PATH)\n",
    "            results = sorted(results)\n",
    "            #results =  [os.path.join(PATH,x) for x in results if 'acc' in x]\n",
    "            results =  [os.path.join(PATH,x) for x in results if ('acc' in x) & ('_lp_' in x)]\n",
    "            \n",
    "            results1 = [x for x in results if 'tw96' in x]\n",
    "            results2 = [x for x in results if 'tw192' in x]\n",
    "            results3 = [x for x in results if 'tw336' in x]\n",
    "            results4 = [x for x in results if 'tw720' in x]\n",
    "            #print(len(results1))\n",
    "            for i in range(4):\n",
    "                df1 = pd.read_csv(results1[i])\n",
    "                df2 = pd.read_csv(results2[i])\n",
    "                df3 = pd.read_csv(results3[i])\n",
    "                df4 = pd.read_csv(results4[i])\n",
    "                print(df1.values[0].round(3))\n",
    "                print(df2.values[0].round(3))\n",
    "                print(df3.values[0].round(3))\n",
    "                print(df4.values[0].round(3))\n",
    "                print(((df1.values[0]+df2.values[0]+df3.values[0]+df4.values[0])/4).round(3))\n",
    "\n",
    "                print('-----------')\n",
    "            print('==========================')\n",
    "        except:\n",
    "            pass    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.305 0.348]\n",
      "[0.339 0.368]\n",
      "[0.367 0.383]\n",
      "[0.422 0.414]\n",
      "[0.358 0.378]\n",
      "-----------\n",
      "[0.305 0.347]\n",
      "[0.337 0.366]\n",
      "[0.368 0.384]\n",
      "[0.423 0.415]\n",
      "[0.358 0.378]\n",
      "-----------\n",
      "[0.302 0.346]\n",
      "[0.335 0.365]\n",
      "[0.367 0.384]\n",
      "[0.424 0.415]\n",
      "[0.357 0.377]\n",
      "-----------\n",
      "[0.304 0.347]\n",
      "[0.337 0.366]\n",
      "[0.368 0.384]\n",
      "[0.421 0.414]\n",
      "[0.358 0.378]\n",
      "-----------\n",
      "==========================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n[0.301 0.353]\\n[0.341 0.377]\\n[0.364 0.39 ]\\n[0.404 0.417]\\n[0.353 0.384]\\n'"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#setting = 'etth22etth1'\n",
    "#setting = 'ettm12etth1'\n",
    "#setting = 'ettm22etth1'\n",
    "setting = 'etth12ettm1'\n",
    "#setting = 'etth22ettm1'\n",
    "#setting = 'ettm22ettm1'\n",
    "get_result(setting)\n",
    "'''\n",
    "[0.301 0.353]\n",
    "[0.341 0.377]\n",
    "[0.364 0.39 ]\n",
    "[0.404 0.417]\n",
    "[0.353 0.384]\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.305 0.347]\n",
      "[0.337 0.365]\n",
      "[0.367 0.383]\n",
      "[0.421 0.413]\n",
      "[0.357 0.377]\n",
      "-----------\n",
      "[0.305 0.347]\n",
      "[0.337 0.366]\n",
      "[0.368 0.384]\n",
      "[0.422 0.414]\n",
      "[0.358 0.378]\n",
      "-----------\n",
      "[0.304 0.346]\n",
      "[0.336 0.365]\n",
      "[0.367 0.383]\n",
      "[0.423 0.414]\n",
      "[0.357 0.377]\n",
      "-----------\n",
      "[0.305 0.347]\n",
      "[0.337 0.366]\n",
      "[0.368 0.384]\n",
      "[0.422 0.414]\n",
      "[0.358 0.378]\n",
      "-----------\n",
      "==========================\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'\\n[0.3   0.354]\\n[0.335 0.375]\\n[0.361 0.393]\\n[0.403 0.417]\\n[0.35  0.385]\\n'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#setting = 'etth22etth1'\n",
    "#setting = 'ettm12etth1'\n",
    "#setting = 'ettm22etth1'\n",
    "#setting = 'etth12ettm1'\n",
    "setting = 'etth22ettm1'\n",
    "#setting = 'ettm22ettm1'\n",
    "get_result(setting)\n",
    "'''\n",
    "[0.3   0.354]\n",
    "[0.335 0.375]\n",
    "[0.361 0.393]\n",
    "[0.403 0.417]\n",
    "[0.35  0.385]\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
