import os
#os.environ['OPENBLAS_NUM_THREADS'] = '1'

import argparse
import os
from stable_baselines3 import SAC
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.logger import configure

from IPython import embed
from distutils.util import strtobool

def parse_args():
    # fmt: off
    parser = argparse.ArgumentParser()
    parser.add_argument("--xpid", type=str, default=os.path.basename(__file__).rstrip(".py"),
        help="the name of this experiment")
    parser.add_argument("--savedir", type=str, default="godot_runs")
    
    parser.add_argument("--seed", type=int, default=1,
        help="seed of the experiment")

    # Algorithm specific arguments
    parser.add_argument("--env-id", type=str, default="Godot")
    parser.add_argument("--nb-agents", type=int, default=80)
    parser.add_argument("--proc_gen", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")

    parser.add_argument("--total-timesteps", type=int, default=50000000)
    parser.add_argument("--num-envs", type=int, default=4)
    
    parser.add_argument("--parallelize-cpu", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--augment", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    parser.add_argument("--episodic", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,help="")
    
    # SAC params
    parser.add_argument("--learning_rate", type=float, default=0.0003)
    parser.add_argument("--buffer_size", type=int, default=1000000)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--tau", type=float, default=0.005)
    parser.add_argument("--train_freq", type=int, default=1)
    parser.add_argument("--use_sde", type=bool, default=True)
    parser.add_argument("--log_std_init", type=float, default=-3)
    parser.add_argument("--net_size", type=int, default=400)
    parser.add_argument("--learning_starts", type=int, default=100)

    args = parser.parse_args()
    return args

import socket
from contextlib import closing

def find_free_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(('', 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]

def make_env(idx):
    def thunk():
        if args.env_id == "Godot":
            from countbased.envs.Godot.godot import GodotGymWrapper
            agent_type = "countbased" if args.augment == False else "countbased-augmented"

            env = GodotGymWrapper(
                env_path = f"{os.getcwd()}/envs/Godot/linux/Project88.x86_64",
                agent_type=agent_type,
                episodic=int(args.episodic),
                show_window=True,
                framerate=60,
                nb_agents=args.nb_agents, 
                action_repeat=10, 
                port=find_free_port())
        return env
    return thunk

if __name__ == "__main__":
    args = parse_args()

    #######################
    import sys
    sep = os.pathsep
    os.environ['PYTHONPATH'] = sep.join(sys.path)
    #######################

    ####################### ENV #######################
    # Train Environment
    if args.parallelize_cpu:
        vec_env = SubprocVecEnv([make_env(args.env_id, args.augment, i) for i in range(args.num_envs)])
    else:
        from countbased.wrappers.dummy_vec_env import DummyVecEnv
        vec_env = DummyVecEnv([make_env(i) for i in range(args.num_envs)], nb_agents = args.nb_agents)

    ####################### #######################
    tmp_path = f"godot_runs/{args.savedir}/{args.xpid}"
    new_logger = configure(tmp_path, ["csv", "tensorboard"])
    
    ## CALLBACKS
    from countbased.callbacks.callbacks import LogRewardsGodotCallback
    log_reward_callback = LogRewardsGodotCallback(log_every=100, num_envs = args.num_envs * args.nb_agents)

    from countbased.models.models import CustomAugmentedExtractorCNNGodot
    policy = "MultiInputPolicy"
    policy_kwargs = dict(
        normalize_images=False,
        features_extractor_class=CustomAugmentedExtractorCNNGodot,
    )

    model = SAC(policy,
                vec_env,
                tensorboard_log=tmp_path,
                verbose=1,
                seed=args.seed,
                train_freq=4,
                gradient_steps=1,
                policy_kwargs=policy_kwargs
            )
   
    model.set_logger(new_logger)
    print(model.policy)

    model.learn(
        total_timesteps=args.total_timesteps, 
        log_interval=1,
        progress_bar=True,
        callback=[
            log_reward_callback, 
        ]
    )
    
    model.save(f"{args.xpid}/{args.xpid}")