import argparse
import sys

import numpy as np
import torch
from gym import spaces
import gym

import pfrl
from pfrl import experiments, explorers
from pfrl import nn as pnn
from pfrl import utils, replay_buffers
from pfrl.q_functions import DistributionalDuelingDQN
from pfrl.wrappers import atari_wrappers
from agent import ERBLearnCategoricalDoubleDQN
from buffer import GreedyReplayBuffer


def main():
    import logging

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--outdir",
        type=str,
        default="results",
        help=(
            "Directory path to save output files."
            " If it does not exist, it will be created."
        ),
    )
    parser.add_argument("--env", type=str, default="PongNoFrameskip-v4")
    parser.add_argument("--seed", type=int, default=0, help="Random seed [0, 2 ** 32)")
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument(
        "--eval-epsilon",
        type=float,
        default=0.001,
        help="Exploration epsilon used during eval episodes.",
    )
    parser.add_argument("--noisy-net-sigma", type=float, default=0.5)
    parser.add_argument(
        "--demo", action="store_true", help="Just run evaluation, not training."
    )
    parser.add_argument(
        "--load", type=str, default=None, help="Directory to load agent from."
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=1 * 10 ** 6,
        help="Total number of timesteps to train the agent.",
    )
    parser.add_argument(
        "--replay-start-size",
        type=int,
        default=10 ** 3,
        help="Minimum replay buffer size before " + "performing gradient updates.",
    )
    parser.add_argument(
        "--target-update-interval",
        type=int,
        default=10 ** 3,
        help="Frequency (in timesteps) at which " + "the target network is updated.",
    )
    parser.add_argument(
        "--update-interval",
        type=int,
        default=1,
        help="Frequency (in timesteps) of network updates.",
    )
    parser.add_argument(
        "--eval-n-runs",
        type=int,
        default=5,
        help="Number of episodes run for each evaluation.",
    )
    parser.add_argument(
        "--eval-interval",
        type=int,
        default=2 * 10 ** 4,
        help="Frequency (in timesteps) of evaluation phase.",
    )
    parser.add_argument("--n-steps", type=int, default=3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--batch-size", type=int, default=32, help="Minibatch size")
    parser.add_argument("--reward-scale-factor", type=float, default=1)
    parser.add_argument("--load-pretrained", action="store_true", default=False)
    parser.add_argument(
        "--pretrained-type", type=str, default="best", choices=["best", "final"]
    )
    parser.add_argument(
        "--max-frames",
        type=int,
        default=30 * 60 * 60,  # 30 minutes with 60 fps
        help="Maximum number of frames for each episode.",
    )
    parser.add_argument(
        "--log-level",
        type=int,
        default=20,
        help="Logging level. 10:DEBUG, 20:INFO etc.",
    )
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate.")
    parser.add_argument("--prune-exp", action="store_true", default=False)
    parser.add_argument("--greedy", action="store_true", default=False)
    parser.add_argument("--rbuf-capacity", type=int, default=5 * 10 ** 5, help="Maximum size of replay buffer.")
    parser.add_argument(
        "--checkpoint-frequency",
        type=int,
        default=None,
        help="Frequency at which agents are stored.",
    )
    parser.add_argument(
        "--max-grad-norm",
        type=float,
        default=None,
    )
    parser.add_argument(
        "--num-bits",
        type=int,
        default=16,
    )
    args = parser.parse_args()

    logging.basicConfig(level=args.log_level)

    # Set a random seed used in PFRL
    utils.set_random_seed(args.seed)

    args.outdir = experiments.prepare_output_dir(args, args.outdir, argv=sys.argv)
    print("Output files are saved in {}".format(args.outdir))

    train_seed = args.seed
    test_seed = 2 ** 31 - 1 - args.seed

    def clip_action_filter(a):
        return np.clip(a, action_space.low, action_space.high)

    def make_env(test=False):
        env = gym.make(args.env)
        # Use different random seeds for train and test envs
        env_seed = test_seed if test else train_seed
        utils.set_random_seed(env_seed)
        # Cast observations to float32 because our model uses float32
        env = pfrl.wrappers.CastObservationToFloat32(env)
        # if args.monitor:
        #     env = pfrl.wrappers.Monitor(env, args.outdir)
        if isinstance(env.action_space, spaces.Box):
            utils.env_modifiers.make_action_filtered(env, clip_action_filter)
        if not test:
            # Scale rewards (and thus returns) to a reasonable range so that
            # training is easier
            env = pfrl.wrappers.ScaleReward(env, args.reward_scale_factor)
        # if (args.render_eval and test) or (args.render_train and not test):
        #     env = pfrl.wrappers.Render(env)
        return env

    def make_atari_env(test=False):
        # Use different random seeds for train and test envs
        env_seed = test_seed if test else train_seed
        env = atari_wrappers.wrap_deepmind(
            atari_wrappers.make_atari(args.env, max_frames=args.max_frames),
            episode_life=not test,
            clip_rewards=not test,
            fire_reset=True,
        )
        env.seed(int(env_seed))
        if test:
            # Randomize actions like epsilon-greedy in evaluation as well
            env = pfrl.wrappers.RandomizeAction(env, args.eval_epsilon)
        return env

    def atari_phi(x):
        # Feature extractor
        return np.asarray(x, dtype=np.float32) / 255

    if "NoFrameskip-v4" in args.env:
        env = make_atari_env(test=False)
        eval_env = make_atari_env(test=True)
        phi = atari_phi
    else:
        raise NotImplementedError()
        env = make_env(test=False)
        eval_env = make_env(test=True)
        phi = lambda x: x

    timestep_limit = env.spec.max_episode_steps
    obs_space = env.observation_space
    obs_size = obs_space.low.size
    action_space = env.action_space
    n_actions = env.action_space.n

    n_atoms = 51
    v_max = 10
    v_min = -10
    q_func = DistributionalDuelingDQN(
        n_actions,
        n_atoms,
        v_min,
        v_max,
    )

    pnn.to_factorized_noisy(q_func, sigma_scale=args.noisy_net_sigma)
    explorer = explorers.Greedy()
    opt = torch.optim.Adam(q_func.parameters(), lr=args.lr)

    # Prioritized Replay
    # Noisy nets
    betasteps = args.steps / args.update_interval
    if args.greedy:
        rbuf = GreedyReplayBuffer(
            args.rbuf_capacity,
            alpha=0.5,
            beta0=0.4,
            betasteps=betasteps,
            num_steps=args.n_steps,
            normalize_by_max="memory",
        )
    else:
        rbuf = replay_buffers.PrioritizedReplayBuffer(
            args.rbuf_capacity,
            alpha=0.5,
            beta0=0.4,
            betasteps=betasteps,
            num_steps=args.n_steps,
            normalize_by_max="memory",
        )

    agent = ERBLearnCategoricalDoubleDQN(
        q_func,
        opt,
        rbuf,
        gpu=args.gpu,
        phi=phi,
        gamma=args.gamma,
        explorer=explorer,
        replay_start_size=args.replay_start_size,
        target_update_interval=args.target_update_interval,
        update_interval=args.update_interval,
        minibatch_size=args.batch_size,
        max_grad_norm=args.max_grad_norm,
    )

    if args.load or args.load_pretrained:
        # either load_ or load_pretrained must be false
        assert not args.load or not args.load_pretrained
        if args.load:
            agent.load(args.load)
        else:
            agent.load(
                utils.download_model(
                    "Rainbow", args.env, model_type=args.pretrained_type
                )[0]
            )

    if args.demo:
        eval_stats = experiments.eval_performance(
            env=eval_env,
            agent=agent,
            n_steps=None,
            n_episodes=args.eval_n_runs,
            max_episode_len=timestep_limit,
        )
        print(
            "n_runs: {} mean: {} median: {} stdev {}".format(
                args.eval_n_runs,
                eval_stats["mean"],
                eval_stats["median"],
                eval_stats["stdev"],
            )
        )

    else:
        experiments.train_agent_with_evaluation(
            agent=agent,
            env=env,
            steps=args.steps,
            eval_n_steps=None,
            eval_n_episodes=args.eval_n_runs,
            checkpoint_freq=args.checkpoint_frequency,
            eval_interval=args.eval_interval,
            outdir=args.outdir,
            eval_env=eval_env,
            train_max_episode_len=timestep_limit,
            eval_during_episode=True,
        )


if __name__ == "__main__":
    main()