{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# core\n",
    "\n",
    "> Fill in a module description here"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| default_exp core"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "# use_one_gpu = input(\"Do you want to use just one gpu? If yes, type the number of the gpu. If no, type -1.\\n\")\n",
    "# import os\n",
    "# if int(use_one_gpu) >= 0:\n",
    "#     os.environ[\"CUDA_VISIBLE_DEVICES\"] = use_one_gpu\n",
    "\n",
    "# Modules\n",
    "from ProtMamba_ssm.modules import *\n",
    "# Trainer\n",
    "from ProtMamba_ssm.trainer import *\n",
    "# Dataloaders\n",
    "from ProtMamba_ssm.dataloaders import *\n",
    "# Utils\n",
    "from ProtMamba_ssm.utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| export\n",
    "\n",
    "from transformers import TrainingArguments, get_cosine_schedule_with_warmup, get_constant_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup\n",
    "# from transformers.trainer_utils import get_last_checkpoint\n",
    "import yaml\n",
    "import torch\n",
    "import numpy as np\n",
    "from transformers.integrations import TensorBoardCallback\n",
    "from torch.optim import AdamW\n",
    "\n",
    "# List of available models\n",
    "_mamba_model = {\"none\": MambaLMHeadModelSafe, \"1d\": MambaLMHeadModelwithPosids, \"2d\": MambaLMHeadModelwith2DPosids}\n",
    "\n",
    "def run(config, namedir=None, finetune_model_path=None, restart_optimizer_and_scheduler=False):\n",
    "    r\"\"\"Run the training/finetuning loop.\n",
    "\n",
    "    Args:\n",
    "        config (dict): dictionary with the configuration parameters.\n",
    "        namedir (str, optional): name of the directory where the model will be saved. If None, the name will be taken from the config file.\n",
    "        finetune_model_path (str, optional): path to the model to be finetuned. If None, a new model will be created.\n",
    "    \"\"\"\n",
    "    if namedir is None:\n",
    "        namedir = config[\"namedir\"]\n",
    "    # Load Dataset\n",
    "    full_dataset = Uniclust30_Dataset(filename=config[\"train_dataset_path\"],\n",
    "                                      filepath=config[\"data_dir\"],\n",
    "                                      sample=config[\"sample_sequences\"],\n",
    "                                      max_msa_len=config[\"max_msa_len\"],\n",
    "                                      reverse=config[\"reverse\"],\n",
    "                                      seed=config[\"seed_sequence_sampling\"],\n",
    "                                      troubleshoot=False,\n",
    "                                      fim_strategy=config[\"fim_strategy\"],\n",
    "                                      always_mask=config[\"always_mask\"],\n",
    "                                      max_position_embeddings=config[\"max_position_embeddings\"],\n",
    "                                      max_seq_position_embeddings=config[\"max_seq_position_embeddings\"],\n",
    "                                      add_position_ids=config[\"add_position_ids\"])\n",
    "    \n",
    "    assert len(AA_TO_ID) == config[\"vocab_size\"], f\"Vocab size in the config file does not match the one in the code. I should be {len(AA_TO_ID)}\"\n",
    "    \n",
    "    # Split dataset in train, validation and test\n",
    "    gen = torch.Generator()\n",
    "    gen.manual_seed(config[\"seed_datasets\"])\n",
    "    np_gen = np.random.default_rng(seed=config[\"seed_datasets\"])\n",
    "    len_full_dataset = len(full_dataset)\n",
    "    len_val = config[\"size_validation\"] #len_full_dataset - len_train\n",
    "    len_train = len_full_dataset - len_val # int(0.99 * len_full_dataset)\n",
    "    assert len_val < len_full_dataset, \"Validation set is larger than the full dataset\"\n",
    "    assert len_val % config[\"batch_size\"] == 0, \"Validation set size must be a multiple of the batch size\"\n",
    "    print(f\"Train: {len_train} samples, Validation: {len_val} samples\")\n",
    "    train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [len_train, len_val], generator=gen)\n",
    "    eval_train_dataset = torch.utils.data.Subset(train_dataset,\n",
    "                                                 np_gen.choice(len(train_dataset), len(eval_dataset)))\n",
    "    # Create data collator for batched training\n",
    "    data_collator = DataCollatorForUniclust30Dataset()\n",
    "    \n",
    "    # Configure Mamba model\n",
    "    Mamba = _mamba_model[config[\"add_position_ids\"]]\n",
    "    if config[\"dtype\"] == \"float32\":\n",
    "        dtype = torch.float32\n",
    "        print(\"Using float32\")\n",
    "    elif config[\"dtype\"] == \"bfloat16\":\n",
    "        dtype = torch.bfloat16\n",
    "        print(\"Using bfloat16\")\n",
    "    else:\n",
    "        raise ValueError(\"dtype must be either float32 or bfloat16\")        \n",
    "    if finetune_model_path is not None:\n",
    "        # Load model for finetuning\n",
    "        model = load_model(finetune_model_path, device=\"cuda\", model_class=Mamba, dtype=dtype, checkpoint_mixer=config[\"checkpoint_mixer\"])\n",
    "    else:\n",
    "        # Create model for training\n",
    "        mamba_config = MambaConfig(d_model=config[\"d_model\"],\n",
    "                                n_layer=config[\"n_layer\"],\n",
    "                                vocab_size=config[\"vocab_size\"],\n",
    "                                residual_in_fp32=config[\"residual_in_fp32\"])\n",
    "        model = Mamba(mamba_config, dtype=dtype, checkpoint_mixer=config[\"checkpoint_mixer\"])\n",
    "    \n",
    "    # Print model information\n",
    "    print_number_of_parameters(model)\n",
    "    print(f\"Epochs: {config['num_epochs']}\")\n",
    "    print(f\"Batch size: {config['batch_size']}\")\n",
    "    print(f\"Gradient accumulation steps: {config['gradient_accumulation_steps']}\")\n",
    "    eff_batch_size = config['batch_size'] * config['gradient_accumulation_steps']\n",
    "    nr_gpus = torch.cuda.device_count()\n",
    "    eff_batch_size *= nr_gpus\n",
    "    print(f\"Effective batch size: {eff_batch_size}\")\n",
    "    print(f\"Steps per training epoch: {len(train_dataset) // config['batch_size']}, eff. steps: {len(train_dataset) // eff_batch_size}\")\n",
    "    print(f\"Steps per evaluation epoch: {len(eval_dataset) // config['batch_size']}\")\n",
    "    print(f\"Max MSA length: {config['max_msa_len']}\")\n",
    "    ev_epochs = round(config['eval_steps']*config[\"batch_size\"]/len(train_dataset), 3)\n",
    "    print(f\"Evaluation every {config['eval_steps']} steps, i.e. {ev_epochs} epochs. Effectively every {config['eval_steps']*config['gradient_accumulation_steps']} steps, i.e. {ev_epochs*config['gradient_accumulation_steps']} epochs.\")\n",
    "     \n",
    "    # Training callbacks\n",
    "    es_callback = EarlyStoppingCallback(train_path=config[\"output_dir\"] + namedir, config=config)\n",
    "    callbacks = [TensorBoardCallback(), es_callback]\n",
    "    \n",
    "    # Optimizer and Schedulers\n",
    "    optimizer = AdamW(model.parameters(),\n",
    "                      lr=config[\"learning_rate\"],\n",
    "                      betas=(config[\"beta1\"], config[\"beta2\"]),\n",
    "                      weight_decay=config[\"weight_decay\"])\n",
    "    if config[\"scheduler\"] == \"cosine\":\n",
    "        print(\"Using cosine scheduler\")\n",
    "        scheduler = get_cosine_schedule_with_warmup(optimizer,\n",
    "                                                    num_warmup_steps=config[\"warmup_steps\"],\n",
    "                                                    num_training_steps=config[\"num_epochs\"] * len(train_dataset) // eff_batch_size)\n",
    "    if config[\"scheduler\"] == \"cosine-restarts\":\n",
    "        scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(optimizer,\n",
    "                                                                       num_warmup_steps=config[\"warmup_steps\"],\n",
    "                                                                       num_training_steps=config[\"num_epochs\"] * len(train_dataset) // eff_batch_size,\n",
    "                                                                       num_cycles=config[\"num_cycles\"])\n",
    "    elif config[\"scheduler\"] == \"constant\":\n",
    "        print(\"Using constant scheduler with warmup\")\n",
    "        scheduler = get_constant_schedule_with_warmup(optimizer,num_warmup_steps=config[\"warmup_steps\"])\n",
    "    else:\n",
    "        raise ValueError(\"Scheduler must be either cosine or constant\")\n",
    "    \n",
    "    if finetune_model_path is not None:\n",
    "        # raise NotImplementedError(\"Test state dict loadin of optimizer and scheduler\")\n",
    "        if not restart_optimizer_and_scheduler:\n",
    "            optimizer.load_state_dict(torch.load(finetune_model_path + \"/optimizer.pt\"))\n",
    "            scheduler.load_state_dict(torch.load(finetune_model_path + \"/scheduler.pt\"))\n",
    "    \n",
    "    \n",
    "    # Find checkpoint if available\n",
    "    last_checkpoint = None\n",
    "    if finetune_model_path is None:\n",
    "        if os.path.exists(config[\"output_dir\"] + namedir):\n",
    "            last_checkpoint = get_last_checkpoint(config[\"output_dir\"] + namedir)\n",
    "            if last_checkpoint is None:\n",
    "                print(\"No checkpoint found, starting training from scratch.\")\n",
    "            else:\n",
    "                print(f\"Resuming training from the last checkpoint: {last_checkpoint}\")\n",
    "    if config[\"checkpoint_mixer\"]:\n",
    "        print(\"Using gradient checkpointing\")\n",
    "    # Create model trainer\n",
    "    trainer = MambaTrainer(\n",
    "        model=model,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset={\"valid\": eval_dataset, \"train\": eval_train_dataset},\n",
    "        optimizers=(optimizer, scheduler),\n",
    "        args=TrainingArguments(\n",
    "            learning_rate=config[\"learning_rate\"],\n",
    "            num_train_epochs=config[\"num_epochs\"],\n",
    "            per_device_train_batch_size=config[\"batch_size\"],\n",
    "            per_device_eval_batch_size=config[\"batch_size\"],\n",
    "            gradient_accumulation_steps=config[\"gradient_accumulation_steps\"],\n",
    "            eval_accumulation_steps=config[\"eval_accumulation_steps\"],\n",
    "            evaluation_strategy=\"steps\",\n",
    "            max_grad_norm=config[\"max_grad_norm\"],\n",
    "            bf16=config[\"dtype\"] == \"bfloat16\",\n",
    "            dataloader_num_workers=10,\n",
    "            logging_steps=config[\"logging_steps\"],\n",
    "            eval_steps=config[\"eval_steps\"],\n",
    "            save_steps=config[\"save_steps\"],\n",
    "            output_dir=config[\"output_dir\"] + namedir,\n",
    "            logging_dir=config[\"output_dir\"] + namedir,\n",
    "            overwrite_output_dir=False,\n",
    "            push_to_hub=False,\n",
    "            label_names=[\"labels\"],\n",
    "        ),\n",
    "        compute_only_fim_loss=config[\"compute_only_fim_loss\"],\n",
    "        data_collator=data_collator,\n",
    "        compute_metrics=compute_metrics,\n",
    "        callbacks=callbacks)\n",
    "    assert trainer.args._n_gpu == nr_gpus, \"Number of gpus used is not the same as the number of gpus available\"\n",
    "    if trainer.compute_only_fim_loss:\n",
    "        print(\"Computing only FIM loss for training\")\n",
    "    # Train model\n",
    "    while True:\n",
    "        if last_checkpoint is None and trainer.state.global_step == 0:\n",
    "            eval_results = trainer.evaluate()\n",
    "            print(f\">>> Initial Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}\")\n",
    "        else:\n",
    "            print(f\"Resuming training from the last checkpoint: {last_checkpoint}\")\n",
    "        # Train        \n",
    "        trainer.train(resume_from_checkpoint=last_checkpoint) \n",
    "        # Break training when the number of epochs is reached\n",
    "        if not es_callback.should_restart or trainer.state.epoch >= config[\"num_epochs\"]:\n",
    "            eval_results = trainer.evaluate()\n",
    "            print(f\">>> Final Perplexity: {eval_results['eval_valid_perplexity/batch']:.2f}\")\n",
    "            break\n",
    "        # If the training was interrupted because of a loss spike, restart from the last checkpoint\n",
    "        last_checkpoint = es_callback.checkpoint_path    \n",
    "    \n",
    "    return trainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if int(use_one_gpu) >= 0:\n",
    "    print(f\"Using gpu {use_one_gpu}\")\n",
    "print(\"Number of gpus used: \", torch.cuda.device_count())\n",
    "    \n",
    "with open(\"../configs/default_config.yaml\", \"r\") as file:\n",
    "    defaultconfig = yaml.safe_load(file)\n",
    "    \n",
    "namedir = input(\"Enter name of directory to save results: \")\n",
    "\n",
    "trainer = run(defaultconfig, namedir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Finetune"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if int(use_one_gpu) >= 0:\n",
    "    print(f\"Using gpu {use_one_gpu}\")\n",
    "print(\"Number of gpus used: \", torch.cuda.device_count())\n",
    "    \n",
    "with open(\"../configs/default_config.yaml\", \"r\") as file:\n",
    "    defaultconfig = yaml.safe_load(file)\n",
    "    \n",
    "namedir = input(\"Enter name of directory to save results: \")\n",
    "\n",
    "trainer = run(config=defaultconfig,\n",
    "              namedir=namedir,\n",
    "              finetune_model_path=\"\",\n",
    "              restart_optimizer_and_scheduler=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compile package"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#| hide\n",
    "import nbdev; nbdev.nbdev_export()\n",
    "!nbdev_readme"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "python3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
