import argparse
import os

import numpy as np
from stable_baselines3 import DQN, HerReplayBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.logger import configure
from stable_baselines3.common.utils import set_random_seed

from envs.mnist_grid import MNISTHyperGrid


if __name__ == "__main__":
    parser = argparse.ArgumentParser("Train goal-conditioned DQN with HER.")
    parser.add_argument("--n-sample", help="Number of samples", type=int, required=True)
    parser.add_argument("--grid-size", help="Size in each dimension", nargs="+", type=int, required=True)
    parser.add_argument("--eps", help="The stochasticity rate of the environment", type=float, required=True)
    parser.add_argument("--seed", type=int, required=True)
    args = parser.parse_args()

    set_random_seed(args.seed)
    env = MNISTHyperGrid(dimensions=args.grid_size, eps=args.eps, max_episode_steps=20, goal_conditioned=True)
    check_env(env)

    logger = configure("out/logs/", ["csv", "stdout"])
    model = DQN(
        "MultiInputPolicy",
        env,
        replay_buffer_class=HerReplayBuffer,
        replay_buffer_kwargs=dict(
            n_sampled_goal=4,
            goal_selection_strategy="future",
            copy_info_dict=True
        ),
        verbose=1,
        train_freq=1)
    print(model.q_net)
    model.set_logger(logger)
    model.learn(total_timesteps=args.n_sample)
    model.save(os.path.join("save", f"gcdqn_{args.n_sample}_{args.seed}"))

    rewards = []
    for e in range(100):
        obs, info = env.reset()
        done = False
        cum_rew = 0.0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, r, terminated, truncated, _ = env.step(action)
            cum_rew += r
            done = terminated or truncated
        rewards.append(cum_rew)
        print(f"Episode {e}, cum. rew.: {cum_rew:.1f}")
    print(f"Average cumulative reward: {np.mean(rewards):.2f}")
    env.close()
