"""
Script for running online preference-based reward learning.
"""

import argparse
from pprint import pprint
import gym
from torch import device, load
from torch.cuda import is_available

import rlutils
import fastjet
from config.base import P as base


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument("task", type=str)
    parser.add_argument("agent", type=str)
    parser.add_argument("num_eps", type=int)
    parser.add_argument("--model", type=str)
    parser.add_argument("--schedule", type=str, default="1k_200_1_1")
    parser.add_argument("--irrationality", type=str)
    parser.add_argument("--render_freq", type=int, default=0)
    parser.add_argument("--save_freq", type=int, default=0)
    parser.add_argument("--do_wandb", type=int, default=0)
    parser.add_argument("--wandb_project", type=str, default="reward_learning_with_trees")
    args = parser.parse_args()

    # Build configs for various components of the system by reading in parameter dictionaries
    P = rlutils.build_params([
        f"task.{args.task}",
        f"oracle.{args.task}",
        f"agent.{args.agent}",
        f"model.{args.model}" if args.model is not None else "",
        f"features.default" if args.model is not None else "",
        f"schedule.{args.schedule}" if args.model is not None else "",
        f"sampler.uniform_recency" if args.model is not None else "",
        f"irrationality.{args.irrationality}" if args.irrationality is not None else ""
        ], base, root_dir="config")
    P["deployment"]["num_episodes"] = args.num_eps
    P["deployment"]["render_freq"] = args.render_freq
    P["pbrl"]["save_freq"] = args.save_freq
    pprint(P)

    # Create fast jet environment
    env = gym.make("FastJet-v0",
        task=P["deployment"]["task"],
        continuous=(P["deployment"]["agent"] != "dqn"),
        skip_frames=P["deployment"]["skip_frames"],
        render_mode=("human" if args.render_freq > 0 else False),
        camera_angle=P["deployment"]["camera_angle"]
    )

    # Create "observer" instance that handles reward learning
    pbrl = rlutils.RewardLearningHandler(P=P["pbrl"])
    P["deployment"]["observers"]["pbrl"] = pbrl

    # Create either PETS or SAC agent, setting up link required for agent to access the reward model
    do_link = False
    if P["deployment"]["agent"] == "pets":
        P["agent"]["reward"] = pbrl.reward
        if P["agent"]["pretrained_model"]:
            P["agent"]["pretrained_model"] = load(f"pretrained_dynamics/{P['deployment']['task']}.dynamics",
                                                  map_location=device("cuda" if is_available() else "cpu"))
    else:
        assert P["deployment"]["agent"] == "sac"
        do_link = True
    agent = rlutils.make(P["deployment"]["agent"], env=env, hyperparameters=P["agent"])
    if do_link: pbrl.link(agent)

    # Run the online process
    rlutils.deploy(agent, P=P["deployment"], train=P["deployment"]["train"],
                   wandb_config={"project": args.wandb_project} if args.do_wandb else None)
