{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# PEFT with DNA Language Models",
   "id": "5813ea61984f612f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources.",
   "id": "ce34a98e26a66f79"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 1. Import relevant libraries",
   "id": "83620a55c621714"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We'll start by importing the required libraries, including the PEFT library and other dependencies.",
   "id": "e6a1fc548a3eba9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "import transformers\n",
    "import peft\n",
    "import tqdm\n",
    "import numpy as np"
   ],
   "id": "54b6947516fb01f4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 2. Load models\n",
   "id": "2dcec3cb20f69f48"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We'll load a pre-trained DNA Language Model, \"SpeciesLM\", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.\n",
    "\n",
    "The tokenizer and the model comes from the paper, \"Species-aware DNA language models capture regulatory elements and their evolution\". [Paper Link](https://www.biorxiv.org/content/10.1101/2023.01.26.525670v2), [Code Link](https://github.com/gagneurlab/SpeciesLM). They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution."
   ],
   "id": "801215edb15c8261"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "from transformers import AutoTokenizer, AutoModelForMaskedLM",
   "id": "bf7c274bbfaf288a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"gagneurlab/SpeciesLM\", revision=\"downstream_species_lm\")\n",
    "lm = AutoModelForMaskedLM.from_pretrained(\"gagneurlab/SpeciesLM\", revision=\"downstream_species_lm\")"
   ],
   "id": "9e006d277d4e4c3b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "lm.eval()\n",
    "lm.to(\"cuda\");"
   ],
   "id": "7c8ccbeced7d2a87"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 2. Prepare datasets",
   "id": "4739a0545930373d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks.",
   "id": "ed329a2f60d835c7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_data = load_dataset(\"InstaDeepAI/nucleotide_transformer_downstream_tasks\", \"H3\")"
   ],
   "id": "d49b652fbb24a9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We'll use the \"H3\" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data.",
   "id": "535463f2c5e0eb94"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "raw_data",
   "id": "5f9e5b0815a743a7"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "The dataset consists of three columns, ```sequence```, ```name``` and ```label```. An row in this dataset looks like:",
   "id": "54a6dbc6feb861d9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "raw_data['train'][0]",
   "id": "d0994e3690345e2a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We split out dataset into training, test, and validation sets.",
   "id": "3ff1b1207d3857d1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import Dataset, DatasetDict\n",
    "\n",
    "train_valid_split = raw_data['train'].train_test_split(test_size=0.15, seed=42)\n",
    "\n",
    "train_valid_split = DatasetDict({\n",
    "    'train': train_valid_split['train'],\n",
    "    'validation': train_valid_split['test']\n",
    "})\n",
    "\n",
    "ds = DatasetDict({\n",
    "    'train': train_valid_split['train'],\n",
    "    'validation': train_valid_split['validation'],\n",
    "    'test': raw_data['test']\n",
    "})"
   ],
   "id": "8e510e285e3b908a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Then, we use the tokenizer and a utility function we created, ```get_kmers``` to generate the final data and labels. The ```get_kmers``` function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.\n",
    "\n"
   ],
   "id": "31a80ee0d7c4d85c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def get_kmers(seq, k=6, stride=1):\n",
    "    return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]"
   ],
   "id": "579a95eec4e93811"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "test_sequences = []\n",
    "train_sequences = []\n",
    "val_sequences = []\n",
    "\n",
    "dataset_limit = 200  # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the\n",
    "# entire dataset\n",
    "\n",
    "for i in range(0, len(ds['train'])):\n",
    "\n",
    "    if dataset_limit and i == dataset_limit:\n",
    "        break\n",
    "\n",
    "    sequence = ds['train'][i]['sequence']\n",
    "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
    "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
    "    train_sequences.append(sequence)\n",
    "\n",
    "for i in range(0, len(ds['validation'])):\n",
    "    if dataset_limit and i == dataset_limit:\n",
    "        break\n",
    "    sequence = ds['validation'][i]['sequence']\n",
    "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
    "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
    "    val_sequences.append(sequence)\n",
    "\n",
    "for i in range(0, len(ds['test'])):\n",
    "    if dataset_limit and i == dataset_limit:\n",
    "        break\n",
    "    sequence = ds['test'][i]['sequence']\n",
    "    sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
    "    sequence = tokenizer(sequence)[\"input_ids\"]\n",
    "    test_sequences.append(sequence)\n",
    "\n",
    "train_labels = ds['train']['label']\n",
    "test_labels = ds['test']['label']\n",
    "val_labels = ds['validation']['label']\n",
    "\n",
    "if dataset_limit:\n",
    "    train_labels = train_labels[0:dataset_limit]\n",
    "    test_labels = test_labels[0:dataset_limit]\n",
    "    val_labels = val_labels[0:dataset_limit]"
   ],
   "id": "9d4f9c7c48dc53fe"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Finally, we create a Dataset object for each our sets.",
   "id": "d2316a46f3dff2c9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "train_dataset = Dataset.from_dict({\"input_ids\": train_sequences, \"labels\": train_labels})\n",
    "val_dataset = Dataset.from_dict({\"input_ids\": val_sequences, \"labels\": val_labels})\n",
    "test_dataset = Dataset.from_dict({\"input_ids\": test_sequences, \"labels\": test_labels})"
   ],
   "id": "2928ba222170a13f"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 4. Train model",
   "id": "cb8f94a23da4c069"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset.",
   "id": "6f45df8e88181958"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ],
   "id": "17921dc94f82cacc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "\n",
    "class DNA_LM(nn.Module):\n",
    "    def __init__(self, model, num_labels):\n",
    "        super(DNA_LM, self).__init__()\n",
    "        self.model = model.bert\n",
    "        self.in_features = model.config.hidden_size\n",
    "        self.out_features = num_labels\n",
    "        self.classifier = nn.Linear(self.in_features, self.out_features)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask=None, labels=None):\n",
    "        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
    "        sequence_output = outputs.hidden_states[-1]\n",
    "        # Use the [CLS] token for classification\n",
    "        cls_output = sequence_output[:, 0, :]\n",
    "        logits = self.classifier(cls_output)\n",
    "\n",
    "        loss = None\n",
    "        if labels is not None:\n",
    "            loss_fct = nn.CrossEntropyLoss()\n",
    "            loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))\n",
    "\n",
    "        return (loss, logits) if loss is not None else logits\n",
    "\n",
    "\n",
    "# Number of classes for your classification task\n",
    "num_labels = 2\n",
    "classification_model = DNA_LM(lm, num_labels)\n",
    "classification_model.to('cuda');"
   ],
   "id": "c9e566f1fcdba2b9"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import DataCollatorWithPadding\n",
    "\n",
    "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
   ],
   "id": "ee86ec1edf3a0d1a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import Trainer, TrainingArguments\n",
    "\n",
    "# Define training arguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    eval_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=0.01,\n",
    "    eval_steps=1,\n",
    "    logging_steps=1,\n",
    ")\n",
    "\n",
    "# Initialize Trainer\n",
    "trainer = Trainer(\n",
    "    model=classification_model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "# Train the model\n",
    "trainer.train()"
   ],
   "id": "742ab728ad976d0"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 5. Evaluation",
   "id": "2e3668c4f59cd307"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Generate predictions\n",
    "\n",
    "predictions = trainer.predict(test_dataset)\n",
    "logits = predictions.predictions\n",
    "predicted_labels = logits.argmax(axis=-1)\n",
    "print(predicted_labels)"
   ],
   "id": "29b8046769a0ebe6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Then, we create a function to calculate the accuracy from the test and predicted labels.",
   "id": "4c16e323f9bf4d58"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def calculate_accuracy(true_labels, predicted_labels):\n",
    "    assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
    "    correct_predictions = np.sum(true_labels == predicted_labels)\n",
    "    accuracy = correct_predictions / len(true_labels)\n",
    "\n",
    "    return accuracy\n",
    "\n",
    "\n",
    "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
    "print(f\"Accuracy: {accuracy:.2f}\")"
   ],
   "id": "d16cbfbe153a79b8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "The results aren't that good, which we can attribute to the small dataset size.",
   "id": "fef317c27b13029a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 7. Parameter Efficient Fine-Tuning Techniques",
   "id": "5af35ae8d04fca90"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "In this section, we demonstrate how to employ parameter-efficient fine-tuning (PEFT) techniques to adapt a pre-trained model for specific genomics tasks using the PEFT library.",
   "id": "99fd88c628e54d31"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "The LoraConfig object is instantiated to configure the PEFT parameters:\n",
    "\n",
    "- task_type: Specifies the type of task, in this case, sequence classification (SEQ_CLS).\n",
    "- r: The rank of the LoRA matrices.\n",
    "- lora_alpha: Scaling factor for adaptive re-parameterization.\n",
    "- target_modules: Modules within the model to apply PEFT re-parameterization (query, key, value in this example).\n",
    "- lora_dropout: Dropout rate used during PEFT fine-tuning."
   ],
   "id": "d71ccb1bcfe9e6a8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Number of classes for your classification task\n",
    "num_labels = 2\n",
    "classification_model = DNA_LM(lm, num_labels)\n",
    "classification_model.to('cuda');"
   ],
   "id": "61741510f312d1db"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import LoraConfig, TaskType\n",
    "\n",
    "peft_config = LoraConfig(\n",
    "    r=8,\n",
    "    lora_alpha=32,\n",
    "    target_modules=[\"query\", \"key\", \"value\"],\n",
    "    lora_dropout=0.01,\n",
    ")"
   ],
   "id": "e0b9af45e454bee3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import get_peft_model\n",
    "\n",
    "peft_model = get_peft_model(classification_model, peft_config)\n",
    "peft_model.print_trainable_parameters()"
   ],
   "id": "5d758ca0df1a9c2b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "peft_model",
   "id": "c59aeaef198006d2"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Define training arguments\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',\n",
    "    eval_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    num_train_epochs=5,\n",
    "    weight_decay=0.01,\n",
    "    eval_steps=1,\n",
    "    logging_steps=1,\n",
    ")\n",
    "\n",
    "# Initialize Trainer\n",
    "trainer = Trainer(\n",
    "    model=peft_model.model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    tokenizer=tokenizer,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "# Train the model\n",
    "trainer.train()"
   ],
   "id": "704da0276b3ef827"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### 8. Evaluate PEFT Model",
   "id": "c75b2fa4d36f34fc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Generate predictions\n",
    "\n",
    "predictions = trainer.predict(test_dataset)\n",
    "logits = predictions.predictions\n",
    "predicted_labels = logits.argmax(axis=-1)\n",
    "print(predicted_labels)"
   ],
   "id": "9cae699c03f89045"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def calculate_accuracy(true_labels, predicted_labels):\n",
    "    assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
    "    correct_predictions = np.sum(true_labels == predicted_labels)\n",
    "    accuracy = correct_predictions / len(true_labels)\n",
    "\n",
    "    return accuracy\n",
    "\n",
    "\n",
    "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
    "print(f\"Accuracy: {accuracy:.2f}\")"
   ],
   "id": "a2ec67e7004e78ea"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "As we can see, the PEFT model achieves similar performance to the baseline model, demonstrating the effectiveness of PEFT in adapting pre-trained models to specific tasks with limited computational resources.\n",
    "\n",
    "With PEFT, we only train 442,368 parameters, which is 0.49% of the total parameters in the model. This is a significant reduction in computational resources compared to training the entire model from scratch.\n",
    "\n",
    "We can improve the results by using a larger dataset, fine-tuning the model for more epochs or changing the hyperparameters (rank, learning rate, etc.).\n"
   ],
   "id": "97b792f4ff8d490"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
