{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# trainer\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "from transformers import Trainer, TrainerCallback, TrainerState, TrainerControl\n",
    "from ProtMamba_ssm.utils import *\n",
    "import re\n",
    "import torch\n",
    "import os\n",
    "\n",
    "class MambaTrainer(Trainer):\n",
    "    \"\"\"\n",
    "    Base HuggingFace Trainer used for training.\n",
    "    \n",
    "    from https://github.com/havenhq/mamba-chat/blob/main/trainer/mamba_trainer.py\"\"\"\n",
    "    def __init__(self, compute_only_fim_loss, **kwargs,):\n",
    "        super().__init__(**kwargs)\n",
    "        self.compute_only_fim_loss = compute_only_fim_loss\n",
    "        \n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        input_ids = inputs.pop(\"input_ids\")\n",
    "        if \"seq_position_ids\" in inputs and \"position_ids\" in inputs:\n",
    "            position_ids = inputs.pop(\"position_ids\")\n",
    "            seq_position_ids = inputs.pop(\"seq_position_ids\")\n",
    "            output = model(input_ids, position_ids=position_ids, seq_position_ids=seq_position_ids)\n",
    "        elif \"position_ids\" in inputs:\n",
    "            position_ids = inputs.pop(\"position_ids\")\n",
    "            output = model(input_ids, position_ids=position_ids)\n",
    "        else:\n",
    "            output = model(input_ids)\n",
    "        lm_logits = output.logits\n",
    "\n",
    "        labels = input_ids.to(lm_logits.device)\n",
    "        shift_logits = lm_logits[:, :-1, :].contiguous()\n",
    "        labels = labels[:, 1:].contiguous()\n",
    "\n",
    "        loss_fct = torch.nn.CrossEntropyLoss()\n",
    "        if self.compute_only_fim_loss:\n",
    "            # start and end tokens\n",
    "            is_cls_tokens = (labels == AA_TO_ID[\"<cls>\"])\n",
    "            is_eos_tokens = (labels == AA_TO_ID[\"<eos>\"])\n",
    "            bool_fim = find_fim_indices(is_cls_tokens, is_eos_tokens)\n",
    "            # include also the cls token\n",
    "            bool_fim = bool_fim | is_cls_tokens\n",
    "            inds = torch.where(bool_fim)\n",
    "            lm_loss = loss_fct(shift_logits[inds[0], inds[1], :], labels[bool_fim])\n",
    "        else:\n",
    "            lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))\n",
    "\n",
    "        return (lm_loss, output) if return_outputs else lm_loss\n",
    "\n",
    "    def save_model(self, output_dir, _internal_call):\n",
    "        if not os.path.exists(output_dir):\n",
    "            os.makedirs(output_dir)\n",
    "        self.model.save_pretrained(output_dir)\n",
    "\n",
    "PREFIX_CHECKPOINT_DIR = \"checkpoint\"\n",
    "_re_checkpoint = re.compile(r\"^\" + PREFIX_CHECKPOINT_DIR + r\"\\-(\\d+)$\")\n",
    "\n",
    "def get_last_checkpoint(folder, max_steps=None):\n",
    "    content = os.listdir(folder)\n",
    "    checkpoints = [\n",
    "        path\n",
    "        for path in content\n",
    "        if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))\n",
    "    ]\n",
    "    if len(checkpoints) == 0:\n",
    "        return\n",
    "    \n",
    "    max_steps = max_steps if max_steps is not None else float(\"inf\")\n",
    "    # func = lambda x: int(_re_checkpoint.search(x).groups()[0])\n",
    "    def func(x):\n",
    "        num = int(_re_checkpoint.search(x).groups()[0])\n",
    "        return num if num < max_steps else -1\n",
    "    return os.path.join(folder, max(checkpoints, key=func))\n",
    "\n",
    "class EarlyStoppingCallback(TrainerCallback):\n",
    "    def __init__(self, train_path, config=None):\n",
    "        self.step_counter = 0\n",
    "        self.best_loss = None\n",
    "        self.train_path = train_path\n",
    "        self.patience = config[\"patience\"]\n",
    "        self.metric_name = config[\"early_stopping_metric\"]\n",
    "        self.checkpoint_path = None\n",
    "        self.should_restart = False\n",
    "        self.eval_steps = config[\"eval_steps\"]\n",
    "        self.loss_increase_factor = config[\"loss_increase_factor\"]\n",
    "    \n",
    "    def get_checkpoint_path(self, max_steps):\n",
    "        last_checkpoint = None\n",
    "        if os.path.exists(self.train_path):\n",
    "            last_checkpoint = get_last_checkpoint(self.train_path, max_steps)\n",
    "            if last_checkpoint is None:\n",
    "                print(\"No checkpoint found, starting training from scratch.\")\n",
    "            else:\n",
    "                print(f\"Max checkpoint allowed: {max_steps}, restarting from {last_checkpoint}.\")\n",
    "        return last_checkpoint\n",
    "\n",
    "    def on_evaluate(self, args, state, control, model, metrics, **kwargs):\n",
    "        if self.metric_name in metrics:\n",
    "            if self.best_loss is None:\n",
    "                self.best_loss = metrics[self.metric_name]\n",
    "            elif self.best_loss*self.loss_increase_factor < metrics[self.metric_name]:\n",
    "                self.step_counter += 1\n",
    "                if self.step_counter >= self.patience:\n",
    "                    checkpoint_path = self.get_checkpoint_path(max_steps=(state.global_step-self.patience*self.eval_steps))\n",
    "                    control.should_training_stop = True\n",
    "                    self.checkpoint_path = checkpoint_path\n",
    "                    self.should_restart = True \n",
    "            else:\n",
    "                self.step_counter = 0\n",
    "                self.best_loss = min(self.best_loss, metrics[self.metric_name])\n",
    "                self.should_restart = False\n",
    "\n",
    "    def on_train_begin(self, args, state, control, **kwargs):\n",
    "        self.step_counter = 0\n",
    "        self.best_loss = None\n",
    "        self.should_restart = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
