import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import gym
import argparse
import os
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.logger import configure
from countbased.wrappers.countbased import CountBasedExplorationWrapper


from IPython import embed
from distutils.util import strtobool

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="MazeEnv",
        help="the id of the environment")
    parser.add_argument("--maze-id", type=int, default=1)
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--total-timesteps", type=int, default=30000000,
        help="total timesteps of the experiments")
    parser.add_argument("--num-envs", type=int, default=1,
        help="the learning rate of the optimizer")
    parser.add_argument("--gradient-steps", type=int, default=1,
        help="the learning rate of the optimizer")
    parser.add_argument("--target-update-interval", type=int, default=10000,
        help="the learning rate of the optimizer")
    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")
    parser.add_argument("--buffer-size", type=int, default=1000000,
        help="the learning rate of the optimizer")
    parser.add_argument("--learning-starts", type=int, default=50000,
        help="the learning rate of the optimizer")
    parser.add_argument("--batch-size", type=int, default=32,
        help="the learning rate of the optimizer")
    parser.add_argument("--train-freq", type=int, default=4,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-fraction", type=float, default=0.65,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-initial-eps", type=float, default=1.0,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-final-eps", type=float, default=0.05,
        help="the learning rate of the optimizer")
    parser.add_argument("--replay-buffer-class", type=str, default="ReplayBuffer",
        help="the learning rate of the optimizer")
    parser.add_argument("--policy", type=str, default="MlpPolicy",
        help="")
    
    parser.add_argument("--frame-stack", type=int, default=-1,help="")
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--use_cnn", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--salesman", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--partial_obs", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    
    
    args = parser.parse_args()
    return args

if __name__ == "__main__":

    ##### REGISTER MAZE ENV ####
    from griddly import GymWrapperFactory, gd
    from gym.envs.registration import register as gym_register
    wrapper = GymWrapperFactory()
    wrapper.build_gym_from_yaml('_MazeEnv', f"{os.getcwd()}/envs/maze.yaml")
    gym_register(
        id='GDY-MazeEnv-v0',
        entry_point='envs.maze:MazeEnv'
    )

    args = parse_args()
    
    args.total_timesteps = 100000

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    def make_env(env_id, augment):
        def thunk():
            if args.env_id != "MazeEnv":
                raise "Only Maze Env is supported now for countbased experiments"
            else:
                ##### REGISTER MAZE ENV ####
                from countbased.envs.maze_proc_gen import ProcGenMazeEnv, ProcGenPOMazeEnv
                from countbased.envs.maze_level_generator import LabyrinthLevelGenerator

                if args.maze_id == 1 or args.maze_id == 2 or args.maze_id == 3:
                    w, h = 32, 32
            
                if args.maze_id == 4:
                    w, h = 20, 20
                
                if args.maze_id == 5:
                    w, h = 64, 64

                config = {
                    'width': w,
                    'height': h,
                    'wall_density': 0.8,
                    'num_goals': 0
                }
                
                # the level generator will only be used if args.proc_gen = True
                level_generator = LabyrinthLevelGenerator(config)
                if args.partial_obs == True:
                    env = ProcGenPOMazeEnv(level_generator, config, max_steps = 1500)
                else:
                    env = ProcGenMazeEnv(level_generator, config, max_steps = 1500)

            level_string = "".join(open(f"envs/maze_32x32_{args.maze_id}.txt", "r").readlines())
            env = CountBasedExplorationWrapper(
                env,
                heatmap_shape=(config["width"], config["height"]),
                beta=1,
                add_true_rew=False,
                proc_gen=args.proc_gen,
                episodic=args.episodic,
                augment=args.augment,
                use_cnn=args.use_cnn,
                partial_obs=args.partial_obs,
                salesman=args.salesman,
                level_string=level_string
            )
            return env
        return thunk
    
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env(args.env_id, args.augment) for i in range(args.num_envs)])
    else:
        vec_env = DummyVecEnv([make_env(args.env_id, args.augment) for i in range(args.num_envs)])

    if args.frame_stack > 0:
        vec_env = VecFrameStack(vec_env, args.frame_stack)

    # generate random experience
    env = make_env(args.env_id, args.augment)()
    num_episodes = 100
    coverages = []
    for i in range(num_episodes):
        obs = env.reset()
        done = False
        steps = 0
        while not done:
            steps += 1      
            action = env.action_space.sample()
            obs, reward, done, info = env.step(action)
            if done:
                if args.episodic:
                    coverages.append(info["episodic_coverage"])            
                else:
                    coverages.append(info["coverage"])
                print("Episode finished after {} timesteps".format(steps), info["coverage"], info["episodic_coverage"])
    
    import numpy as np
    print(np.mean(coverages))       
    