import sys
import importlib
import yaml
import traceback
import ray
import time

from expground import settings
from expground.types import Dict, PolicyConfig, RolloutConfig, TrainingConfig, Tuple
from expground.logger import Log
from expground.learner import get_learner
from expground.utils import rollout
from expground.cmd_utils import common_arg_parser, parse_unknown_args
from expground.utils.preprocessor import get_preprocessor
from expground.utils.logging import ExpConfig
from expground.utils.path import load_class_from_str, parse_env_config


def parse_learner_config(learner_config: Dict) -> Tuple[type, Dict]:
    learner_type = learner_config["type"]
    learner_cls = get_learner(learner_type)
    other_params = learner_config.get("params", {})
    return learner_cls, other_params


def parse_rollout_config(rollout_config: Dict) -> RolloutConfig:
    rollout_func = rollout.get_rollout_func(rollout_config["type"])
    fragment_length = rollout_config["fragment_length"]
    max_step = rollout_config["max_step"]
    num_simulation = rollout_config["num_simulation"]

    return RolloutConfig(
        caller=rollout_func,
        fragment_length=fragment_length,
        max_step=max_step,
        num_simulation=num_simulation,
        vector_mode=rollout_config.get("vector_mode", False),
        max_episode=rollout_config.get("max_episode"),
    )


def parse_training_config(global_config: Dict):
    if global_config.get("algorithm"):
        trainer_cls = load_class_from_str(
            "expground.algorithms", global_config["algorithm"]["trainer"]
        )
    else:
        trainer_cls = None

    return TrainingConfig(
        trainer_cls=trainer_cls, hyper_params=global_config["training_config"]
    )


def parse_cmdline_kwargs(args):
    """
    convert a list of '='-spaced command-line arguments to a dictionary, evaluating python objects when possible
    """

    def parse(v):

        assert isinstance(v, str)
        try:
            return eval(v)
        except (NameError, SyntaxError):
            return v

    return {k: parse(v) for k, v in parse_unknown_args(args).items()}


def main(args):
    arg_parser = common_arg_parser()
    args, unknown_args = arg_parser.parse_known_args(args)
    extra_args = parse_cmdline_kwargs(unknown_args)

    ray.init(local_mode=args.debug)

    Log.info("Config file path is specified, will load from: %s", args.config)
    with open(args.config, "r") as f:
        raw_yaml = yaml.safe_load(f)

    env_desc, env_lib = parse_env_config(raw_yaml["env_config"])
    learner_cls, other_params = parse_learner_config(raw_yaml["learner_config"])
    rollout_config = parse_rollout_config(raw_yaml["rollout_config"])
    trainining_config = parse_training_config(raw_yaml)

    # load sampler_config from environment lib
    env_config = env_desc["config"]
    action_spaces = env_config["action_spaces"]
    observation_spaces = env_config["observation_spaces"]
    preprocessors = {
        aid: get_preprocessor(
            observation_space,
            mode=raw_yaml["custom_config"].get("preprocess_mode", "flatten"),
        )(observation_space)
        for aid, observation_space in observation_spaces.items()
    }

    policy_cls = (
        load_class_from_str("expground.algorithms", raw_yaml["algorithm"]["policy"])
        if raw_yaml.get("algorithm")
        else None
    )
    loss_func = (
        load_class_from_str("expground.algorithms", raw_yaml["algorithm"]["loss"])
        if raw_yaml.get("algorithm")
        else None
    )

    # lambda for multi-agent cases
    sampler_config = lambda aid: env_lib.basic_sampler_config(
        observation_spaces[aid],
        action_spaces[aid],
        preprocessors[aid],
        capacity=raw_yaml["sampler_config"]["params"]["capacity"],
        learning_starts=raw_yaml["sampler_config"]["params"].get("learning_starts", -1),
    )

    policy_config = PolicyConfig(
        policy=policy_cls,
        mapping=lambda agent: agent,
        observation_space=lambda k: env_config["observation_spaces"][k],
        action_space=lambda k: env_config["action_spaces"][k],
        custom_config=raw_yaml.get("custom_config", {}),
        model_config=raw_yaml.get("model_config", {}),
    )
    exp_config = ExpConfig(raw_yaml, base_path=args.log_path, seed=args.seed)
    exp_config.init_logpath()
    try:
        algorithm = (
            raw_yaml["algorithm"]["name"] if raw_yaml.get("algorithm") else "default"
        )
        exp_prefix = raw_yaml.get("exp_prefix", None)

        learner = learner_cls(
            experiment=settings.EXP_NAME_FORMAT_LAMBDA(
                args=(
                    exp_prefix,
                    f"{env_desc['config']['env_id']}_{algorithm}",
                    str(time.time()).split(".")[0],
                )
            ),
            policy_config=policy_config,
            env_description=env_desc,
            rollout_config=rollout_config,
            training_config=trainining_config,
            loss_func=loss_func,
            exp_config=exp_config,
            **other_params,
        )
        general_stop_conditions = {
            "stop_conditions": raw_yaml["learner_config"]["stop_conditions"]
        }
        inner_stop_conditions = raw_yaml["learner_config"].get(
            "inner_stop_conditions", None
        )
        if inner_stop_conditions:
            general_stop_conditions.update(
                {"inner_stop_conditions": inner_stop_conditions}
            )
        analysis = learner.learn(sampler_config, **general_stop_conditions)
        print(analysis)
    except Exception as e:
        print(traceback.format_exc())
    finally:
        Log.info("saving model ....")
        if learner is not None:
            learner.save()
        ray.shutdown()
        sys.exit(0)


if __name__ == "__main__":
    main(sys.argv)
