{
 "cells": [
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-05-19T05:05:05.027895Z",
     "start_time": "2025-05-19T05:04:59.716963Z"
    }
   },
   "cell_type": "code",
   "source": [
    "import os\n",
    "# Please replace it with your actual catalog\n",
    "os.chdir('your path')\n",
    "\n",
    "from train.utils import *\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from torch.utils.data import DataLoader\n",
    "from train.GACET import GACET\n",
    "from train.trainer import Trainer"
   ],
   "id": "338ce8497b111312",
   "outputs": [],
   "execution_count": 1
  },
  {
   "metadata": {
    "collapsed": true,
    "ExecuteTime": {
     "end_time": "2025-05-19T05:23:04.414098Z",
     "start_time": "2025-05-19T05:16:43.363526Z"
    }
   },
   "cell_type": "code",
   "source": [
    "set_seed(42)\n",
    "g = torch.Generator().manual_seed(42)\n",
    "data_SampEn = Path('./data/SampEn/sub-01').resolve()\n",
    "data_DE = Path('./data/DE/sub-01').resolve()\n",
    "task_order = ['MATB_level0.pkl', 'MATB_level1.pkl', 'MATB_level2.pkl', 'MATB_level3.pkl', 'MATB_level4.pkl']\n",
    "acc_list = []\n",
    "day_permutations = [\n",
    "\t([1], [2], [3]),\n",
    "\t([1], [3], [2]),\n",
    "\t([2], [3], [1])\n",
    "]\n",
    "for day1, day2, day_test in day_permutations:\n",
    "\tprint(f'training on {day1}, {day2} and testing on {day_test}')\n",
    "\tdataset_train_1 = DualSourceDataset(\n",
    "\t\tdata_SampEn, data_DE,\n",
    "\t\tdays=day1,\n",
    "\t\ttask_order=task_order\n",
    "\t)\n",
    "\tdataset_train_2 = DualSourceDataset(\n",
    "\t\tdata_SampEn, data_DE,\n",
    "\t\tdays=day2,\n",
    "\t\ttask_order=task_order\n",
    "\t)\n",
    "\tdataset_test = DualSourceDataset(\n",
    "\t\tdata_SampEn, data_DE,\n",
    "\t\tdays=day_test,\n",
    "\t\ttask_order=task_order\n",
    "\t)\n",
    "\tlen_train_1, len_train_2 = len(dataset_train_1), len(dataset_train_2)\n",
    "\tif len_train_1 != len_train_2:\n",
    "\t\tmin_len = min(len_train_1, len_train_2)\n",
    "\t\tif len_train_1 > min_len:\n",
    "\t\t\tdataset_train_1.trim_to_length(min_len)\n",
    "\t\tif len_train_2 > min_len:\n",
    "\t\t\tdataset_train_2.trim_to_length(min_len)\n",
    "\n",
    "\tkf = StratifiedKFold(n_splits=5, shuffle=False)\n",
    "\tfor fold, (train_idx, val_idx) in enumerate(kf.split(dataset_train_1.data[0], dataset_train_1.labels)):\n",
    "\t\tprint(f'fold {fold+1} start')\n",
    "\t\tsplitter = DualSourceDataSplitter(dataset_train_1, dataset_train_2, train_idx, val_idx)\n",
    "\t\tcombined_train = splitter.train_dataset\n",
    "\t\tcombined_val = splitter.val_dataset\n",
    "\n",
    "\t\tstandardized_train = StandardizedDataset(combined_train, is_train=True)\n",
    "\t\tmean, std = standardized_train.get_mean_std()\n",
    "\t\tstandardized_val = StandardizedDataset(combined_val, is_train=False, mean=mean, std=std)\n",
    "\t\tstandardized_test = StandardizedDataset(dataset_test, is_train=False, mean=mean, std=std)\n",
    "\n",
    "\t\ttrain_loader = DataLoader(standardized_train, batch_size=32, shuffle=True,generator=g)\n",
    "\t\tval_loader = DataLoader(standardized_val, batch_size=32, shuffle=False,generator=g)\n",
    "\t\ttest_loader = DataLoader(standardized_test, batch_size=32, shuffle=False,generator=g)\n",
    "\n",
    "\t\tdevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\t\tmodel = GACET(num_classes=len(task_order), embed_dim=300)\n",
    "\t\tmodel.to(device)\n",
    "\t\toptimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)\n",
    "\t\tcriterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "\t\ttrainer = Trainer(model, train_loader, val_loader, test_loader, criterion, optimizer, device)\n",
    "\t\tacc = trainer.train()\n",
    "\t\tacc_list.append(acc)\n",
    "\n",
    "print((f\"acc: {np.mean(acc_list) * 100:.2f}%\"))"
   ],
   "id": "initial_id",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training on [1], [2] and testing on [3]\n",
      "fold 1 start\n",
      "The Best Validation Accuracy: 60.71%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 33.81%\n",
      "fold 2 start\n",
      "The Best Validation Accuracy: 65.48%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 39.05%\n",
      "fold 3 start\n",
      "The Best Validation Accuracy: 60.71%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 33.81%\n",
      "fold 4 start\n",
      "The Best Validation Accuracy: 60.71%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 41.43%\n",
      "fold 5 start\n",
      "The Best Validation Accuracy: 61.90%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 39.05%\n",
      "training on [1], [3] and testing on [2]\n",
      "fold 1 start\n",
      "The Best Validation Accuracy: 64.29%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 43.81%\n",
      "fold 2 start\n",
      "The Best Validation Accuracy: 63.10%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 40.48%\n",
      "fold 3 start\n",
      "The Best Validation Accuracy: 63.10%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 45.71%\n",
      "fold 4 start\n",
      "The Best Validation Accuracy: 65.48%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 41.43%\n",
      "fold 5 start\n",
      "The Best Validation Accuracy: 58.33%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 50.48%\n",
      "training on [2], [3] and testing on [1]\n",
      "fold 1 start\n",
      "The Best Validation Accuracy: 66.67%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 50.95%\n",
      "fold 2 start\n",
      "The Best Validation Accuracy: 54.76%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 50.95%\n",
      "fold 3 start\n",
      "The Best Validation Accuracy: 65.48%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 55.24%\n",
      "fold 4 start\n",
      "The Best Validation Accuracy: 63.10%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 50.95%\n",
      "fold 5 start\n",
      "The Best Validation Accuracy: 51.19%\n",
      "Loaded best model state from training.\n",
      "Test Accuracy: 49.05%\n",
      "acc: 44.41%\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "The results above present the predictions for Subject 1 using DE and Sampen. The outcomes of three rounds of five-fold cross-validation are consistent with the Subject_1 Performance results presented in Section F2.4: Dataset 2 (5-class) on page 37. The average value of 44.41% also matches the result reported in Table 21 on page 26.\n",
   "id": "b444711ad5704929"
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
