{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f636a13d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright authors of TSPulse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ccc6220-6a78-4aa1-b3ac-737223dbe987",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import math\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, ConcatDataset\n",
    "import numpy as np\n",
    "import tempfile\n",
    "from types import SimpleNamespace\n",
    "\n",
    "from torch.optim import AdamW\n",
    "from torch.optim.lr_scheduler import OneCycleLR\n",
    "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
    "from transformers.trainer_utils import RemoveColumnsCollator\n",
    "from transformers.data.data_collator import default_data_collator\n",
    "\n",
    "import warnings\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ee950055-4527-40eb-8fe2-ce023003f2b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sys.path.append(\"..\")\n",
    "from models.tspulse import TSPulseForClassificationOrRegression\n",
    "from classification.data import get_uea_classification_data\n",
    "from classification.utils import optimal_lr_finder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dc3eedad-946d-41fa-b2c2-838307091d7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 42\n",
    "set_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "873aecaf-b329-4bc8-a72b-40989bd9a6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "OUT_DIR = \"tspulse_finetuned_models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "99640a13-a249-4dca-8ab0-aedf23d326b1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "json_file = \"config.json\"\n",
    "with open(json_file, \"r\") as file:\n",
    "    clf_params = json.load(file)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3adfcd30-f9c9-406a-b058-2745bb2cd38a",
   "metadata": {},
   "source": [
    "## UEA Multivariate Datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "938db5cc-a761-4eb7-9042-e42aa8075a5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "dset = \"ArticularyWordRecognition\"\n",
    "\n",
    "ds = clf_params[dset][\"DATA_PARAMS\"]\n",
    "ds[\"context_points\"] = 512\n",
    "ds[\"dset\"] = dset\n",
    "ds[\"data_path\"] = \"datasets/\"\n",
    "\n",
    "args = SimpleNamespace(**ds)\n",
    "\n",
    "output = get_uea_classification_data(args)\n",
    "train_dataset = output[\"dset_train\"]\n",
    "valid_dataset = output[\"dset_valid\"]\n",
    "test_dataset = output[\"dset_test\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b0976700-08fc-42c6-8dab-c1134c0fcc21",
   "metadata": {},
   "source": [
    "### Configs for the TSPulse model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07694897-1fa7-46d5-8b01-71af52a5a64d",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_dict = clf_params[dset][\"MODEL_PARAMS\"]\n",
    "config_dict[\"loss\"] = \"cross_entropy\"\n",
    "config_dict[\"ignore_mismatched_sizes\"] = True\n",
    "\n",
    "\n",
    "config_dict[\"num_input_channels\"] = output[\"num_input_channels\"]\n",
    "config_dict[\"num_targets\"] = output[\"num_targets\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0f48a3aa-10bb-476d-a0bf-c20b47f286d7",
   "metadata": {},
   "source": [
    "## Getting the Pretrained Model with above configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e2f7a221-91b7-4bfb-b397-ffc441746786",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of TSPulseForClassificationOrRegression were not initialized from the model checkpoint at ../../model-binaries/tspulse_classification/tspulse_model and are newly initialized: ['decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc1.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc1.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc2.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.mlp.fc2.weight', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.norm.norm.bias', 'decoder_with_head.decoder.decoder_block.mixers.0.channel_feature_mixer.norm.norm.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.gating_block.attn_layer.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.gating_block.attn_layer.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc1.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc1.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc2.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.mlp.fc2.weight', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.norm.norm.bias', 'decoder_with_head.decoder.decoder_block.mixers.1.channel_feature_mixer.norm.norm.weight', 'decoder_with_head.head.head_norm.norm.bias', 'decoder_with_head.head.head_norm.norm.weight', 'decoder_with_head.head.loc_scale_norm.bias', 'decoder_with_head.head.loc_scale_norm.weight', 'decoder_with_head.head.projection.bias', 'decoder_with_head.head.projection.weight', 'decoder_with_head.head.reduce_proj.bias', 'decoder_with_head.head.reduce_proj.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Identity Init in Module:  TSPulseChannelFeatureMixerBlock\n",
      "Init identity weights for channel mixing\n",
      "Try identity init in Gated Attention.\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Identity Init in Module:  TSPulseChannelFeatureMixerBlock\n",
      "Init identity weights for channel mixing\n",
      "Try identity init in Gated Attention.\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n",
      "Initializing Linear layers with method: pytorch\n"
     ]
    }
   ],
   "source": [
    "model_path = \"../../model-binaries/tspulse_classification/tspulse_model\"\n",
    "model = TSPulseForClassificationOrRegression.from_pretrained(model_path, **config_dict)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "17dcd1b1-6e88-4359-9214-882e6e7ffea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model.to(\"cuda\").float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "15e12599-c519-49b4-9e5b-d32dfec912e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Freezing Backbone except patch embedding layer....\n",
    "\n",
    "for param in model.backbone.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "for param in model.backbone.time_encoding.parameters():\n",
    "    param.requires_grad = True\n",
    "for param in model.backbone.fft_encoding.parameters():\n",
    "    param.requires_grad = True"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03211294-ff0d-4358-aa43-dd024387f2e1",
   "metadata": {},
   "source": [
    "## Finetuning the classifier head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "98b4ab6a-ee7b-4602-b207-b0cb9dc3ae54",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
      "LR Finder: Using GPU:0.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR Finder: Suggested learning rate = 0.0013219411484660286\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1600' max='1600' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1600/1600 01:29, Epoch 200/200]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>3.304000</td>\n",
       "      <td>3.333564</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>3.278500</td>\n",
       "      <td>3.318591</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>3.270400</td>\n",
       "      <td>3.301692</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>3.246500</td>\n",
       "      <td>3.288637</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>3.225700</td>\n",
       "      <td>3.271371</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>3.184200</td>\n",
       "      <td>3.254283</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>3.181000</td>\n",
       "      <td>3.234792</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>3.136400</td>\n",
       "      <td>3.206495</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>3.148800</td>\n",
       "      <td>3.154639</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>3.076700</td>\n",
       "      <td>3.089476</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>11</td>\n",
       "      <td>2.991600</td>\n",
       "      <td>2.976810</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>12</td>\n",
       "      <td>2.872900</td>\n",
       "      <td>2.814673</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>13</td>\n",
       "      <td>2.705100</td>\n",
       "      <td>2.575317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>14</td>\n",
       "      <td>2.421100</td>\n",
       "      <td>2.273871</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>15</td>\n",
       "      <td>2.117100</td>\n",
       "      <td>1.828617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>16</td>\n",
       "      <td>1.704600</td>\n",
       "      <td>1.376519</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>17</td>\n",
       "      <td>1.308700</td>\n",
       "      <td>1.024923</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>18</td>\n",
       "      <td>0.922600</td>\n",
       "      <td>0.765895</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>19</td>\n",
       "      <td>0.687900</td>\n",
       "      <td>0.546625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>0.475100</td>\n",
       "      <td>0.410213</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>21</td>\n",
       "      <td>0.325800</td>\n",
       "      <td>0.327102</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>22</td>\n",
       "      <td>0.240200</td>\n",
       "      <td>0.278404</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>23</td>\n",
       "      <td>0.171500</td>\n",
       "      <td>0.228393</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>24</td>\n",
       "      <td>0.132300</td>\n",
       "      <td>0.194863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>25</td>\n",
       "      <td>0.099300</td>\n",
       "      <td>0.181284</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>26</td>\n",
       "      <td>0.073600</td>\n",
       "      <td>0.155605</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>27</td>\n",
       "      <td>0.057800</td>\n",
       "      <td>0.128473</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>28</td>\n",
       "      <td>0.043600</td>\n",
       "      <td>0.128538</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>29</td>\n",
       "      <td>0.036500</td>\n",
       "      <td>0.124963</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>0.033000</td>\n",
       "      <td>0.114599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>31</td>\n",
       "      <td>0.024800</td>\n",
       "      <td>0.114230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>32</td>\n",
       "      <td>0.028300</td>\n",
       "      <td>0.108646</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>33</td>\n",
       "      <td>0.015600</td>\n",
       "      <td>0.111331</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>34</td>\n",
       "      <td>0.015500</td>\n",
       "      <td>0.101725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>35</td>\n",
       "      <td>0.014700</td>\n",
       "      <td>0.085431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>36</td>\n",
       "      <td>0.012600</td>\n",
       "      <td>0.090809</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>37</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>0.085086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>38</td>\n",
       "      <td>0.012900</td>\n",
       "      <td>0.080192</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>39</td>\n",
       "      <td>0.010000</td>\n",
       "      <td>0.086008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>0.008500</td>\n",
       "      <td>0.088949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>41</td>\n",
       "      <td>0.007100</td>\n",
       "      <td>0.092885</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>42</td>\n",
       "      <td>0.006300</td>\n",
       "      <td>0.085050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>43</td>\n",
       "      <td>0.005700</td>\n",
       "      <td>0.080808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>44</td>\n",
       "      <td>0.005200</td>\n",
       "      <td>0.069821</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>45</td>\n",
       "      <td>0.005000</td>\n",
       "      <td>0.066910</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>46</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.067954</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>47</td>\n",
       "      <td>0.003600</td>\n",
       "      <td>0.061832</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>48</td>\n",
       "      <td>0.003500</td>\n",
       "      <td>0.056266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>49</td>\n",
       "      <td>0.003900</td>\n",
       "      <td>0.062069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>0.005000</td>\n",
       "      <td>0.077475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>51</td>\n",
       "      <td>0.003000</td>\n",
       "      <td>0.069107</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>52</td>\n",
       "      <td>0.002900</td>\n",
       "      <td>0.058045</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>53</td>\n",
       "      <td>0.003300</td>\n",
       "      <td>0.057712</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>54</td>\n",
       "      <td>0.004100</td>\n",
       "      <td>0.056876</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>55</td>\n",
       "      <td>0.002000</td>\n",
       "      <td>0.057304</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>56</td>\n",
       "      <td>0.002100</td>\n",
       "      <td>0.057638</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>57</td>\n",
       "      <td>0.002100</td>\n",
       "      <td>0.061736</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>58</td>\n",
       "      <td>0.001600</td>\n",
       "      <td>0.066391</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>59</td>\n",
       "      <td>0.002100</td>\n",
       "      <td>0.063031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>0.002400</td>\n",
       "      <td>0.057825</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>61</td>\n",
       "      <td>0.002600</td>\n",
       "      <td>0.066053</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>62</td>\n",
       "      <td>0.002700</td>\n",
       "      <td>0.067239</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>63</td>\n",
       "      <td>0.004200</td>\n",
       "      <td>0.055008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>64</td>\n",
       "      <td>0.002300</td>\n",
       "      <td>0.049982</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>65</td>\n",
       "      <td>0.001600</td>\n",
       "      <td>0.048912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>66</td>\n",
       "      <td>0.001500</td>\n",
       "      <td>0.052118</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>67</td>\n",
       "      <td>0.001600</td>\n",
       "      <td>0.058328</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>68</td>\n",
       "      <td>0.002200</td>\n",
       "      <td>0.068968</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>69</td>\n",
       "      <td>0.001400</td>\n",
       "      <td>0.067260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>0.002600</td>\n",
       "      <td>0.055577</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>71</td>\n",
       "      <td>0.004200</td>\n",
       "      <td>0.133069</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>72</td>\n",
       "      <td>0.003100</td>\n",
       "      <td>0.112142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>73</td>\n",
       "      <td>0.001700</td>\n",
       "      <td>0.057151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>74</td>\n",
       "      <td>0.002000</td>\n",
       "      <td>0.046345</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>75</td>\n",
       "      <td>0.001300</td>\n",
       "      <td>0.046908</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>76</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.051823</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>77</td>\n",
       "      <td>0.000900</td>\n",
       "      <td>0.053974</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>78</td>\n",
       "      <td>0.001100</td>\n",
       "      <td>0.054040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>79</td>\n",
       "      <td>0.001100</td>\n",
       "      <td>0.054258</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.058498</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>81</td>\n",
       "      <td>0.000800</td>\n",
       "      <td>0.055266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>82</td>\n",
       "      <td>0.000900</td>\n",
       "      <td>0.050853</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>83</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.048124</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>84</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.046582</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>85</td>\n",
       "      <td>0.000800</td>\n",
       "      <td>0.044850</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>86</td>\n",
       "      <td>0.002800</td>\n",
       "      <td>0.168220</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>87</td>\n",
       "      <td>0.003900</td>\n",
       "      <td>0.106710</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>88</td>\n",
       "      <td>0.001900</td>\n",
       "      <td>0.043439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>89</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.048086</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>0.000900</td>\n",
       "      <td>0.045705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>91</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.042380</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>92</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.040863</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>93</td>\n",
       "      <td>0.000900</td>\n",
       "      <td>0.039625</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>94</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.038800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>95</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.047493</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>96</td>\n",
       "      <td>0.000900</td>\n",
       "      <td>0.047691</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>97</td>\n",
       "      <td>0.001500</td>\n",
       "      <td>0.042556</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>98</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.039142</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>99</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.036535</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.036946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>101</td>\n",
       "      <td>0.000700</td>\n",
       "      <td>0.039354</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>102</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.038275</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>103</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.038373</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>104</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.037890</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>105</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.043615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>106</td>\n",
       "      <td>0.000800</td>\n",
       "      <td>0.042967</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>107</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.040066</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>108</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.037998</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>109</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.036763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.035939</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>111</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.035461</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>112</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.034601</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>113</td>\n",
       "      <td>0.002700</td>\n",
       "      <td>0.029410</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>114</td>\n",
       "      <td>0.001100</td>\n",
       "      <td>0.029230</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>115</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.028483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>116</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.028119</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>117</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.028065</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>118</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.028456</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>119</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.028818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.029760</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>121</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.030621</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>122</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.030784</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>123</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.030744</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>124</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.030176</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>125</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.029265</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>126</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.028672</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>127</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.028624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>128</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.028648</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>129</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.028474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.028463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>131</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.028160</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>132</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.027873</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>133</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.027770</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>134</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.027629</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>135</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.027281</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>136</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.027041</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>137</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026751</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>138</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.026508</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>139</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026219</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026207</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>141</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>142</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>143</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.026395</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>144</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026449</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>145</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.026381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>146</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.026424</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>147</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.026439</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>148</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.026313</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>149</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.025981</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.025783</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>151</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.025650</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>152</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.025601</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>153</td>\n",
       "      <td>0.001000</td>\n",
       "      <td>0.024532</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>154</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.023899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>155</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.023440</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>156</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.023299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>157</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.023289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>158</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.023346</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>159</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.023375</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.023414</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>161</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.023453</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>162</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.023505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>163</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.023610</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>164</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.023695</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>165</td>\n",
       "      <td>0.000600</td>\n",
       "      <td>0.023813</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>166</td>\n",
       "      <td>0.000500</td>\n",
       "      <td>0.023989</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>167</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024141</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>168</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024279</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>169</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024327</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024355</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>171</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024418</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>172</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>173</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024462</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>174</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024435</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>175</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024384</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>176</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024311</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>177</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024259</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>178</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024225</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>179</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.024168</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>181</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024151</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>182</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024129</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>183</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024113</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>184</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024082</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>185</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.024059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>186</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>187</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>188</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.024043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>189</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024043</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024047</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>191</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.024050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>192</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>193</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>194</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>195</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024049</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>196</td>\n",
       "      <td>0.000200</td>\n",
       "      <td>0.024050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>197</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024050</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>198</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>199</td>\n",
       "      <td>0.000400</td>\n",
       "      <td>0.024051</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.000300</td>\n",
       "      <td>0.024051</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1600, training_loss=0.2590252548675926, metrics={'train_runtime': 90.4591, 'train_samples_per_second': 548.314, 'train_steps_per_second': 17.688, 'total_flos': 555520558694400.0, 'train_loss': 0.2590252548675926, 'epoch': 200.0})"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "temp_dir = tempfile.mkdtemp()\n",
    "\n",
    "suggested_lr = None\n",
    "\n",
    "train_dict = clf_params[dset][\"TRAINING_PARAMS\"]\n",
    "EPOCHS = train_dict[\"num_train_epochs\"]\n",
    "BATCH_SIZE = train_dict[\"per_device_train_batch_size\"]\n",
    "eval_accumulation_steps = train_dict[\"eval_accumulation_steps\"]\n",
    "NUM_WORKERS = 1\n",
    "NUM_GPUS = 1\n",
    "\n",
    "set_seed(42)\n",
    "if suggested_lr is None:\n",
    "    lr, model = optimal_lr_finder(\n",
    "        model,\n",
    "        train_dataset,\n",
    "        batch_size=BATCH_SIZE,\n",
    "    )\n",
    "    suggested_lr = lr\n",
    "\n",
    "finetune_args = TrainingArguments(\n",
    "    output_dir=temp_dir,\n",
    "    overwrite_output_dir=True,\n",
    "    learning_rate=suggested_lr,\n",
    "    num_train_epochs=EPOCHS,\n",
    "    do_eval=True,\n",
    "    eval_strategy=\"epoch\",\n",
    "    per_device_train_batch_size=BATCH_SIZE,\n",
    "    per_device_eval_batch_size=BATCH_SIZE,\n",
    "    eval_accumulation_steps=eval_accumulation_steps,\n",
    "    dataloader_num_workers=NUM_WORKERS,\n",
    "    report_to=\"tensorboard\",\n",
    "    save_strategy=\"epoch\",\n",
    "    logging_strategy=\"epoch\",\n",
    "    save_total_limit=1,\n",
    "    logging_dir=os.path.join(OUT_DIR, \"output\"),  # Make sure to specify a logging directory\n",
    "    load_best_model_at_end=True,  # Load the best model when training ends\n",
    "    metric_for_best_model=\"eval_loss\",  # Metric to monitor for early stopping\n",
    "    greater_is_better=False,  # For loss\n",
    ")\n",
    "\n",
    "# Create the early stopping callback\n",
    "early_stopping_callback = EarlyStoppingCallback(\n",
    "    early_stopping_patience=clf_params[dset][\"EARYL_STOPPING_PARAMS\"][\n",
    "        \"early_stopping_patience\"\n",
    "    ],  # Number of epochs with no improvement after which to stop\n",
    "    early_stopping_threshold=0.0001,  # Minimum improvement required to consider as improvement\n",
    ")\n",
    "\n",
    "# Optimizer and scheduler\n",
    "optimizer = AdamW(model.parameters(), lr=suggested_lr)\n",
    "scheduler = OneCycleLR(\n",
    "    optimizer,\n",
    "    suggested_lr,\n",
    "    epochs=EPOCHS,\n",
    "    steps_per_epoch=math.ceil(len(train_dataset) / (BATCH_SIZE * NUM_GPUS)),\n",
    ")\n",
    "\n",
    "finetune_trainer = Trainer(\n",
    "    model=model,\n",
    "    args=finetune_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=valid_dataset,\n",
    "    callbacks=[early_stopping_callback],\n",
    "    optimizers=(optimizer, scheduler),\n",
    ")\n",
    "\n",
    "# Fine tune\n",
    "finetune_trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f91d63f9-1bcf-457b-a582-7453f9ab2716",
   "metadata": {},
   "source": [
    "## Classification Scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "98dce5c3-33ad-4859-99d2-0ef7e54e76aa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "predictions_dict = finetune_trainer.predict(test_dataset)\n",
    "preds_np = predictions_dict.predictions[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6562744b-490a-40e0-8d82-ca86b1b44d70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "test_accuracy :  0.98\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "\n",
    "remove_columns_collator = RemoveColumnsCollator(\n",
    "    data_collator=default_data_collator,\n",
    "    signature_columns=[\"target_values\"],\n",
    "    logger=None,\n",
    "    description=None,\n",
    "    model_name=\"temp\",\n",
    ")\n",
    "\n",
    "test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=remove_columns_collator)\n",
    "target_list = []\n",
    "for batch in test_dataloader:\n",
    "    batch_labels = batch[\"target_values\"].numpy()\n",
    "    target_list.append(batch_labels)\n",
    "targets_np = np.concatenate(target_list, axis=0)\n",
    "test_accuracy = np.mean(targets_np == np.argmax(preds_np, axis=1))\n",
    "print(\"test_accuracy : \", test_accuracy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
