import argparse
import os
from datetime import datetime

import ray
import torch
from gym.spaces import Tuple
from ray import tune

from experiments import *
from unrealpose import DEBUG, ROOT_DIR
from unrealpose.config import config as unreal_config
from unrealpose.config import update_config
from unrealpose.control.envs import make_train_env as make_env_impl
from unrealpose.custom import load_custom_models, load_trainers

load_custom_models()

parser = argparse.ArgumentParser(description='')
parser.add_argument('--tags', nargs='+', default=[], help='wandb tags')
parser.add_argument('--project', type=str, required=True, help='wandb project')
parser.add_argument('--group', type=str, required=True, help='wandb group')
parser.add_argument('--exp', type=str, required=True)
parser.add_argument('--restore', type=str, default=None, help='path to rllib checkpoints')
args = parser.parse_args()

LOCAL_DIR = os.path.join(ROOT_DIR, 'ray_results')

print(f'ROOT DIR: {ROOT_DIR}')

os.environ['TUNE_RESULT_DIR'] = LOCAL_DIR
os.environ['DISPLAY'] = ':'

torch.cuda.device_count()

NUM_NODE_CPUS = os.cpu_count()
NUM_NODE_GPUS = torch.cuda.device_count()
ray.init(num_cpus=NUM_NODE_CPUS, local_mode=DEBUG)

cluster_resource = ray.cluster_resources()
NUM_NODE_CPUS = int(cluster_resource['CPU'])
NUM_NODE_GPUS = int(cluster_resource['GPU'])

print(f'DEBUG MODE: {DEBUG}')
print(f'NUM_NODE_CPUS: {NUM_NODE_CPUS}')
print(f'NUM_NODE_GPUS: {NUM_NODE_GPUS}')



os.environ['UE4Binary_SLEEPTIME'] = '120'



def make_env(env_config):
    env = make_env_impl(env_config)
    return env


exp = globals()[args.exp]
tmp_env_config = exp.spec['config']['env_config']
if args.group is None:
    wandb_group = f"{exp.spec['config']['env_config']['args'].env_name}"
else:
    wandb_group = args.group

wandb_tags = [exp.spec['config']['env_config']['algo'], "dev"] + \
    args.tags if not DEBUG else [tmp_env_config['algo'], "DEBUG"]

exp.spec['trial_dirname_creator'] = lambda trial: f"{trial.trainable_name}_{trial.trial_id}_{datetime.now().strftime('%b%d')}"
exp.spec['config']['env_config']['UE4Binary_SLEEPTIME'] = '120'

if DEBUG:
    exp.spec['config']['num_workers'] = 1
    exp.spec['config']['num_envs_per_worker'] = 1
    exp.spec['config']['env_config']['UE4Binary_SLEEPTIME'] = '10'
    exp.spec['config']['train_batch_size'] = 1*exp.spec['config']['num_workers'] * exp.spec['config']['num_envs_per_worker'] * exp.spec['config']['rollout_fragment_length']
    exp.spec['config']['sgd_minibatch_size'] = exp.spec['config']['train_batch_size'] // 1
tune.register_env('urealpose-parallel', make_env)


update_config(unreal_config, exp.spec['config']['env_config']['args'])


if any([True for s in ['qmix', 'ppo_group'] if s in tmp_env_config['algo'].lower()]):
    tmp_env = make_env(tmp_env_config)

    agent_group = {'camera': tmp_env.agent_ids}
    obs_group = Tuple([tmp_env.observation_space for _ in agent_group["camera"]])
    act_group = Tuple([tmp_env.action_space for _ in agent_group["camera"]])
    tune.register_env(
        "urealpose-parallel-grouped",
        lambda config: make_env(config).with_agent_groups(
            groups=agent_group, obs_space=obs_group, act_space=act_group))
    exp.spec['config']['env'] = 'urealpose-parallel-grouped'
    tmp_env.close()

if args.restore:
    exp.spec['restore'] = args.restore


print(f"train_batch_size = {exp.spec['config']['train_batch_size']}, sgd_minibatch_size = {exp.spec['config']['sgd_minibatch_size']}")
print(exp.spec['config'])

tune.run_experiments(
    experiments=exp,
    verbose=3
)
