import os
import shutil
import time
from threading import Thread
from ray import tune

from unrealpose.envs import *
from .wrappers import *

NUM_BINARIES_PER_SERVER = 21  # Each server has 21 binary folders as for now

# For debugging (colorful print)


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def remove_unreal_logs(binary_path):
    binary_dir = os.path.dirname(os.path.abspath(binary_path))
    log_path = os.path.abspath(os.path.join(binary_dir, '..', '..', 'Saved'))
    while True:
        if os.path.exists(log_path):
            shutil.rmtree(log_path, ignore_errors=True)
        time.sleep(30 * 60)


def make_train_env(env_config, is_training=True):

    worker_index = getattr(env_config, 'worker_index', 0)

    # print(f"{bcolors.OKCYAN}{get_host_address()}: {bcolors.OKBLUE}w{worker_index}/v{env_config.vector_index}{bcolors.ENDC}")
    os.environ['UE4Binary_SLEEPTIME'] = env_config.get('UE4Binary_SLEEPTIME', '120')
    os.environ['DISPLAY'] = ':'

    if is_training:
        env_config['args'].worker_index = worker_index
        env_config['args'].binary_path_index = (worker_index - 1) % NUM_BINARIES_PER_SERVER

        if env_config.get('in_evaluation', False):
            env_config['args'].num_humans = 7
        else:
            if env_config.get('mixed_training', True):
                env_config['args'].num_humans = 1 + worker_index % 6

    if env_config.get('use_numerical', False):
        env = make_env_numerical(**env_config)
    else:
        env = make_env(**env_config)  # update config in make_env

        # TODO: do wee need to create a lock here?
        daemon = Thread(name='remove_unreal_logs', target=remove_unreal_logs, args=(env.binary.bin_path,), daemon=True)
        daemon.start()
    env = SingleEvaluationStep(env)

    if env_config.get('communicate_monocular_3d', False):
        env = CommunicateMonocular3D(env)

    env = NormalizeObservation(env)

    if env_config.get('aux_rewards', False) and env_config['aux_rewards'].get('use', False):
        env = Aux_Rewards(env, **env_config['aux_rewards'].get('args', dict()))

    if env_config.get('truncate_observation', False) and env_config['truncate_observation'].get('use', False):
        env = TruncateObservation(env, **env_config['truncate_observation'].get('args', dict()))

    if env_config.get('override_no_action', False):
        env = OverrideNoAction(env)

    if env.unwrapped.MULTI_AGENT and not env_config.get('force_single_agent', False):
        env = SplitActionSpace(env)
        if env_config.get('convert_multi_discrete', False):
            env = DiscreteAction(env)

        if env_config.get('expert_action', False):
            env = ExpertAction(env, **env_config['expert_action'].get('args', dict()))

        env = JointObservationTuneReward(env, teammate_stats_dim=env_config.get('teammate_stats_dim', None),
                                         reward_dict=env_config.get('reward_dict', None))

        if env_config.get('shapley_reward', False):
            env = ShapleyValueReward(env)

        if env_config.get('done_when_colliding', False) and env_config['done_when_colliding'].get('use', False):
            env = DoneWhenColliding(env, **env_config['done_when_colliding'].get('args', dict()))

        if env_config.get('remove_info', False):
            env = RemoveInfo(env)

        if env_config.get('running_normalized_reward', False):
            env = RunningNormalizedReward(env, momentum=0.1)

        env = RllibMultiAgentAPI(env)

    else:
        env = FlattenAction(env)

        # only DQN use this wrapper now
        if env_config.get('convert_multi_discrete', False):
            env = DiscreteAction(env)

        env = FlattenObservation(env)

        if is_training:
            env = SingleAgentRewardLogger(env=env, reward_dict=env_config.get('reward_dict', None))

        if env_config.get('remove_info', False):
            env = RemoveInfo(env)

        if env_config.get('running_normalized_reward', False):
            env = RunningNormalizedReward(env, momentum=0.1)

    return env


tune.register_env('urealpose', make_train_env)
