from lzero.entry import eval_muzero
import numpy as np

if __name__ == "__main__":
    """
    Overview:
        Main script to evaluate the MuZero model on Atari games. The script will loop over multiple seeds,
        evaluating a certain number of episodes per seed. Results are aggregated and printed.

    Variables:
        - model_path (:obj:`Optional[str]`): The pretrained model path, pointing to the ckpt file of the pretrained model. 
          The path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
        - seeds (:obj:`List[int]`): List of seeds to use for the evaluations.
        - num_episodes_each_seed (:obj:`int`): Number of episodes to evaluate for each seed.
        - total_test_episodes (:obj:`int`): Total number of test episodes, calculated as num_episodes_each_seed * len(seeds).
        - returns_mean_seeds (:obj:`np.array`): Array of mean return values for each seed.
        - returns_seeds (:obj:`np.array`): Array of all return values for each seed.
    """
    # Importing the necessary configuration files from the atari muzero configuration in the zoo directory.
    from zoo.atari.config.atari_muzero_config import main_config, create_config

    # model_path is the path to the trained MuZero model checkpoint.
    # If no path is provided, the script will use the default model.
    model_path = None

    # seeds is a list of seed values for the random number generator, used to initialize the environment.
    seeds = [0]
    # num_episodes_each_seed is the number of episodes to run for each seed.
    num_episodes_each_seed = 1
    # total_test_episodes is the total number of test episodes, calculated as the product of the number of seeds and the number of episodes per seed
    total_test_episodes = num_episodes_each_seed * len(seeds)

    # Setting the type of the environment manager to 'base' for the visualization purposes.
    create_config.env_manager.type = 'base'
    # The number of environments to evaluate concurrently. Set to 1 for visualization purposes.
    main_config.env.evaluator_env_num = 1
    # The total number of evaluation episodes that should be run.
    main_config.env.n_evaluator_episode = total_test_episodes
    # A boolean flag indicating whether to render the environments in real-time.
    main_config.env.render_mode_human = False

    # A boolean flag indicating whether to save the video of the environment.
    main_config.env.save_replay = True
    # The path where the recorded video will be saved.
    main_config.env.replay_path = './video'
    # The maximum number of steps for each episode during evaluation. This may need to be adjusted based on the specific characteristics of the environment.
    main_config.env.eval_max_episode_steps = int(100)

    # These lists will store the mean and total rewards for each seed.
    returns_mean_seeds = []
    returns_seeds = []

    # The main evaluation loop. For each seed, the MuZero model is evaluated and the mean and total rewards are recorded.
    for seed in seeds:
        returns_mean, returns = eval_muzero(
            [main_config, create_config],
            seed=seed,
            num_episodes_each_seed=num_episodes_each_seed,
            print_seed_details=False,
            model_path=model_path
        )
        print(returns_mean, returns)
        returns_mean_seeds.append(returns_mean)
        returns_seeds.append(returns)

    # Convert the list of mean and total rewards into numpy arrays for easier statistical analysis.
    returns_mean_seeds = np.array(returns_mean_seeds)
    returns_seeds = np.array(returns_seeds)

    # Printing the evaluation results. The average reward and the total reward for each seed are displayed, followed by the mean reward across all seeds.
    print("=" * 20)
    print(f"We evaluated a total of {len(seeds)} seeds. For each seed, we evaluated {num_episodes_each_seed} episode(s).")
    print(f"For seeds {seeds}, the mean returns are {returns_mean_seeds}, and the returns are {returns_seeds}.")
    print("Across all seeds, the mean reward is:", returns_mean_seeds.mean())
    print("=" * 20)