{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['SleepEEG2EMG', 'SleepEEG2Epilepsy', 'SleepEEG2FD_B', 'SleepEEG2Gesture']"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "source_dataset = 'SleepEEG'\n",
    "target_dataset = ['EMG','Epilepsy','FD_B','Gesture']\n",
    "aggregate_type_list = ['max','avg','concat']\n",
    "datasets = [source_dataset+'2'+i for i in target_dataset]\n",
    "datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(source_data, target_data, agg, load_epoch, D, C, P, finetune_epoch, type2):\n",
    "    data = f'{source_data}2{target_data}'\n",
    "    DATA_PATH = f'../saved_models/{data}/masked_patchtst_sim_half_v3_mean_FC2_R/based_model/{agg}' \n",
    "    \n",
    "    pattern = f'_D{D}_'\n",
    "    pattern2 = f'cw{C}_patch{P}_stride{P}'\n",
    "    settings = os.listdir(DATA_PATH)\n",
    "    \n",
    "    settings = [s for s in settings if pattern in s]\n",
    "    settings = [s for s in settings if pattern2 in s]\n",
    "    \n",
    "    ################## HARD CL #######################\n",
    "    settings = [s for s in settings if 'tau' not in s]\n",
    "    ##################################################\n",
    "    \n",
    "    ft_class_dict = dict()\n",
    "    ft_class_dict['EMG'] = 3\n",
    "    ft_class_dict['FD_B'] = 3\n",
    "    ft_class_dict['Gesture'] = 8\n",
    "    ft_class_dict['Gesture2'] = 8\n",
    "    ft_class_dict['Epilepsy'] = 2\n",
    "    \n",
    "    target = ft_class_dict[target_data]\n",
    "    \n",
    "    n_done = 0\n",
    "    n_undone = 0\n",
    "    \n",
    "    result_dict = dict()\n",
    "    for setting in settings:\n",
    "        try:\n",
    "            if type2==1:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}type2_acc.csv')\n",
    "            else:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}_acc.csv')\n",
    "            \n",
    "            \n",
    "            #print(FILE_PATH)\n",
    "            result = pd.read_csv(FILE_PATH)\n",
    "            result_dict[setting] = result['acc'][0]\n",
    "            #acc,weighted_F1,micro_F1,macro_F1,precision,recall\n",
    "            n_done +=1\n",
    "        except:\n",
    "            n_undone +=1\n",
    "    result_dict = dict(sorted(result_dict.items(), key=lambda x: x[1]))    \n",
    "    return result_dict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(source_data, target_data, agg, load_epoch, D, C, P, finetune_epoch, type2):\n",
    "    data = f'{source_data}2{target_data}'\n",
    "    DATA_PATH = f'../saved_models/{data}/masked_patchtst_sim_half_v3_mean_FC2_R/based_model/{agg}' \n",
    "    \n",
    "    pattern = f'_D{D}_'\n",
    "    pattern2 = f'cw{C}_patch{P}_stride{P}'\n",
    "    settings = os.listdir(DATA_PATH)\n",
    "    \n",
    "    settings = [s for s in settings if pattern in s]\n",
    "    settings = [s for s in settings if pattern2 in s]\n",
    "    \n",
    "    ################## HARD CL #######################\n",
    "    settings = [s for s in settings if 'tau' not in s]\n",
    "    ##################################################\n",
    "    \n",
    "    ft_class_dict = dict()\n",
    "    ft_class_dict['EMG'] = 3\n",
    "    ft_class_dict['FD_B'] = 3\n",
    "    ft_class_dict['Gesture'] = 8\n",
    "    ft_class_dict['Gesture2'] = 8\n",
    "    ft_class_dict['Epilepsy'] = 2\n",
    "    \n",
    "    target = ft_class_dict[target_data]\n",
    "    \n",
    "    n_done = 0\n",
    "    n_undone = 0\n",
    "    \n",
    "    result_dict = dict()\n",
    "    for setting in settings:\n",
    "        try:\n",
    "            if type2==1:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}type2_acc.csv')\n",
    "            else:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}_acc.csv')\n",
    "            \n",
    "            \n",
    "            #print(FILE_PATH)\n",
    "            result = pd.read_csv(FILE_PATH)\n",
    "            print(result)\n",
    "            fadsads\n",
    "            result_dict[setting] = [result['acc'][0],\n",
    "                                    result['weighted_F1'][0],\n",
    "                                    result['micro_F1'][0],\n",
    "                                    result['macro_F1'][0],\n",
    "                                    result['precision'][0],\n",
    "                                    result['recall'][0]]\n",
    "            n_done +=1\n",
    "        except:\n",
    "            n_undone +=1\n",
    "    #result_dict = dict(sorted(result_dict.items(), key=lambda x: x[1]))    \n",
    "    return result_dict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_data = 'SleepEEG'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "#arch1 = [4,16,128]\n",
    "#arch2 = [8,128,256]\n",
    "#arch3 = [16,128,512]\n",
    "\n",
    "arch_list = [32,64,128,256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_struc1 = [16,176]\n",
    "data_struc2 = [8,176]\n",
    "data_struc3 = [4,176]\n",
    "\n",
    "data_struc_list = [data_struc1, data_struc2, data_struc3]\n",
    "#data_struc_list = [data_struc2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "finetune_epoch_list = [100,200,300,400]\n",
    "\n",
    "load_epoch_list = [20,40,60,80,100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "ag_list = ['concat','avg','max']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "type_list = [0,1]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 반드시 같아야 :\n",
    "- ag_list\n",
    "- load_epoch_list\n",
    "- data_struc_list\n",
    "- arch_list\n",
    "- type2"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 달라도 OK\n",
    "- target_dat_list\n",
    "- finetune_epoch_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_result(source_data, target_data, agg, load_epoch, D, C, P, finetune_epoch, type2):\n",
    "    data = f'{source_data}2{target_data}'\n",
    "    DATA_PATH = f'../saved_models/{data}/masked_patchtst_sim_half_v3_mean_FC2_R/based_model/{agg}' \n",
    "    \n",
    "    pattern = f'_D{D}_'\n",
    "    pattern2 = f'cw{C}_patch{P}_stride{P}'\n",
    "    settings = os.listdir(DATA_PATH)\n",
    "    \n",
    "    settings = [s for s in settings if pattern in s]\n",
    "    settings = [s for s in settings if pattern2 in s]\n",
    "    \n",
    "    ################## HARD CL #######################\n",
    "    settings = [s for s in settings if 'tau' not in s]\n",
    "    ##################################################\n",
    "    \n",
    "    ft_class_dict = dict()\n",
    "    ft_class_dict['EMG'] = 3\n",
    "    ft_class_dict['FD_B'] = 3\n",
    "    ft_class_dict['Gesture'] = 8\n",
    "    ft_class_dict['Gesture2'] = 8\n",
    "    ft_class_dict['Epilepsy'] = 2\n",
    "    \n",
    "    target = ft_class_dict[target_data]\n",
    "    \n",
    "    n_done = 0\n",
    "    n_undone = 0\n",
    "    \n",
    "    result_dict = dict()\n",
    "    for setting in settings:\n",
    "        try:\n",
    "            if type2==1:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}type2_acc.csv')\n",
    "            else:\n",
    "                FILE_PATH = os.path.join(DATA_PATH, setting, f'tw{target}_ft_ep{finetune_epoch}_model1_load_ep{load_epoch}_acc.csv')\n",
    "            \n",
    "            \n",
    "            #print(FILE_PATH)\n",
    "            result = pd.read_csv(FILE_PATH)\n",
    "\n",
    "            result_dict[setting] = [result['acc'][0],\n",
    "                                    result['weighted_F1'][0],\n",
    "                                    result['micro_F1'][0],\n",
    "                                    result['macro_F1'][0],\n",
    "                                    result['precision'][0],\n",
    "                                    result['recall'][0]]\n",
    "            n_done +=1\n",
    "\n",
    "        except:\n",
    "            n_undone +=1\n",
    "\n",
    "    result_dict = dict(result_dict.items())    \n",
    "\n",
    "    return result_dict "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\"\\n#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\\n#target_data_list = ['Epilepsy','FD_B','EMG']\\ntarget_data_list = ['Epilepsy','FD_B','Gesture','EMG']\\n\\nag_list = ['concat','avg','max']\\n#ag_list = ['max']\\n\\n\\nbest_acc_summary = 0\\nbest_acc = 0\\n\\nacc_list = []\\nacc_total_list = []\\nstruc_list = []\\n\\nfor ag in ag_list:\\n    print('='*50)\\n    print('='*50)\\n    print(ag)\\n    print('='*50)\\n    print('='*50)\\n    for ep in load_epoch_list:\\n        for type_ in type_list:\\n            for D in arch_list:\\n                for data_struc in data_struc_list:\\n                    P,C = data_struc\\n                    S = P\\n                    num_patch = int(C/S)\\n                    print(ag, ep, D, data_struc)\\n                    acc_total = []\\n                    for target_data in target_data_list:\\n                        acc = 0\\n                        for ft_epoch in finetune_epoch_list:\\n                            #------------------------------------#\\n                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\\n                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\\n                            \\n                            try:\\n                                if list(result_concat.values())[-1] > acc:\\n                                    acc = list(result_concat.values())[-1]\\n                                #print(list(result_concat.values())[-1].round(3))\\n                            except:\\n                                pass\\n                        acc_total.append(acc)\\n                    summary = np.mean(acc_total)\\n                    print(acc_total, '-------------', summary.round(3))\\n                    acc_list.append(summary)\\n                    acc_total_list.append(acc_total)\\n                    struc_list.append([ag, ep, type_, D, data_struc])\\n                    \\n                    if best_acc_summary<summary:\\n                        best_acc_summary = summary\\n                        best_acc = acc_total\\n                        best_struc = [ag, ep, D, data_struc]\\n                    \\n                    print('--------------')\\n\""
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'''\n",
    "#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "#target_data_list = ['Epilepsy','FD_B','EMG']\n",
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "\n",
    "ag_list = ['concat','avg','max']\n",
    "#ag_list = ['max']\n",
    "\n",
    "\n",
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "acc_list = []\n",
    "acc_total_list = []\n",
    "struc_list = []\n",
    "\n",
    "for ag in ag_list:\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    print(ag)\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    for ep in load_epoch_list:\n",
    "        for type_ in type_list:\n",
    "            for D in arch_list:\n",
    "                for data_struc in data_struc_list:\n",
    "                    P,C = data_struc\n",
    "                    S = P\n",
    "                    num_patch = int(C/S)\n",
    "                    print(ag, ep, D, data_struc)\n",
    "                    acc_total = []\n",
    "                    for target_data in target_data_list:\n",
    "                        acc = 0\n",
    "                        for ft_epoch in finetune_epoch_list:\n",
    "                            #------------------------------------#\n",
    "                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\n",
    "                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\n",
    "                            \n",
    "                            try:\n",
    "                                if list(result_concat.values())[-1] > acc:\n",
    "                                    acc = list(result_concat.values())[-1]\n",
    "                                #print(list(result_concat.values())[-1].round(3))\n",
    "                            except:\n",
    "                                pass\n",
    "                        acc_total.append(acc)\n",
    "                    summary = np.mean(acc_total)\n",
    "                    print(acc_total, '-------------', summary.round(3))\n",
    "                    acc_list.append(summary)\n",
    "                    acc_total_list.append(acc_total)\n",
    "                    struc_list.append([ag, ep, type_, D, data_struc])\n",
    "                    \n",
    "                    if best_acc_summary<summary:\n",
    "                        best_acc_summary = summary\n",
    "                        best_acc = acc_total\n",
    "                        best_struc = [ag, ep, D, data_struc]\n",
    "                    \n",
    "                    print('--------------')\n",
    "'''                    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_struc1 = [16,176]\n",
    "data_struc2 = [8,176]\n",
    "data_struc3 = [4,176]\n",
    "\n",
    "data_struc_list = [data_struc1, data_struc2, data_struc3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "==================================================\n",
      "concat\n",
      "==================================================\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "#target_data_list = ['Epilepsy','FD_B','EMG']\n",
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "\n",
    "#ag_list = ['concat','avg','max']\n",
    "ag_list = ['concat']\n",
    "#ag_list = ['max']\n",
    "\n",
    "\n",
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "\n",
    "\n",
    "acc_list = []\n",
    "weighted_F1_list = []\n",
    "micro_F1_list = []\n",
    "macro_F1_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "\n",
    "acc_total_list = []\n",
    "weighted_F1_total_list = []\n",
    "micro_F1_total_list = []\n",
    "macro_F1_total_list = []\n",
    "precision_total_list = []\n",
    "recall_total_list = []\n",
    "\n",
    "struc_list = []\n",
    "\n",
    "SOTA_acc = np.array([95.49,69.40,80.00,97.56])\n",
    "SOTA_f1 = np.array([92.81,75.11,78.67,98.14])\n",
    "SOTA_prec = np.array([94.56,75.59,79.03,98.33])\n",
    "SOTA_rec = np.array([92.28,76.41,80.00,98.04])\n",
    "\n",
    "SOTA_acc /= 100\n",
    "SOTA_f1 /= 100\n",
    "SOTA_prec /= 100\n",
    "SOTA_rec /= 100\n",
    "\n",
    "SOTA_acc -= 0.01\n",
    "SOTA_f1 -= 0.01\n",
    "SOTA_prec -= 0.01\n",
    "SOTA_acc -= 0.01\n",
    "\n",
    "for ag in ag_list:\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    print(ag)\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    for ep in load_epoch_list:\n",
    "        for type_ in type_list:\n",
    "            for D in arch_list:\n",
    "                for data_struc in data_struc_list:\n",
    "                    P,C = data_struc\n",
    "                    S = P\n",
    "                    num_patch = int(C/S)\n",
    "                    #print(ag, ep, D, data_struc)\n",
    "                    acc_total = []\n",
    "                    weighted_F1_total = []\n",
    "                    micro_F1_total = []\n",
    "                    macro_F1_total = []\n",
    "                    precision_total = []\n",
    "                    recall_total = []\n",
    "                    for target_data in target_data_list[2:3]:\n",
    "                        acc = 0\n",
    "                        weighted_F1 = 0\n",
    "                        micro_F1 = 0\n",
    "                        macro_F1 = 0\n",
    "                        precision = 0\n",
    "                        recall = 0\n",
    "                        for ft_epoch in finetune_epoch_list:\n",
    "                            #------------------------------------#\n",
    "                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\n",
    "                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\n",
    "                            try:\n",
    "                                results = list(result_concat.values())[0]\n",
    "                                acc_temp,weighted_F1_temp,micro_F1_temp,macro_F1_temp,precision_temp,recall_temp = results\n",
    "                                if acc_temp > acc:\n",
    "                                    acc = acc_temp\n",
    "                                    weighted_F1 = weighted_F1_temp\n",
    "                                    micro_F1 = micro_F1_temp\n",
    "                                    macro_F1 = macro_F1_temp\n",
    "                                    precision = precision_temp\n",
    "                                    recall = recall_temp\n",
    "                            except:\n",
    "                                pass\n",
    "                            \n",
    "                        acc_total.append(acc)\n",
    "                        weighted_F1_total.append(weighted_F1)\n",
    "                        micro_F1_total.append(micro_F1)\n",
    "                        macro_F1_total.append(macro_F1)\n",
    "                        precision_total.append(precision)\n",
    "                        recall_total.append(recall)\n",
    "                    \n",
    "                    acc_summary = np.mean(acc_total)\n",
    "                    weighted_F1_summary = np.mean(weighted_F1_total)\n",
    "                    micro_F1_summary = np.mean(micro_F1_total)\n",
    "                    macro_F1_summary = np.mean(macro_F1_total)\n",
    "                    precision_summary = np.mean(precision_total)\n",
    "                    recall_summary = np.mean(recall_total)\n",
    "                    plus_count = 0\n",
    "                    acc_diff = acc_total - SOTA_acc\n",
    "                    #################################################################\n",
    "                    f1_diff = weighted_F1_total - SOTA_f1\n",
    "                    #f1_diff = micro_F1_summary - SOTA_f1\n",
    "                    #f1_diff = macro_F1_summary - SOTA_f1\n",
    "                    #################################################################\n",
    "                    prec_diff = precision_total - SOTA_prec\n",
    "                    rec_diff = recall_total - SOTA_rec\n",
    "                    plus_count += np.sum(acc_diff>0)\n",
    "                    plus_count += np.sum(f1_diff>0)\n",
    "                    plus_count += np.sum(prec_diff>0)\n",
    "                    plus_count += np.sum(rec_diff>0)\n",
    "                    '''\n",
    "                    if plus_count>0:\n",
    "                        print(ag, ep, D, type_, data_struc)\n",
    "                        print(acc_total, acc_diff,'-------------', acc_summary.round(3))\n",
    "                        #print(micro_F1_total, '-------------', micro_F1_summary.round(3))\n",
    "                        #print(macro_F1_total, '-------------', macro_F1_summary.round(3))\n",
    "                        print(precision_total, prec_diff, '-------------', precision_summary.round(3))\n",
    "                        print(recall_total, rec_diff,'-------------', recall_summary.round(3))\n",
    "                        print(weighted_F1_total, f1_diff,'-------------', weighted_F1_summary.round(3))\n",
    "                        print(plus_count)\n",
    "                    '''\n",
    "                    acc_list.append(acc_summary)\n",
    "                    weighted_F1_list.append(weighted_F1_summary)\n",
    "                    micro_F1_list.append(micro_F1_summary)\n",
    "                    macro_F1_list.append(macro_F1_summary)\n",
    "                    precision_list.append(precision_summary)\n",
    "                    recall_list.append(recall_summary)\n",
    "                    \n",
    "                    acc_total_list.append(acc_total)\n",
    "                    weighted_F1_total_list.append(weighted_F1_total)\n",
    "                    micro_F1_total_list.append(micro_F1_total)\n",
    "                    macro_F1_total_list.append(macro_F1_total)\n",
    "                    precision_total_list.append(precision_total)\n",
    "                    recall_total_list.append(recall_total)\n",
    "                    \n",
    "                    struc_list.append([ag, ep, type_, D, data_struc])\n",
    "                    \n",
    "                    if best_acc_summary<acc_summary:\n",
    "                        best_acc_summary = acc_summary\n",
    "                        best_acc = acc_total\n",
    "                        best_struc = [ag, ep, D, data_struc]\n",
    "                    \n",
    "data3 = pd.DataFrame({'structure':struc_list,\n",
    "              'acc':acc_total_list,\n",
    "              'F1':weighted_F1_total_list,\n",
    "              'prec':precision_total_list,\n",
    "              'rec':recall_total_list})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "==================================================\n",
      "max\n",
      "==================================================\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "#target_data_list = ['Epilepsy','FD_B','EMG']\n",
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "\n",
    "#ag_list = ['concat','avg','max']\n",
    "ag_list = ['max']\n",
    "\n",
    "\n",
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "\n",
    "\n",
    "acc_list = []\n",
    "weighted_F1_list = []\n",
    "micro_F1_list = []\n",
    "macro_F1_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "\n",
    "acc_total_list = []\n",
    "weighted_F1_total_list = []\n",
    "micro_F1_total_list = []\n",
    "macro_F1_total_list = []\n",
    "precision_total_list = []\n",
    "recall_total_list = []\n",
    "\n",
    "struc_list = []\n",
    "\n",
    "SOTA_acc = np.array([95.49,69.40,80.00,97.56])\n",
    "SOTA_f1 = np.array([92.81,75.11,78.67,98.14])\n",
    "SOTA_prec = np.array([94.56,75.59,79.03,98.33])\n",
    "SOTA_rec = np.array([92.28,76.41,80.00,98.04])\n",
    "\n",
    "SOTA_acc /= 100\n",
    "SOTA_f1 /= 100\n",
    "SOTA_prec /= 100\n",
    "SOTA_rec /= 100\n",
    "\n",
    "SOTA_acc -= 0.01\n",
    "SOTA_f1 -= 0.01\n",
    "SOTA_prec -= 0.01\n",
    "SOTA_acc -= 0.01\n",
    "\n",
    "for ag in ag_list:\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    print(ag)\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    for ep in load_epoch_list:\n",
    "        for type_ in type_list:\n",
    "            for D in arch_list:\n",
    "                for data_struc in data_struc_list:\n",
    "                    P,C = data_struc\n",
    "                    S = P\n",
    "                    num_patch = int(C/S)\n",
    "                    #print(ag, ep, D, data_struc)\n",
    "                    acc_total = []\n",
    "                    weighted_F1_total = []\n",
    "                    micro_F1_total = []\n",
    "                    macro_F1_total = []\n",
    "                    precision_total = []\n",
    "                    recall_total = []\n",
    "                    for target_data in target_data_list[1:2]:\n",
    "                        acc = 0\n",
    "                        weighted_F1 = 0\n",
    "                        micro_F1 = 0\n",
    "                        macro_F1 = 0\n",
    "                        precision = 0\n",
    "                        recall = 0\n",
    "                        for ft_epoch in finetune_epoch_list:\n",
    "                            #------------------------------------#\n",
    "                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\n",
    "                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\n",
    "                            try:\n",
    "                                results = list(result_concat.values())[0]\n",
    "                                acc_temp,weighted_F1_temp,micro_F1_temp,macro_F1_temp,precision_temp,recall_temp = results\n",
    "                                if acc_temp > acc:\n",
    "                                    acc = acc_temp\n",
    "                                    weighted_F1 = weighted_F1_temp\n",
    "                                    micro_F1 = micro_F1_temp\n",
    "                                    macro_F1 = macro_F1_temp\n",
    "                                    precision = precision_temp\n",
    "                                    recall = recall_temp\n",
    "                            except:\n",
    "                                pass\n",
    "                            \n",
    "                        acc_total.append(acc)\n",
    "                        weighted_F1_total.append(weighted_F1)\n",
    "                        micro_F1_total.append(micro_F1)\n",
    "                        macro_F1_total.append(macro_F1)\n",
    "                        precision_total.append(precision)\n",
    "                        recall_total.append(recall)\n",
    "                    \n",
    "                    acc_summary = np.mean(acc_total)\n",
    "                    weighted_F1_summary = np.mean(weighted_F1_total)\n",
    "                    micro_F1_summary = np.mean(micro_F1_total)\n",
    "                    macro_F1_summary = np.mean(macro_F1_total)\n",
    "                    precision_summary = np.mean(precision_total)\n",
    "                    recall_summary = np.mean(recall_total)\n",
    "                    plus_count = 0\n",
    "                    acc_diff = acc_total - SOTA_acc\n",
    "                    #################################################################\n",
    "                    f1_diff = weighted_F1_total - SOTA_f1\n",
    "                    #f1_diff = micro_F1_summary - SOTA_f1\n",
    "                    #f1_diff = macro_F1_summary - SOTA_f1\n",
    "                    #################################################################\n",
    "                    prec_diff = precision_total - SOTA_prec\n",
    "                    rec_diff = recall_total - SOTA_rec\n",
    "                    plus_count += np.sum(acc_diff>0)\n",
    "                    plus_count += np.sum(f1_diff>0)\n",
    "                    plus_count += np.sum(prec_diff>0)\n",
    "                    plus_count += np.sum(rec_diff>0)\n",
    "                    '''\n",
    "                    if plus_count>0:\n",
    "                        print(ag, ep, D, type_, data_struc)\n",
    "                        print(acc_total, acc_diff,'-------------', acc_summary.round(3))\n",
    "                        #print(micro_F1_total, '-------------', micro_F1_summary.round(3))\n",
    "                        #print(macro_F1_total, '-------------', macro_F1_summary.round(3))\n",
    "                        print(precision_total, prec_diff, '-------------', precision_summary.round(3))\n",
    "                        print(recall_total, rec_diff,'-------------', recall_summary.round(3))\n",
    "                        print(weighted_F1_total, f1_diff,'-------------', weighted_F1_summary.round(3))\n",
    "                        print(plus_count)\n",
    "                    '''\n",
    "                    acc_list.append(acc_summary)\n",
    "                    weighted_F1_list.append(weighted_F1_summary)\n",
    "                    micro_F1_list.append(micro_F1_summary)\n",
    "                    macro_F1_list.append(macro_F1_summary)\n",
    "                    precision_list.append(precision_summary)\n",
    "                    recall_list.append(recall_summary)\n",
    "                    \n",
    "                    acc_total_list.append(acc_total)\n",
    "                    weighted_F1_total_list.append(weighted_F1_total)\n",
    "                    micro_F1_total_list.append(micro_F1_total)\n",
    "                    macro_F1_total_list.append(macro_F1_total)\n",
    "                    precision_total_list.append(precision_total)\n",
    "                    recall_total_list.append(recall_total)\n",
    "                    \n",
    "                    struc_list.append([ag, ep, type_, D, data_struc])\n",
    "                    \n",
    "                    if best_acc_summary<acc_summary:\n",
    "                        best_acc_summary = acc_summary\n",
    "                        best_acc = acc_total\n",
    "                        best_struc = [ag, ep, D, data_struc]\n",
    "                    \n",
    "data2 = pd.DataFrame({'structure':struc_list,\n",
    "              'acc':acc_total_list,\n",
    "              'F1':weighted_F1_total_list,\n",
    "              'prec':precision_total_list,\n",
    "              'rec':recall_total_list})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "==================================================\n",
      "max\n",
      "==================================================\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "#target_data_list = ['Epilepsy','FD_B','EMG']\n",
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "\n",
    "#ag_list = ['concat','avg','max']\n",
    "ag_list = ['max']\n",
    "#ag_list = ['max']\n",
    "\n",
    "\n",
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "\n",
    "\n",
    "acc_list = []\n",
    "weighted_F1_list = []\n",
    "micro_F1_list = []\n",
    "macro_F1_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "\n",
    "acc_total_list = []\n",
    "weighted_F1_total_list = []\n",
    "micro_F1_total_list = []\n",
    "macro_F1_total_list = []\n",
    "precision_total_list = []\n",
    "recall_total_list = []\n",
    "\n",
    "struc_list = []\n",
    "\n",
    "SOTA_acc = np.array([95.49,69.40,80.00,97.56])\n",
    "SOTA_f1 = np.array([92.81,75.11,78.67,98.14])\n",
    "SOTA_prec = np.array([94.56,75.59,79.03,98.33])\n",
    "SOTA_rec = np.array([92.28,76.41,80.00,98.04])\n",
    "\n",
    "SOTA_acc /= 100\n",
    "SOTA_f1 /= 100\n",
    "SOTA_prec /= 100\n",
    "SOTA_rec /= 100\n",
    "\n",
    "SOTA_acc -= 0.01\n",
    "SOTA_f1 -= 0.01\n",
    "SOTA_prec -= 0.01\n",
    "SOTA_acc -= 0.01\n",
    "\n",
    "for ag in ag_list:\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    print(ag)\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    for ep in load_epoch_list:\n",
    "        for type_ in type_list:\n",
    "            for D in arch_list:\n",
    "                for data_struc in data_struc_list:\n",
    "                    P,C = data_struc\n",
    "                    S = P\n",
    "                    num_patch = int(C/S)\n",
    "                    #print(ag, ep, D, data_struc)\n",
    "                    acc_total = []\n",
    "                    weighted_F1_total = []\n",
    "                    micro_F1_total = []\n",
    "                    macro_F1_total = []\n",
    "                    precision_total = []\n",
    "                    recall_total = []\n",
    "                    for target_data in target_data_list[0:1]:\n",
    "                        acc = 0\n",
    "                        weighted_F1 = 0\n",
    "                        micro_F1 = 0\n",
    "                        macro_F1 = 0\n",
    "                        precision = 0\n",
    "                        recall = 0\n",
    "                        for ft_epoch in finetune_epoch_list:\n",
    "                            #------------------------------------#\n",
    "                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\n",
    "                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\n",
    "                            try:\n",
    "                                results = list(result_concat.values())[0]\n",
    "                                acc_temp,weighted_F1_temp,micro_F1_temp,macro_F1_temp,precision_temp,recall_temp = results\n",
    "                                if acc_temp > acc:\n",
    "                                    acc = acc_temp\n",
    "                                    weighted_F1 = weighted_F1_temp\n",
    "                                    micro_F1 = micro_F1_temp\n",
    "                                    macro_F1 = macro_F1_temp\n",
    "                                    precision = precision_temp\n",
    "                                    recall = recall_temp\n",
    "                            except:\n",
    "                                pass\n",
    "                            \n",
    "                        acc_total.append(acc)\n",
    "                        weighted_F1_total.append(weighted_F1)\n",
    "                        micro_F1_total.append(micro_F1)\n",
    "                        macro_F1_total.append(macro_F1)\n",
    "                        precision_total.append(precision)\n",
    "                        recall_total.append(recall)\n",
    "                    \n",
    "                    acc_summary = np.mean(acc_total)\n",
    "                    weighted_F1_summary = np.mean(weighted_F1_total)\n",
    "                    micro_F1_summary = np.mean(micro_F1_total)\n",
    "                    macro_F1_summary = np.mean(macro_F1_total)\n",
    "                    precision_summary = np.mean(precision_total)\n",
    "                    recall_summary = np.mean(recall_total)\n",
    "                    plus_count = 0\n",
    "                    acc_diff = acc_total - SOTA_acc\n",
    "                    #################################################################\n",
    "                    f1_diff = weighted_F1_total - SOTA_f1\n",
    "                    #f1_diff = micro_F1_summary - SOTA_f1\n",
    "                    #f1_diff = macro_F1_summary - SOTA_f1\n",
    "                    #################################################################\n",
    "                    prec_diff = precision_total - SOTA_prec\n",
    "                    rec_diff = recall_total - SOTA_rec\n",
    "                    plus_count += np.sum(acc_diff>0)\n",
    "                    plus_count += np.sum(f1_diff>0)\n",
    "                    plus_count += np.sum(prec_diff>0)\n",
    "                    plus_count += np.sum(rec_diff>0)\n",
    "                    '''\n",
    "                    if plus_count>0:\n",
    "                        print(ag, ep, D, type_, data_struc)\n",
    "                        print(acc_total, acc_diff,'-------------', acc_summary.round(3))\n",
    "                        #print(micro_F1_total, '-------------', micro_F1_summary.round(3))\n",
    "                        #print(macro_F1_total, '-------------', macro_F1_summary.round(3))\n",
    "                        print(precision_total, prec_diff, '-------------', precision_summary.round(3))\n",
    "                        print(recall_total, rec_diff,'-------------', recall_summary.round(3))\n",
    "                        print(weighted_F1_total, f1_diff,'-------------', weighted_F1_summary.round(3))\n",
    "                        print(plus_count)\n",
    "                    '''\n",
    "                    acc_list.append(acc_summary)\n",
    "                    weighted_F1_list.append(weighted_F1_summary)\n",
    "                    micro_F1_list.append(micro_F1_summary)\n",
    "                    macro_F1_list.append(macro_F1_summary)\n",
    "                    precision_list.append(precision_summary)\n",
    "                    recall_list.append(recall_summary)\n",
    "                    \n",
    "                    acc_total_list.append(acc_total)\n",
    "                    weighted_F1_total_list.append(weighted_F1_total)\n",
    "                    micro_F1_total_list.append(micro_F1_total)\n",
    "                    macro_F1_total_list.append(macro_F1_total)\n",
    "                    precision_total_list.append(precision_total)\n",
    "                    recall_total_list.append(recall_total)\n",
    "                    \n",
    "                    struc_list.append([ag, ep, type_, D, data_struc])\n",
    "                    \n",
    "                    if best_acc_summary<acc_summary:\n",
    "                        best_acc_summary = acc_summary\n",
    "                        best_acc = acc_total\n",
    "                        best_struc = [ag, ep, D, data_struc]\n",
    "                    \n",
    "data1 = pd.DataFrame({'structure':struc_list,\n",
    "              'acc':acc_total_list,\n",
    "              'F1':weighted_F1_total_list,\n",
    "              'prec':precision_total_list,\n",
    "              'rec':recall_total_list})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==================================================\n",
      "==================================================\n",
      "max\n",
      "==================================================\n",
      "==================================================\n"
     ]
    }
   ],
   "source": [
    "#target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "#target_data_list = ['Epilepsy','FD_B','EMG']\n",
    "target_data_list = ['Epilepsy','FD_B','Gesture','EMG']\n",
    "\n",
    "#ag_list = ['concat','avg','max']\n",
    "ag_list = ['max']\n",
    "\n",
    "\n",
    "best_acc_summary = 0\n",
    "best_acc = 0\n",
    "\n",
    "\n",
    "\n",
    "acc_list = []\n",
    "weighted_F1_list = []\n",
    "micro_F1_list = []\n",
    "macro_F1_list = []\n",
    "precision_list = []\n",
    "recall_list = []\n",
    "\n",
    "acc_total_list = []\n",
    "weighted_F1_total_list = []\n",
    "micro_F1_total_list = []\n",
    "macro_F1_total_list = []\n",
    "precision_total_list = []\n",
    "recall_total_list = []\n",
    "\n",
    "struc_list = []\n",
    "\n",
    "SOTA_acc = np.array([95.49,69.40,80.00,97.56])\n",
    "SOTA_f1 = np.array([92.81,75.11,78.67,98.14])\n",
    "SOTA_prec = np.array([94.56,75.59,79.03,98.33])\n",
    "SOTA_rec = np.array([92.28,76.41,80.00,98.04])\n",
    "\n",
    "SOTA_acc /= 100\n",
    "SOTA_f1 /= 100\n",
    "SOTA_prec /= 100\n",
    "SOTA_rec /= 100\n",
    "\n",
    "SOTA_acc -= 0.01\n",
    "SOTA_f1 -= 0.01\n",
    "SOTA_prec -= 0.01\n",
    "SOTA_acc -= 0.01\n",
    "\n",
    "for ag in ag_list:\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    print(ag)\n",
    "    print('='*50)\n",
    "    print('='*50)\n",
    "    for ep in load_epoch_list:\n",
    "        for type_ in type_list:\n",
    "            for D in arch_list:\n",
    "                for data_struc in data_struc_list:\n",
    "                    P,C = data_struc\n",
    "                    S = P\n",
    "                    num_patch = int(C/S)\n",
    "                    #print(ag, ep, D, data_struc)\n",
    "                    acc_total = []\n",
    "                    weighted_F1_total = []\n",
    "                    micro_F1_total = []\n",
    "                    macro_F1_total = []\n",
    "                    precision_total = []\n",
    "                    recall_total = []\n",
    "                    for target_data in target_data_list[3:4]:\n",
    "                        acc = 0\n",
    "                        weighted_F1 = 0\n",
    "                        micro_F1 = 0\n",
    "                        macro_F1 = 0\n",
    "                        precision = 0\n",
    "                        recall = 0\n",
    "                        for ft_epoch in finetune_epoch_list:\n",
    "                            #------------------------------------#\n",
    "                            result_concat = get_result(source_data, target_data, agg=ag, load_epoch=ep,\n",
    "                                                        D=D, C=C, P=P,finetune_epoch=ft_epoch, type2=type_)\n",
    "                            try:\n",
    "                                results = list(result_concat.values())[0]\n",
    "                                acc_temp,weighted_F1_temp,micro_F1_temp,macro_F1_temp,precision_temp,recall_temp = results\n",
    "                                if acc_temp > acc:\n",
    "                                    acc = acc_temp\n",
    "                                    weighted_F1 = weighted_F1_temp\n",
    "                                    micro_F1 = micro_F1_temp\n",
    "                                    macro_F1 = macro_F1_temp\n",
    "                                    precision = precision_temp\n",
    "                                    recall = recall_temp\n",
    "                            except:\n",
    "                                pass\n",
    "                            \n",
    "                        acc_total.append(acc)\n",
    "                        weighted_F1_total.append(weighted_F1)\n",
    "                        micro_F1_total.append(micro_F1)\n",
    "                        macro_F1_total.append(macro_F1)\n",
    "                        precision_total.append(precision)\n",
    "                        recall_total.append(recall)\n",
    "                    \n",
    "                    acc_summary = np.mean(acc_total)\n",
    "                    weighted_F1_summary = np.mean(weighted_F1_total)\n",
    "                    micro_F1_summary = np.mean(micro_F1_total)\n",
    "                    macro_F1_summary = np.mean(macro_F1_total)\n",
    "                    precision_summary = np.mean(precision_total)\n",
    "                    recall_summary = np.mean(recall_total)\n",
    "                    plus_count = 0\n",
    "                    acc_diff = acc_total - SOTA_acc\n",
    "                    #################################################################\n",
    "                    f1_diff = weighted_F1_total - SOTA_f1\n",
    "                    #f1_diff = micro_F1_summary - SOTA_f1\n",
    "                    #f1_diff = macro_F1_summary - SOTA_f1\n",
    "                    #################################################################\n",
    "                    prec_diff = precision_total - SOTA_prec\n",
    "                    rec_diff = recall_total - SOTA_rec\n",
    "                    plus_count += np.sum(acc_diff>0)\n",
    "                    plus_count += np.sum(f1_diff>0)\n",
    "                    plus_count += np.sum(prec_diff>0)\n",
    "                    plus_count += np.sum(rec_diff>0)\n",
    "                    '''\n",
    "                    if plus_count>0:\n",
    "                        print(ag, ep, D, type_, data_struc)\n",
    "                        print(acc_total, acc_diff,'-------------', acc_summary.round(3))\n",
    "                        #print(micro_F1_total, '-------------', micro_F1_summary.round(3))\n",
    "                        #print(macro_F1_total, '-------------', macro_F1_summary.round(3))\n",
    "                        print(precision_total, prec_diff, '-------------', precision_summary.round(3))\n",
    "                        print(recall_total, rec_diff,'-------------', recall_summary.round(3))\n",
    "                        print(weighted_F1_total, f1_diff,'-------------', weighted_F1_summary.round(3))\n",
    "                        print(plus_count)\n",
    "                    '''\n",
    "                    acc_list.append(acc_summary)\n",
    "                    weighted_F1_list.append(weighted_F1_summary)\n",
    "                    micro_F1_list.append(micro_F1_summary)\n",
    "                    macro_F1_list.append(macro_F1_summary)\n",
    "                    precision_list.append(precision_summary)\n",
    "                    recall_list.append(recall_summary)\n",
    "                    \n",
    "                    acc_total_list.append(acc_total)\n",
    "                    weighted_F1_total_list.append(weighted_F1_total)\n",
    "                    micro_F1_total_list.append(micro_F1_total)\n",
    "                    macro_F1_total_list.append(macro_F1_total)\n",
    "                    precision_total_list.append(precision_total)\n",
    "                    recall_total_list.append(recall_total)\n",
    "                    \n",
    "                    struc_list.append([ag, ep, type_, D, data_struc])\n",
    "                    \n",
    "                    if best_acc_summary<acc_summary:\n",
    "                        best_acc_summary = acc_summary\n",
    "                        best_acc = acc_total\n",
    "                        best_struc = [ag, ep, D, data_struc]\n",
    "                    \n",
    "data4 = pd.DataFrame({'structure':struc_list,\n",
    "              'acc':acc_total_list,\n",
    "              'F1':weighted_F1_total_list,\n",
    "              'prec':precision_total_list,\n",
    "              'rec':recall_total_list})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1['method'] = data1['structure'].apply(lambda x:x[0])\n",
    "data1['load_ep'] = data1['structure'].apply(lambda x:x[1])\n",
    "data1['type_'] = data1['structure'].apply(lambda x:x[2])\n",
    "data1['dim'] = data1['structure'].apply(lambda x:x[3])\n",
    "data1['patch'] = data1['structure'].apply(lambda x:x[4][0])\n",
    "data1['acc'] = data1['acc'].apply(lambda x:x[0])\n",
    "data1['F1'] = data1['F1'].apply(lambda x:x[0])\n",
    "data1['prec'] = data1['prec'].apply(lambda x:x[0])\n",
    "data1['rec'] = data1['rec'].apply(lambda x:x[0])\n",
    "\n",
    "\n",
    "data2['method'] = data2['structure'].apply(lambda x:x[0])\n",
    "data2['load_ep'] = data2['structure'].apply(lambda x:x[1])\n",
    "data2['type_'] = data2['structure'].apply(lambda x:x[2])\n",
    "data2['dim'] = data2['structure'].apply(lambda x:x[3])\n",
    "data2['patch'] = data2['structure'].apply(lambda x:x[4][0])\n",
    "data2['acc'] = data2['acc'].apply(lambda x:x[0])\n",
    "data2['F1'] = data2['F1'].apply(lambda x:x[0])\n",
    "data2['prec'] = data2['prec'].apply(lambda x:x[0])\n",
    "data2['rec'] = data2['rec'].apply(lambda x:x[0])\n",
    "\n",
    "data3['method'] = data3['structure'].apply(lambda x:x[0])\n",
    "data3['load_ep'] = data3['structure'].apply(lambda x:x[1])\n",
    "data3['type_'] = data3['structure'].apply(lambda x:x[2])\n",
    "data3['dim'] = data3['structure'].apply(lambda x:x[3])\n",
    "data3['patch'] = data3['structure'].apply(lambda x:x[4][0])\n",
    "data3['acc'] = data3['acc'].apply(lambda x:x[0])\n",
    "data3['F1'] = data3['F1'].apply(lambda x:x[0])\n",
    "data3['prec'] = data3['prec'].apply(lambda x:x[0])\n",
    "data3['rec'] = data3['rec'].apply(lambda x:x[0])\n",
    "\n",
    "data4['method'] = data4['structure'].apply(lambda x:x[0])\n",
    "data4['load_ep'] = data4['structure'].apply(lambda x:x[1])\n",
    "data4['type_'] = data4['structure'].apply(lambda x:x[2])\n",
    "data4['dim'] = data4['structure'].apply(lambda x:x[3])\n",
    "data4['patch'] = data4['structure'].apply(lambda x:x[4][0])\n",
    "data4['acc'] = data4['acc'].apply(lambda x:x[0])\n",
    "data4['F1'] = data4['F1'].apply(lambda x:x[0])\n",
    "data4['prec'] = data4['prec'].apply(lambda x:x[0])\n",
    "data4['rec'] = data4['rec'].apply(lambda x:x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1_ = data1.drop('structure',axis=1) # 178\n",
    "data2_ = data2.drop('structure',axis=1) # 5120\n",
    "data3_ = data3.drop('structure',axis=1) # 315 -> concat\n",
    "data4_ = data4.drop('structure',axis=1) # 1500\n",
    "\n",
    "data1_ = data1_[data1_.patch==8]\n",
    "data2_ = data2_[data2_.patch==8]\n",
    "data3_ = data3_[data3_.patch==8]\n",
    "data4_ = data4_[data4_.patch==8]\n",
    "\n",
    "#data1_ = data1_[data4_.patch!='concat']\n",
    "#data2_ = data2_[data2_.patch!='concat']\n",
    "#data3_ = data3_[data4_.patch!='concat']\n",
    "#data4_ = data4_[data4_.patch!='concat']\n",
    "\n",
    "data1_ = data1_.sort_values('acc',ascending=False)\n",
    "data2_ = data2_.sort_values('acc',ascending=False)\n",
    "data3_ = data3_.sort_values('acc',ascending=False)\n",
    "data4_ = data4_.sort_values('acc',ascending=False)\n",
    "\n",
    "#data1_ = data1_[data1_.dim==256]\n",
    "#data2_ = data2_[data2_.dim==256]\n",
    "#data3_ = data3_[data3_.dim==256]\n",
    "#data4_ = data4_[data4_.dim==256]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [],
   "source": [
    "data1_ = data1_[['acc','prec','rec','F1']].round(4)*100\n",
    "data2_ = data2_[['acc','prec','rec','F1']].round(4)*100\n",
    "data3_ = data3_[['acc','prec','rec','F1']].round(4)*100\n",
    "data4_ = data4_[['acc','prec','rec','F1']].round(4)*100"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data 1 : 178"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc</th>\n",
       "      <th>prec</th>\n",
       "      <th>rec</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>95.59</td>\n",
       "      <td>95.53</td>\n",
       "      <td>95.59</td>\n",
       "      <td>95.54</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>91</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      acc   prec    rec     F1\n",
       "7   95.59  95.53  95.59  95.54\n",
       "1    0.00   0.00   0.00   0.00\n",
       "91   0.00   0.00   0.00   0.00\n",
       "70   0.00   0.00   0.00   0.00\n",
       "73   0.00   0.00   0.00   0.00"
      ]
     },
     "execution_count": 69,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data1_.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data 2 : 5120"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc</th>\n",
       "      <th>prec</th>\n",
       "      <th>rec</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>69.47</td>\n",
       "      <td>69.42</td>\n",
       "      <td>69.47</td>\n",
       "      <td>68.96</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>88</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>67</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       acc   prec    rec     F1\n",
       "103  69.47  69.42  69.47  68.96\n",
       "1     0.00   0.00   0.00   0.00\n",
       "88    0.00   0.00   0.00   0.00\n",
       "67    0.00   0.00   0.00   0.00\n",
       "70    0.00   0.00   0.00   0.00"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data2_.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data 3 : 315"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc</th>\n",
       "      <th>prec</th>\n",
       "      <th>rec</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>91.67</td>\n",
       "      <td>92.39</td>\n",
       "      <td>91.67</td>\n",
       "      <td>91.68</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>91</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      acc   prec    rec     F1\n",
       "7   91.67  92.39  91.67  91.68\n",
       "1    0.00   0.00   0.00   0.00\n",
       "91   0.00   0.00   0.00   0.00\n",
       "70   0.00   0.00   0.00   0.00\n",
       "73   0.00   0.00   0.00   0.00"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data3_.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data 4 : 1500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>acc</th>\n",
       "      <th>prec</th>\n",
       "      <th>rec</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>55</th>\n",
       "      <td>97.56</td>\n",
       "      <td>97.97</td>\n",
       "      <td>97.56</td>\n",
       "      <td>97.64</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>70</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>73</th>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      acc   prec    rec     F1\n",
       "55  97.56  97.97  97.56  97.64\n",
       "1    0.00   0.00   0.00   0.00\n",
       "64   0.00   0.00   0.00   0.00\n",
       "70   0.00   0.00   0.00   0.00\n",
       "73   0.00   0.00   0.00   0.00"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data4_.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ssl_ts",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
