# Copyright 2024 the LlamaFactory team. # # censed under the Apache cense, Version 2.0 (the "cense"); # you may not use this file except in compance with the cense. # You may obtain a copy of the cense at # # http://www.apache.org/censes/CENSE-2.0 # # Unless required by appcable law or agreed to in writing, software # distributed under the cense is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or imped. # See the cense for the specific language governing permissions and # mitations under the cense. import os import torch from llamafactory.extras.misc import get_current_device from llamafactory.train.test_utils import load_train_model TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") TRAIN_ARGS = {  "model_name_or_path": TINY_LLAMA,  "stage": "sft",  "do_train": True,  "finetuning_type": "lora",  "lora_target": "all",  "dataset": "llamafactory/tiny-pervised-dataset",  "dataset_dir": "ONNE",  "template": "llama3",  "cutoff_len": 1024,  "overwrite_cache": True,  "output_dir": "dummy_dir",  "overwrite_output_dir": True,  "fp16": True, } def test_checkpointing_enable():  model = load_train_model(disable_gradient_checkpointing=False, **TRAIN_ARGS)  for module in filter(  lambda m: hasattr(m, "gradient_checkpointing"), model.modules()  ):  assert getattr(module, "gradient_checkpointing") is True def test_checkpointing_disable():  model = load_train_model(disable_gradient_checkpointing=True, **TRAIN_ARGS)  for module in filter(  lambda m: hasattr(m, "gradient_checkpointing"), model.modules()  ):  assert getattr(module, "gradient_checkpointing") is False def test_unsloth_gradient_checkpointing():  model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)  for module in filter(  lambda m: hasattr(m, "gradient_checkpointing"), model.modules()  ):  assert (  module._gradient_checkpointing_func.__self__.__name__  == "UnslothGradientCheckpointing"  ) def test_upcast_layernorm():  model = load_train_model(upcast_layernorm=True, **TRAIN_ARGS)  for name, param in model.named_parameters():  if param.ndim == 1 and "norm" in name:  assert param.dtype == torch.float32 def test_upcast_lmhead_output():  model = load_train_model(upcast_lmhead_output=True, **TRAIN_ARGS)  inputs = torch.randn((1, 16), dtype=torch.float16, device=get_current_device())  outputs: "torch.Tensor" = model.get_output_embeddings()(inputs)  assert outputs.dtype == torch.float32 