import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'
import time
import matplotlib.pyplot as plt
import numpy as np
import cv2
import imageio
import argparse
import os
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from countbased.wrappers.goal_reaching_wrapper import GoalReachingWrapper
from countbased.wrappers.countbased import CountBasedExplorationWrapper

from IPython import embed
from distutils.util import strtobool

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="MazeEnv")
    parser.add_argument("--maze-id", type=int, default=1)

    parser.add_argument("--total-timesteps", type=int, default=15000000)
    parser.add_argument("--num-envs", type=int, default=1)
    parser.add_argument("--n-steps", type=int, default=2048)
    parser.add_argument("--learning-rate", type=float, default=0.0003)
    parser.add_argument("--ent-coef", type=float, default=0.0)
    parser.add_argument("--batch-size", type=int, default=2048)
    parser.add_argument("--n-epochs", type=int, default=10)
    parser.add_argument("--policy", type=str, default="MlpPolicy")
    
    parser.add_argument("--frame-stack", type=int, default=-1,help="")
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--use_cnn", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,help="")
    parser.add_argument("--salesman", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--partial_obs", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--goal_eval", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,help="")

    parser.add_argument("--e3b", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--e3b_hidden_dim", type=int, default=256)
    parser.add_argument("--e3b_ridge", type=int, default=0.1)

    args = parser.parse_args()
    return args

if __name__ == "__main__":
    
    def make_env():
        def thunk():
            if args.env_id != "MazeEnv":
                raise "Only Maze Env is supported now for countbased experiments"
            else:
                ##### REGISTER MAZE ENV ####
                from countbased.envs.maze_proc_gen import ProcGenMazeEnv, ProcGenPOMazeEnv
                from countbased.envs.maze_level_generator import LabyrinthLevelGenerator

                if args.maze_id == 1 or args.maze_id == 2 or args.maze_id == 3:
                    w, h = 32, 32
            
                config = {
                    'width': w,
                    'height': h,
                    'wall_density': 0.8,
                    'num_goals': 0
                }
                
                # the level generator will only be used if args.proc_gen = True
                level_generator = LabyrinthLevelGenerator(config)
                if args.partial_obs == True:
                    env = ProcGenPOMazeEnv(level_generator, config, max_steps = 1500, pixel_obs=True if args.e3b else False)
                else:
                    env = ProcGenMazeEnv(level_generator, config, max_steps = 1500, pixel_obs=True if args.e3b else False)

            level_string = "".join(open(f"envs/eval_mazes/eval_maze_m2_{args.maze_id}.txt", "r").readlines())
            
            if args.goal_eval:
                env = GoalReachingWrapper(
                    env,
                    heatmap_shape=(w, h),
                    level_string=level_string,
                    episodic=args.episodic,
                    salesman=args.salesman,
                    partial_obs=args.partial_obs,
                    use_cnn=args.use_cnn,
                )
            else:
                env = CountBasedExplorationWrapper(
                    env,
                    heatmap_shape=(w, h),
                    beta=1,
                    add_true_rew=False,
                    proc_gen=False,
                    episodic=args.episodic,
                    augment=args.augment,
                    use_cnn=True,
                    partial_obs=args.partial_obs,
                    salesman=args.salesman,
                    level_string=level_string
                )

            return env
        return thunk
    
    args = parse_args()

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env() for i in range(args.num_envs)])
    else:
        vec_env = DummyVecEnv([make_env() for i in range(args.num_envs)])
    ####################### #######################

    ####################### Policy #######################
    if args.use_cnn == True:
        from countbased.models.models import CustomAugmentedExtractorCNN
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorCNN,
        )
    else:
        from countbased.models.models import CustomAugmentedExtractorMLP
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorMLP,
        )
    
    if args.e3b == True:
        from countbased.models.models import CustomAugmentedExtractorNatureCNN
        policy = "MultiInputPolicy"
        policy_kwargs = dict(
            share_features_extractor=True,
            normalize_images=False,
            features_extractor_class=CustomAugmentedExtractorNatureCNN,
            features_extractor_kwargs=dict(hidden_dim=args.e3b_hidden_dim),
        )

    ####################### #######################
    model = PPO.load("checkpoints/ppo_aug_episodic", custom_objects={"policy_kwargs": policy_kwargs})

    obs = vec_env.reset()
    dones = [False]
    t = 0

    heatmap = np.zeros((32, 32))
    
    # lets store frames of each render() call and build a video out of it
    frames = []
    
    while not dones[0]:
        action, _states = model.predict(obs, deterministic=False)
        obs, rewards, dones, info = vec_env.step(action)
        
        frames.append(vec_env.render(mode="rgb_array"))
        vec_env.render("human")

        heatmap[vec_env.envs[0].x, vec_env.envs[0].y] = 1
        t+=1

        if dones[0]:
            break

        # cmap = plt.get_cmap('Greens')
        # cmap.set_under((0,0,0,0))
        # cmap_args = dict(cmap=cmap, vmin=1)
        
        # fig = plt.figure()
        # background_img = vec_env.render(mode="rgb_array")
        
        # pixel_size = 32 * 20

        # if background_img.shape[0] > pixel_size:
        #     background_img = background_img[0:pixel_size, 0:pixel_size]

        # background = cv2.resize(background_img, dsize=(32, 32), interpolation=cv2.INTER_AREA)

        # plt.imshow(background, alpha=1)
        # plt.imshow(heatmap.T, **cmap_args, interpolation='nearest')
        # plt.xticks([])
        # plt.yticks([])
        # plt.colorbar()
        # plt.savefig("lol.png")
        # plt.close(fig)
        # plt.clf()

    # build video
    imageio.mimsave(f"videos/{args.maze_id}.gif", frames, fps=10)