"""
Deploy PETS with a pretrained dynamics model and a learnt reward model.
"""
import argparse
from torch import device, load
from torch.cuda import is_available
import gym, fastjet
from rlutils import build_params, make, deploy
from rlutils.rewards.handler import RewardLearningHandler
    

parser = argparse.ArgumentParser()
parser.add_argument("task", type=str)
parser.add_argument("model", type=str)
parser.add_argument("--num_eps", type=int, default=100)
args = parser.parse_args()

P = build_params(["agent.pets_moreplanning", f"task.{args.task}", f"oracle.{args.task}"], root_dir="config")

# Create fast jet environment
env = gym.make("FastJet-v0",
    task=args.task, 
    skip_frames=P["deployment"]["skip_frames"],
    render_mode="human",
    camera_angle=P["deployment"]["camera_angle"]
)

# Create "observer" instance that handles agent-reward interaction
pbrl = RewardLearningHandler(P["pbrl"])
device_ = device("cuda" if is_available() else "cpu")
if args.model == "oracle":
    P["pbrl"]["reward_source"] = "oracle"
else:
    P["pbrl"]["reward_source"] = "model"
    pbrl.model = load(f"trained_models/{args.task}/{args.model}.reward", map_location=device_)
    pbrl.model.device = device_

# Create PETS agent
P["agent"]["pretrained_model"] = load(f"pretrained_dynamics/{args.task}.dynamics", map_location=device_)
P["agent"]["reward"] = pbrl.reward
agent = make("pets", env, hyperparameters=P["agent"])

# Deploy
deploy(agent=agent, P={
        "num_episodes": args.num_eps,
        "episode_time_limit": P["deployment"]["episode_time_limit"],
        "render_freq": 1,
    }
)

