import torch
import os

checkpoint_path = "/root/checkpoints/my_training/global_step2000/"
model_state_filename = "mp_rank_00_model_states.pt"
full_path = os.path.join(checkpoint_path, model_state_filename)

print(f"Loading checkpoint from: {full_path}")
try:
    state_dict = torch.load(full_path, map_location="cpu")
    # 打印 'module' 键下的所有键
    if 'module' in state_dict:
        print("Keys in state_dict['module']:")
        for key in state_dict['module'].keys():
            print(key)
    else:
        print("No 'module' key found in the top level state_dict.")
except Exception as e:
    print(f"Error loading checkpoint or printing keys: {e}")

