import numpy as np
import os
import datetime
import collections
from os.path import dirname, abspath
from copy import deepcopy
from sacred import Experiment, SETTINGS
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import sys
import torch as th
from utils.logging import get_logger
import yaml

# import run program
from run import run as run
from baseline_run import run as baseline_run
from data_collect import run as data_collect
from bikt import run as bikt



SETTINGS['CAPTURE_MODE'] = "fd" 
logger = get_logger()

ex = Experiment("BiKT")
ex.logger = logger
ex.captured_out_filter = apply_backspaces_and_linefeeds

results_path = os.path.join(dirname(dirname(abspath(__file__))), "results")


@ex.main
def my_main(_run, _config, _log):
    config = config_copy(_config)
    np.random.seed(config["seed"])
    th.manual_seed(config["seed"])
    config['env_args']['seed'] = config["seed"]
    if config['run_file'].startswith('bikt'):
        bikt(_run, config, _log)
    else:
        run(_run, config, _log)


def _get_config(params, arg_name, subfolder):
    config_name = None
    for _i, _v in enumerate(params):
        if _v.split("=")[0] == arg_name:
            config_name = _v.split("=")[1]
            del params[_i]
            break

    if config_name is not None:
        with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f:
            try:
                config_dict = yaml.load(f)
            except yaml.YAMLError as exc:
                assert False, "{}.yaml error: {}".format(config_name, exc)

        return config_dict
    else:
        return {}

def _get_run_file(params):
    run_file = ''
    for _i, _v in enumerate(params):
        if _v.startswith('--') and '=' not in _v:
            run_file = _v[2:]
            del params[_i]
            return run_file
    return run_file

def recursive_dict_update(d, u):
    for k, v in u.items():
        if isinstance(v, collections.abc.Mapping):
            d[k] = recursive_dict_update(d.get(k, {}), v)
        else:
            d[k] = v
    return d

def config_copy(config):
    if isinstance(config, dict):
        return {k: config_copy(v) for k, v in config.items()}
    elif isinstance(config, list):
        return [config_copy(v) for v in config]
    else:
        return deepcopy(config)

# get config from argv, such as "remark"
def _get_argv_config(params):
    config = {}
    to_del = []
    for _i, _v in enumerate(params):
        item = _v.split("=")[0]
        if item[:2] == "--" and item not in ["envs", "algs"]:
            config_v = _v.split("=")[1]
            try:
                config_v = eval(config_v)
            except:
                pass
            config[item[2:]] = config_v
            to_del.append(_v)
    for _v in to_del:
        params.remove(_v)
    return config

if __name__ == '__main__':
    params = deepcopy(sys.argv)

    # Get the defaults from default.yaml
    with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f:
        try:
            config_dict = yaml.load(f)
        except yaml.YAMLError as exc:
            assert False, "default.yaml error: {}".format(exc)

    # read the run file
    run_file = _get_run_file(params)
    config_dict['run_file'] = run_file

    # Load algorithm base configs
    alg_config = _get_config(params, "--config", "algs")
    # config_dict = {**config_dict, **alg_config}
    config_dict = recursive_dict_update(config_dict, alg_config)

    # get env type and load env config
    env_config = _get_config(params, "--env-config", "envs")
    config_dict = recursive_dict_update(config_dict, env_config)
    
    task_config = _get_config(params, "--task-config", "tasks")
    config_dict = recursive_dict_update(config_dict, task_config)

    config_dict = recursive_dict_update(config_dict, _get_argv_config(params))
    
    # overwrite map_name config
    if "map_name" in config_dict:
        config_dict["env_args"]["map_name"] = config_dict["map_name"]

    # get result path
    if 'remark' in config_dict:
        config_dict['remark'] = '_' + config_dict['remark']
    else:
        config_dict['remark'] = ''
    unique_token = "{}{}_{}".format(config_dict['name'], config_dict['remark'], datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
    config_dict['unique_token'] = unique_token
    
    if config_dict['evaluate']:
        results_path = os.path.join(results_path, 'evaluate')
    
    bikt_results_save_dir = os.path.join(
        results_path, config_dict['run_file'], config_dict['env'] + os.sep + config_dict['task']if config_dict['env'].startswith('sc2') else config_dict['env'], "train_dt", unique_token
    )
    pretrain_save_dir = os.path.join(
        results_path, config_dict['run_file'], config_dict['env'] + os.sep + config_dict['task']if config_dict['env'].startswith('sc2') else config_dict['env'], "pretrain",
        unique_token
    )
    
    pretrain_vqvae_save_dir = os.path.join(
        results_path, config_dict['run_file'], config_dict['env'] + os.sep + config_dict['task']if config_dict['env'].startswith('sc2') else config_dict['env'], "pretrain_vqvae",
        unique_token
    )
    
    dt_w_glsk_results_save_dir = os.path.join(
        results_path, config_dict['run_file'], config_dict['env'] + os.sep + config_dict['task']if config_dict['env'].startswith('sc2') else config_dict['env'], "train_dt_w_glsk",
        unique_token
    )
    
    vae_pretrain_save_dir =  os.path.join(pretrain_save_dir, 'vae')
    vqvae_pretrain_save_dir =  os.path.join(pretrain_vqvae_save_dir, 'vqvae')
    
    config_dict['bikt_results_save_dir'] = bikt_results_save_dir
    config_dict['vae_pretrain_save_dir'] = vae_pretrain_save_dir
    config_dict['vqvae_pretrain_save_dir'] = vqvae_pretrain_save_dir
    config_dict['dt_w_glsk_results_save_dir'] = dt_w_glsk_results_save_dir
    
    if config_dict['train_DT']:
        os.makedirs(bikt_results_save_dir, exist_ok=True)
        file_obs_path = os.path.join(bikt_results_save_dir, "sacred")
    elif config_dict['pretrain_vae']:
        os.makedirs(vae_pretrain_save_dir, exist_ok=True)
        file_obs_path = os.path.join(vae_pretrain_save_dir, "sacred")
    elif config_dict['pretrain_vqvae']:
        os.makedirs(vqvae_pretrain_save_dir,exist_ok=True)
        file_obs_path = os.path.join(vqvae_pretrain_save_dir, "sacred")
    elif config_dict['train_DT_w_glsk']:
        os.makedirs(dt_w_glsk_results_save_dir,exist_ok=True)
        file_obs_path = os.path.join(dt_w_glsk_results_save_dir, "sacred")


    ex.observers.append(FileStorageObserver.create(file_obs_path))

    ex.add_config(config_dict)

    ex.run_commandline(params)
