import os
import argparse
import subprocess
import numpy as np
import ray
from ray import tune
from ray.tune.config_parser import make_parser
from ray.tune.experiment import Experiment

import misc
from trainers import get_trainer_class

import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using {} device".format(device))
torch.backends.cudnn.benchmark = True

def parse_args():
    # Set default arguments based on https://github.com/ray-project/ray/blob/master/python/ray/tune/config_parser.py
    parser = make_parser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description='Train a reinforcement learning agent.')
    parser.add_argument(
        '-f',
        '--config-file',
        default=None,
        type=str,
        help='Use config options from this file. Note that this.')
    parser.add_argument(
        '--temp-dir',
        default=os.environ['RAY_TMPDIR'],
        type=str,
        help='Directory for temporary files generated by ray.')
    parser.add_argument(
        '--resume',
        action='store_true',
        help='Whether to attempt to resume previous Tune experiments.')
    parser.add_argument(
        '-v', action='store_true', help='Whether to use INFO level logging.')
    parser.add_argument(
        '-vv', action='store_true', help='Whether to use DEBUG level logging.')
    parser.add_argument(
        '--local-mode',
        action='store_true',
        help='Whether to run ray with `local_mode=True`. '
             'Only if --ray-num-nodes is not used.')
    parser.set_defaults(max_failures=1) # overwrite default value
    
    args = parser.parse_args()

    return args


def main():
    # Get experiment configuration
    args = parse_args()
    experiments = misc.load_yaml(args.config_file)
    exp_names = list(experiments.keys())
    assert len(exp_names) == 1
    exp = experiments[exp_names[0]]
    exp['trial_name_creator'] = trial_name_creator
    path_keys = ['local_dir']
    for k in path_keys:
        try:
            paths = misc.get_dict_value_by_str(exp, k)
            is_list = isinstance(paths, list)
            paths = [paths] if not is_list else paths
            for i, path in enumerate(paths):
                path = [os.environ[_v[1:]] if _v.startswith('$') else _v for _v in path.split('/')]
                path = os.path.abspath(os.path.expanduser(os.path.join(*path)))
                paths[i] = path
            paths = paths if is_list else paths[0]
            misc.set_dict_value_by_str(exp, k, paths)
        except KeyError:
            pass
    verbose = 1
    if args.v:
        exp['config']['log_level'] = 'INFO'
        verbose = 2
    if args.vv:
        exp['config']['log_level'] = 'DEBUG'
        verbose = 3
    if args.local_mode:
        exp['config']['num_workers'] = 0

    # TODO: Copy config file

    # Register custom model and environments
    try:
        import gym
        from functools import partial
        env = gym.make(exp['env'])
        env_creator = partial(gym.make, exp['env'])
        
    except:
        env_creator = misc.register_custom_env(exp['env'])
    if exp['run'] == 'PPO':
        register_action_dist = exp['config']['model']['custom_action_dist'] is not None
        misc.register_custom_model(exp['config']['model'], register_action_dist=register_action_dist)
    elif exp['run'] == 'CustomSAC':
        misc.register_custom_model(exp['config']['Q_model'], register_action_dist=False)
        misc.register_custom_model(exp['config']['policy_model'], register_action_dist=False)

        exp['config']['buffer_size'] = int(exp['config']['buffer_size'])
    else:
        raise NotImplementedError('Unrecognized algorithm {}'.format(exp['run']))

    # Start ray (earlier to accomodate in-the-env agent)
    args.temp_dir = os.path.abspath(os.path.expanduser(args.temp_dir))
    ray_init_kwargs = {
        'local_mode': args.local_mode,
        '_temp_dir': args.temp_dir,
        'include_dashboard': False
    }
    if 'ray_resources' in exp.keys():
        if 'num_cpus' in exp['ray_resources'].keys():
            ray_init_kwargs['num_cpus'] = exp['ray_resources']['num_cpus']
        if 'num_gpus' in exp['ray_resources'].keys():
            ray_init_kwargs['num_gpus'] = exp['ray_resources']['num_gpus']
    ray.init(**ray_init_kwargs)

    # Set callbacks (should be prior to setting mult-agent attribute)
    policy_manager = misc.PolicyManager(env_creator, exp['config'])
    if 'callbacks' in exp['config'].keys():
        agent_ids = policy_manager.env.agent_ids
        misc.set_callbacks(exp, agent_ids)

    # Setup multi-agent
    if 'multiagent' in exp['config'].keys():
        policy_ids = exp['config']['multiagent']['policies']
        exp['config']['multiagent']['policies'] = dict()
        for p_id in policy_ids:
            exp['config']['multiagent']['policies'][p_id] = policy_manager.get_policy(p_id)
        exp['config']['multiagent']['policy_mapping_fn'] = policy_manager.get_policy_mapping_fn(
            exp['config']['multiagent']['policy_mapping_fn'])

    # Convert to Experiment object
    exp['config']['env'] = exp["env"] # move env inside config to follow Experiment format
    del exp['env']
    exp['restore'] = None if 'restore' not in exp.keys() else exp['restore']
    exp['keep_checkpoints_num'] = None if 'keep_checkpoints_num' not \
        in exp.keys() else exp['keep_checkpoints_num']
    exp['resources_per_trial'] = None if 'resources_per_trial' not \
        in exp.keys() else exp['resources_per_trial']

    # Get trainer
    trainer = get_trainer_class(exp['run'])
    if args.resume:
        misc.resume_from_latest_ckpt(exp, exp_names[0])

    # Run experiment
    tune.run(trainer,
             name=exp_names[0],
             stop=exp['stop'],
             config=exp['config'],
             resources_per_trial=exp['resources_per_trial'],
             num_samples=exp['num_samples'],
             local_dir=exp['local_dir'],
             keep_checkpoints_num=exp['keep_checkpoints_num'],
             checkpoint_freq=exp['checkpoint_freq'],
             checkpoint_at_end=exp['checkpoint_at_end'],
             verbose=verbose,
             trial_name_creator=exp['trial_name_creator'],
             restore=exp['restore'],
             )

    # End ray
    ray.shutdown()


def trial_name_creator(trial):
    githash = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).\
                                      strip().decode('utf-8')
    return str(trial) + '_' + githash


if __name__ == '__main__':
    main()
