{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP\n",
    "- `results_with_FC2_sep.ipynb`\n",
    "- `results_with_FC2_wo_CL.ipynb`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "def MLP_data_input_output_patchlen(data,dim,do):\n",
    "    input_= 'O'\n",
    "    output_= 'O'\n",
    "    patch_len= 12\n",
    "    \n",
    "    print('='*50)\n",
    "    print(output_, 'patch size=',patch_len, 'dim=',dim)\n",
    "    print('='*50)\n",
    "\n",
    "    PATH = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{data}2{data}/XY_ablation_FC2_downstream_no_do{input_}_{output_}/based_model/max'\n",
    "    PATCH2 = [x for x in os.listdir(PATH) if f'patch{patch_len}_' in x]\n",
    "    if do==0.2:\n",
    "        PATCH2 = [x for x in PATCH2 if (f'_D{dim}_' in x) & ('head_drop' not in x)][0]\n",
    "    else:\n",
    "        PATCH2 = [x for x in PATCH2 if (f'_D{dim}_' in x) & (f'head_drop_{do}' in x)][0]\n",
    "    temp = os.path.join(PATH,PATCH2)\n",
    "    PATCH2 = os.listdir(temp)\n",
    "    PATCH2 = [x for x in PATCH2 if 'acc.csv' in x]\n",
    "    for ft in [10,20,40,60]:\n",
    "        ft_n = [x for x in PATCH2 if f'ft_ep{ft}' in x]\n",
    "        df_list = []\n",
    "        try:\n",
    "            for file in ft_n:\n",
    "                df = pd.read_csv(os.path.join(temp,file))\n",
    "                df_list.append(df)\n",
    "            df = pd.concat(df_list,axis=0)\n",
    "            result =  df.mean(axis=0)\n",
    "            mse = result['mse']\n",
    "            mae = result['mae']\n",
    "            #print(df)\n",
    "            print(f'{len(ft_n)} --- ft={ft}: mse={mse.round(3)},mae={mae.round(3)}')\n",
    "        except:\n",
    "            pass\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = 'ettm2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.313\n",
      "4 --- ft=20: mse=0.254,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.254,mae=0.313\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.313\n",
      "4 --- ft=20: mse=0.254,mae=0.313\n",
      "4 --- ft=40: mse=0.254,mae=0.313\n",
      "4 --- ft=60: mse=0.255,mae=0.314\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.256,mae=0.316\n",
      "4 --- ft=20: mse=0.253,mae=0.313\n",
      "4 --- ft=40: mse=0.253,mae=0.313\n",
      "4 --- ft=60: mse=0.252,mae=0.312\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "dataset = 'ettm2'\n",
    "do_ratio = 0.2\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 8\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.341,mae=0.389\n",
      "4 --- ft=20: mse=0.342,mae=0.39\n",
      "4 --- ft=40: mse=0.347,mae=0.391\n",
      "4 --- ft=60: mse=0.351,mae=0.395\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "dataset = 'etth2'\n",
    "do_ratio = 0.0\n",
    "\n",
    "#MLP_data_input_output_patchlen(data=dataset,dim=4,do=do_ratio)\n",
    "#MLP_data_input_output_patchlen(data=dataset,dim=6,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=8,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.26,mae=0.322\n",
      "4 --- ft=20: mse=0.258,mae=0.319\n",
      "4 --- ft=40: mse=0.258,mae=0.32\n",
      "4 --- ft=60: mse=0.257,mae=0.318\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.257,mae=0.318\n",
      "4 --- ft=20: mse=0.258,mae=0.32\n",
      "4 --- ft=40: mse=0.258,mae=0.32\n",
      "4 --- ft=60: mse=0.257,mae=0.318\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.256,mae=0.319\n",
      "4 --- ft=20: mse=0.257,mae=0.32\n",
      "4 --- ft=40: mse=0.256,mae=0.316\n",
      "4 --- ft=60: mse=0.255,mae=0.317\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.0\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.312\n",
      "4 --- ft=20: mse=0.253,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.312\n",
      "4 --- ft=20: mse=0.254,mae=0.313\n",
      "4 --- ft=40: mse=0.254,mae=0.313\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.255,mae=0.315\n",
      "4 --- ft=20: mse=0.254,mae=0.312\n",
      "4 --- ft=40: mse=0.254,mae=0.314\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.1\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.255,mae=0.314\n",
      "4 --- ft=20: mse=0.253,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.253,mae=0.312\n",
      "4 --- ft=20: mse=0.254,mae=0.313\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.254,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.255,mae=0.315\n",
      "4 --- ft=20: mse=0.253,mae=0.313\n",
      "4 --- ft=40: mse=0.253,mae=0.313\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.2\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.313\n",
      "4 --- ft=20: mse=0.253,mae=0.311\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.311\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.253,mae=0.312\n",
      "4 --- ft=20: mse=0.254,mae=0.313\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.311\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.255,mae=0.315\n",
      "4 --- ft=20: mse=0.253,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.3\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.313\n",
      "4 --- ft=20: mse=0.252,mae=0.311\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.253,mae=0.312\n",
      "4 --- ft=20: mse=0.252,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.254,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.314\n",
      "4 --- ft=20: mse=0.253,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.4\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "O patch size= 12 dim= 32\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.254,mae=0.313\n",
      "4 --- ft=20: mse=0.253,mae=0.313\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.252,mae=0.311\n",
      "==================================================\n",
      "O patch size= 12 dim= 64\n",
      "==================================================\n",
      "4 --- ft=10: mse=0.253,mae=0.312\n",
      "4 --- ft=20: mse=0.252,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.312\n",
      "4 --- ft=60: mse=0.253,mae=0.312\n",
      "==================================================\n",
      "O patch size= 12 dim= 128\n",
      "==================================================\n",
      "3 --- ft=10: mse=0.265,mae=0.322\n",
      "4 --- ft=20: mse=0.253,mae=0.312\n",
      "4 --- ft=40: mse=0.253,mae=0.313\n",
      "4 --- ft=60: mse=0.254,mae=0.314\n",
      "====================================================================================================\n"
     ]
    }
   ],
   "source": [
    "do_ratio = 0.5\n",
    "\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=32,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=64,do=do_ratio)\n",
    "MLP_data_input_output_patchlen(data=dataset,dim=128,do=do_ratio)\n",
    "print('='*100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
