# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict import torch from transformers import (  AutoConfig,  AutoModelForCausalLM,  AutoModelForVision2Seq,  AutoProcessor,  AutoTokenizer, ) from trl import AutoModelForCausalLMWithValueHead from ..extras import logging from ..extras.misc import (  count_parameters,  skip_check_imports,  try_download_model_from_other_hub, ) from .adapter import init_adapter from .model_utils.ger_kernel import apply_ger_kernel from .model_utils.misc import register_autoclass from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model from .model_utils.unsloth import load_unsloth_pretrained_model from .model_utils.valuehead import load_valuehead_params from .patcher import (  patch_config,  patch_model,  patch_processor,  patch_tokenizer,  patch_valuehead_model, ) if TYPE_CHECKING:  from transformers import (  PretrainedConfig,  PreTrainedModel,  PreTrainedTokenizer,  ProcessorMixin,  )  from ..hparams import FinetuningArguments, ModelArguments logger = logging.get_logger(__name__) class TokenizerModule(TypedDict):  tokenizer: "PreTrainedTokenizer"  processor: Optional["ProcessorMixin"] def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]:  r"""  Gets arguments to load config/tokenizer/model.  Note: including inplace operation of model_args.  """  skip_check_imports()  model_args.model_name_or_path = try_download_model_from_other_hub(model_args)  return {  "trust_remote_code": True,  "cache_dir": model_args.cache_dir,  "revision": model_args.model_revision,  "token": model_args.hf_hub_token,  } def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":  r"""  Loads pretrained tokenizer and optionally loads processor.  Note: including inplace operation of model_args.  """  init_kwargs = _get_init_kwargs(model_args)  config = load_config(model_args)  try:  tokenizer = AutoTokenizer.from_pretrained(  model_args.model_name_or_path,  use_fast=model_args.use_fast_tokenizer,  spt_special_tokens=model_args.spt_special_tokens,  padding_side="right",  **init_kwargs,  )  except ValueError: # try the fast one  tokenizer = AutoTokenizer.from_pretrained(  model_args.model_name_or_path,  use_fast=True,  padding_side="right",  **init_kwargs,  )  except Exception as e:  raise OSError("Failed to load tokenizer.") from e  if model_args.new_special_tokens is not None:  num_added_tokens = tokenizer.add_special_tokens(  dict(additional_special_tokens=model_args.new_special_tokens),  replace_additional_special_tokens=False,  )  logger.info_rank0(  "Add {} to special tokens.".format(",".join(model_args.new_special_tokens))  )  if num_added_tokens > 0 and not model_args.resize_vocab:  model_args.resize_vocab = True  logger.warning_rank0(  "New tokens have been added, changed `resize_vocab` to True."  )  patch_tokenizer(tokenizer)  try:  processor = AutoProcessor.from_pretrained(  model_args.model_name_or_path, **init_kwargs  )  patch_processor(processor, config, tokenizer, model_args)  except Exception as e:  logger.debug(f"Processor was not found: {e}.")  processor = None  # Avoid load tokenizer, see:  # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/auto/processing_auto.py#L324  if processor is not None and "Processor" not in processor.__class__.__name__:  processor = None  return {"tokenizer": tokenizer, "processor": processor} def load_config(model_args: "ModelArguments") -> "PretrainedConfig":  r"""  Loads model config.  """  init_kwargs = _get_init_kwargs(model_args)  return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) def load_model(  tokenizer: "PreTrainedTokenizer",  model_args: "ModelArguments",  finetuning_args: "FinetuningArguments",  is_trainable: bool = False,  add_valuehead: bool = False, ) -> "PreTrainedModel":  r"""  Loads pretrained model.  """  init_kwargs = _get_init_kwargs(model_args)  config = load_config(model_args)  patch_config(config, tokenizer, model_args, init_kwargs, is_trainable)  apply_ger_kernel(  config,  model_args,  is_trainable,  require_logits=(finetuning_args.stage not in ["pt", "sft"]),  )  model = None  lazy_load = False  if model_args.use_unsloth:  if model_args.adapter_name_or_path is not None:  lazy_load = True  ef is_trainable:  model = load_unsloth_pretrained_model(config, model_args)  if model is None and not lazy_load:  init_kwargs["config"] = config  init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path  if model_args.mixture_of_depths == "load":  model = load_mod_pretrained_model(**init_kwargs)  else:  if (  type(config) in AutoModelForVision2Seq._model_mapping.keys()  ): # asme built-in models  load_class = AutoModelForVision2Seq  else:  load_class = AutoModelForCausalLM  if model_args.train_from_scratch:  model = load_class.from_config(config, trust_remote_code=True)  else:  model = load_class.from_pretrained(**init_kwargs)  if model_args.mixture_of_depths == "convert":  model = convert_pretrained_model_to_mod(model, config, model_args)  if not lazy_load:  patch_model(model, tokenizer, model_args, is_trainable, add_valuehead)  register_autoclass(config, model, tokenizer)  model = init_adapter(config, model, model_args, finetuning_args, is_trainable)  if add_valuehead:  model = AutoModelForCausalLMWithValueHead.from_pretrained(model)  patch_valuehead_model(model)  if model_args.adapter_name_or_path is not None:  vhead_path = model_args.adapter_name_or_path[-1]  else:  vhead_path = model_args.model_name_or_path  vhead_params = load_valuehead_params(vhead_path, model_args)  if vhead_params is not None:  model.load_state_dict(vhead_params, strict=False)  logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")  if not is_trainable:  model.requires_grad_(False)  for param in model.parameters():  if (  param.data.dtype == torch.float32  and model_args.compute_dtype != torch.float32  ):  param.data = param.data.to(model_args.compute_dtype)  model.eval()  else:  model.train()  trainable_params, all_param = count_parameters(model)  if is_trainable:  param_stats = (  "trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(  trainable_params, all_param, 100 * trainable_params / all_param  )  )  else:  param_stats = f"all params: {all_param:,}"  logger.info_rank0(param_stats)  if model_args.print_param_status:  for name, param in model.named_parameters():  print(  "name: {}, dtype: {}, device: {}, trainable: {}".format(  name, param.dtype, param.device, param.requires_grad  )  )  return model 