import os
import sys
import re
import yaml
import copy
import itertools
import collections.abc
from functools import partial, partialmethod
from importlib import import_module
import socket
import numpy as np
from ray.tune.registry import register_env
from ray.rllib.models import ModelCatalog

import envs
import models
from . import callbacks
from .policy_manager import partialclass


def register_custom_env(env_name):
    """ Register custom environment in rllib. """
    def _env_creator_base(env_config, env_cls):
        # keep original env_config intact
        env_config = copy.copy(env_config)

        # append port / worker id
        if hasattr(env_config, 'worker_index'): # handle case of partial method
            if not env_config.worker_index: # for local process
                env_config.worker_index = 1
            port_range = env_config.pop("port_range") if "port_range" in env_config.keys() else [5005, 6000]
            port_not_in_use = [port for port in range(*port_range) if not is_port_in_use(port)]
            worker_id = port_not_in_use[env_config.worker_index] - 5005 # convert port to worker_id (only determine port in env creation)
            env_config['worker_id'] = worker_id

        # add wrapper
        if 'wrappers' in env_config.keys():
            _wrappers = env_config.pop('wrappers')
            _wrappers_config = env_config.pop('wrappers_config')
            _env = env_cls(**env_config)
            for _wrapper in _wrappers:
                if _wrappers_config and _wrapper in _wrappers_config.keys():
                    _wrapper_config = _wrappers_config[_wrapper]
                else:
                    _wrapper_config = dict()
                _wrapper = getattr(envs.wrappers, _wrapper)
                _env = _wrapper(_env, **_wrapper_config)
        else:
            _env = env_cls(**env_config)

        return _env

    # register environment
    _env_creator = partial(_env_creator_base, env_cls=getattr(envs, env_name))
    register_env(env_name, _env_creator)

    return _env_creator


def register_custom_model(model_config, register_action_dist=True):
    """ Register custom model in rllib. """
    # register action distribution
    if register_action_dist and 'custom_action_dist_config' in model_config['custom_model_config']:
        action_dist_config = model_config['custom_model_config']['custom_action_dist_config']
        ActDist = getattr(models, model_config['custom_action_dist'])
        ActDist = partialclass(ActDist, **action_dist_config)
        ModelCatalog.register_custom_action_dist(model_config['custom_action_dist'], ActDist)

    # register model
    if 'custom_model' not in model_config.keys():
        return 
    Model = getattr(models, model_config['custom_model'])
    ModelCatalog.register_custom_model(model_config['custom_model'], Model)


def load_yaml(fpath):
    """ Load yaml file with scientifc notation.
        Ref: https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number """
    with open(fpath) as f:
        loader = yaml.SafeLoader
        loader.add_implicit_resolver(
            u'tag:yaml.org,2002:float',
            re.compile(u'''^(?:
            [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
            |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
            |\\.[0-9_]+(?:[eE][-+][0-9]+)?
            |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
            |[-+]?\\.(?:inf|Inf|INF)
            |\\.(?:nan|NaN|NAN))$''', re.X),
            list(u'-+0123456789.'))
        data = yaml.load(f, Loader=loader)
    return data


def get_latest_checkpoint(ckpt_root_dir):
    ckpt_dirnames = [v for v in os.listdir(ckpt_root_dir) if 'checkpoint' in v]
    ckpt_nums = [int(v.split('_')[-1]) for v in ckpt_dirnames]
    latest_ckpt_dirname = ckpt_dirnames[np.argmax(ckpt_nums)]
    latest_ckpt = os.path.join(ckpt_root_dir, latest_ckpt_dirname, latest_ckpt_dirname.replace('_', '-'))
    assert os.path.exists(latest_ckpt)
    return latest_ckpt


def update_dict(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = update_dict(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def resume_from_latest_ckpt(exp, exp_name):
    exp_dir = os.path.join(exp['local_dir'], exp_name)
    assert os.path.isdir(exp_dir)
    trial_dirs = []
    for v in os.listdir(exp_dir):
        v = os.path.join(exp_dir, v)
        if os.path.isdir(v):
            trial_dirs.append(v)
    trial_time = ['_'.join(v.split('_')[-2:]) for v in trial_dirs]
    latest_trial_dir = trial_dirs[np.argmax(trial_time)]
    ckpt_dirnames = [v for v in os.listdir(latest_trial_dir) if 'checkpoint' in v]
    latest_ckpt_dirname = ckpt_dirnames[np.argmax([int(v.split('_')[-1]) for v in ckpt_dirnames])]
    latest_ckpt = os.path.join(latest_trial_dir, latest_ckpt_dirname, latest_ckpt_dirname.replace('_', '-'))
    assert os.path.exists(latest_ckpt)
    exp['restore'] = latest_ckpt


def get_dict_value_by_str(data, string, delimiter=':'):
    """ Recursively access nested dict with string,
        e.g., get_dict_value_by_str(d, 'a:b') := d['a']['b']. """
    assert isinstance(data, dict)
    keys = string.split(delimiter)
    if len(keys) == 1:
        return data[keys[0]]
    else:
        string = ':'.join(keys[1:])
        return get_dict_value_by_str(data[keys[0]], string, delimiter)


def set_dict_value_by_str(data, string, value, delimiter=':'):
    """ Recursively access nested dict with string and set value,
        e.g., set_dict_value_by_str(d, 'a:b', 1) := d['a']['b'] = 1. """
    assert isinstance(data, dict)
    keys = string.split(delimiter)
    if len(keys) == 1:
        data[keys[0]] = value
        return
    else:
        string = ':'.join(keys[1:])
        set_dict_value_by_str(data[keys[0]], string, value, delimiter)


def set_callbacks(exp, agent_ids):
    """ Set callbacks to a callback class by string. """
    _callbacks = getattr(callbacks, exp['config']['callbacks'])
    if 'multiagent' in exp['config'].keys() and 'callbacks_config' in exp['config']['multiagent'].keys():
        kwargs = exp['config']['multiagent']['callbacks_config']
    else:
        kwargs = dict()
    _callbacks = partial(_callbacks, agent_ids=agent_ids, **kwargs)
    exp['config']['callbacks'] = _callbacks


def is_port_in_use(port):
    if sys.platform in ['linux', 'linux2']:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        try:
            s.bind(("localhost", port))
        except:
            return True
        finally:
            s.close()
            return False
    elif sys.platform == 'win32':
        return False # TODO: check port use in windows
    elif sys.platform == 'darwin':
        return False # TODO: check port use in MacOS
    else:
        raise NotImplementedError("Check-used-port function is not implemented in {} system.".format(sys.platform))
