import numpy as np
import os
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml
from types import SimpleNamespace as SN


from run import run

SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console
logger = get_logger()

ex = Experiment("pymarl")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")


@ex.main
def my_main(_run, _config, _log):
    # Setting the random seed throughout the modules
    config = config_copy(_config)
    np.random.seed(config["seed"])
    th.manual_seed(config["seed"])
    config['env_args']['seed'] = config["seed"]

    # run the framework
    run(_run, config, _log)


def _get_config(params, arg_name, subfolder):
    config_name = None
    for _i, _v in enumerate(params):
        if _v.split("=")[0] == arg_name:
            config_name = _v.split("=")[1]
            del params[_i]
            break

    if config_name is not None:
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f)
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc)
        return config_dict


def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d


def config_copy(config):
    if isinstance(config, dict):
        return {k: config_copy(v) for k, v in config.items()}
    elif isinstance(config, list):
        return [config_copy(v) for v in config]
    else:
        return deepcopy(config)


if __name__ == '__main__':
    params = deepcopy(sys.argv)

    # Get the defaults from default.yaml
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
        try:
            config_dict = yaml.load(f)
        except yaml.YAMLError as exc:
            assert False, "default.yaml error: {}".format(exc)

    # Load algorithm and env base configs
    env_config = _get_config(params, "--env-config", "envs")
    alg_config = _get_config(params, "--config", "algs")
    # config_dict = {**config_dict, **env_config, **alg_config}
    config_dict = recursive_dict_update(config_dict, env_config)
    config_dict = recursive_dict_update(config_dict, alg_config)

    # now add all the config to sacred
    ex.add_config(config_dict)

    args_temp = SN(**config_dict)
    args_env_temp = SN(**env_config)

    # args_temp.env_args['map_name']
    # if args_temp.algname is not None:

    save_name = ''

    save_name = save_name + '_alg_' + str(args_temp.algname)

    # if args_temp.learner == "wsac_learner":
    #     save_name = str(args_temp.env_args['map_name']) + '_' + str(args_temp.learner[0:4]) \
    #                 + '_combineEPS' + str(args_temp.combine_eps) + '_lr' + str(args_temp.lr)
    #     save_name = save_name + '_adapAlpha_' + str(args_temp.adap_total_alpha) + str(args_temp.tar_ent_ratio)
    #     if args_temp.adap_total_alpha_ver == 2:
    #         save_name = save_name + '_v_' + str(args_temp.adap_total_alpha_ver) +'_tau_' + str(args_temp.adap_total_alpha_tau)+'_start_'+str(args_temp.adap_total_alpha_start)
    #     save_name = save_name + '_Ec_' + str(args_temp.contribution_exploration) + str(
    #         args_temp.contribution_temperature) + '_Em_' + str(args_temp.minimum_exploration) + '_Et_' + str(args_temp.tderror_exploration) + str(args_temp.tderror_temperature)
    #     save_name = save_name + '_combine_eps' + str(args_temp.combine_eps)
    # elif args_temp.learner == "ocwsac_learner":
    #     save_name = str(args_temp.env_args['map_name']) + '_' + str(args_temp.learner)
    # elif 'mujoco_multi' in args_temp.env:
    #     save_name = save_name + '_map_' + str(args_temp.env_args['scenario_name']) + str(
    #         args_temp.env_args['agent_obsk']) + str(args_temp.env_args['obs_add_global_pos'])
    # elif 'particle' in args_temp.env:
    #     save_name = save_name + '_map_' + str(args_temp.env_args['scenario_name'])
    #     if args_temp.adap_total_alpha_ver == 2:
    #         save_name = save_name + '_v_' + str(args_temp.adap_total_alpha_ver) +'_tau_' + str(args_temp.adap_total_alpha_tau)
    #     save_name = save_name + '_Ec_' + str(args_temp.contribution_exploration) + str(
    #         args_temp.contribution_temperature) + '_Em_' + str(args_temp.minimum_exploration) + '_Et_' + str(args_temp.tderror_exploration) + str(args_temp.tderror_temperature)
    #     save_name = save_name + '_combine_eps' + str(args_temp.combine_eps)
    # elif args_temp.learner == "fop_mmi_learner":
    #     save_name = str(args_temp.env_args['map_name']) + '_mmi' + str(args_temp.mmi_ver)
    # elif args_temp.learner == "fop_learner":
    #     ave_name = str(args_temp.env_args['map_name'])


    file_obs_path = os.path.join(results_path, "sacred_" + save_name)

    # Save to disk by default for sacred
    logger.info("Saving to FileStorageObserver in results/sacred.")
    ex.observers.append(FileStorageObserver.create(file_obs_path))
    
    ex.run_commandline(params)

