import torch
import numpy as np
import argparse
from ppo_fetch import Agent
import gymnasium as gym

def save_model_to_npz(model, filepath):
    """
    Save PyTorch model weights to NPZ format
    Args:
        model: PyTorch model
        filepath: Path to save the NPZ file
    """
    weights_dict = {}
    for name, param in model.state_dict().items():
        weights_dict[name] = param.cpu().numpy()
    np.savez(filepath, **weights_dict)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to PyTorch checkpoint file (.pt)")
    parser.add_argument("--output", type=str, help="Output NPZ file path (optional)")
    args = parser.parse_args()

    # Create a vectorized environment
    def make_env():
        env = gym.make("fetch-v0")
        env = gym.wrappers.FlattenObservation(env)
        return env

    envs = gym.vector.AsyncVectorEnv([make_env for _ in range(1)])
    
    # Initialize the agent with vectorized environment
    agent = Agent(envs)
    
    # Load the checkpoint
    checkpoint = torch.load(args.checkpoint, map_location=torch.device('cpu'))
    agent.load_state_dict(checkpoint)
    
    # Generate output filename if not provided
    if args.output is None:
        args.output = args.checkpoint.replace('.pt', '.npz')
    
    # Save to NPZ
    save_model_to_npz(agent, args.output)
    print(f"Converted {args.checkpoint} to {args.output}")

    # Clean up
    envs.close()

if __name__ == "__main__":
    main() 