# Now you can load the model normally with PyTorch or Hugging Face
import os
from transformers import AutoModelForCausalLM
import torch
model = AutoModelForCausalLM.from_pretrained("Llama-3.2-1B-Instruct")  # First load base architecture

from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict

# Path to the directory containing your DeepSpeed checkpoint files

checkpoint_dir = "Llama-3.2-1B-Instruct-SFT-DP_dist-EP8-EPOCH2-ACCS64"
tag = "global_step3436"

output_path = os.path.join(checkpoint_dir, "consolidated_model.pt")

# Convert the distributed checkpoint to a single consolidated checkpoint
state_dict = convert_zero_checkpoint_to_fp32_state_dict(
    checkpoint_dir,  # Directory containing bf16_zero_* and zero_pp_* files
    output_path,      # Output consolidated model path
    tag=tag
)

# model.load_state_dict(torch.load(output_path))  # Then load your fine-tuned weights
# print('Done with model loading')

# Check what keys are present/missing
model_keys = set(k for k, _ in model.named_parameters())
state_dict_keys = set(state_dict.keys())
missing_keys = model_keys - state_dict_keys
print(f"Missing keys: {missing_keys}")


if "lm_head.weight" not in state_dict:
    print("lm_head.weight not found. Using model.embed_tokens.weight as replacement.")
    state_dict["lm_head.weight"] = state_dict["model.embed_tokens.weight"]

model.load_state_dict(state_dict)  # Then load your fine-tuned weights
print('Done with model loading')
model.save_pretrained(checkpoint_dir)

