import argparse
import sys
import json



train_defaults = [("save_reset_freq", 10), ("test_log_freq", 100),
                  ("smoothing_eps", 0.0), ("pos_weight", 1.0), ("lr_net", 0.005), ("ws_reg", 0.01),
                  ("ws_norm", -1), ("init_scale", 1.0), ("optimizer_net", "SGD"), ("gradient_step_ratio", 0),
                  ("vis_batch_size", 5), ("curriculum", False)]

model_defaults = {"MLP": [("world_state_size", 100), ("n_hidden", 4), ("hidden_size", 512), ("num_entities", 32), ("num_relations", 3)],
                  "Residual": [("world_state_size", 100), ("n_blocks", 2), ("block_depth", 2), ("hidden_size", 512), ("num_entities", 32), ("num_relations", 3)],
                  "TransformerExtractor": [("world_state_size", 100), ("n_blocks", 2), ("embedding_size", 64),
                                           ("num_entities", 32), ("num_relations", 3),
                                           ("world_token_count", 2), ("nheads", 8),
                                           ("dim_feedforward", 2048)],
                  "TransformerExtractorV2": [("world_state_size", 100), ("n_blocks", 2), ("embedding_size", 64),
                                             ("num_entities", 32), ("num_relations", 3),
                                             ("world_token_count", 2), ("nheads", 8), ("dim_feedforward", 2048),
                                              ("special_mode", {})],
                  "TransformerUpdater": [("world_state_size", 100), ("n_blocks", 6), ("embedding_size", 64),
                                         ("num_entities", 32), ("num_relations", 3),
                                         ("world_token_count", 5), ("nheads", 8), ("dim_feedforward", 2048),
                                          ("special_mode", {})],
                  "TransformerUpdaterV2": [("world_state_size", 100), ("n_blocks", 6), ("embedding_size", 64),
                                           ("num_entities", 32), ("num_relations", 3),
                                           ("world_token_count", 5), ("nheads", 8), ("dim_feedforward", 2048),
                                            ("special_mode", {}),
                                           ("n_pre_blocks", 2), ("n_post_blocks", 2)]
                  }


updater_defaults = [("updater_ws_reg", 0.05), ("updater_lr_net", 0.005),
                    ("updater_optimizer_net", "SGD"),
                    ("n_updater_updates", 1000), ("n_updater_steps", 10),
                    ("min_train", 1), ("max_train", 1),
                    ("min_test", 0), ("max_test", 2),
                    ("train_samples_per_step", 100), ("test_samples_per_step", 100),
                    ("updater_smoothing_eps", 0.00), ("variable_update_requests", 0),
                    ("lr_scheduler", None), ("updater_lr_scheduler", None)
                    ]

extra_args = [("save_path",), ("load_path",)]



class DefaultSetter():

    def __init__(self, sim_defaults):
        self.sim_defaults = sim_defaults
        self.train_defaults = train_defaults
        self.updater_defaults = updater_defaults
        self.model_defaults = model_defaults
        self.extra_args = extra_args



    def add_core_defaults(self, params):
        self.add_model_defaults(params)
        add_default(params, self.sim_defaults)
        add_default(params, self.train_defaults)
        return params

    def add_model_defaults(self, params, arg_name="model_vars"):
        if arg_name not in params:
            params[arg_name] = {"model_name": "MLP"}
        elif "model_name" not in params[arg_name]:
            params[arg_name]["model_name"] = "MLP"
        if "model_inits" not in params[arg_name]:
            params[arg_name]["model_inits"] = {}
        add_default(params[arg_name]["model_inits"], self.model_defaults[params[arg_name]["model_name"]])


    def add_updater_defaults(self, params, extractor_params):
        add_default(params, updater_defaults)
        self.add_model_defaults(params, arg_name="updater_model_vars")

        params["updater_model_vars"]["model_inits"]["num_entities"] = extractor_params["model_inits"]["num_entities"]
        params["updater_model_vars"]["model_inits"]["num_relations"] = extractor_params["model_inits"][
            "num_relations"]
        params["updater_model_vars"]["model_inits"]["world_state_size"] = \
        extractor_params["model_inits"][
            "world_state_size"]

    def check_model_args(self, model_dict):
        default_set = set([x[0] for x in self.model_defaults[model_dict["model_name"]]])
        for k in model_dict["model_inits"]:
            if k not in default_set:
                print("arg {} not recognized".format(k))

    def check_args(self, config):
        all_args = set()
        for defaults in [self.sim_defaults, self.train_defaults, self.updater_defaults, self.extra_args]:
            all_args.update([x[0] for x in defaults])
        all_args.update(["updater_model_vars", "model_vars"])
        for k in config:
            if k not in all_args:
                print("arg {} not recognized".format(k))

        if "updater_model_vars" in config:
            self.check_model_args(config["updater_model_vars"])
        self.check_model_args(config["model_vars"])


pathfinder_sim_defaults = [("world_resets", 50), ("n_worlds", 5), ("sampling_schedule", "uniform"),
                           ("direct_samples", 1024), ("tail_samples", 1),
                           ("random_seed", 0),
                           ("acc_thresh", (.25,.25))]

PathfinderDefaults = DefaultSetter(pathfinder_sim_defaults)


conway_sim_defaults = [("world_resets", 50),  ("world_length", 14),
                ("n_worlds", 5), ("train_size", 100000),
                ("min_init", 1), ("max_init", 1),
               ("random_seed", 0), ("acc_thresh", (.5, 0.01)),
                ("grid_restriction", False)]


ConwayDefaults = DefaultSetter(conway_sim_defaults)


mnist_sim_defaults = [("world_resets", 50),
            ("n_worlds", 5), ("random_seed", 0), ("acc_thresh", (.5, 0.01)),
            ("num_digits", 10), ("shift_min", 1), ("shift_max", 1)]


MNISTDefaults = DefaultSetter(mnist_sim_defaults)


def add_default(params, default_vars):
    for arg, default in default_vars:
        if arg not in params:
            params[arg] = default


def read_params():
    return json.loads(sys.argv[1])

def merge_two_configs(conf_original, conf_new):

    for key in conf_original:

        if key not in conf_new:
            conf_new[key] = conf_original[key]

        elif isinstance(conf_new[key], dict):
            conf_new[key] = merge_two_configs(conf_original[key], conf_new[key])


    return conf_new



