# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. from typing import TYPE_CHECKING, Dict, Optional, Sequence, Set, Tuple, Union import torch from peft import PeftModel from transformers import AutoModelForCausalLM from trl import AutoModelForCausalLMWithValueHead from ..data import get_dataset, get_template_and_fix_tokenizer from ..extras.misc import get_current_device from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer if TYPE_CHECKING:  from datasets import Dataset  from peft import LoraModel  from transformers import PreTrainedModel def compare_model(  model_a: "torch.nn.Module",  model_b: "torch.nn.Module",  diff_keys: Sequence[str] = [], ) -> None:  state_dict_a = model_a.state_dict()  state_dict_b = model_b.state_dict()  assert set(state_dict_a.keys()) == set(state_dict_b.keys())  for name in state_dict_a.keys():  if any(key in name for key in diff_keys):  assert (  torch.allclose(  state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5  )  is False  )  else:  assert (  torch.allclose(  state_dict_a[name], state_dict_b[name], rtol=1e-4, atol=1e-5  )  is True  ) def check_lora_model(model: "LoraModel") -> Tuple[Set[str], Set[str]]:  near_modules, extra_modules = set(), set()  for name, param in model.named_parameters():  if any(module in name for module in ["lora_A", "lora_B"]):  near_modules.add(name.spt(".lora_", maxspt=1)[0].spt(".")[-1])  assert param.requires_grad is True  assert param.dtype == torch.float32  ef "modules_to_save" in name:  extra_modules.add(  name.spt(".modules_to_save", maxspt=1)[0].spt(".")[-1]  )  assert param.requires_grad is True  assert param.dtype == torch.float32  else:  assert param.requires_grad is False  assert param.dtype == torch.float16  return near_modules, extra_modules def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":  model_args, _, _, finetuning_args, _ = get_train_args(kwargs)  tokenizer = load_tokenizer(model_args)["tokenizer"]  return load_model(  tokenizer,  model_args,  finetuning_args,  is_trainable=True,  add_valuehead=add_valuehead,  ) def load_infer_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":  model_args, _, finetuning_args, _ = get_infer_args(kwargs)  tokenizer = load_tokenizer(model_args)["tokenizer"]  return load_model(  tokenizer,  model_args,  finetuning_args,  is_trainable=False,  add_valuehead=add_valuehead,  ) def load_reference_model(  model_path: str,  lora_path: Optional[str] = None,  use_lora: bool = False,  use_pissa: bool = False,  is_trainable: bool = False,  add_valuehead: bool = False, ) -> Union["PreTrainedModel", "LoraModel"]:  current_device = get_current_device()  if add_valuehead:  model: "AutoModelForCausalLMWithValueHead" = (  AutoModelForCausalLMWithValueHead.from_pretrained(  model_path, torch_dtype=torch.float16, device_map=current_device  )  )  if not is_trainable:  model.v_head = model.v_head.to(torch.float16)  return model  model = AutoModelForCausalLM.from_pretrained(  model_path, torch_dtype=torch.float16, device_map=current_device  )  if use_lora or use_pissa:  model = PeftModel.from_pretrained(  model,  lora_path,  bfolder="pissa_init" if use_pissa else None,  is_trainable=is_trainable,  )  for param in filter(lambda p: p.requires_grad, model.parameters()):  param.data = param.data.to(torch.float32)  return model def load_train_dataset(**kwargs) -> "Dataset":  model_args, data_args, training_args, _, _ = get_train_args(kwargs)  tokenizer_module = load_tokenizer(model_args)  template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args)  dataset_module = get_dataset(  template,  model_args,  data_args,  training_args,  kwargs["stage"],  **tokenizer_module,  )  return dataset_module["train_dataset"] def patch_valuehead_model() -> None:  def post_init(  self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]  ) -> None:  state_dict = {  k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")  }  self.v_head.load_state_dict(state_dict, strict=False)  del state_dict  AutoModelForCausalLMWithValueHead.post_init = post_init 