{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# Using FourierFT for sequence classification",
   "id": "150971b27da47fd6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "In this example, we fine-tune Roberta (base) on a sequence classification task using FourierFT.",
   "id": "625c98b5376c1531"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Imports",
   "id": "2c545d5bd8842005"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "#  To run this notebook, please run `pip install evaluate` to install additional dependencies not covered by PEFT.\n",
    "import torch\n",
    "from torch.optim import AdamW\n",
    "from torch.utils.data import DataLoader\n",
    "from peft import (\n",
    "    get_peft_model,\n",
    "    FourierFTConfig,\n",
    "    PeftType,\n",
    ")\n",
    "\n",
    "import evaluate\n",
    "from datasets import load_dataset\n",
    "from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed, \\\n",
    "    AutoConfig\n",
    "from tqdm import tqdm"
   ],
   "id": "b8c9e9d6959a7929"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Parameters",
   "id": "b15728fe787bfa22"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "batch_size = 32\n",
    "model_name_or_path = \"roberta-base\"\n",
    "task = \"mrpc\"\n",
    "peft_type = PeftType.FOURIERFT\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "num_epochs = 5  # for better results, increase this number\n",
    "n_frequency = 1000  # for better results, increase this number\n",
    "scaling = 150.0\n",
    "max_length = 512\n",
    "torch.manual_seed(0)"
   ],
   "id": "b38c08cf4f640303"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "peft_config = FourierFTConfig(\n",
    "    task_type=\"SEQ_CLS\",\n",
    "    n_frequency=n_frequency,\n",
    "    target_modules=[\"query\", \"value\"],\n",
    "    scaling=scaling,\n",
    ")\n",
    "head_lr = 6e-3  # the learning rate for the classification head for NLU tasks\n",
    "fft_lr = 6e-2  # the learning rate for the parameters other than the classification head (q,v in this case)"
   ],
   "id": "126fbb3f8cb435e5"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Loading data",
   "id": "a2c6db145ffce86f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "if any(k in model_name_or_path for k in (\"gpt\", \"opt\", \"bloom\")):\n",
    "    padding_side = \"left\"\n",
    "else:\n",
    "    padding_side = \"right\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side=padding_side)\n",
    "if getattr(tokenizer, \"pad_token_id\") is None:\n",
    "    tokenizer.pad_token_id = tokenizer.eos_token_id"
   ],
   "id": "7565b1a69d764b9b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "datasets = load_dataset(\"glue\", task)\n",
    "metric = evaluate.load(\"glue\", task)"
   ],
   "id": "770253705e0dab"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def tokenize_function(examples):\n",
    "    # max_length=None => use the model max length (it's actually the default)\n",
    "    outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, max_length=max_length)\n",
    "    return outputs\n",
    "\n",
    "\n",
    "tokenized_datasets = datasets.map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    remove_columns=[\"idx\", \"sentence1\", \"sentence2\"],\n",
    ")\n",
    "\n",
    "# We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the\n",
    "# transformers library\n",
    "tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")"
   ],
   "id": "7be923780be3f60c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def collate_fn(examples):\n",
    "    return tokenizer.pad(examples, padding=\"longest\", return_tensors=\"pt\")\n",
    "\n",
    "\n",
    "# Instantiate dataloaders.\n",
    "train_dataloader = DataLoader(tokenized_datasets[\"train\"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size)\n",
    "eval_dataloader = DataLoader(\n",
    "    tokenized_datasets[\"validation\"], shuffle=False, collate_fn=collate_fn, batch_size=batch_size\n",
    ")"
   ],
   "id": "6f277582799c4bef"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Preparing the FourierFT model",
   "id": "8eaa4fb7ae6aae45"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, return_dict=True, max_length=None)\n",
    "model = get_peft_model(model, peft_config)\n",
    "model.print_trainable_parameters()"
   ],
   "id": "6e2dc17383c0a0b0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "head_param = list(map(id, model.classifier.parameters()))\n",
    "\n",
    "others_param = filter(lambda p: id(p) not in head_param, model.parameters())\n",
    "\n",
    "optimizer = AdamW([\n",
    "    {\"params\": model.classifier.parameters(), \"lr\": head_lr},\n",
    "    {\"params\": others_param, \"lr\": fft_lr}\n",
    "], weight_decay=0.)\n",
    "\n",
    "# Instantiate scheduler\n",
    "lr_scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer=optimizer,\n",
    "    num_warmup_steps=0.06 * (len(train_dataloader) * num_epochs),\n",
    "    num_training_steps=(len(train_dataloader) * num_epochs),\n",
    ")"
   ],
   "id": "e146ef8f489a02fd"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Training",
   "id": "51394cf88bbadaa7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.to(device)\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()\n",
    "    for step, batch in enumerate(tqdm(train_dataloader)):\n",
    "        batch.to(device)\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "    model.eval()\n",
    "    for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "        batch.to(device)\n",
    "        with torch.no_grad():\n",
    "            outputs = model(**batch)\n",
    "        predictions = outputs.logits.argmax(dim=-1)\n",
    "        predictions, references = predictions, batch[\"labels\"]\n",
    "        metric.add_batch(\n",
    "            predictions=predictions,\n",
    "            references=references,\n",
    "        )\n",
    "\n",
    "    eval_metric = metric.compute()\n",
    "    print(f\"epoch {epoch}:\", eval_metric)"
   ],
   "id": "228d77872a3d2851"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Share adapters on the 🤗 Hub",
   "id": "6b1e08e55065fc07"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "account_id = ...  # your Hugging Face Hub account ID",
   "id": "db4735504e8ae477"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "model.push_to_hub(f\"{account_id}/roberta-base-mrpc-peft-fourierft\")",
   "id": "44f6171aa2524e68"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "## Load adapters from the Hub\n",
    "\n",
    "You can also directly load adapters from the Hub using the commands below:"
   ],
   "id": "78454938281059f7"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import AutoTokenizer"
   ],
   "id": "92fec5e4b02a7eb3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "peft_model_id = f\"{account_id}/roberta-base-mrpc-peft-fourierft\"\n",
    "config = PeftConfig.from_pretrained(peft_model_id)\n",
    "inference_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)"
   ],
   "id": "833f101babdbc9cd"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Load the FourierFT model\n",
    "inference_model = PeftModel.from_pretrained(inference_model, peft_model_id, config=config)"
   ],
   "id": "60d4ccf468cae332"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "inference_model.to(device)\n",
    "inference_model.eval()\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    batch.to(device)\n",
    "    with torch.no_grad():\n",
    "        outputs = inference_model(**batch)\n",
    "    predictions = outputs.logits.argmax(dim=-1)\n",
    "    predictions, references = predictions, batch[\"labels\"]\n",
    "    metric.add_batch(\n",
    "        predictions=predictions,\n",
    "        references=references,\n",
    "    )\n",
    "\n",
    "eval_metric = metric.compute()\n",
    "print(eval_metric)"
   ],
   "id": "dfe296a08e402f6d"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
