import torch

# 1. Load the checkpoint file
ckpt_path = "p4rl_assets/inv_dynamics_new/1x_mlp_noise_free_350epochs.ckpt"
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))

# 2. Extract the model's state_dict (this might be under 'state_dict' key)
state_dict = checkpoint['state_dict']

# OPTIONAL: remove "model." or "net." prefix if needed
# For example, if keys look like "model.layer1.weight", do this:
cleaned_state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}

# 3. Save the cleaned state_dict to a .pt file
torch.save(cleaned_state_dict, "p4rl_assets/inv_dynamics_new/1x_mlp_noise_free_350epochs.pt")
print("Checkpoint converted and saved.")