import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import gym
import argparse
import os
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.logger import configure

from IPython import embed
from distutils.util import strtobool
from countbased.wrappers.countbased import CountBasedExplorationWrapper
from countbased.wrappers.e3b import E3BWrapper

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="MazeEnv")
    parser.add_argument("--maze-id", type=int, default=1)
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--total-timesteps", type=int, default=30000000)
    parser.add_argument("--num-envs", type=int, default=16)
    parser.add_argument("--n-steps", type=int, default=2048)
    parser.add_argument("--learning-rate", type=float, default=0.0003)
    parser.add_argument("--ent-coef", type=float, default=0.0)
    parser.add_argument("--batch-size", type=int, default=2048)
    parser.add_argument("--n-epochs", type=int, default=10)
    parser.add_argument("--policy", type=str, default="MlpPolicy")
    
    parser.add_argument("--frame-stack", type=int, default=-1,help="")
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--use_cnn", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--salesman", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--partial_obs", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--add_true_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--no_intr_rew", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--beta", type=int, default=1)
    
    parser.add_argument("--e3b", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--e3b_hidden_dim", type=int, default=256)
    parser.add_argument("--e3b_ridge", type=int, default=0.1)

    args = parser.parse_args()

    if args.e3b == True:
        args.batch_size = 512
        args.n_epochs = 4
        args.n_steps = 256
        args.num_envs = 16
        args.total_timesteps = 28000000
        args.env_id = "BiomesEnv"

    if args.no_intr_rew:
        args.add_true_rew = False
        args.maze_id = 2
        args.augment = False

    if args.add_true_rew:
        args.beta = 0.1
        args.maze_id = 2
        args.no_intr_rew = False

    return args

if __name__ == "__main__":
    args = parse_args()

    def make_env():
        def thunk():
            if args.env_id != "MazeEnv" and args.env_id != "BiomesEnv":
                raise "Only Maze Env is supported now for countbased experiments"
            else:
                
                if args.env_id == "MazeEnv":
                    ##### REGISTER MAZE ENV ####
                    from countbased.envs.maze_proc_gen import ProcGenMazeEnv, ProcGenPOMazeEnv
                    from countbased.envs.maze_level_generator import LabyrinthLevelGenerator

                    if args.maze_id == 1 or args.maze_id == 2 or args.maze_id == 3:
                        w, h = 32, 32
                
                    if args.maze_id == 4:
                        w, h = 20, 20
                    
                    if args.maze_id == 5:
                        w, h = 64, 64

                    config = {
                        'width': w,
                        'height': h,
                        'wall_density': 0.8,
                        'num_goals': 0
                    }
                    
                    # the level generator will only be used if args.proc_gen = True
                    level_generator = LabyrinthLevelGenerator(config)
                    max_steps = max_steps = 1500 if args.episodic==True else 250
                    if args.add_true_rew or args.no_intr_rew:
                        max_steps = 1500

                    if args.partial_obs == True:
                        env = ProcGenPOMazeEnv(level_generator, config, max_steps = max_steps, pixel_obs=True if args.e3b else False)
                    else:
                        env = ProcGenMazeEnv(level_generator, config, max_steps = max_steps, pixel_obs=True if args.e3b else False)

                    level_string = "".join(open(f"envs/maze_32x32_{args.maze_id}.txt", "r").readlines())

                else:
                    from countbased.envs.biomes import BiomesEnv
                    env = BiomesEnv(max_steps = 10)
                    level_string = ""
            
            if args.e3b == False:
                env = CountBasedExplorationWrapper(
                    env,
                    heatmap_shape=(config["width"], config["height"]),
                    no_intr_rew=args.no_intr_rew,
                    beta=args.beta,
                    add_true_rew=args.add_true_rew,
                    proc_gen=args.proc_gen,
                    episodic=args.episodic,
                    augment=args.augment,
                    use_cnn=args.use_cnn,
                    partial_obs=args.partial_obs,
                    salesman=args.salesman,
                    level_string=level_string
                )
            else:
                env = E3BWrapper(
                    env,
                    hidden_dim=args.e3b_hidden_dim,
                    ridge=args.e3b_ridge,
                    add_true_rew=False,
                    proc_gen=args.proc_gen,
                    episodic=args.episodic,
                    augment=args.augment,
                    level_string=level_string,
                    is_biomes=True if args.env_id == "BiomesEnv" else False
                )

            return env
        return thunk

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env() for i in range(args.num_envs)])
    else:
        vec_env = DummyVecEnv([make_env() for i in range(args.num_envs)])

    ### CALLBACKS
    from countbased.callbacks.callbacks import LogCoverageCallback
    log_heatmap_every = 1500 if args.e3b == False else -1
    log_coverage_callback = LogCoverageCallback(log_heatmap_every=log_heatmap_every, map_size=32 if args.maze_id == 2 else 64)
    callback = [log_coverage_callback]

    if args.e3b == True:
        from countbased.callbacks.train_e3b import TrainE3BCallback
        train_e3b_callback = TrainE3BCallback(hidden_dim = args.e3b_hidden_dim)
        callback.append(train_e3b_callback)

    tmp_path = f"mazes_with_goals_lstm/{args.savedir}/{args.xpid}"

    # set up logger
    new_logger = configure(tmp_path, ["csv", "tensorboard"])
    
    ####################### Policy #######################
    if args.use_cnn == True:
        from countbased.models.models import CustomAugmentedExtractorCNN
        policy = "MultiInputLstmPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorCNN,
        )
    else:
        from countbased.models.models import CustomAugmentedExtractorMLP
        policy = "MultiInputLstmPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorMLP,
        )

    if args.e3b == True:
        from countbased.models.models import CustomAugmentedExtractorNatureCNN
        policy = "MultiInputLstmPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorNatureCNN,
            features_extractor_kwargs=dict(hidden_dim=args.e3b_hidden_dim),
        )
    ####################### #######################

    model = RecurrentPPO(
        policy,
        vec_env,
        batch_size = args.batch_size,
        n_steps = args.n_steps,
        n_epochs = args.n_epochs,
        ent_coef = args.ent_coef,
        learning_rate = args.learning_rate,
        policy_kwargs=policy_kwargs,
        verbose=1,
    )

    model.set_logger(new_logger)
    
    print(model.policy)
    model.learn(
        total_timesteps=args.total_timesteps, 
        log_interval=1,
        progress_bar=True,
        callback=callback
    )
    
    #model.save(f"{args.xpid}")