# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils.versions import require_version from ...extras import logging if TYPE_CHECKING:  from transformers import PretrainedConfig  from ...hparams import ModelArguments logger = logging.get_logger(__name__) def configure_attn_implementation(  config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> None:  if getattr(config, "model_type", None) == "gemma2" and is_trainable:  if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":  if is_flash_attn_2_available():  require_version(  "transformers>=4.42.4", "To fix: pip install transformers>=4.42.4"  )  require_version(  "flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3"  )  if model_args.flash_attn != "fa2":  logger.warning_rank0(  "Gemma-2 should use flash attention 2, change `flash_attn` to fa2."  )  model_args.flash_attn = "fa2"  else:  logger.warning_rank0(  "FlashAttention-2 is not installed, use eager attention."  )  model_args.flash_attn = "disabled"  ef model_args.flash_attn == "sdpa":  logger.warning_rank0(  "Gemma-2 should use soft-capping attention, while the SDPA attention does not pport it."  )  if model_args.flash_attn == "auto":  return  ef model_args.flash_attn == "disabled":  requested_attn_implementation = "eager"  ef model_args.flash_attn == "sdpa":  if not is_torch_sdpa_available():  logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")  return  requested_attn_implementation = "sdpa"  ef model_args.flash_attn == "fa2":  if not is_flash_attn_2_available():  logger.warning_rank0("FlashAttention-2 is not installed.")  return  requested_attn_implementation = "flash_attention_2"  else:  raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")  if (  getattr(config, "model_type", None) == "internlm2"  ): # special case for custom models  setattr(config, "attn_implementation", requested_attn_implementation)  else:  setattr(config, "_attn_implementation", requested_attn_implementation) def print_attn_implementation(config: "PretrainedConfig") -> None:  if (  getattr(config, "model_type", None) == "internlm2"  ): # special case for custom models  attn_implementation = getattr(config, "attn_implementation", None)  else:  attn_implementation = getattr(config, "_attn_implementation", None)  if attn_implementation == "flash_attention_2":  logger.info_rank0("Using FlashAttention-2 for faster training and inference.")  ef attn_implementation == "sdpa":  logger.info_rank0("Using torch SDPA for faster training and inference.")  else:  logger.info_rank0("Using vanilla attention implementation.") 