{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Latex-OCR SFT\n",
    "\n",
    "Here is a demonstration of using python to perform Latex-OCR SFT of Qwen2-VL-2B-Instruct. Through this tutorial, you can quickly understand some details of swift sft, which will be of great help in customizing ms-swift for you~\n",
    "\n",
    "Are you ready? Let's begin the journey..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "vscode": {
     "languageId": "shellscript"
    }
   },
   "outputs": [],
   "source": [
    "# # install ms-swift\n",
    "# pip install ms-swift -U"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import some libraries\n",
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "\n",
    "from swift.llm import (\n",
    "    get_model_tokenizer, load_dataset, get_template, EncodePreprocessor, get_model_arch,\n",
    "    get_multimodal_target_regex, LazyLLMDataset\n",
    ")\n",
    "from swift.utils import get_logger, get_model_parameter_info, plot_images, seed_everything\n",
    "from swift.tuners import Swift, LoraConfig\n",
    "from swift.trainers import Seq2SeqTrainer, Seq2SeqTrainingArguments\n",
    "from functools import partial\n",
    "\n",
    "logger = get_logger()\n",
    "seed_everything(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters for training\n",
    "# model\n",
    "model_id_or_path = 'Qwen/Qwen2-VL-2B-Instruct'\n",
    "system = None  # Using the default system defined in the template.\n",
    "output_dir = 'output'\n",
    "\n",
    "# dataset\n",
    "dataset = ['AI-ModelScope/LaTeX_OCR#20000']  # dataset_id or dataset_path. Sampling 20000 data points\n",
    "data_seed = 42\n",
    "max_length = 2048\n",
    "split_dataset_ratio = 0.01  # Split validation set\n",
    "num_proc = 4  # The number of processes for data loading.\n",
    "\n",
    "# lora\n",
    "lora_rank = 8\n",
    "lora_alpha = 32\n",
    "freeze_llm = False\n",
    "freeze_vit = True\n",
    "freeze_aligner = True\n",
    "\n",
    "# training_args\n",
    "training_args = Seq2SeqTrainingArguments(\n",
    "    output_dir=output_dir,\n",
    "    learning_rate=1e-4,\n",
    "    per_device_train_batch_size=1,\n",
    "    per_device_eval_batch_size=1,\n",
    "    gradient_checkpointing=True,\n",
    "    weight_decay=0.1,\n",
    "    lr_scheduler_type='cosine',\n",
    "    warmup_ratio=0.05,\n",
    "    report_to=['tensorboard'],\n",
    "    logging_first_step=True,\n",
    "    save_strategy='steps',\n",
    "    save_steps=50,\n",
    "    eval_strategy='steps',\n",
    "    eval_steps=50,\n",
    "    gradient_accumulation_steps=16,\n",
    "    # To observe the training results more quickly, this is set to 1 here. \n",
    "    # Under normal circumstances, a larger number should be used.\n",
    "    num_train_epochs=1,\n",
    "    metric_for_best_model='loss',\n",
    "    save_total_limit=5,\n",
    "    logging_steps=5,\n",
    "    dataloader_num_workers=4,\n",
    "    data_seed=data_seed,\n",
    "    remove_unused_columns=False,\n",
    ")\n",
    "\n",
    "output_dir = os.path.abspath(os.path.expanduser(output_dir))\n",
    "logger.info(f'output_dir: {output_dir}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Obtain the model and template\n",
    "model, processor = get_model_tokenizer(model_id_or_path)\n",
    "logger.info(f'model_info: {model.model_info}')\n",
    "template = get_template(model.model_meta.template, processor, default_system=system, max_length=max_length)\n",
    "template.set_mode('train')\n",
    "\n",
    "# Get target_modules and add trainable LoRA modules to the model.\n",
    "model_arch = get_model_arch(model.model_meta.model_arch)\n",
    "target_modules = get_multimodal_target_regex(model_arch, freeze_llm=freeze_llm, freeze_vit=freeze_vit, \n",
    "                            freeze_aligner=freeze_aligner)\n",
    "lora_config = LoraConfig(task_type='CAUSAL_LM', r=lora_rank, lora_alpha=lora_alpha,\n",
    "                         target_modules=target_modules)\n",
    "model = Swift.prepare_model(model, lora_config)\n",
    "logger.info(f'lora_config: {lora_config}')\n",
    "\n",
    "# Print model structure and trainable parameters.\n",
    "logger.info(f'model: {model}')\n",
    "model_parameter_info = get_model_parameter_info(model)\n",
    "logger.info(f'model_parameter_info: {model_parameter_info}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download and load the dataset, split it into a training set and a validation set,\n",
    "# and encode the text data into tokens.\n",
    "train_dataset, val_dataset = load_dataset(dataset, split_dataset_ratio=split_dataset_ratio, num_proc=num_proc,\n",
    "                                          seed=data_seed)\n",
    "\n",
    "logger.info(f'train_dataset: {train_dataset}')\n",
    "logger.info(f'val_dataset: {val_dataset}')\n",
    "logger.info(f'train_dataset[0]: {train_dataset[0]}')\n",
    "\n",
    "train_dataset = LazyLLMDataset(train_dataset, template.encode, random_state=data_seed)\n",
    "val_dataset = LazyLLMDataset(val_dataset, template.encode, random_state=data_seed)\n",
    "data = train_dataset[0]\n",
    "logger.info(f'encoded_train_dataset[0]: {data}')\n",
    "\n",
    "template.print_inputs(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the trainer and start the training.\n",
    "model.enable_input_require_grads()  # Compatible with gradient checkpointing\n",
    "trainer = Seq2SeqTrainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=template.data_collator,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=val_dataset,\n",
    "    template=template,\n",
    ")\n",
    "trainer.train()\n",
    "\n",
    "last_model_checkpoint = trainer.state.last_model_checkpoint\n",
    "logger.info(f'last_model_checkpoint: {last_model_checkpoint}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize the training loss.\n",
    "# You can also use the TensorBoard visualization interface during training by entering\n",
    "# `tensorboard --logdir '{output_dir}/runs'` at the command line.\n",
    "images_dir = os.path.join(output_dir, 'images')\n",
    "logger.info(f'images_dir: {images_dir}')\n",
    "plot_images(images_dir, training_args.logging_dir, ['train/loss'], 0.9)  # save images\n",
    "\n",
    "# Read and display the image.\n",
    "# The light yellow line represents the actual loss value,\n",
    "# while the yellow line represents the loss value smoothed with a smoothing factor of 0.9.\n",
    "from IPython.display import display\n",
    "from PIL import Image\n",
    "image = Image.open(os.path.join(images_dir, 'train_loss.png'))\n",
    "display(image)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py310",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
