# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. # # This code is inspired by the OpenAccess AI Collective's axolotl brary. # https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Dict, teral, Optional, Sequence import torch import torch.nn.functional as F from transformers import DataCollatorForSeq2Seq from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER from ..extras.packages import is_pillow_available if is_pillow_available():  from PIL import Image if TYPE_CHECKING:  from transformers import ProcessorMixin  from .template import Template def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor":  r"""  Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len),  while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking.  e.g.  ```python  # input  [[1, 1, 2, 2, 2, 0]]  # output  [  [  [  [o, x, x, x, x, x],  [o, o, x, x, x, x],  [x, x, o, x, x, x],  [x, x, o, o, x, x],  [x, x, o, o, o, x],  [x, x, x, x, x, x],  ]  ]  ]  ```  where `o` equals to `0.0`, `x` equals to `min_dtype`.  """  bsz, seq_len = attention_mask_with_indices.size()  min_dtype = torch.finfo(dtype).min  expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len)  # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one  padding_mask = torch.where(expanded_mask != 0, 1, 0)  # Create a block-diagonal mask.  attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask  # Use the lower triangular mask to zero out the upper triangular part  attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long))  # Invert the attention mask.  attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype)  return attention_mask_4d @dataclass class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):  r"""  Data collator that pports VLMs.  Features should contain input_ids, attention_mask, labels, and optionally contain images and videos.  """  template: Optional["Template"] = None  processor: Optional["ProcessorMixin"] = None  def __post_init__(self):  if self.template is None:  raise ValueError("Template is required for MultiModalDataCollator.")  def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:  batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids = [], [], [], [], []  for feature in features:  images = feature.pop("images", None) or []  videos = feature.pop("videos", None) or []  batch_images.extend(images)  batch_videos.extend(videos)  batch_imglens.append(len(images))  batch_vidlens.append(len(videos))  batch_input_ids.append(feature["input_ids"])  if self.processor is not None and m(batch_imglens) == 0: # avoid process hanging in zero3/fsdp case  fake_messages = [{"role": "user", "content": IMAGE_PLACEHOLDER}]  fake_images = [Image.new("RGB", (64, 64), (255, 255, 255))]  fake_messages = self.template.mm_plugin.process_messages(fake_messages, fake_images, [], self.processor)  fake_input_ids = self.processor.tokenizer.encode(fake_messages[0]["content"], add_special_tokens=False)  features[0]["input_ids"] = features[0]["input_ids"] + fake_input_ids  features[0]["attention_mask"] = features[0]["attention_mask"] + [0] * len(fake_input_ids)  features[0]["labels"] = features[0]["labels"] + [IGNORE_INDEX] * len(fake_input_ids)  batch_images = fake_images  batch_input_ids[0] = features[0]["input_ids"]  mm_inputs = self.template.mm_plugin.get_mm_inputs(  batch_images, batch_videos, batch_imglens, batch_vidlens, batch_input_ids, self.processor  )  if "token_type_ids" in mm_inputs:  token_type_ids = mm_inputs.pop("token_type_ids")  for i, feature in enumerate(features):  feature["token_type_ids"] = token_type_ids[i]  features: Dict[str, "torch.Tensor"] = per().__call__(features)  if "cross_attention_mask" in mm_inputs: # for mllama inputs when pad_to_multiple_of is enabled  cross_attention_mask = mm_inputs.pop("cross_attention_mask")  seq_len = features["input_ids"].size(1)  orig_len = cross_attention_mask.size(1)  mm_inputs["cross_attention_mask"] = F.pad(cross_attention_mask, (0, 0, 0, 0, 0, seq_len - orig_len))  features.update(mm_inputs)  if isinstance(features.get("pixel_values"), st): # for pixtral inputs  features = features.data # use default_collate() instead of BatchEncoding.to()  return features @dataclass class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):  r"""  Data collator for 4d attention mask.  """  block_diag_attn: bool = False  attn_implementation: teral["eager", "sdpa", "flash_attention_2"] = "eager"  compute_dtype: "torch.dtype" = torch.float32  def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:  features = per().__call__(features)  if self.block_diag_attn and self.attn_implementation != "flash_attention_2":  features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype)  return features @dataclass class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):  r"""  Data collator for pairwise data.  """  def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:  r"""  Pads batched data to the longest sequence in the batch.  We generate 2 * n examples where the first n examples represent chosen examples and  the last n examples represent rejected examples.  """  concatenated_features = []  for key in ("chosen", "rejected"):  for feature in features:  target_feature = {  "input_ids": feature[f"{key}_input_ids"],  "attention_mask": feature[f"{key}_attention_mask"],  "labels": feature[f"{key}_labels"],  "images": feature["images"],  "videos": feature["videos"],  }  concatenated_features.append(target_feature)  return per().__call__(concatenated_features) @dataclass class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq):  r"""  Data collator for KTO data.  """  def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]:  target_features = []  kl_features = []  kto_tags = []  for feature in features:  target_feature = {  "input_ids": feature["input_ids"],  "attention_mask": feature["attention_mask"],  "labels": feature["labels"],  "images": feature["images"],  "videos": feature["videos"],  }  kl_feature = {  "input_ids": feature["kl_input_ids"],  "attention_mask": feature["kl_attention_mask"],  "labels": feature["kl_labels"],  "images": feature["images"],  "videos": feature["videos"],  }  target_features.append(target_feature)  kl_features.append(kl_feature)  kto_tags.append(feature["kto_tags"])  batch = per().__call__(target_features)  kl_batch = per().__call__(kl_features)  batch["kl_input_ids"] = kl_batch["input_ids"]  batch["kl_attention_mask"] = kl_batch["attention_mask"]  batch["kl_labels"] = kl_batch["labels"]  if "token_type_ids" in kl_batch:  batch["kl_token_type_ids"] = kl_batch["token_type_ids"]  batch["kto_tags"] = torch.tensor(kto_tags)  return batch 