{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xTG2AJZT89qk"
      },
      "source": [
        "## Fine-tune Video-LLaVa on CinePile dataset\n",
        "\n",
        "In this notebook, we are going to fine-tune the [Video-LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava) model on CinePile dataset which is a question-answering-based, long-form video understanding dataset. \n",
        "\n",
        "Video-LLaVa is an open-source multimodal model that can accept both, images and videos as input in an interleaved manner. The model architecture is pretty much similar to [LLaVa](https://huggingface.co/docs/transformers/main/en/model_doc/llava).\n",
        "\n",
        "The goal for the model in this notebook is to answer given multiple choice questions basedd on the video. The questions can be realetd to temporal aspects, character and relationship dynamics, narrative and plot analysis or theme exploration.\n",
        "\n",
        "Sources:\n",
        "\n",
        "* Video-LLaVa [documentation](https://huggingface.co/docs/transformers/main/en/model_doc/video_llava)\n",
        "* Video-LLaVa [checkpoint on the hub that we use for fine-tuning](https://huggingface.co/LanguageBind/Video-LLaVA-7B-hf)\n",
        "\n",
        "**Note: this notebook is a direct adaptation of Niels' [LLaVa notebook](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/LLaVa/Fine_tune_LLaVa_on_a_custom_dataset_(with_PyTorch_Lightning).ipynb).**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Pre-requisites\n",
        "\n",
        "This notebook assumes that you have downloaded the videos pointed in the CinePile dataset and those are accessible in a local folder.\n",
        "We used [Video2Dataset](https://github.com/iejMac/video2dataset) for this. Our YAML config file (video2dataset-config.yaml) and links (dataset.csv) can be found in this repo."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Szf17AKL89qm"
      },
      "source": [
        "## Define variables\n",
        "\n",
        "We'll first set some variables useful througout this notebook and do all the necessary imports."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LJtnWc3b89qn",
        "outputId": "8306b1f9-be6b-4083-f1c7-429a21984f96"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import json\n",
        "import av\n",
        "import re\n",
        "import bisect\n",
        "import numpy as np\n",
        "import wandb\n",
        "import datetime\n",
        "import cv2\n",
        "\n",
        "from transformers import BitsAndBytesConfig, VideoLlavaForConditionalGeneration, VideoLlavaProcessor\n",
        "from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model\n",
        "\n",
        "import torch\n",
        "from torch.utils.data import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "from datasets import load_dataset, concatenate_datasets, load_from_disk\n",
        "\n",
        "import lightning as L\n",
        "from lightning.pytorch.callbacks.early_stopping import EarlyStopping\n",
        "from lightning.pytorch.callbacks import Callback\n",
        "from lightning.pytorch.profilers import SimpleProfiler\n",
        "\n",
        "\n",
        "NUM_FRAMES_VIDEO = 8\n",
        "MAX_LENGTH_PROCESSOR=2048 \n",
        "\n",
        "MODEL_ID = \"LanguageBind/Video-LLaVA-7B-hf\"\n",
        "\n",
        "#Path to the download folder of Video2Dataset\n",
        "VIDEO_SNAPSHOT_PATH = \"/path/to/cinepile/fulldatasetvideoscenes/\"\n",
        "\n",
        "#Base path for temporary files and model snapshots\n",
        "LOCAL_PATH = \"/path/to/video-llava-data-cinepile/\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Auxiliar video processing functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yVdfSXMGlvDz"
      },
      "outputs": [],
      "source": [
        "def resize_and_crop(img, target_height=224):\n",
        "    img = img.to_ndarray(format=\"rgb24\")\n",
        "    if img is None or not isinstance(img, np.ndarray):\n",
        "        raise ValueError(\"Input image is not a valid NumPy array.\")\n",
        "\n",
        "    # Ensure the image is in uint8 format\n",
        "    if img.dtype != np.uint8:\n",
        "        img = np.clip(img, 0, 255).astype(np.uint8)\n",
        "\n",
        "    height, width = img.shape[:2]\n",
        "    \n",
        "    if height <= 0 or width <= 0:\n",
        "        raise ValueError(f\"Image dimensions are invalid: {height}x{width}.\")\n",
        "    \n",
        "    # Calculate the new width while maintaining aspect ratio\n",
        "    aspect_ratio = width / height\n",
        "    new_width = int(target_height * aspect_ratio)\n",
        "    \n",
        "    # Resize image\n",
        "    try:\n",
        "        resized_img = cv2.resize(img, (new_width, target_height))\n",
        "    except cv2.error as e:\n",
        "        raise RuntimeError(f\"Error resizing image: {e}\")\n",
        "    \n",
        "    # Crop to make width same as height\n",
        "    if new_width < target_height:\n",
        "        raise ValueError(f\"Resized width {new_width} is smaller than target height {target_height}.\")\n",
        "    \n",
        "    start_x = (new_width - target_height) // 2\n",
        "    cropped_img = resized_img[:, start_x:start_x + target_height]\n",
        "    \n",
        "    return cropped_img\n",
        "\n",
        "def read_equidistant_frames_pyav(video_path, num_frames):\n",
        "    \"\"\"Reads a video for given start-end timestamps interval and uniformly samples num+frames of it\"\"\"\n",
        "    container = av.open(video_path)\n",
        "    video = container.streams.get(0)[0]\n",
        "\n",
        "    av_timestamps = [\n",
        "        int(packet.pts * video.time_base) for packet in container.demux(video) if packet.pts is not None\n",
        "    ]\n",
        "\n",
        "    av_timestamps.sort()\n",
        "    start_id = bisect.bisect_left(av_timestamps, 1)\n",
        "    end_id = bisect.bisect_left(av_timestamps, 1e10)\n",
        "\n",
        "    # in case it is a very short video, lets take a longer duration and sample\n",
        "    if end_id  - start_id < 10:\n",
        "        end_id += 10\n",
        "        start_id -= 10\n",
        "\n",
        "    end_id = min(len(av_timestamps) - 1, end_id)\n",
        "    start_id = max(1, start_id)\n",
        "    indices = np.linspace(start_id, end_id, num_frames).astype(int)\n",
        "\n",
        "    frames = []\n",
        "    container.seek(0)\n",
        "    for i, frame in enumerate(container.decode(video=0)):\n",
        "        if i > end_id:\n",
        "            break\n",
        "        if i >= start_id and i in indices:\n",
        "            frames.append(resize_and_crop(frame))\n",
        "    assert len(frames) == num_frames, f\"Got {len(frames)} frames but should be {num_frames}. Check the indices: {indices};, start_id: {start_id}, end_id: {end_id}. Len of video is {len(av_timestamps)} frames.\"\n",
        "    return np.stack(frames)\n",
        "\n",
        "\n",
        "def read_specific_frames_pyav(video_path, frames):\n",
        "    frames_out = []\n",
        "    container = av.open(video_path)\n",
        "    video = container.streams.get(0)[0]\n",
        "    container.seek(0)\n",
        "\n",
        "    for idx, frame in enumerate(container.decode(video=0)):\n",
        "        if idx in frames:\n",
        "            frames_out.append(resize_and_crop(frame))\n",
        "    \n",
        "    assert len(frames) == len(frames_out), f\"{video_path}: Got {len(frames_out)} frames extracted but should be {len(frames)}\"\n",
        "    return np.stack(frames_out)\n",
        "\n",
        "def get_frames_for_video(path, cuts, num_frames):\n",
        "    num_segments = len(cuts)\n",
        "    if num_segments < num_frames:\n",
        "        return read_equidistant_frames_pyav(path, num_frames)\n",
        "    else:\n",
        "        step = num_segments / num_frames\n",
        "        frame_starts = []\n",
        "        for i in range(num_frames):\n",
        "            index = int(i * step)\n",
        "            frame_starts.append(cuts[index][0])\n",
        "\n",
        "        return read_specific_frames_pyav(path, frame_starts)\n",
        "\n",
        "def collate_read_video(example, lookup, num_frames):\n",
        "    # Some datasets have a start-end interval, so we try to get it if exists. Otherwise just set a very large end timestamp\n",
        "    if example['yt_clip_link'] not in lookup:\n",
        "        example['clip'] = None\n",
        "        return example\n",
        "    \n",
        "    if lookup[example['yt_clip_link']]['precalculation'] is None:\n",
        "        clip = get_frames_for_video(lookup[example['yt_clip_link']]['path'],\n",
        "                        lookup[example['yt_clip_link']]['cuts'],\n",
        "                        num_frames)\n",
        "        lookup[example['yt_clip_link']]['precalculation'] = clip\n",
        "    \n",
        "    example['clip'] = lookup[example['yt_clip_link']]['precalculation']\n",
        "    return example\n",
        "\n",
        "def filter_no_video(example, lookup):\n",
        "    if example['yt_clip_link'] not in lookup:\n",
        "        return False\n",
        "    return True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Dataset preparation\n",
        "In this section, we combine the metadata from CinePile with the frames extracted from the videos downloaded using Video2Dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "def generate_video_md_lookup_table(video_base_dir):\n",
        "    \"\"\"Assuming dataset stored in multiple subfolders (shards from video2dataset). We add to the lookup table only the cases where we have metadata and video\"\"\"  \n",
        "    lookup_table = {}\n",
        "\n",
        "    # Traverse the directory recursively\n",
        "    for root, _, files in os.walk(video_base_dir):\n",
        "        for file in files:\n",
        "            if file.endswith('.json'):\n",
        "                json_path = os.path.join(root, file)\n",
        "                video_path = json_path[:-5] + '.mp4'  # Assuming json_path always ends with .json\n",
        "\n",
        "                # Check if the corresponding .mp4 file exists\n",
        "                if os.path.exists(video_path):\n",
        "                    # Read the JSON file\n",
        "                    with open(json_path, 'r') as f:\n",
        "                        data = json.load(f)\n",
        "\n",
        "                    # Extract the cuts section from the JSON\n",
        "                    cuts = data['cuts']['cuts_original_fps']\n",
        "\n",
        "                    # Extract yt_clip_link from the JSON (assuming it's a unique identifier)\n",
        "                    yt_clip_link = data['url']\n",
        "\n",
        "                    # Store the result in the lookup table\n",
        "                    lookup_table[yt_clip_link] = {\"path\": video_path, \"cuts\": cuts, \"precalculation\":None}\n",
        "\n",
        "    return lookup_table"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yD9StuiIlvD0",
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "# Load each config and save in a mapping\n",
        "ds = load_from_disk(\"../dataset/cinepile\")\n",
        "lookup = generate_video_md_lookup_table(VIDEO_SNAPSHOT_PATH)\n",
        "num_processes = 8\n",
        "train_ds = ds['train']\n",
        "\n",
        "print(f\"Initial train dataset length: {len(train_ds)}\")\n",
        "dataset = train_ds.filter(lambda example: filter_no_video(example, lookup))\n",
        "print(f\"Dataset size after filtering non-video cases: {len(dataset)}\")\n",
        "\n",
        "#Sharding the mapping work\n",
        "num_blocks = 200\n",
        "# Calculate block size\n",
        "block_size = len(dataset) // num_blocks\n",
        "print(block_size)\n",
        "remainder = len(dataset) % num_blocks\n",
        "# Iterate through each block\n",
        "for i in range(num_blocks):\n",
        "    start_idx = i * block_size\n",
        "    end_idx = start_idx + block_size\n",
        "    \n",
        "    # For the last block, include any remaining samples\n",
        "    if i == num_blocks - 1:\n",
        "        end_idx = len(dataset)\n",
        "    \n",
        "    # Select the current shard\n",
        "    print(f\"Selecting between {start_idx}-{end_idx}\")\n",
        "    save_directory = LOCAL_PATH + f'/saved_datasets/cinepile/train_dataset_small/shard_{i}/'\n",
        "\n",
        "    # Useful in case of a crash while preparing the dataset\n",
        "    if os.path.exists(save_directory):\n",
        "        print(\"\\t Folder for that shard found. Skipping recalculation\")\n",
        "        continue\n",
        "\n",
        "    curr_shard = dataset.select(range(start_idx, end_idx))\n",
        "    curr_shard = curr_shard.map(\n",
        "        collate_read_video, \n",
        "        batched=False, \n",
        "        fn_kwargs={\"lookup\": lookup, \"num_frames\": NUM_FRAMES_VIDEO},\n",
        "        num_proc=num_processes,\n",
        "        writer_batch_size=10\n",
        "    )\n",
        "    \n",
        "    \n",
        "    curr_shard.save_to_disk(save_directory)\n",
        "    print(f\"Shard {i} - mapping complete\")\n",
        "    dataset.cleanup_cache_files()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Reload stored partitions\n",
        "num_blocks = 200\n",
        "sharded_dataset = []\n",
        "for i in range(num_blocks):\n",
        "    storage_directory = LOCAL_PATH + f'/saved_datasets/cinepile/train_dataset_small/shard_{i}/'\n",
        "    sharded_dataset.append(load_from_disk(storage_directory))\n",
        "dataset = concatenate_datasets(sharded_dataset)\n",
        "print(f\"load complete - len:{len(dataset)}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GRTrOYp6lvD0",
        "outputId": "7c90aeec-491a-4f00-de39-43aa0d8d2546"
      },
      "outputs": [],
      "source": [
        "processor = VideoLlavaProcessor.from_pretrained(MODEL_ID)\n",
        "processor.tokenizer.padding_side = \"right\" # during training, one always uses padding on the right"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rtxYp7h2lvD0"
      },
      "source": [
        "## Custom Dataset Class\n",
        "\n",
        "In the next step, we'll define a custom dataset class and the necessary functions to prepare our data for fine-tuning the Video-LLaVA model. The VideoLlavaDataset class extends the [PyTorch Dataset](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html) class to facilitate loading and processing \"MMBench\". This class will handle the conversion of dataset samples into the format required for training and evaluation by preparing a prompt and making array from videos.\n",
        "\n",
        "NOTE: Video-LLaVa accepts videos in one of the following formats:\n",
        "- an array or tensor of shape: (batch-size, frames, channel, height, width) where batch-size is an optional dimension\n",
        "- a list of arrays of shape: (frames, channel, height, width)\n",
        "- a nested list of video frames, where each frame is an image\n",
        "\n",
        "\n",
        "Next, we define collate functions to handle the batching of data during training and evaluation. These functions ensure that the input data is properly formatted and padded.\n",
        "\n",
        "It's only here that we're going to use the processor to turn the (video, target token sequence) into the format that the model expects (which is pixel_values, input_ids etc.). The reason we do that here is because it allows for dynamic padding of the batches: each batch contains ground truth sequences of varying lengths. By only using the processor here, we will pad the input_ids up to the largest sequence in the batch.\n",
        "\n",
        "We also decide to limit the length of the text tokens (input_ids) to a max length due to memory constraints, feel free to expand if your target token sequences are longer (I'd recommend plotting the average token length of your dataset to determine the optimal value).\n",
        "\n",
        "The formatting of the input_ids is super important: we need to respect a so-called [chat template](https://huggingface.co/docs/transformers/main/en/chat_templating). As of now, Video-LLaVa does not yet support chat templates, so we manually write down the prompt in the correct format (which starts with USER and ends with ASSISTANT).You could also omit this and just train the model on (video, instruction) pairs without text prompt.\n",
        "\n",
        "Labels are created for the model by simply copying the inputs to the LLM (input_ids), but with padding tokens replaced by the ignore index of the loss function. This ensures that the model doesn't need to learn to predict padding tokens (used to batch examples together).\n",
        "\n",
        "Why are the labels a copy of the model inputs, you may ask? The model will internally shift the labels one position to the right so that the model will learn to predict the next token. This can be seen here.\n",
        "\n",
        "The collate function for evaluation is different, since there we only need to feed the prompt to the model, as we'll use the `generate()` method to autoregressively generate a completion."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "un_H3-pPlvD0"
      },
      "outputs": [],
      "source": [
        "class VideoLlavaDataset(Dataset):\n",
        "    \"\"\"\n",
        "    PyTorch Dataset for VideoLlavaDataset. This class takes a HuggingFace Dataset as input.\n",
        "    \"\"\"\n",
        "    \n",
        "    def format_question_and_options(self, question, options):\n",
        "        \"\"\"\n",
        "        Formats a question and a list of options into a single string with options labeled A, B, C, etc.\n",
        "\n",
        "        Parameters:\n",
        "        - question (str): The question to be formatted.\n",
        "        - options (list of str): The options for the question.\n",
        "\n",
        "        Returns:\n",
        "        - str: The formatted question and options.\n",
        "        \"\"\"\n",
        "        formatted_string = f\"{question}\\n\"\n",
        "        option_labels = [chr(ord('A') + i) for i in range(len(options))]  # Generate option labels dynamically\n",
        "\n",
        "        for label, option in zip(option_labels, options):\n",
        "            formatted_string += f\"- {label}) {option}\\n\"\n",
        "\n",
        "        return formatted_string\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        dataset: str,\n",
        "    ):\n",
        "        super().__init__()\n",
        "        self.dataset = dataset\n",
        "        self.id2choice = {0: \"A\", 1: \"B\", 2: \"C\", 3: \"D\", 4: \"E\"}\n",
        "\n",
        "    def __len__(self) -> int:\n",
        "        return len(self.dataset)\n",
        "\n",
        "    def __getitem__(self, idx: int):\n",
        "        sample = self.dataset[idx]\n",
        "        clip = np.array(sample[\"clip\"])\n",
        "        vision_and_language_dependence_prompt = '''USER: <prompt>You will be provided with subtitles from a specific scene of a movie and a few frames from that scene. After going through the movie scene and seeing the frames, please answer the question that follows. The question will have five possible answers labeled A, B, C, D, and E, please try to provide the most probable answer in your opinion. Your output should be just one of A,B,C,D,E and nothing else.\n",
        "\n",
        "**Output Format:**\n",
        "    **Answer:** <Option_key>\n",
        "**Video:** <video>\\n\n",
        "**Subtitles:** \\n{subs}\\nQuestion: {question}\n",
        "\n",
        "Note: Follow the output format strictly. Only answer with the option key (A, B, C, D, E) and nothing else.\n",
        "ASSISTANT:{choice}'''\n",
        "\n",
        "        formatted_question = self.format_question_and_options(sample['question'], sample['choices'])\n",
        "        prompt = vision_and_language_dependence_prompt.format(subs=sample['subtitles'], question=formatted_question, choice=self.id2choice[sample['answer_key_position']])\n",
        "        return prompt, clip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AQ6rlxudlvD1"
      },
      "outputs": [],
      "source": [
        "def train_collate_fn(examples):\n",
        "    texts, videos = list(zip(*examples))\n",
        "    batch = processor(text=texts, videos=videos, padding=True,  truncation = True, max_length=MAX_LENGTH_PROCESSOR, return_tensors=\"pt\")    \n",
        "    labels = batch[\"input_ids\"].clone()\n",
        "    labels[labels == processor.tokenizer.pad_token_id] = -100\n",
        "    batch[\"labels\"] = labels\n",
        "    input_ids = batch[\"input_ids\"]\n",
        "    attention_mask = batch[\"attention_mask\"]\n",
        "    pixel_values_videos = batch[\"pixel_values_videos\"]\n",
        "    labels = batch[\"labels\"]\n",
        "\n",
        "    return input_ids, attention_mask, pixel_values_videos, labels\n",
        "\n",
        "\n",
        "def eval_collate_fn(examples):\n",
        "    now = datetime.now()\n",
        "    current_time = now.strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
        "    print(f\"{current_time}-inside eval\")\n",
        "\n",
        "    # We only feed the prompt to the model\n",
        "    textsOriginal, videos = list(zip(*examples))\n",
        "    texts = [text[:-2] for text in textsOriginal]\n",
        "    batch = processor(text=texts, videos=videos, padding=True, truncation = True, max_length=MAX_LENGTH_PROCESSOR, return_tensors=\"pt\")\n",
        "\n",
        "    input_ids = batch[\"input_ids\"]\n",
        "    attention_mask = batch[\"attention_mask\"]\n",
        "    pixel_values_videos = batch[\"pixel_values_videos\"]\n",
        "    answer_choice = [text[-1] for text in textsOriginal]\n",
        "    return input_ids, attention_mask, pixel_values_videos, answer_choice"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0sH2oWArlvD1"
      },
      "source": [
        "## Combining and Splitting the Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DenEK2IqlvD1"
      },
      "outputs": [],
      "source": [
        "dataset = dataset.shuffle(seed=42)\n",
        "dataset = dataset.train_test_split(test_size=0.2)\n",
        "train_dataset = VideoLlavaDataset(dataset[\"train\"])\n",
        "eval_dataset = VideoLlavaDataset(dataset[\"test\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pd4mrynzlvD2"
      },
      "source": [
        "## Load model\n",
        "Next, we're going to load the Video-LLaVa model from the hub. This is a model with about 7 billion trainable parameters (as it combines a LLaMa-7B language model with a relatively low-parameter vision encoder). Do note that we load a model here which already has undergone supervised fine-tuning (SFT) on VideoChat instruction dataset. We can benefit from the fine-tuning that the model already has undergone.\n",
        "\n",
        "## Q-LoRa\n",
        "As this model has 7 billion trainable parameters, that's going to have quite an impact on the amount of memory used. For reference, fine-tuning a model using the AdamW optimizer (which is often used to optimize neural networks) with mixed precision, you need about 18 times the amount of parameters in GB of GPU RAM. So in this case, we would need 18x7 billion bytes = 126 GB of GPU RAM if we want to update all the parameters of the model!! That's huge right? And for most people infeasible.\n",
        "\n",
        "Luckily, some clever people came up with the LoRa method (LoRa is short for low-rank adapation). It allows to just freeze the existing weights and only train a couple of adapter layers on top of the base model. Hugging Face offers the separate [PEFT library](https://huggingface.co/docs/peft/main/en/index) for easy use of LoRa, along with other Parameter-Efficient Fine-Tuning methods (that's where the name PEFT comes from).\n",
        "\n",
        "Moreover, one can not only freeze the existing base model but also quantize it (which means, shrinking down its size). A neural network's parameters are typically saved in either float32 (which means, 32 bits or 4 bytes are used to store each parameter value) or float16 (which means, 16 bits or half a byte - also called half precision). However, with some clever algorithms one can shrink each parameter to just 8 or 4 bits (half a byte!), without significant effect on final performance. Read all about it here: https://huggingface.co/blog/4bit-transformers-bitsandbytes.\n",
        "\n",
        "Of course, if you have the memory available, feel free to use full fine-tuning or LoRa without quantization! In case of full fine-tuning, the code snippet below instantiates the model with Flash Attention which considerably speeds up computations.\n",
        "\n",
        "There exist many forms of quantization, here we leverage the [BitsAndBytes integration](https://huggingface.co/docs/transformers/main_classes/quantization#transformers.BitsAndBytesConfig)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "referenced_widgets": [
            "0b428184283c4dcd8f2fa10ff3adbf04"
          ]
        },
        "id": "DQ0nTqbVlvD2",
        "outputId": "044885b1-6d44-46bb-e2c5-16038a34fdb2"
      },
      "outputs": [],
      "source": [
        "## Load model\n",
        "# QLoRA: model uses 4-bit quantization, which helps in reducing memory usage while maintaining performance.\n",
        "\n",
        "\n",
        "bnb_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_quant_type=\"nf4\",\n",
        "    bnb_4bit_compute_dtype=torch.float16,\n",
        ")\n",
        " \n",
        "model = VideoLlavaForConditionalGeneration.from_pretrained(\n",
        "    MODEL_ID,\n",
        "    torch_dtype=torch.float16,\n",
        "    quantization_config=bnb_config,\n",
        "    device_map=\"auto\",\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aNtOGpvplvD2"
      },
      "source": [
        "## Apply PEFT\n",
        "After loading the base model, we're going to add LoRa adapter layers. We're going to only train these adapter layers (the base model is kept frozen).\n",
        "\n",
        "The difference here with other models are the layers at which we're going to add adapters (in PEFT this is called target_modules). This typically depends a bit on the model.\n",
        "\n",
        "We defined a function to find all linear layers in the model, excluding any layers related to multimodal projections and vision models. This function will help us identify which layers should have LoRA applied. We're going to add adapters to all linear layers of the model (nn.Linear), except for the ones present in the vision encoder and multimodal projector. This means that we're mostly going to adapt the language model part of Video-LLaVa for our use case."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MYDW50LslvD2"
      },
      "outputs": [],
      "source": [
        "def find_all_linear_names(model):\n",
        "    cls = torch.nn.Linear\n",
        "    lora_module_names = set()\n",
        "    multimodal_keywords = ['multi_modal_projector', 'vision_model']\n",
        "    for name, module in model.named_modules():\n",
        "        if any(mm_keyword in name for mm_keyword in multimodal_keywords):\n",
        "            continue\n",
        "        if isinstance(module, cls):\n",
        "            names = name.split('.')\n",
        "            lora_module_names.add(names[0] if len(names) == 1 else names[-1])\n",
        "\n",
        "    if 'lm_head' in lora_module_names: # needed for 16-bit\n",
        "        lora_module_names.remove('lm_head')\n",
        "    return list(lora_module_names)\n",
        "\n",
        "\n",
        "lora_config = LoraConfig(\n",
        "    r=8,\n",
        "    lora_alpha=8,\n",
        "    lora_dropout=0.1,\n",
        "    target_modules=find_all_linear_names(model),\n",
        "    init_lora_weights=\"gaussian\",\n",
        ")\n",
        "\n",
        "model = prepare_model_for_kbit_training(model)\n",
        "model = get_peft_model(model, lora_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zuv04h-DlvD2"
      },
      "source": [
        "## Define PyTorch Lightning Module for Video-LLaVA\n",
        "To streamline the training and evaluation of the Video-LLaVA model, we use [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html), which abstracts away much of the boilerplate code and provides a structured framework for model training. In this section, we define the VideoLlavaModelPLModule, a custom PyTorch Lightning module that encapsulates the model, training loop, validation loop, and optimizer configuration.\n",
        "\n",
        "### VideoLlavaModelPLModule Class\n",
        "\n",
        "The VideoLlavaModelPLModule class inherits from LightningModule and includes methods for training, validation, and optimizer configuration. This setup ensures a clean and efficient training process.\n",
        "\n",
        "Basically, PyTorch Lightning will take care of all device placements (.to(device)) for us, as well as the backward pass, putting the model in training mode, etc.\n",
        "\n",
        "Notice the difference between a training step and an evaluation step:\n",
        "\n",
        "- a training step only consists of a forward pass, in which we compute the cross-entropy loss between the model's next token predictions and the ground truth (in parallel for all tokens, this technique is known as \"teacher forcing\"). The backward pass is handled by PyTorch Lightning.\n",
        "- an evaluation step consists of making the model autoregressively complete the prompt using the generate() method. After that, we compute an evaluation metric between the predicted sequences and the ground truth ones. This allows us to see how the model is improving over the course of training. The metric we use here is accuracy of answering the question.\n",
        "\n",
        "Besides that, we define the optimizer to use (AdamW is a good default choice) and the data loaders, which use the collate functions defined above to batch together items of the PyTorch datasets. Do note that AdamW is a pretty heavy optimizer in terms of memory requirements, but as we're training with QLoRa we only need to store optimizer states for the adapter layers. For full fine-tuning, one could take a look at more memory friendly optimizers such as 8-bit Adam."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "config = {\"max_epochs\": 5,\n",
        "          \"val_check_interval\": 0.2, # how often we want to validate during an epoch,\n",
        "          \"check_val_every_n_epoch\": 1,\n",
        "          \"gradient_clip_val\": 1.0,\n",
        "          \"accumulate_grad_batches\": 8,\n",
        "          \"lr\": 1e-3,\n",
        "          \"batch_size\": 1,\n",
        "          \"num_nodes\": 1,\n",
        "          \"warmup_steps\": 50,\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WvNdvJ1ylvD2"
      },
      "outputs": [],
      "source": [
        "class VideoLlavaModelPLModule(L.LightningModule):\n",
        "    def __init__(self, config, processor, model):\n",
        "        super().__init__()\n",
        "        self.config = config\n",
        "        self.processor = processor\n",
        "        self.model = model\n",
        "\n",
        "        self.batch_size = config.get(\"batch_size\")\n",
        "\n",
        "    def training_step(self, batch, batch_idx):\n",
        "\n",
        "        input_ids, attention_mask, pixel_values_videos, labels = batch\n",
        "\n",
        "        outputs = self.model(\n",
        "            input_ids=input_ids,\n",
        "            attention_mask=attention_mask,\n",
        "            pixel_values_videos=pixel_values_videos,\n",
        "            labels=labels\n",
        "        )\n",
        "        loss = outputs.loss\n",
        "\n",
        "        self.log(\"train_loss\", loss)\n",
        "\n",
        "        return loss\n",
        "\n",
        "    def validation_step(self, batch, batch_idx, dataset_idx=0):\n",
        "        with torch.no_grad():\n",
        "            MAX_NEW_TOKENS = 256\n",
        "            input_ids, attention_mask, pixel_values_videos, answers = batch\n",
        "\n",
        "            # autoregressively generate token IDs\n",
        "            generated_ids = self.model.generate(\n",
        "                input_ids=input_ids,\n",
        "                attention_mask=attention_mask,\n",
        "                pixel_values_videos=pixel_values_videos,\n",
        "                max_new_tokens=MAX_NEW_TOKENS,\n",
        "                do_sample=False,\n",
        "            )\n",
        "            # turn them back into text, chopping of the prompt\n",
        "            predictions = self.processor.batch_decode(generated_ids[:, input_ids.size(1):], skip_special_tokens=True)\n",
        "\n",
        "            correct = 0\n",
        "            for pred, answer in zip(predictions, answers):\n",
        "                correct += (pred.strip().lower() == answer.lower())\n",
        "\n",
        "            self.log(\"val_accuracy\", float(correct) / len(answers))\n",
        "\n",
        "\n",
        "            return correct\n",
        "\n",
        "    def configure_optimizers(self):\n",
        "        # you could also add a learning rate scheduler if you want\n",
        "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.get(\"lr\"))\n",
        "\n",
        "        return optimizer\n",
        "\n",
        "    def train_dataloader(self):\n",
        "        return DataLoader(train_dataset, collate_fn=train_collate_fn, batch_size=self.batch_size, shuffle=True, num_workers=3)\n",
        "\n",
        "    def val_dataloader(self):\n",
        "        return DataLoader(eval_dataset, collate_fn=eval_collate_fn, batch_size=self.batch_size, shuffle=False, num_workers=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u6iWFbnQlvD2"
      },
      "source": [
        "Let's instantiate it (based on a config dictionary which defines all hyperparameters for training).\n",
        "\n",
        "The batch size was determined based on the compute available.\n",
        "\n",
        "Do note that one can play around with the hyperparameters, I just use good defaults here: 10 epochs, a learning rate of 1e-4 which I found in the original Idefics2 notebook (linked at the top of this notebook), use mixed precision for training (more memory friendly). One could extend this with things like gradient accumulation and gradient checkpointing.\n",
        "\n",
        "I recommend [this guide](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one) which goes over all tips and tricks regarding maximizing fine-tuning performance on consumer hardware."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K2I7vJaDlvD2"
      },
      "outputs": [],
      "source": [
        "model_module = VideoLlavaModelPLModule(config, processor, model)\n",
        "early_stop_callback = EarlyStopping(monitor=\"val_accuracy\", patience=3, verbose=False, mode=\"min\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I-EncDAllvD2"
      },
      "source": [
        "## Define callbacks\n",
        "Optionally, Lightning allows to define so-called [callbacks](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), which are arbitrary pieces of code that can be executed during training.\n",
        "We will use them to store checkpoints of the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "from datetime import datetime\n",
        "\n",
        "class SaveModelCallback(Callback):\n",
        "    def on_train_epoch_end(self, trainer, pl_module):\n",
        "        now = datetime.now()\n",
        "        current_time = now.strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
        "        output_dir = f\"{LOCAL_PATH}/weights/{current_time}-checkpoint-{trainer.current_epoch}\"\n",
        "        pl_module.model.save_pretrained(output_dir)\n",
        "        print(f\"Model checkpoint saved at epoch {trainer.current_epoch} to {output_dir}\")\n",
        "    def on_train_end(self, trainer, pl_module):\n",
        "        now = datetime.now()\n",
        "        current_time = now.strftime(\"%Y-%m-%d_%H-%M-%S\")\n",
        "        output_dir = f\"{LOCAL_PATH}/weights/{current_time}-checkpoint-final\"\n",
        "        pl_module.model.save_pretrained(output_dir)\n",
        "        print(f\"Model checkpoint saved at the end of the training to {output_dir}\")\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S-ylKSrMlvD2"
      },
      "source": [
        "## Train!\n",
        "Alright, we're set to start training!\n",
        "\n",
        "Do note that this Trainer class supports many more flags! See the [docs](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "liZhW7NhlvD3",
        "outputId": "16ab8773-d40f-48f7-e485-b3960c2bea47"
      },
      "outputs": [],
      "source": [
        "trainer = L.Trainer(\n",
        "        accelerator=\"gpu\",\n",
        "        devices=1,\n",
        "        max_epochs=config.get(\"max_epochs\"),\n",
        "        accumulate_grad_batches=config.get(\"accumulate_grad_batches\"),\n",
        "        gradient_clip_val=config.get(\"gradient_clip_val\"),\n",
        "        precision=\"16-mixed\",\n",
        "        limit_val_batches=5,\n",
        "        num_sanity_val_steps=1,\n",
        "        callbacks=[early_stop_callback,SaveModelCallback()],\n",
        "        val_check_interval=config.get(\"val_check_interval\"),\n",
        "#        fast_dev_run=True,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "referenced_widgets": [
            "",
            "9c99e8a134844fc6bc734868514d92fc"
          ]
        },
        "id": "X1XGM6QflvD3",
        "outputId": "113a2fc3-d7b3-4140-b614-1b46e14b411e",
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "trainer.fit(model_module)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
