import numpy as np
import os
import datetime
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml

# import run program
from run_online import run as run_online
from run_ad import run_ad
from run_population import run_population

SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger()

ex = Experiment("pymarl-ext")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")

def _get_run_file(params):
    run_file = None
    for _i, _v in enumerate(params):
        if _v.startswith('--') and '=' not in _v:
            run_file = _v[2:]
            del params[_i]
            return run_file
    return run_file

def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d

def config_copy(config):
    if isinstance(config, dict):
        return {k: config_copy(v) for k, v in config.items()}
    elif isinstance(config, list):
        return [config_copy(v) for v in config]
    else:
        return deepcopy(config)

def _get_config(params, arg_name, subfolder):
    config_name = None
    for _i, _v in enumerate(params):
        if _v.split("=")[0] == arg_name:
            config_name = _v.split("=")[1]
            del params[_i]
            break

    if config_name is not None:
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f, Loader=yaml.CLoader)
            except AttributeError:
                config_dict = yaml.load(f, Loader=yaml.Loader)
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc)

        return config_dict
    else:
        return {}

# get config from argv, such as "remark"
def _get_argv_config(params):
    config = {}
    to_del = []
    for _i, _v in enumerate(params):
        item = _v.split("=")[0]
        if item[:2] == "--" and item not in ["envs", "algs"]:
            config_v = _v.split("=")[1]
            try:
                config_v = eval(config_v)
            except:
                pass
            config[item[2:]] = config_v
            to_del.append(_v)
    for _v in to_del:
        params.remove(_v)
    return config


@ex.main
def my_main(_run, _config, _log):
    # Setting the random seed throughout the modules
    config = config_copy(_config)
    np.random.seed(config["seed"])
    th.manual_seed(config["seed"])
    config['env_args']['seed'] = config["seed"]

    if config['run_file'] == 'run_online':
        run_online(_run, config, _log)
    elif config['run_file'] == 'run_ad':
        run_ad(_run, config, _log)
    elif config['run_file'] == 'run_population':
        run_population(_run, config, _log)
    else:
        raise ValueError(f"Undefined run file {config['run_file']}")


if __name__ == '__main__':
    params = deepcopy(sys.argv)

    # Get the defaults from default.yaml
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
        try:
            config_dict = yaml.load(f, Loader=yaml.CLoader)
        except AttributeError:
            config_dict = yaml.load(f, Loader=yaml.Loader)
        except yaml.YAMLError as exc:
            assert False, "default.yaml error: {}".format(exc)

    # read the run file, default run_file is run_online
    run_file = _get_run_file(params)
    if run_file is None: run_file = 'run_online'
    config_dict['run_file'] = run_file

    # Load algorithm base configs
    alg_config = _get_config(params, "--config", "algs")
    # config_dict = {**config_dict, **alg_config}
    config_dict = recursive_dict_update(config_dict, alg_config)

    # get env type and load env config
    env_config = _get_config(params, "--env-config", "envs")
    config_dict = recursive_dict_update(config_dict, env_config)

    # read other commandline args
    commandline_options = _get_argv_config(params)
    # overwrite map_name config
    if "map_name" in commandline_options:
        config_dict["env_args"]["map_name"] = commandline_options["map_name"]
        del commandline_options["map_name"]
    config_dict = recursive_dict_update(config_dict, commandline_options)
    
    config_dict['remark'] = '_' + config_dict['remark'] if 'remark' in config_dict else ''
    unique_token = "{}{}_{}".format(config_dict['name'], config_dict['remark'], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    config_dict['unique_token'] = unique_token
    
    if config_dict['evaluate']:
        results_path = os.path.join(results_path, 'evaluate')
    results_save_dir = os.path.join(
        results_path, run_file, 
        config_dict['env'] + os.sep + config_dict['env_args']['map_name'] if 'map_name' in config_dict['env_args'] else config_dict['env'], 
        config_dict['name'] + config_dict['remark'],
        unique_token
    )
    # Path without the unique token for saving common models or others
    results_dir_without_token = os.path.join(dirname(results_save_dir))

    os.makedirs(results_save_dir, exist_ok=True)
    config_dict['results_save_dir'] = results_save_dir
    config_dict['results_dir_without_token'] = results_dir_without_token

    # Save to disk by default for sacred
    file_obs_path = os.path.join(results_save_dir, "sacred")
    ex.observers.append(FileStorageObserver.create(file_obs_path))
    logger.info("Saving to FileStorageObserver in {}.".format(file_obs_path))

    # now add all the config to sacred
    ex.add_config(config_dict)

    ex.run_commandline(params)
