{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "# Finetuning Whisper-large-V2 on Colab using PEFT-Lora + BNB INT8 training",
   "id": "b9b641fdc24bba41"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "In this Colab, we present a step-by-step guide on how to fine-tune Whisper for any multilingual ASR dataset using Hugging Face 🤗 Transformers and 🤗 PEFT. Using 🤗 PEFT and `bitsandbytes`, you can train the `whisper-large-v2` seamlessly on a colab with T4 GPU (16 GB VRAM). In this notebook, with most parts from [fine_tune_whisper.ipynb](https://colab.research.google.com/github/sanchit-gandhi/notebooks/blob/main/fine_tune_whisper.ipynb#scrollTo=BRdrdFIeU78w) is adapted to train using PEFT LoRA+BNB INT8.\n",
    "\n",
    "For more details on model, datasets and metrics, refer blog [Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers](https://huggingface.co/blog/fine-tune-whisper)\n",
    "\n"
   ],
   "id": "28c5ecf7bfb49d06"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Inital Setup",
   "id": "100ede8038fc21e4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "!add-apt-repository -y ppa:jonathonf/ffmpeg-4\n",
    "!apt update\n",
    "!apt install -y ffmpeg"
   ],
   "id": "b6a21c7de022cba1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "!pip install datasets>=2.6.1\n",
    "!pip install git+https://github.com/huggingface/transformers\n",
    "!pip install librosa\n",
    "!pip install evaluate>=0.30\n",
    "!pip install jiwer\n",
    "!pip install gradio\n",
    "!pip install -q bitsandbytes datasets accelerate\n",
    "!pip install -q git+https://github.com/huggingface/transformers.git@main git+https://github.com/huggingface/peft.git@main"
   ],
   "id": "bf1d1871c48abd0c"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Linking the notebook to the Hub is straightforward - it simply requires entering your Hub authentication token when prompted. Find your Hub authentication token [here](https://huggingface.co/settings/tokens):",
   "id": "34a5833f9a9c829b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ],
   "id": "1706e8431fac7aa"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "# Select CUDA device index\n",
    "import os\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "model_name_or_path = \"openai/whisper-large-v2\"\n",
    "language = \"Marathi\"\n",
    "language_abbr = \"mr\"\n",
    "task = \"transcribe\"\n",
    "dataset_name = \"mozilla-foundation/common_voice_11_0\""
   ],
   "id": "4bbe30ff79760461"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Load Dataset",
   "id": "a69a97e06974db9d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import load_dataset, DatasetDict\n",
    "\n",
    "common_voice = DatasetDict()\n",
    "\n",
    "common_voice[\"train\"] = load_dataset(dataset_name, language_abbr, split=\"train+validation\", use_auth_token=True)\n",
    "common_voice[\"test\"] = load_dataset(dataset_name, language_abbr, split=\"test\", use_auth_token=True)\n",
    "\n",
    "print(common_voice)"
   ],
   "id": "e88865e54182272e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "common_voice = common_voice.remove_columns(\n",
    "    [\"accent\", \"age\", \"client_id\", \"down_votes\", \"gender\", \"locale\", \"path\", \"segment\", \"up_votes\"]\n",
    ")\n",
    "\n",
    "print(common_voice)"
   ],
   "id": "d1d2083ae028dbd8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Prepare Feature Extractor, Tokenizer and Data",
   "id": "f2c3a335ade6db8"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import WhisperFeatureExtractor\n",
    "\n",
    "feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name_or_path)"
   ],
   "id": "1923c4065d6002"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import WhisperTokenizer\n",
    "\n",
    "tokenizer = WhisperTokenizer.from_pretrained(model_name_or_path, language=language, task=task)"
   ],
   "id": "ce98d574b5848a39"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import WhisperProcessor\n",
    "\n",
    "processor = WhisperProcessor.from_pretrained(model_name_or_path, language=language, task=task)"
   ],
   "id": "2acce5f925d797b9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Prepare Data",
   "id": "ead1964ed5d522c5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "print(common_voice[\"train\"][0])",
   "id": "694a36c6a262bea1"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Since \n",
    "our input audio is sampled at 48kHz, we need to _downsample_ it to \n",
    "16kHz prior to passing it to the Whisper feature extractor, 16kHz being the sampling rate expected by the Whisper model. \n",
    "\n",
    "We'll set the audio inputs to the correct sampling rate using dataset's \n",
    "[`cast_column`](https://huggingface.co/docs/datasets/package_reference/main_classes.html?highlight=cast_column#datasets.DatasetDict.cast_column)\n",
    "method. This operation does not change the audio in-place, \n",
    "but rather signals to `datasets` to resample audio samples _on the fly_ the \n",
    "first time that they are loaded:"
   ],
   "id": "b5b2af62e57dea29"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from datasets import Audio\n",
    "\n",
    "common_voice = common_voice.cast_column(\"audio\", Audio(sampling_rate=16000))"
   ],
   "id": "1fae05e49349b8a9"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Re-loading the first audio sample in the Common Voice dataset will resample \n",
    "it to the desired sampling rate:"
   ],
   "id": "716666f3c9b1d9ec"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "print(common_voice[\"train\"][0])",
   "id": "4d1470c11429c084"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Now we can write a function to prepare our data ready for the model:\n",
    "1. We load and resample the audio data by calling `batch[\"audio\"]`. As explained above, 🤗 Datasets performs any necessary resampling operations on the fly.\n",
    "2. We use the feature extractor to compute the log-Mel spectrogram input features from our 1-dimensional audio array.\n",
    "3. We encode the transcriptions to label ids through the use of the tokenizer."
   ],
   "id": "5407a0ac4f39c14c"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def prepare_dataset(batch):\n",
    "    # load and resample audio data from 48 to 16kHz\n",
    "    audio = batch[\"audio\"]\n",
    "\n",
    "    # compute log-Mel input features from input audio array\n",
    "    batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
    "\n",
    "    # encode target text to label ids\n",
    "    batch[\"labels\"] = tokenizer(batch[\"sentence\"]).input_ids\n",
    "    return batch"
   ],
   "id": "3ee344591c3712ec"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We can apply the data preparation function to all of our training examples using dataset's `.map` method. The argument `num_proc` specifies how many CPU cores to use. Setting `num_proc` > 1 will enable multiprocessing. If the `.map` method hangs with multiprocessing, set `num_proc=1` and process the dataset sequentially.",
   "id": "46e15b002bf790e4"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names[\"train\"], num_proc=2)",
   "id": "b8ca2bed3e92f724"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "common_voice[\"train\"]",
   "id": "84e16760761af8b4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Training and Evaluation",
   "id": "231f8858c843ccd6"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Define a Data Collator",
   "id": "ecf4771e4620adde"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "\n",
    "from dataclasses import dataclass\n",
    "from typing import Any, Dict, List, Union\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class DataCollatorSpeechSeq2SeqWithPadding:\n",
    "    processor: Any\n",
    "\n",
    "    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
    "        # split inputs and labels since they have to be of different lengths and need different padding methods\n",
    "        # first treat the audio inputs by simply returning torch tensors\n",
    "        input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
    "        batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
    "\n",
    "        # get the tokenized label sequences\n",
    "        label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
    "        # pad the labels to max length\n",
    "        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
    "\n",
    "        # replace padding with -100 to ignore loss correctly\n",
    "        labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
    "\n",
    "        # if bos token is appended in previous tokenization step,\n",
    "        # cut bos token here as it's append later anyways\n",
    "        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n",
    "            labels = labels[:, 1:]\n",
    "\n",
    "        batch[\"labels\"] = labels\n",
    "\n",
    "        return batch"
   ],
   "id": "c471333936c4d797"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Let's initialise the data collator we've just defined:",
   "id": "ca7058273ecb38e1"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)",
   "id": "66fd0c3ae90bdfae"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Evaluation Metrics",
   "id": "2defe13f30995a5a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We'll use the word error rate (WER) metric, the 'de-facto' metric for assessing \n",
    "ASR systems. For more information, refer to the WER [docs](https://huggingface.co/metrics/wer). We'll load the WER metric from 🤗 Evaluate:"
   ],
   "id": "9241dd4a05fdfe30"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"wer\")"
   ],
   "id": "6cd9b7b8474dbd7d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "We then simply have to define a function that takes our model \n",
    "predictions and returns the WER metric. This function, called\n",
    "`compute_metrics`, first replaces `-100` with the `pad_token_id`\n",
    "in the `label_ids` (undoing the step we applied in the \n",
    "data collator to ignore padded tokens correctly in the loss).\n",
    "It then decodes the predicted and label ids to strings. Finally,\n",
    "it computes the WER between the predictions and reference labels:"
   ],
   "id": "bbb4f42b12e34ad3"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "def compute_metrics(pred):\n",
    "    pred_ids = pred.predictions\n",
    "    label_ids = pred.label_ids\n",
    "\n",
    "    # replace -100 with the pad_token_id\n",
    "    label_ids[label_ids == -100] = tokenizer.pad_token_id\n",
    "\n",
    "    # we do not want to group tokens when computing the metrics\n",
    "    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
    "    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
    "\n",
    "    wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n",
    "\n",
    "    return {\"wer\": wer}"
   ],
   "id": "7da9d47c3412f946"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Load a Pre-Trained Checkpoint",
   "id": "5af83cc34a3b9a9d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "Now let's load the pre-trained Whisper `small` checkpoint. Again, this \n",
    "is trivial through use of 🤗 Transformers!"
   ],
   "id": "90c954a4c394edb5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import WhisperForConditionalGeneration, BitsAndBytesConfig\n",
    "\n",
    "model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path,\n",
    "                                                        quantization_config=BitsAndBytesConfig(load_in_8bit=True))\n",
    "\n",
    "# model.hf_device_map - this should be {\" \": 0}"
   ],
   "id": "698c8065c11a2332"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "Override generation arguments - no tokens are forced as decoder outputs (see [`forced_decoder_ids`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.forced_decoder_ids)), no tokens are suppressed during generation (see [`suppress_tokens`](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.generation_utils.GenerationMixin.generate.suppress_tokens)):",
   "id": "9630e7ae441f1e56"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model.config.forced_decoder_ids = None\n",
    "model.config.suppress_tokens = []"
   ],
   "id": "4f53856dcaa7104"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Post-processing on the model\n",
    "\n",
    "Finally, we need to apply some post-processing on the 8-bit model to enable training, let's freeze all our layers, and cast all non `int8` layers in `float32` for stability."
   ],
   "id": "55ab448c3804a0c5"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import prepare_model_for_kbit_training\n",
    "\n",
    "model = prepare_model_for_kbit_training(model)"
   ],
   "id": "ab778a80733327ff"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "### Apply LoRA\n",
    "\n",
    "Here comes the magic with `peft`! Let's load a `PeftModel` and specify that we are going to use low-rank adapters (LoRA) using `get_peft_model` utility function from `peft`."
   ],
   "id": "34dec97b23c8b12d"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import LoraConfig, PeftModel, LoraModel, LoraConfig, get_peft_model\n",
    "\n",
    "config = LoraConfig(r=32, lora_alpha=64, target_modules=[\"q_proj\", \"v_proj\"], lora_dropout=0.05, bias=\"none\")\n",
    "\n",
    "model = get_peft_model(model, config)\n",
    "model.print_trainable_parameters()"
   ],
   "id": "178a0512cd434f54"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We are ONLY using **1%** of the total trainable parameters, thereby performing **Parameter-Efficient Fine-Tuning**",
   "id": "ab5b186a00c3a53a"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "### Define the Training Configuration",
   "id": "18febfac8d7f38f0"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "In the final step, we define all the parameters related to training. For more detail on the training arguments, refer to the Seq2SeqTrainingArguments [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.Seq2SeqTrainingArguments).",
   "id": "b0c9f0b42e12b7c6"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import Seq2SeqTrainingArguments\n",
    "\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=\"temp\",  # change to a repo name of your choice\n",
    "    per_device_train_batch_size=8,\n",
    "    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size\n",
    "    learning_rate=1e-3,\n",
    "    warmup_steps=50,\n",
    "    num_train_epochs=3,\n",
    "    eval_strategy=\"epoch\",\n",
    "    fp16=True,\n",
    "    per_device_eval_batch_size=8,\n",
    "    generation_max_length=128,\n",
    "    logging_steps=25,\n",
    "    remove_unused_columns=False,\n",
    "    # required as the PeftModel forward doesn't have the signature of the wrapped model's forward\n",
    "    label_names=[\"labels\"],  # same reason as above\n",
    ")"
   ],
   "id": "2666fb8cbbf31ef8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "**Few Important Notes:**\n",
    "1. `remove_unused_columns=False` and `label_names=[\"labels\"]` are required as the PeftModel's forward doesn't have the signature of the base model's forward.\n",
    "\n",
    "2. INT8 training required autocasting. `predict_with_generate` can't be passed to Trainer because it internally calls transformer's `generate` without autocasting leading to errors. \n",
    "\n",
    "3. Because of point 2, `compute_metrics` shouldn't be passed to `Seq2SeqTrainer` as seen below. (commented out)"
   ],
   "id": "f852b0b5d7b88c39"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from transformers import Seq2SeqTrainer, TrainerCallback, TrainingArguments, TrainerState, TrainerControl\n",
    "from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR\n",
    "\n",
    "\n",
    "class SavePeftModelCallback(TrainerCallback):\n",
    "    def on_save(\n",
    "            self,\n",
    "            args: TrainingArguments,\n",
    "            state: TrainerState,\n",
    "            control: TrainerControl,\n",
    "            **kwargs,\n",
    "    ):\n",
    "        checkpoint_folder = os.path.join(args.output_dir, f\"{PREFIX_CHECKPOINT_DIR}-{state.global_step}\")\n",
    "\n",
    "        peft_model_path = os.path.join(checkpoint_folder, \"adapter_model\")\n",
    "        kwargs[\"model\"].save_pretrained(peft_model_path)\n",
    "\n",
    "        pytorch_model_path = os.path.join(checkpoint_folder, \"pytorch_model.bin\")\n",
    "        if os.path.exists(pytorch_model_path):\n",
    "            os.remove(pytorch_model_path)\n",
    "        return control\n",
    "\n",
    "\n",
    "trainer = Seq2SeqTrainer(\n",
    "    args=training_args,\n",
    "    model=model,\n",
    "    train_dataset=common_voice[\"train\"],\n",
    "    eval_dataset=common_voice[\"test\"],\n",
    "    data_collator=data_collator,\n",
    "    # compute_metrics=compute_metrics,\n",
    "    tokenizer=processor.feature_extractor,\n",
    "    callbacks=[SavePeftModelCallback],\n",
    ")\n",
    "model.config.use_cache = False  # silence the warnings. Please re-enable for inference!"
   ],
   "id": "ed1bc81e2a5b1594"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "trainer.train()",
   "id": "88243dd7eb7cd12a"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "model_name_or_path = \"openai/whisper-large-v2\"\n",
    "peft_model_id = \"smangrul/\" + f\"{model_name_or_path}-{model.peft_config.peft_type}-colab\".replace(\"/\", \"-\")\n",
    "model.push_to_hub(peft_model_id)\n",
    "print(peft_model_id)"
   ],
   "id": "8c7746d84ef3333d"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Evaluation and Inference",
   "id": "e0d56e71ed40ba71"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "**Important points to note while inferencing**:\n",
    "1. As `predict_with_generate` can't be used, we will write the eval loop with `torch.cuda.amp.autocast()` as shown below. \n",
    "2. As the base model is frozen, PEFT model sometimes fails ot recognise the language while decoding.Hence, we force the starting tokens to mention the language we are transcribing. This is done via `forced_decoder_ids = processor.get_decoder_prompt_ids(language=\"Marathi\", task=\"transcribe\")` and passing that to the `model.generate` call.\n",
    "3. Please note that [AutoEvaluate Leaderboard](https://huggingface.co/spaces/autoevaluate/leaderboards?dataset=mozilla-foundation%2Fcommon_voice_11_0&only_verified=0&task=automatic-speech-recognition&config=mr&split=test&metric=wer) for `mr` language on `common_voice_11_0` has a bug wherein openai's `BasicTextNormalizer` normalizer is used while evaluation leading to degerated output text, an example is shown below:\n",
    "```\n",
    "without normalizer: 'स्विच्चान नरुवित्तीची पद्दत मोठ्या प्रमाणात आमलात आणल्या बसोन या दुपन्याने अनेक राथ प्रवेश केला आहे.'\n",
    "with normalizer: 'स व च च न नर व त त च पद दत म ठ य प रम ण त आमल त आणल य बस न य द पन य न अन क र थ प रव श क ल आह'\n",
    "```\n",
    "Post fixing this bug, we report the 2 metrics for the top model of the leaderboard and the PEFT model:\n",
    "1. `wer`: `wer` without using the `BasicTextNormalizer` as it doesn't cater to most indic languages. This is want we consider as true performance metric.\n",
    "2. `normalized_wer`: `wer` using the `BasicTextNormalizer` to be comparable to the leaderboard metrics.\n",
    "Below are the results:\n",
    "\n",
    "| Model          | DrishtiSharma/whisper-large-v2-marathi | smangrul/openai-whisper-large-v2-LORA-colab |\n",
    "|----------------|----------------------------------------|---------------------------------------------|\n",
    "| wer            | 35.6457                                | 36.1356                                     |\n",
    "| normalized_wer | 13.6440                                | 14.0165                                     |\n",
    "\n",
    "We see that PEFT model's performance is comparable to the fully fine-tuned model on the top of the leaderboard. At the same time, we are able to train the large model in Colab notebook with limited GPU memory and the added advantage of resulting checkpoint being jsut `63` MB.\n",
    "\n"
   ],
   "id": "8fe777ab10b63f8b"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from peft import PeftModel, PeftConfig\n",
    "from transformers import WhisperForConditionalGeneration, Seq2SeqTrainer\n",
    "\n",
    "peft_model_id = \"smangrul/openai-whisper-large-v2-LORA-colab\"\n",
    "peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = WhisperForConditionalGeneration.from_pretrained(\n",
    "    peft_config.base_model_name_or_path, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map=\"auto\"\n",
    ")\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)"
   ],
   "id": "e86ccf17077583e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import gc\n",
    "\n",
    "eval_dataloader = DataLoader(common_voice[\"test\"], batch_size=8, collate_fn=data_collator)\n",
    "\n",
    "model.eval()\n",
    "for step, batch in enumerate(tqdm(eval_dataloader)):\n",
    "    with torch.cuda.amp.autocast():\n",
    "        with torch.no_grad():\n",
    "            generated_tokens = (\n",
    "                model.generate(\n",
    "                    input_features=batch[\"input_features\"].to(\"cuda\"),\n",
    "                    decoder_input_ids=batch[\"labels\"][:, :4].to(\"cuda\"),\n",
    "                    max_new_tokens=255,\n",
    "                )\n",
    "                .cpu()\n",
    "                .numpy()\n",
    "            )\n",
    "            labels = batch[\"labels\"].cpu().numpy()\n",
    "            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)\n",
    "            decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)\n",
    "            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)\n",
    "            metric.add_batch(\n",
    "                predictions=decoded_preds,\n",
    "                references=decoded_labels,\n",
    "            )\n",
    "    del generated_tokens, labels, batch\n",
    "    gc.collect()\n",
    "wer = 100 * metric.compute()\n",
    "print(f\"{wer=}\")"
   ],
   "id": "61ac81e32b825717"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Using AutomaticSpeechRecognitionPipeline",
   "id": "e2b512107b6a2387"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "**Few important notes:**\n",
    "1. `pipe()` should be in the autocast context manager `with torch.cuda.amp.autocast():`\n",
    "2. `forced_decoder_ids` specifying the `language` being transcribed should be provided in `generate_kwargs` dict.\n",
    "3. You will get warning along the below lines which is **safe to ignore**.\n",
    "```\n",
    "The model 'PeftModel' is not supported for . Supported models are ['SpeechEncoderDecoderModel', 'Speech2TextForConditionalGeneration', 'SpeechT5ForSpeechToText', 'WhisperForConditionalGeneration', 'Data2VecAudioForCTC', 'HubertForCTC', 'MCTCTForCTC', 'SEWForCTC', 'SEWDForCTC', 'UniSpeechForCTC', 'UniSpeechSatForCTC', 'Wav2Vec2ForCTC', 'Wav2Vec2ConformerForCTC', 'WavLMForCTC'].\n",
    "\n",
    "```"
   ],
   "id": "3630e3da345c6713"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "import torch\n",
    "import gradio as gr\n",
    "from transformers import (\n",
    "    AutomaticSpeechRecognitionPipeline,\n",
    "    WhisperForConditionalGeneration,\n",
    "    WhisperTokenizer,\n",
    "    WhisperProcessor,\n",
    ")\n",
    "from peft import PeftModel, PeftConfig\n",
    "\n",
    "peft_model_id = \"smangrul/openai-whisper-large-v2-LORA-colab\"\n",
    "language = \"Marathi\"\n",
    "task = \"transcribe\"\n",
    "peft_config = PeftConfig.from_pretrained(peft_model_id)\n",
    "model = WhisperForConditionalGeneration.from_pretrained(\n",
    "    peft_config.base_model_name_or_path, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map=\"auto\"\n",
    ")\n",
    "\n",
    "model = PeftModel.from_pretrained(model, peft_model_id)\n",
    "tokenizer = WhisperTokenizer.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "processor = WhisperProcessor.from_pretrained(peft_config.base_model_name_or_path, language=language, task=task)\n",
    "feature_extractor = processor.feature_extractor\n",
    "forced_decoder_ids = processor.get_decoder_prompt_ids(language=language, task=task)\n",
    "pipe = AutomaticSpeechRecognitionPipeline(model=model, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
    "\n",
    "\n",
    "def transcribe(audio):\n",
    "    with torch.cuda.amp.autocast():\n",
    "        text = pipe(audio, generate_kwargs={\"forced_decoder_ids\": forced_decoder_ids}, max_new_tokens=255)[\"text\"]\n",
    "    return text\n",
    "\n",
    "\n",
    "iface = gr.Interface(\n",
    "    fn=transcribe,\n",
    "    inputs=gr.Audio(source=\"microphone\", type=\"filepath\"),\n",
    "    outputs=\"text\",\n",
    "    title=\"PEFT LoRA + INT8 Whisper Large V2 Marathi\",\n",
    "    description=\"Realtime demo for Marathi speech recognition using `PEFT-LoRA+INT8` fine-tuned Whisper Large V2 model.\",\n",
    ")\n",
    "\n",
    "iface.launch(share=True)"
   ],
   "id": "c35459db646a6265"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": "",
   "id": "b13ce8f94f964044"
  }
 ],
 "metadata": {},
 "nbformat": 5,
 "nbformat_minor": 9
}
