{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run Experiments\n",
    "\n",
    "This notebook can be used to automatically construct the configs for the different experiments and submit slurm jobs to run them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from core.utils import create_trial_configs, submit_slurm_job"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Config Creation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Baseline Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TUAB Baseline\n",
    "configs = create_trial_configs(\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"TUAB\",\n",
    "    task=\"normality\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    n_participants=[25, 50, 100, 200, 400, 800, 1600],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320],\n",
    "    seed=[42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],\n",
    "    experiment_name=\"baseline\"\n",
    ")\n",
    "\n",
    "# CAUEEG Baseline\n",
    "configs += create_trial_configs(\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"CAUEEG\",\n",
    "    task=\"dementia\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80],\n",
    "    seed=[42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],\n",
    "    experiment_name=\"baseline\"\n",
    ")\n",
    "\n",
    "# PhysioNet Baseline\n",
    "configs += create_trial_configs(\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320, 640],\n",
    "    seed=[42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Augmentation Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Augmentations on TUAB\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_AmplitudeScaling\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"TUAB\",\n",
    "    task=\"normality\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"AmplitudeScaling\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800, 1600],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_FrequencyShift\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"TUAB\",\n",
    "    task=\"normality\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"FrequencyShift\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800, 1600],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_PhaseRandomization\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"TUAB\",\n",
    "    task=\"normality\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"PhaseRandomization\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800, 1600],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "# Augmentations on CAUEEG\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_AmplitudeScaling\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"CAUEEG\",\n",
    "    task=\"dementia\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"AmplitudeScaling\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_FrequencyShift\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"CAUEEG\",\n",
    "    task=\"dementia\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"FrequencyShift\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_PhaseRandomization\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"CAUEEG\",\n",
    "    task=\"dementia\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"PhaseRandomization\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "# Augmentations on PhysioNet\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_AmplitudeScaling\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"AmplitudeScaling\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320, 640],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_FrequencyShift\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"FrequencyShift\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320, 640],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")\n",
    "\n",
    "configs += create_trial_configs(\n",
    "    experiment_name=\"Augmentation_PhaseRandomization\",\n",
    "    model=[\"TCN\", \"mAtt\", \"LaBraM\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    augmentation=\"PhaseRandomization\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320, 640],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Pretraining Experiments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TUAB Pretraining Experiments\n",
    "configs += create_trial_configs(\n",
    "    model=[\"LaBraM\"],\n",
    "    dataset=\"TUAB\",\n",
    "    task=\"normality\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    checkpoint=\"labram-base.pth\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800, 1600],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320],\n",
    "    seed=[42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],\n",
    "    experiment_name=\"pretrained\",\n",
    ")\n",
    "\n",
    "# CAUEEG Pretraining Experiments\n",
    "configs += create_trial_configs(\n",
    "    model=[\"LaBraM\"],\n",
    "    dataset=\"CAUEEG\",\n",
    "    task=\"dementia\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    checkpoint=\"labram-base.pth\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80],\n",
    "    seed=[42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],\n",
    "    experiment_name=\"pretrained\",\n",
    ")\n",
    "\n",
    "# PhysioNet Pretraining Experiments\n",
    "configs += create_trial_configs(\n",
    "    model=[\"LaBraM\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    checkpoint=\"labram-base.pth\",\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[5, 10, 20, 40, 80, 160, 320, 640],\n",
    "    seed=[42, 43, 44, 45, 46],\n",
    "    experiment_name=\"pretrained\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ablation: TCN trained with LaBraM-compatible Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "configs += create_trial_configs(\n",
    "    model=[\"TCN\"],\n",
    "    dataset=\"PhysioNet\",\n",
    "    task=\"sleep_stage\",\n",
    "    batch_size=64,\n",
    "    max_batches=50000,\n",
    "    evaluation_interval=500,\n",
    "    early_stopping_patience=5,\n",
    "    n_participants=[25, 50, 100, 200, 400, 800],\n",
    "    n_segments=[40],\n",
    "    seed=[42],\n",
    "    experiment_name=\"LaBraM-Preproc\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Job Submission"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following code block can be used to submit individual jobs for each configuration to a Slurm cluster.\n",
    "\n",
    "Customize the setup and cleanup commands if you need to manage scratch space or set environment variables on the compute nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "use_wandb = False  # Set to True to use Weights & Biases for logging\n",
    "custom_output_root = None  # If None, the output_root set in user_config.yml will be used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for cfg in configs:\n",
    "    cmd_args = \" \".join([f\"--{key} {value}\" for key, value in cfg.items()])\n",
    "    \n",
    "    if custom_output_root:\n",
    "        cmd_args += f\" --output_root {custom_output_root}\"\n",
    "    \n",
    "    if use_wandb:\n",
    "        cmd_args += \" --use_wandb\"\n",
    "\n",
    "    command = f\"python core/run_trial.py {cmd_args}\"\n",
    "\n",
    "    # Feel free to add setup or cleanup commands that should be run on the compute node\n",
    "    # before and after the main command. This can be useful, e.g. to manage scratch space\n",
    "    # or set up environment variables.\n",
    "    setup_command = \"\"\n",
    "    if setup_command:\n",
    "        command = setup_command + \"; \" + command\n",
    "\n",
    "    cleanup_command = \"\"\n",
    "    if cleanup_command:\n",
    "        command = command + \"; \" + cleanup_command\n",
    "\n",
    "    job_name = f\"{cfg['experiment_name']}-{cfg['dataset']}-{cfg['model']}-{cfg['task']}-{cfg['n_participants']}p-{cfg['n_segments']}s-{cfg['seed']}\"\n",
    "\n",
    "    submit_slurm_job(\n",
    "        command=command,\n",
    "        job_name=job_name,\n",
    "        log_dir=\"logs\",\n",
    "        cpus=2,\n",
    "        gpus=1,\n",
    "        time_limit=\"24:00:00\",\n",
    "    )"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "participant-diversity-paper",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
