import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import gym
import argparse
import os
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.logger import configure

from IPython import embed
from distutils.util import strtobool
from countbased.wrappers.countbased import CountBasedExplorationWrapper

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="MazeEnv",
        help="the id of the environment")
    parser.add_argument("--maze-id", type=int, default=1)
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--total-timesteps", type=int, default=15000000,
        help="total timesteps of the experiments")
    parser.add_argument("--num-envs", type=int, default=16,
        help="the learning rate of the optimizer")
    parser.add_argument("--gradient-steps", type=int, default=4,
        help="the learning rate of the optimizer")
    parser.add_argument("--target-update-interval", type=int, default=10000,
        help="the learning rate of the optimizer")
    parser.add_argument("--learning-rate", type=float, default=2.5e-4,
        help="the learning rate of the optimizer")
    parser.add_argument("--buffer-size", type=int, default=1000000,
        help="the learning rate of the optimizer")
    parser.add_argument("--learning-starts", type=int, default=50000,
        help="the learning rate of the optimizer")
    parser.add_argument("--batch-size", type=int, default=64,
        help="the learning rate of the optimizer")
    parser.add_argument("--train-freq", type=int, default=4,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-fraction", type=float, default=0.1,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-initial-eps", type=float, default=1.0,
        help="the learning rate of the optimizer")
    parser.add_argument("--exploration-final-eps", type=float, default=0.05,
        help="the learning rate of the optimizer")
    parser.add_argument("--replay-buffer-class", type=str, default="ReplayBuffer",
        help="the learning rate of the optimizer")
    parser.add_argument("--policy", type=str, default="MlpPolicy",
        help="")
    
    parser.add_argument("--frame-stack", type=int, default=-1,help="")
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--use_cnn", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--salesman", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--partial_obs", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--add_true_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--no_intr_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--beta", type=int, default=1)
    
    args = parser.parse_args()
    
    if args.no_intr_rew:
        args.add_true_rew = False
        args.maze_id = 2
        args.augment = False

    if args.add_true_rew:
        args.beta = 0.1
        args.maze_id = 2
        args.no_intr_rew = False
        
    return args
if __name__ == "__main__":

    ##### REGISTER MAZE ENV ####
    from griddly import GymWrapperFactory, gd
    from gym.envs.registration import register as gym_register
    wrapper = GymWrapperFactory()
    wrapper.build_gym_from_yaml('_MazeEnv', f"{os.getcwd()}/envs/maze.yaml")
    gym_register(
        id='GDY-MazeEnv-v0',
        entry_point='envs.maze:MazeEnv'
    )

    args = parse_args()

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    def make_env():
        def thunk():
            if args.env_id != "MazeEnv":
                raise "Only Maze Env is supported now for countbased experiments"
            else:
                ##### REGISTER MAZE ENV ####
                from countbased.envs.maze_proc_gen import ProcGenMazeEnv, ProcGenPOMazeEnv
                from countbased.envs.maze_level_generator import LabyrinthLevelGenerator

                if args.maze_id == 1 or args.maze_id == 2 or args.maze_id == 3:
                    w, h = 32, 32
            
                if args.maze_id == 4:
                    w, h = 20, 20
                
                if args.maze_id == 5:
                    w, h = 64, 64

                config = {
                    'width': w,
                    'height': h,
                    'wall_density': 0.8,
                    'num_goals': 0
                }
                
                # the level generator will only be used if args.proc_gen = True
                level_generator = LabyrinthLevelGenerator(config)
                max_steps = max_steps = 1500 if args.episodic==True else 250
                if args.add_true_rew or args.no_intr_rew:
                    max_steps = 1500

                if args.partial_obs == True:
                    env = ProcGenPOMazeEnv(level_generator, config, max_steps = max_steps)
                else:
                    env = ProcGenMazeEnv(level_generator, config, max_steps = max_steps)

            level_string = "".join(open(f"envs/maze_32x32_{args.maze_id}.txt", "r").readlines())
            env = CountBasedExplorationWrapper(
                env,
                heatmap_shape=(config["width"], config["height"]),
                no_intr_rew=args.no_intr_rew,
                beta=args.beta,
                add_true_rew=args.add_true_rew,
                proc_gen=args.proc_gen,
                episodic=args.episodic,
                augment=args.augment,
                use_cnn=args.use_cnn,
                partial_obs=args.partial_obs,
                salesman=args.salesman,
                level_string=level_string
            )
            return env
        return thunk
    
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env() for i in range(args.num_envs)])
    else:
        vec_env = DummyVecEnv([make_env() for i in range(args.num_envs)])

    ####################### #######################

    ### CALLBACKS
    from countbased.callbacks.callbacks import LogCoverageCallback
    log_coverage_callback = LogCoverageCallback(log_heatmap_every=25000, map_size=32)
    
    tmp_path = f"mazes_with_goals_dqn/{args.savedir}/{args.xpid}"

    # set up logger
    new_logger = configure(tmp_path, ["csv", "tensorboard"])
    
    ####################### Policy #######################
    if args.use_cnn == True:
        from countbased.models.models import CustomAugmentedExtractorCNN
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorCNN,
        )
    else:
        from countbased.models.models import CustomAugmentedExtractorMLP
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorMLP,
        )
    ####################### #######################

    model = DQN(policy,
                vec_env,
                train_freq=args.train_freq,
                gradient_steps=args.gradient_steps,
                replay_buffer_kwargs = {"handle_timeout_termination" : False},
                optimize_memory_usage = False,
                verbose=1,
                seed=args.seed,
                exploration_fraction=args.exploration_fraction,
                policy_kwargs=policy_kwargs
            )
    
    model.set_logger(new_logger)
    
    print(model.policy)
    model.learn(
        total_timesteps=args.total_timesteps, 
        log_interval=8,
        progress_bar=True,
        callback=[log_coverage_callback]
    )
    #model.save(f"{args.xpid}")