{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "data = 'etth1'\n",
    "PATH = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{data}2{data}'\n",
    "settings = os.listdir(PATH)\n",
    "settings = [x for x in settings if 'ablation' in x]\n",
    "settings = [x for x in settings if 'norm' not in x]\n",
    "os.chdir(PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "FC_settings = [x for x in settings if 'FC_' in x]\n",
    "FC2_settings = [x for x in settings if 'FC2_' in x]\n",
    "Transformer_settings = [x for x in settings if 'Trans' in x]\n",
    "TSmixer_settings = [x for x in settings if 'mixer_' in x]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['XY_ablation_Transformer_O_O',\n",
       " 'XY_ablation_Transformer_O_X',\n",
       " 'XY_ablation_Transformer_OX_O',\n",
       " 'XY_ablation_Transformer_O_OX',\n",
       " 'XY_ablation_Transformer_OX_OX']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Transformer_settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#setting = FC2_settings[0]\n",
    "\n",
    "def get_mse_mae(data, type_, agg_type, dim):\n",
    "    print('='*50)\n",
    "    print(type_, agg_type, dim)\n",
    "    print('='*50)\n",
    "    PATH = f'/home/seunghan9613/PatchTST_sim/PatchTST_self_supervised/saved_models/{data}2{data}'\n",
    "    settings = os.listdir(PATH)\n",
    "    settings = [x for x in settings if 'ablation' in x]\n",
    "    settings = [x for x in settings if 'norm' not in x]\n",
    "    os.chdir(PATH)\n",
    "\n",
    "    path2 = os.path.join(type_,'based_model',agg_type)\n",
    "    temp = os.listdir(path2)\n",
    "    x = [i for i in temp if f'D{dim}' in i]\n",
    "    x = [i for i in x if 'patch12' in i][0]\n",
    "    path3 = os.path.join(path2, x)\n",
    "    temp = os.listdir(path3)\n",
    "    temp = [x for x in temp if 'acc.csv' in x]\n",
    "    \n",
    "    try:\n",
    "        for ft in [10,20,40,60]:\n",
    "            ft_n = [x for x in temp if f'ft_ep{ft}' in x]\n",
    "            df_list = []\n",
    "            \n",
    "            for file in ft_n:\n",
    "                df = pd.read_csv(os.path.join(path3,file))\n",
    "                df_list.append(df)\n",
    "            print(len(df_list))\n",
    "            df = pd.concat(df_list,axis=0)\n",
    "            result =  df.mean(axis=0)\n",
    "            mse = result['mse']\n",
    "            mae = result['mae']\n",
    "            print(f'ft={ft}: mse={mse.round(3)},mae={mae.round(3)}')\n",
    "    except:\n",
    "        for ft in [10,20,40]:\n",
    "            ft_n = [x for x in temp if f'ft_ep{ft}' in x]\n",
    "            df_list = []\n",
    "            for file in ft_n:\n",
    "                df = pd.read_csv(os.path.join(path3,file))\n",
    "                df_list.append(df)\n",
    "            print(len(df_list))\n",
    "            df = pd.concat(df_list,axis=0)\n",
    "            result =  df.mean(axis=0)\n",
    "            mse = result['mse']\n",
    "            mae = result['mae']\n",
    "            print(f'ft={ft}: mse={mse.round(3)},mae={mae.round(3)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.426\n",
      "4\n",
      "ft=20: mse=0.408,mae=0.424\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.425\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.426\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 64\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.412,mae=0.428\n",
      "4\n",
      "ft=20: mse=0.412,mae=0.426\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.412,mae=0.426\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=20: mse=0.41,mae=0.424\n",
      "4\n",
      "ft=40: mse=0.41,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.426\n",
      "4\n",
      "ft=20: mse=0.408,mae=0.424\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.425\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.426\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 64\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.412,mae=0.428\n",
      "4\n",
      "ft=20: mse=0.412,mae=0.426\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.412,mae=0.426\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=20: mse=0.41,mae=0.424\n",
      "4\n",
      "ft=40: mse=0.41,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=20: mse=0.408,mae=0.423\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.41,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.409,mae=0.423\n",
      "4\n",
      "ft=20: mse=0.411,mae=0.425\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=20: mse=0.412,mae=0.425\n",
      "4\n",
      "ft=40: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=60: mse=0.412,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=20: mse=0.408,mae=0.423\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.41,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.409,mae=0.423\n",
      "4\n",
      "ft=20: mse=0.411,mae=0.425\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.425\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=20: mse=0.412,mae=0.425\n",
      "4\n",
      "ft=40: mse=0.41,mae=0.425\n",
      "4\n",
      "ft=60: mse=0.412,mae=0.425\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.658,mae=0.572\n",
      "4\n",
      "ft=20: mse=0.665,mae=0.575\n",
      "4\n",
      "ft=40: mse=0.679,mae=0.582\n",
      "4\n",
      "ft=60: mse=0.686,mae=0.585\n",
      "==================================================\n",
      "XY_ablation_FC_O_X avgpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.643,mae=0.558\n",
      "4\n",
      "ft=20: mse=0.643,mae=0.557\n",
      "4\n",
      "ft=40: mse=0.642,mae=0.557\n",
      "4\n",
      "ft=60: mse=0.643,mae=0.558\n",
      "==================================================\n",
      "XY_ablation_FC_O_O avgpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.642,mae=0.557\n",
      "4\n",
      "ft=20: mse=0.643,mae=0.557\n",
      "4\n",
      "ft=40: mse=0.642,mae=0.557\n",
      "4\n",
      "ft=60: mse=0.643,mae=0.557\n",
      "==================================================\n",
      "XY_ablation_FC_O_O maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.66,mae=0.573\n",
      "4\n",
      "ft=20: mse=0.668,mae=0.577\n",
      "4\n",
      "ft=40: mse=0.67,mae=0.577\n",
      "4\n",
      "ft=60: mse=0.677,mae=0.581\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.334,mae=0.383\n",
      "4\n",
      "ft=20: mse=0.336,mae=0.384\n",
      "4\n",
      "ft=40: mse=0.338,mae=0.385\n",
      "4\n",
      "ft=60: mse=0.343,mae=0.387\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.334,mae=0.383\n",
      "4\n",
      "ft=20: mse=0.336,mae=0.384\n",
      "4\n",
      "ft=40: mse=0.338,mae=0.385\n",
      "4\n",
      "ft=60: mse=0.343,mae=0.387\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.337,mae=0.385\n",
      "3\n",
      "ft=20: mse=0.324,mae=0.372\n",
      "4\n",
      "ft=40: mse=0.338,mae=0.384\n",
      "4\n",
      "ft=60: mse=0.343,mae=0.387\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.337,mae=0.385\n",
      "3\n",
      "ft=20: mse=0.324,mae=0.372\n",
      "4\n",
      "ft=40: mse=0.338,mae=0.384\n",
      "4\n",
      "ft=60: mse=0.343,mae=0.387\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X maxpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.402,mae=0.444\n",
      "4\n",
      "ft=20: mse=0.403,mae=0.445\n",
      "4\n",
      "ft=40: mse=0.403,mae=0.445\n",
      "4\n",
      "ft=60: mse=0.408,mae=0.45\n",
      "==================================================\n",
      "XY_ablation_FC_O_X avgpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.389,mae=0.432\n",
      "4\n",
      "ft=20: mse=0.391,mae=0.435\n",
      "4\n",
      "ft=40: mse=0.391,mae=0.435\n",
      "4\n",
      "ft=60: mse=0.392,mae=0.435\n",
      "==================================================\n",
      "XY_ablation_FC_O_O avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.389,mae=0.431\n",
      "4\n",
      "ft=20: mse=0.389,mae=0.431\n",
      "4\n",
      "ft=40: mse=0.39,mae=0.433\n",
      "4\n",
      "ft=60: mse=0.39,mae=0.434\n",
      "==================================================\n",
      "XY_ablation_FC_O_O maxpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.403,mae=0.446\n",
      "4\n",
      "ft=20: mse=0.408,mae=0.449\n",
      "4\n",
      "ft=40: mse=0.408,mae=0.45\n",
      "4\n",
      "ft=60: mse=0.407,mae=0.449\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.359,mae=0.379\n",
      "4\n",
      "ft=20: mse=0.359,mae=0.378\n",
      "4\n",
      "ft=40: mse=0.362,mae=0.38\n",
      "3\n",
      "ft=60: mse=0.339,mae=0.367\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.358,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.361,mae=0.379\n",
      "4\n",
      "ft=40: mse=0.361,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.36,mae=0.379\n",
      "==================================================\n",
      "XY_ablation_FC_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.358,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.359,mae=0.379\n",
      "4\n",
      "ft=40: mse=0.361,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.363,mae=0.381\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.359,mae=0.38\n",
      "4\n",
      "ft=20: mse=0.36,mae=0.379\n",
      "4\n",
      "ft=40: mse=0.363,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.361,mae=0.38\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.358,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.361,mae=0.379\n",
      "4\n",
      "ft=40: mse=0.361,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.361,mae=0.38\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.358,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.359,mae=0.379\n",
      "4\n",
      "ft=40: mse=0.361,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.361,mae=0.38\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "\n",
    "\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X maxpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.695,mae=0.563\n",
      "ft=20: mse=0.695,mae=0.563\n",
      "ft=40: mse=0.696,mae=0.562\n",
      "ft=60: mse=0.696,mae=0.562\n",
      "==================================================\n",
      "XY_ablation_FC_O_X avgpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.68,mae=0.552\n",
      "ft=20: mse=0.679,mae=0.551\n",
      "ft=40: mse=0.679,mae=0.551\n",
      "ft=60: mse=0.68,mae=0.551\n",
      "==================================================\n",
      "XY_ablation_FC_O_O avgpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.681,mae=0.553\n",
      "ft=20: mse=0.678,mae=0.55\n",
      "ft=40: mse=0.679,mae=0.55\n",
      "ft=60: mse=0.68,mae=0.552\n",
      "==================================================\n",
      "XY_ablation_FC_O_O maxpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.695,mae=0.563\n",
      "ft=20: mse=0.696,mae=0.563\n",
      "ft=40: mse=0.697,mae=0.561\n",
      "ft=60: mse=0.698,mae=0.563\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.253,mae=0.312\n",
      "4\n",
      "ft=20: mse=0.254,mae=0.312\n",
      "4\n",
      "ft=40: mse=0.254,mae=0.311\n",
      "4\n",
      "ft=60: mse=0.254,mae=0.312\n",
      "==================================================\n",
      "XY_ablation_FC_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.253,mae=0.312\n",
      "4\n",
      "ft=20: mse=0.254,mae=0.311\n",
      "4\n",
      "ft=40: mse=0.254,mae=0.311\n",
      "4\n",
      "ft=60: mse=0.253,mae=0.312\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC_O_X maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.345,mae=0.394\n",
      "4\n",
      "ft=20: mse=0.343,mae=0.394\n",
      "4\n",
      "ft=40: mse=0.344,mae=0.396\n",
      "4\n",
      "ft=60: mse=0.345,mae=0.396\n",
      "==================================================\n",
      "XY_ablation_FC_O_X avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.34,mae=0.386\n",
      "4\n",
      "ft=20: mse=0.339,mae=0.385\n",
      "4\n",
      "ft=40: mse=0.339,mae=0.385\n",
      "4\n",
      "ft=60: mse=0.341,mae=0.386\n",
      "==================================================\n",
      "XY_ablation_FC_O_O avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.34,mae=0.386\n",
      "4\n",
      "ft=20: mse=0.339,mae=0.386\n",
      "4\n",
      "ft=40: mse=0.339,mae=0.384\n",
      "4\n",
      "ft=60: mse=0.341,mae=0.386\n",
      "==================================================\n",
      "XY_ablation_FC_O_O maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.341,mae=0.392\n",
      "4\n",
      "ft=20: mse=0.342,mae=0.394\n",
      "4\n",
      "ft=40: mse=0.342,mae=0.394\n",
      "4\n",
      "ft=60: mse=0.343,mae=0.395\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  MLP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.419,mae=0.431\n",
      "4\n",
      "ft=20: mse=0.42,mae=0.432\n",
      "4\n",
      "ft=40: mse=0.419,mae=0.431\n",
      "4\n",
      "ft=60: mse=0.418,mae=0.431\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.409,mae=0.424\n",
      "4\n",
      "ft=20: mse=0.407,mae=0.422\n",
      "4\n",
      "ft=40: mse=0.408,mae=0.423\n",
      "4\n",
      "ft=60: mse=0.409,mae=0.425\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=32)\n",
    "#et_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.65,mae=0.566\n",
      "4\n",
      "ft=20: mse=0.647,mae=0.564\n",
      "4\n",
      "ft=40: mse=0.652,mae=0.567\n",
      "4\n",
      "ft=60: mse=0.656,mae=0.569\n",
      "==================================================\n",
      "XY_ablation_FC2_O_X avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.649,mae=0.563\n",
      "4\n",
      "ft=20: mse=0.645,mae=0.559\n",
      "4\n",
      "ft=40: mse=0.647,mae=0.56\n",
      "4\n",
      "ft=60: mse=0.647,mae=0.561\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O avgpool 64\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.632,mae=0.547\n",
      "4\n",
      "ft=20: mse=0.642,mae=0.558\n",
      "4\n",
      "ft=40: mse=0.641,mae=0.557\n",
      "4\n",
      "ft=60: mse=0.642,mae=0.557\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.673,mae=0.577\n",
      "4\n",
      "ft=20: mse=0.678,mae=0.579\n",
      "4\n",
      "ft=40: mse=0.672,mae=0.577\n",
      "4\n",
      "ft=60: mse=0.661,mae=0.57\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.355,mae=0.399\n",
      "4\n",
      "ft=20: mse=0.362,mae=0.405\n",
      "4\n",
      "ft=40: mse=0.368,mae=0.41\n",
      "4\n",
      "ft=60: mse=0.369,mae=0.41\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.337,mae=0.385\n",
      "4\n",
      "ft=20: mse=0.334,mae=0.383\n",
      "4\n",
      "ft=40: mse=0.335,mae=0.383\n",
      "4\n",
      "ft=60: mse=0.339,mae=0.386\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.411,mae=0.452\n",
      "4\n",
      "ft=20: mse=0.404,mae=0.446\n",
      "4\n",
      "ft=40: mse=0.404,mae=0.446\n",
      "4\n",
      "ft=60: mse=0.402,mae=0.445\n",
      "==================================================\n",
      "XY_ablation_FC2_O_X avgpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.393,mae=0.435\n",
      "4\n",
      "ft=20: mse=0.396,mae=0.439\n",
      "4\n",
      "ft=40: mse=0.397,mae=0.44\n",
      "4\n",
      "ft=60: mse=0.397,mae=0.439\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.394,mae=0.436\n",
      "4\n",
      "ft=20: mse=0.4,mae=0.442\n",
      "4\n",
      "ft=40: mse=0.392,mae=0.435\n",
      "4\n",
      "ft=60: mse=0.395,mae=0.438\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O maxpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.398,mae=0.441\n",
      "4\n",
      "ft=20: mse=0.402,mae=0.445\n",
      "4\n",
      "ft=40: mse=0.411,mae=0.453\n",
      "4\n",
      "ft=60: mse=0.411,mae=0.453\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.354,mae=0.381\n",
      "4\n",
      "ft=20: mse=0.356,mae=0.384\n",
      "4\n",
      "ft=40: mse=0.354,mae=0.384\n",
      "4\n",
      "ft=60: mse=0.355,mae=0.384\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.358,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.36,mae=0.382\n",
      "4\n",
      "ft=40: mse=0.357,mae=0.383\n",
      "4\n",
      "ft=60: mse=0.36,mae=0.385\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='max', dim=128)\n",
    "\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=64)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X maxpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.703,mae=0.562\n",
      "4\n",
      "ft=20: mse=0.707,mae=0.562\n",
      "4\n",
      "ft=40: mse=0.708,mae=0.562\n",
      "4\n",
      "ft=60: mse=0.711,mae=0.565\n",
      "==================================================\n",
      "XY_ablation_FC2_O_X avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.676,mae=0.552\n",
      "4\n",
      "ft=20: mse=0.676,mae=0.55\n",
      "4\n",
      "ft=40: mse=0.676,mae=0.552\n",
      "4\n",
      "ft=60: mse=0.677,mae=0.546\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O avgpool 64\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.681,mae=0.553\n",
      "4\n",
      "ft=20: mse=0.678,mae=0.551\n",
      "4\n",
      "ft=40: mse=0.677,mae=0.551\n",
      "4\n",
      "ft=60: mse=0.677,mae=0.552\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.7,mae=0.56\n",
      "4\n",
      "ft=20: mse=0.703,mae=0.562\n",
      "4\n",
      "ft=40: mse=0.706,mae=0.561\n",
      "4\n",
      "ft=60: mse=0.704,mae=0.56\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.258,mae=0.319\n",
      "4\n",
      "ft=20: mse=0.258,mae=0.32\n",
      "4\n",
      "ft=40: mse=0.259,mae=0.321\n",
      "4\n",
      "ft=60: mse=0.258,mae=0.32\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.255,mae=0.314\n",
      "4\n",
      "ft=20: mse=0.253,mae=0.312\n",
      "4\n",
      "ft=40: mse=0.253,mae=0.312\n",
      "4\n",
      "ft=60: mse=0.253,mae=0.312\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 94,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2_O_X maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.342,mae=0.392\n",
      "4\n",
      "ft=20: mse=0.343,mae=0.392\n",
      "4\n",
      "ft=40: mse=0.34,mae=0.39\n",
      "4\n",
      "ft=60: mse=0.34,mae=0.389\n",
      "==================================================\n",
      "XY_ablation_FC2_O_X avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.341,mae=0.383\n",
      "4\n",
      "ft=20: mse=0.34,mae=0.384\n",
      "4\n",
      "ft=40: mse=0.339,mae=0.386\n",
      "4\n",
      "ft=60: mse=0.341,mae=0.387\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.342,mae=0.388\n",
      "4\n",
      "ft=20: mse=0.337,mae=0.385\n",
      "4\n",
      "ft=40: mse=0.34,mae=0.387\n",
      "4\n",
      "ft=60: mse=0.34,mae=0.389\n",
      "==================================================\n",
      "XY_ablation_FC2_O_O maxpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.339,mae=0.391\n",
      "4\n",
      "ft=20: mse=0.34,mae=0.392\n",
      "4\n",
      "ft=40: mse=0.339,mae=0.391\n",
      "4\n",
      "ft=60: mse=0.34,mae=0.391\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=64)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=64)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLPMixer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.407,mae=0.427\n",
      "4\n",
      "ft=20: mse=0.407,mae=0.427\n",
      "4\n",
      "ft=40: mse=0.409,mae=0.428\n",
      "4\n",
      "ft=60: mse=0.408,mae=0.428\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.406,mae=0.422\n",
      "4\n",
      "ft=20: mse=0.406,mae=0.422\n",
      "4\n",
      "ft=40: mse=0.407,mae=0.423\n",
      "4\n",
      "ft=60: mse=0.405,mae=0.422\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X maxpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.434,mae=0.45\n",
      "ft=20: mse=0.43,mae=0.447\n",
      "ft=40: mse=0.436,mae=0.452\n",
      "ft=60: mse=0.435,mae=0.452\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X avgpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.421,mae=0.44\n",
      "ft=20: mse=0.439,mae=0.452\n",
      "ft=40: mse=0.447,mae=0.456\n",
      "ft=60: mse=0.442,mae=0.453\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O avgpool 128\n",
      "==================================================\n",
      "ft=10: mse=0.442,mae=0.448\n",
      "ft=20: mse=0.455,mae=0.455\n",
      "ft=40: mse=0.429,mae=0.442\n",
      "ft=60: mse=0.429,mae=0.441\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O maxpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.441,mae=0.458\n",
      "ft=20: mse=0.434,mae=0.452\n",
      "ft=40: mse=0.43,mae=0.451\n",
      "ft=60: mse=0.442,mae=0.455\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=64)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.369,mae=0.402\n",
      "4\n",
      "ft=20: mse=0.369,mae=0.404\n",
      "4\n",
      "ft=40: mse=0.366,mae=0.403\n",
      "4\n",
      "ft=60: mse=0.364,mae=0.401\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.333,mae=0.383\n",
      "4\n",
      "ft=20: mse=0.34,mae=0.387\n",
      "4\n",
      "ft=40: mse=0.336,mae=0.385\n",
      "4\n",
      "ft=60: mse=0.339,mae=0.387\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=64)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X maxpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.366,mae=0.405\n",
      "ft=20: mse=0.364,mae=0.403\n",
      "ft=40: mse=0.359,mae=0.402\n",
      "ft=60: mse=0.358,mae=0.401\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X avgpool 32\n",
      "==================================================\n",
      "ft=10: mse=0.363,mae=0.402\n",
      "ft=20: mse=0.364,mae=0.403\n",
      "ft=40: mse=0.364,mae=0.404\n",
      "ft=60: mse=0.362,mae=0.403\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O avgpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.375,mae=0.411\n",
      "ft=20: mse=0.369,mae=0.407\n",
      "ft=40: mse=0.368,mae=0.406\n",
      "ft=60: mse=0.365,mae=0.404\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O maxpool 128\n",
      "==================================================\n",
      "ft=10: mse=0.379,mae=0.418\n",
      "ft=20: mse=0.382,mae=0.42\n",
      "ft=40: mse=0.392,mae=0.42\n",
      "ft=60: mse=0.384,mae=0.419\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=64)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.355,mae=0.385\n",
      "4\n",
      "ft=20: mse=0.355,mae=0.386\n",
      "4\n",
      "ft=40: mse=0.354,mae=0.385\n",
      "4\n",
      "ft=60: mse=0.356,mae=0.387\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O max 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.353,mae=0.378\n",
      "4\n",
      "ft=20: mse=0.352,mae=0.38\n",
      "4\n",
      "ft=40: mse=0.352,mae=0.38\n",
      "4\n",
      "ft=60: mse=0.351,mae=0.381\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=128)\n",
    "\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X maxpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.363,mae=0.399\n",
      "ft=20: mse=0.365,mae=0.4\n",
      "ft=40: mse=0.363,mae=0.398\n",
      "ft=60: mse=0.361,mae=0.397\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X avgpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.367,mae=0.404\n",
      "ft=20: mse=0.365,mae=0.403\n",
      "ft=40: mse=0.361,mae=0.399\n",
      "ft=60: mse=0.37,mae=0.404\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O avgpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.401,mae=0.422\n",
      "ft=20: mse=0.372,mae=0.406\n",
      "ft=40: mse=0.373,mae=0.406\n",
      "ft=60: mse=0.373,mae=0.406\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O maxpool 64\n",
      "==================================================\n",
      "ft=10: mse=0.418,mae=0.434\n",
      "ft=20: mse=0.411,mae=0.428\n",
      "ft=40: mse=0.383,mae=0.413\n",
      "ft=60: mse=0.375,mae=0.409\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=128)\n",
    "\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=32)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm1', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X max 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.261,mae=0.321\n",
      "4\n",
      "ft=20: mse=0.259,mae=0.318\n",
      "4\n",
      "ft=40: mse=0.259,mae=0.318\n",
      "4\n",
      "ft=60: mse=0.258,mae=0.318\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.253,mae=0.314\n",
      "4\n",
      "ft=20: mse=0.254,mae=0.314\n",
      "4\n",
      "ft=40: mse=0.253,mae=0.313\n",
      "4\n",
      "ft=60: mse=0.253,mae=0.312\n"
     ]
    }
   ],
   "source": [
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='max', dim=128)\n",
    "\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=64)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='max', dim=128)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X maxpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.266,mae=0.33\n",
      "4\n",
      "ft=20: mse=0.265,mae=0.329\n",
      "4\n",
      "ft=40: mse=0.266,mae=0.329\n",
      "4\n",
      "ft=60: mse=0.267,mae=0.329\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_X avgpool 64\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.264,mae=0.328\n",
      "4\n",
      "ft=20: mse=0.265,mae=0.331\n",
      "4\n",
      "ft=40: mse=0.26,mae=0.325\n",
      "4\n",
      "ft=60: mse=0.264,mae=0.328\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O avgpool 32\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.274,mae=0.335\n",
      "4\n",
      "ft=20: mse=0.269,mae=0.33\n",
      "4\n",
      "ft=40: mse=0.268,mae=0.33\n",
      "4\n",
      "ft=60: mse=0.269,mae=0.331\n",
      "==================================================\n",
      "XY_ablation_FC2mixer_O_O maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.281,mae=0.341\n",
      "4\n",
      "ft=20: mse=0.284,mae=0.345\n",
      "4\n",
      "ft=40: mse=0.285,mae=0.346\n",
      "3\n",
      "ft=60: mse=0.257,mae=0.327\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='maxpool', dim=128)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=32)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_X', agg_type='avgpool', dim=128)\n",
    "#\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=64)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='avgpool', dim=128)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=32)\n",
    "#get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=64)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_FC2mixer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  TRANSFORMER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.426,mae=0.441\n",
      "4\n",
      "ft=20: mse=0.425,mae=0.439\n",
      "4\n",
      "ft=40: mse=0.425,mae=0.439\n",
      "4\n",
      "ft=60: mse=0.422,mae=0.437\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O max 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.419,mae=0.439\n",
      "4\n",
      "ft=20: mse=0.413,mae=0.432\n",
      "3\n",
      "ft=40: mse=0.411,mae=0.432\n",
      "4\n",
      "ft=60: mse=0.417,mae=0.436\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_X', agg_type='max', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.421,mae=0.448\n",
      "3\n",
      "ft=20: mse=0.424,mae=0.45\n",
      "3\n",
      "ft=40: mse=0.434,mae=0.459\n",
      "4\n",
      "ft=60: mse=0.423,mae=0.451\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_X avgpool 128\n",
      "==================================================\n",
      "2\n",
      "ft=10: mse=0.455,mae=0.461\n",
      "3\n",
      "ft=20: mse=0.453,mae=0.464\n",
      "3\n",
      "ft=40: mse=0.425,mae=0.45\n",
      "3\n",
      "ft=60: mse=0.457,mae=0.468\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.529,mae=0.5\n",
      "4\n",
      "ft=20: mse=0.504,mae=0.479\n",
      "3\n",
      "ft=40: mse=0.501,mae=0.476\n",
      "2\n",
      "ft=60: mse=0.558,mae=0.515\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O maxpool 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.517,mae=0.495\n",
      "4\n",
      "ft=20: mse=0.466,mae=0.478\n",
      "3\n",
      "ft=40: mse=0.462,mae=0.475\n",
      "3\n",
      "ft=60: mse=0.468,mae=0.479\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_X', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth1', type_='XY_ablation_Transformer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.35,mae=0.393\n",
      "4\n",
      "ft=20: mse=0.352,mae=0.393\n",
      "4\n",
      "ft=40: mse=0.351,mae=0.391\n",
      "4\n",
      "ft=60: mse=0.353,mae=0.394\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.342,mae=0.388\n",
      "4\n",
      "ft=20: mse=0.346,mae=0.391\n",
      "4\n",
      "ft=40: mse=0.348,mae=0.393\n",
      "4\n",
      "ft=60: mse=0.342,mae=0.387\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_X', agg_type='max', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 95,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X maxpool 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.397,mae=0.431\n",
      "3\n",
      "ft=20: mse=0.377,mae=0.419\n",
      "3\n",
      "ft=40: mse=0.367,mae=0.411\n",
      "4\n",
      "ft=60: mse=0.373,mae=0.413\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_X avgpool 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.393,mae=0.427\n",
      "4\n",
      "ft=20: mse=0.39,mae=0.42\n",
      "4\n",
      "ft=40: mse=0.387,mae=0.419\n",
      "4\n",
      "ft=60: mse=0.376,mae=0.413\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.414,mae=0.438\n",
      "3\n",
      "ft=20: mse=0.373,mae=0.416\n",
      "2\n",
      "ft=40: mse=0.359,mae=0.406\n",
      "4\n",
      "ft=60: mse=0.384,mae=0.424\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O maxpool 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.388,mae=0.426\n",
      "4\n",
      "ft=20: mse=0.387,mae=0.421\n",
      "4\n",
      "ft=40: mse=0.399,mae=0.426\n",
      "4\n",
      "ft=60: mse=0.395,mae=0.424\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_X', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='etth2', type_='XY_ablation_Transformer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X max 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.35,mae=0.389\n",
      "4\n",
      "ft=20: mse=0.348,mae=0.385\n",
      "4\n",
      "ft=40: mse=0.35,mae=0.386\n",
      "4\n",
      "ft=60: mse=0.349,mae=0.388\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.35,mae=0.387\n",
      "4\n",
      "ft=20: mse=0.351,mae=0.387\n",
      "4\n",
      "ft=40: mse=0.35,mae=0.387\n",
      "4\n",
      "ft=60: mse=0.349,mae=0.386\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_X', agg_type='max', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.378,mae=0.403\n",
      "4\n",
      "ft=20: mse=0.373,mae=0.401\n",
      "4\n",
      "ft=40: mse=0.376,mae=0.404\n",
      "4\n",
      "ft=60: mse=0.373,mae=0.401\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_X avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.375,mae=0.412\n",
      "4\n",
      "ft=20: mse=0.378,mae=0.412\n",
      "4\n",
      "ft=40: mse=0.372,mae=0.41\n",
      "4\n",
      "ft=60: mse=0.37,mae=0.406\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.49,mae=0.463\n",
      "3\n",
      "ft=20: mse=0.52,mae=0.483\n",
      "4\n",
      "ft=40: mse=0.455,mae=0.45\n",
      "4\n",
      "ft=60: mse=0.449,mae=0.446\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.372,mae=0.406\n",
      "3\n",
      "ft=20: mse=0.397,mae=0.419\n",
      "4\n",
      "ft=40: mse=0.38,mae=0.406\n",
      "4\n",
      "ft=60: mse=0.38,mae=0.405\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_X', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='ettm1', type_='XY_ablation_Transformer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.271,mae=0.329\n",
      "4\n",
      "ft=20: mse=0.273,mae=0.33\n",
      "4\n",
      "ft=40: mse=0.274,mae=0.331\n",
      "4\n",
      "ft=60: mse=0.275,mae=0.331\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O max 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.257,mae=0.32\n",
      "3\n",
      "ft=20: mse=0.269,mae=0.327\n",
      "4\n",
      "ft=40: mse=0.256,mae=0.318\n",
      "4\n",
      "ft=60: mse=0.255,mae=0.318\n"
     ]
    }
   ],
   "source": [
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_X', agg_type='max', dim=128)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_O', agg_type='max', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "XY_ablation_Transformer_O_X maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.275,mae=0.336\n",
      "4\n",
      "ft=20: mse=0.274,mae=0.333\n",
      "3\n",
      "ft=40: mse=0.304,mae=0.355\n",
      "4\n",
      "ft=60: mse=0.273,mae=0.334\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_X avgpool 128\n",
      "==================================================\n",
      "3\n",
      "ft=10: mse=0.334,mae=0.375\n",
      "4\n",
      "ft=20: mse=0.298,mae=0.349\n",
      "4\n",
      "ft=40: mse=0.291,mae=0.346\n",
      "4\n",
      "ft=60: mse=0.305,mae=0.355\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O avgpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.298,mae=0.355\n",
      "3\n",
      "ft=20: mse=0.251,mae=0.325\n",
      "4\n",
      "ft=40: mse=0.296,mae=0.353\n",
      "4\n",
      "ft=60: mse=0.299,mae=0.355\n",
      "==================================================\n",
      "XY_ablation_Transformer_O_O maxpool 128\n",
      "==================================================\n",
      "4\n",
      "ft=10: mse=0.317,mae=0.372\n",
      "3\n",
      "ft=20: mse=0.274,mae=0.345\n",
      "4\n",
      "ft=40: mse=0.304,mae=0.362\n",
      "4\n",
      "ft=60: mse=0.309,mae=0.365\n"
     ]
    }
   ],
   "source": [
    "\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_X', agg_type='maxpool', dim=128)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_X', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_O', agg_type='avgpool', dim=128)\n",
    "get_mse_mae(data='ettm2', type_='XY_ablation_Transformer_O_O', agg_type='maxpool', dim=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets4_Linear_OO_con = np.array([0.408,0.338,0.358,0.254])\n",
    "datasets4_Linear_OO_avg = np.array([0.642,0.389,0.678,0.339])\n",
    "datasets4_Linear_OO_max = np.array([0.660,0.389,0.695,0.341])\n",
    "\n",
    "datasets4_Linear_OX_con = np.array([0.408,0.341,0.359,0.254])\n",
    "datasets4_Linear_OX_avg = np.array([0.642,0.389,0.678,0.339])\n",
    "datasets4_Linear_OX_max = np.array([0.658,0.402,0.695,0.343])\n",
    "\n",
    "######################################################\n",
    "\n",
    "datasets4_MLP_OO_con = np.array([0.407,0.334,0.356,0.253])\n",
    "datasets4_MLP_OO_avg = np.array([0.632,0.392,0.677,0.337])\n",
    "datasets4_MLP_OO_max = np.array([0.651,0.398,0.700,0.339])\n",
    "\n",
    "datasets4_MLP_OX_con = np.array([0.418,0.361,0.356,0.258])\n",
    "datasets4_MLP_OX_avg = np.array([0.645,0.393,0.676,0.339])\n",
    "datasets4_MLP_OX_max = np.array([0.647,0.402,0.703,0.340])\n",
    "\n",
    "######################################################\n",
    "\n",
    "datasets4_TSmixer_OO_con = np.array([0.409,0.352,0.352,0.256])\n",
    "datasets4_TSmixer_OO_avg = np.array([0.429,0.365,0.372,0.268])\n",
    "datasets4_TSmixer_OO_max = np.array([0.430,0.379,0.375,0.257])\n",
    "\n",
    "datasets4_TSmixer_OX_con = np.array([0.420,0.389,0.354,0.263])\n",
    "datasets4_TSmixer_OX_avg = np.array([0.421,0.362,0.361,0.260])\n",
    "datasets4_TSmixer_OX_max = np.array([0.430,0.358,0.361,0.265])\n",
    "\n",
    "######################################################\n",
    "\n",
    "datasets4_Trans_OO_con = np.array([])\n",
    "datasets4_Trans_OO_avg = np.array([])\n",
    "datasets4_Trans_OO_max = np.array([])\n",
    "\n",
    "datasets4_Trans_OX_con = np.array([])\n",
    "datasets4_Trans_OX_avg = np.array([])\n",
    "datasets4_Trans_OX_max = np.array([])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
