from transformers import AutoConfig, AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch

# Define model name or path to local checkpoint
model_name_or_path = "your-model-name-or-path"

# STEP 1: Initialize model on CPU with empty weights
with init_empty_weights():
    config = AutoConfig.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_config(config)

# STEP 2: Load model onto GPU using device_map='auto'
def load_model_onto_gpu(model, model_name_or_path):
    model = load_checkpoint_and_dispatch(
        model,
        checkpoint=model_name_or_path,
        device_map="auto",  # or a custom map like {"": "cuda:0"}
        torch_dtype=torch.float16,  # or "torch.bfloat16" for bfloat16 support
        # no_split_module_classes=["Block"]  # optional, depending on model
    )

# ... use the model on GPU ...

# STEP 3: Offload model to CPU
# You must reinitialize and reload with device_map="cpu"
with init_empty_weights():
    model_cpu = AutoModelForCausalLM.from_config(config)

model_cpu = load_checkpoint_and_dispatch(
    model_cpu,
    checkpoint=model_name_or_path,
    device_map="cpu"
)

# ... model is now on CPU ...

# STEP 4: Reload model back to GPU later
with init_empty_weights():
    model_gpu = AutoModelForCausalLM.from_config(config)

model_gpu = load_checkpoint_and_dispatch(
    model_gpu,
    checkpoint=model_name_or_path,
    device_map="auto"
)
