import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import gym
import argparse
import numpy as np
import os
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.logger import configure

from IPython import embed
from distutils.util import strtobool

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="Godot")
    parser.add_argument("--num-envs", type=int, default=4)
    parser.add_argument("--nb-agents", type=int, default=80)
    parser.add_argument("--port", type=int, default=0)

    parser.add_argument("--total-timesteps", type=int, default=50000000)
    
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    
    args = parser.parse_args()
    return args

import socket
from contextlib import closing

def find_free_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]
    
def make_env(idx):
    def thunk():
        if args.env_id == "Godot":
            from countbased.envs.Godot.godot import GodotGymWrapper
            agent_type = "countbased" if args.augment == False else "countbased-augmented"

            env = GodotGymWrapper(
                env_path = f"{os.getcwd()}/envs/Godot/linux/Project88.x86_64",
                agent_type=agent_type,
                episodic=int(args.episodic),
                show_window=False,
                framerate=60,
                nb_agents=args.nb_agents, 
                action_repeat=10, 
                port=find_free_port())
        return env
    return thunk

if __name__ == "__main__":
    args = parse_args()

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env(i) for i in range(args.num_envs)])
    else:
        from countbased.wrappers.dummy_vec_env import DummyVecEnv
        vec_env = DummyVecEnv([make_env(i) for i in range(args.num_envs)], nb_agents = args.nb_agents)

    ####################### #######################
    tmp_path = f"godot_runs/{args.savedir}/{args.xpid}"
    new_logger = configure(tmp_path, ["csv", "tensorboard"])
    
    # The noise objects for DDPG
    n_actions = vec_env.action_space.shape[-1]
    action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions))

    ## CALLBACKS
    from countbased.callbacks.callbacks import LogRewardsGodotCallback
    log_reward_callback = LogRewardsGodotCallback(log_every=10, num_envs = args.num_envs * args.nb_agents)

    from countbased.models.models import CustomAugmentedExtractorCNNGodot
    policy = "MultiInputPolicy"
    policy_kwargs = dict(
        normalize_images=False,
        features_extractor_class=CustomAugmentedExtractorCNNGodot,
    )

    model = DDPG(policy,
                vec_env,
                action_noise=action_noise,
                verbose=1,
                seed=args.seed,
                train_freq = (4,"step"),
                gradient_steps=1,
                batch_size=256,
                policy_kwargs=policy_kwargs,
            )
    model.set_logger(new_logger)
    
    print(model.policy)
    model.learn(
        total_timesteps=args.total_timesteps, 
        log_interval=1,
        progress_bar=True,
        callback=[log_reward_callback]
    )
    
    model.save(f"{args.xpid}")