import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import gym
import argparse
import os
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv

from IPython import embed
from distutils.util import strtobool
from stable_baselines3.common.logger import configure
from stable_baselines3.common.sb2_compat.rmsprop_tf_like import RMSpropTFLike
from countbased.wrappers.countbased import CountBasedExplorationWrapper

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="MazeEnv")
    parser.add_argument("--maze-id", type=int, default=1)
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--total-timesteps", type=int, default=20000000)
    parser.add_argument("--num-envs", type=int, default=16)
    parser.add_argument("--n-steps", type=int, default=5)
    parser.add_argument("--learning-rate", type=float, default=0.0007)
    parser.add_argument("--policy", type=str, default="MlpPolicy")
    
    parser.add_argument("--frame-stack", type=int, default=-1,help="")
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--use_cnn", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--salesman", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--partial_obs", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--add_true_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--no_intr_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--beta", type=int, default=1)
    
    args = parser.parse_args()
    
    if args.no_intr_rew:
        args.add_true_rew = False
        args.maze_id = 2
        args.augment = False

    if args.add_true_rew:
        args.beta = 0.1
        args.maze_id = 2
        args.no_intr_rew = False
        
    return args

def make_env():
    def thunk():
        if args.env_id != "MazeEnv":
            raise "Only Maze Env is supported now for countbased experiments"
        else:
            ##### REGISTER MAZE ENV ####
            from countbased.envs.maze_proc_gen import ProcGenMazeEnv, ProcGenPOMazeEnv
            from countbased.envs.maze_level_generator import LabyrinthLevelGenerator

            if args.maze_id == 1 or args.maze_id == 2 or args.maze_id == 3:
                w, h = 32, 32
        
            if args.maze_id == 4:
                w, h = 20, 20
            
            if args.maze_id == 5:
                w, h = 64, 64

            config = {
                'width': w,
                'height': h,
                'wall_density': 0.8,
                'num_goals': 0
            }
            
            # the level generator will only be used if args.proc_gen = True
            level_generator = LabyrinthLevelGenerator(config)
            max_steps = max_steps = 1500 if args.episodic==True else 250
            if args.add_true_rew or args.no_intr_rew:
                max_steps = 1500

            if args.partial_obs == True:
                env = ProcGenPOMazeEnv(level_generator, config, max_steps = max_steps)
            else:
                env = ProcGenMazeEnv(level_generator, config, max_steps = max_steps)

        level_string = "".join(open(f"envs/maze_32x32_{args.maze_id}.txt", "r").readlines())
        env = CountBasedExplorationWrapper(
            env,
            heatmap_shape=(config["width"], config["height"]),
            no_intr_rew=args.no_intr_rew,
            beta=args.beta,
            add_true_rew=args.add_true_rew,
            proc_gen=args.proc_gen,
            episodic=args.episodic,
            augment=args.augment,
            use_cnn=args.use_cnn,
            partial_obs=args.partial_obs,
            salesman=args.salesman,
            level_string=level_string
        )
        return env
    return thunk

if __name__ == "__main__":
    args = parse_args()

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env() for i in range(args.num_envs)])
    else:
        vec_env = DummyVecEnv([make_env() for i in range(args.num_envs)])
    ####################### #######################

    ### CALLBACKS
    from countbased.callbacks.callbacks import LogCoverageCallback
    log_coverage_callback = LogCoverageCallback(log_heatmap_every=1500, map_size=32)

    tmp_path = f"mazes_with_goals_a2c/{args.savedir}/{args.xpid}"
    # set up logger
    new_logger = configure(tmp_path, ["csv", "tensorboard"])

    ####################### Policy #######################
    if args.use_cnn == True:
        from countbased.models.models import CustomAugmentedExtractorCNN
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorCNN,
        )
    else:
        from countbased.models.models import CustomAugmentedExtractorMLP
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorMLP,
        )
    ####################### #######################
    
    model = A2C(policy,
                vec_env,
                verbose=1,
                seed=args.seed,
                policy_kwargs=policy_kwargs
            )
    
    model.set_logger(new_logger)
    
    print(model.policy)
    model.learn(
        total_timesteps=args.total_timesteps, 
        log_interval=50,
        progress_bar=True,
        callback=[log_coverage_callback]
    )
    
    #model.save(f"{args.xpid}")