import os
import pickle

import click
import gym
import numpy as np
from mjrl.utils.gym_env import GymEnv

import d4rl_alt

# from mjrl.policies.gaussian_mlp import MLP

DESC = """
Helper script to visualize policy (in mjrl format).\n
USAGE:\n
    Visualizes policy on the env\n
    $ python visualize_policy.py --env_name door-v0 \n
    $ python visualize_policy.py --env_name door-v0 --policy my_policy.pickle --mode evaluation --episodes 10 \n
"""


class RandomPolicy(object):
    def __init__(self, env):
        self.env = env

    def get_action(self, obs):
        return [
            self.env.action_space.sample(),
            {"evaluation": self.env.action_space.sample()},
        ]


# MAIN =========================================================
@click.command(help=DESC)
@click.option("--env_name", type=str, help="environment to load", required=True)
@click.option(
    "--policy", type=str, help="absolute path of the policy file", default=None
)
@click.option(
    "--mode",
    type=str,
    help="exploration or evaluation mode for policy",
    default="evaluation",
)
@click.option(
    "--seed", type=int, help="seed for generating environment instances", default=123
)
@click.option(
    "--episodes", type=int, help="number of episodes to visualize", default=10
)
def main(env_name, policy, mode, seed, episodes):
    e = GymEnv(env_name)
    e.set_seed(seed)
    """
    if policy is not None:
        pi = pickle.load(open(policy, 'rb'))
    else:
        pi = MLP(e.spec, hidden_sizes=(32,32), seed=seed, init_log_std=-1.0)
    """
    pi = RandomPolicy(e)
    # render policy
    e.visualize_policy(pi, num_episodes=episodes, horizon=e.horizon, mode=mode)


if __name__ == "__main__":
    main()
