from __future__ import with_statement

import collections
import itertools
import json
import operator
import os
from collections import namedtuple
from copy import deepcopy
from functools import reduce
import random
import yaml

from codes.utils.config import read_config_file
from codes.utils.util import natural_keys

Hyperparam = namedtuple("Hyperparam", ["key_list", "value"])


def getFromDict(dataDict, mapList):
    return reduce(operator.getitem, mapList, dataDict)


def setInDict(dataDict, mapList, value):
    getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value


def create_list_of_Hyperparams(hyperparams_dict):
    sep = "$$"

    def _flatten(d, parent_key=""):
        """Taken from https://stackoverflow.com/questions/6027558/flatten-nested-python-dictionaries-compressing-keys"""
        items = []
        for k, v in d.items():
            new_key = parent_key + sep + k if parent_key else k
            if isinstance(v, collections.MutableMapping):
                items.extend(_flatten(v, new_key).items())
            else:
                items.append((new_key, v))
        return dict(items)

    flattend_dict = _flatten(hyperparams_dict)
    Hyperparam_list = []
    for key, val_list in flattend_dict.items():
        temp_list = []
        keylist = key.split(sep)
        for val in val_list:
            temp_list.append(Hyperparam(keylist, val))
        Hyperparam_list.append(temp_list)
    return itertools.product(*Hyperparam_list, repeat=1)


def create_configs(
    start_config_id=0,
    tag="",
    base_config_file="multitask/signature_learn_1",
    num_trials=0,
):
    """
    :param num_trials: if > 0, then choose that many trials randomly
    """
    # base_config = read_config_file(config_id="e120")

    path = os.path.dirname(os.path.realpath(__file__)).split("/codes")[0]
    target_dir = os.path.join(path, "config")

    # base_str_id = str("sines_0")
    base_config = read_config_file(config_id=base_config_file)
    # Final ICML Runs
    hyperparams_dict = {
        "general": {
            "base_path": [os.path.expanduser("~/checkpoint/lgw/ckpt/")],
            "batch_size": [256],
            "train_mode": ["supervised_valid"],
            "data_name": [
                "comp_r10_n100_ov"
            ],  # 'composition_100_easy_10','composition_100_10_hard', 'representation_copy_sanity_100_easy_10', 'representation_copy_sanity_100_10_hard',
            "train_rule": ["rule_51", "rule_52", "rule_53"],
            "test_rule": ["rule_4"],
            "is_meta": [False],
            "seed": [42],
        },
        "logger": {
            "should_use": [False],
            "project_name": ["lgw_icml_supervised_logic_all_gamma"],
        },
        "model": {
            "learn_relation_weights": [True],
            "num_nodes": [100],
            "num_epochs": [501],
            "persist_frequency": [100],  # should be 1 for seq_mult
            "weight_norm": [True],
            "optim": {
                "name": ["Adam"],
                "learning_rate": [0.0001],
                "weight_decay": [0.0001],
                "inner_weight_decay": [0.0000001],
                "scheduler_gamma": [0.8],
                "scheduler_patience": [10],
            },
            "lr_inner": [0.001],
            "clamp": [0.1],
            "tasks_per_metaupdate": [10],
            "num_inner_updates": [1],
            "gat": {
                "num_layers": [6],
                "num_heads": [2],
                # 'bias': [True, False],
                # 'concat': [True, False],
                "dropout": [0.4],
            },
            "signature_gat": {
                "num_layers": [2],
                "num_heads": [2],
                # 'bias': [True, False],
                # 'concat': [True, False],
                "dropout": [0.4],
            },
            "gcn": {"cached": [False]},
            "relation_embedding_dim": [200],
            "composition_fn_path": [
                "codes.model.gat.edge_gat.GatedGatEncoder",
                "codes.model.rgcn.rgcn.CompositionRGCNEncoder",
            ],  # "codes.model.gat.edge_gat.GatedGatEncoder". "codes.model.rgcn.rgcn.CompositionRGCNEncoder",
            "representation_fn_path": [
                "codes.model.gat.sig_edge_gat.GatedNodeGatEncoder",
                "codes.model.models.Param",
                "codes.model.gcn.gcn.RepresentationGCNEncoder",
            ],  # codes.model.gat.sig_edge_gat.GatedNodeGatEncoder , "codes.model.models.ParamLinear",
            "use_composition_fn": [False],
            "use_representation_fn": [False],
            "freeze_composition_fn": [False],
            "freeze_representation_fn": [False],
            # 'classify_layers': [1, 2],
            # 'classify_hidden': [50, 100, 200]
        },
    }

    ## supervised pretrained models
    pretrained = {
        "codes.model.gat.sig_edge_gat.GatedNodeGatEncoder": {
            "codes.model.gat.edge_gat.GatedGatEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_6"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_6"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_0"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_6)"
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_easy_easy_0"
                ),
            },
            "codes.model.rgcn.rgcn.CompositionRGCNEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_9"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_9"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_3"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_9")
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask/multitask_logic_easy_easy_3"
                ),
            },
        },
        "codes.model.models.Param": {
            "codes.model.gat.edge_gat.GatedGatEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_7"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_7"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_1"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_7")
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_easy_easy_1"
                ),
            },
            "codes.model.rgcn.rgcn.CompositionRGCNEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_10"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_10"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_4"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_10")
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_easy_easy_4"
                ),
            },
        },
        "codes.model.gcn.gcn.RepresentationGCNEncoder": {
            "codes.model.gat.edge_gat.GatedGatEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_8"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_8"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_2"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_8")
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_easy_easy_2"
                ),
            },
            "codes.model.rgcn.rgcn.CompositionRGCNEncoder": {
                "representation_copy_sanity_100_easy_10": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_11"
                ),
                "composition_100_10_hard": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_exp_icml_11"
                ),
                "comp_r10_n100_t1": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_5"
                ),
                # 'comp_r10_n100_ov': os.path.expanduser("~/checkpoint/lgw/ckpt/multitask/multitask_logic_uc_ur_11")
                "comp_r10_n100_ov": os.path.expanduser(
                    "~/checkpoint/lgw/ckpt/multitask/multitask_logic_easy_easy_5"
                ),
            },
        },
    }
    rule_folders = []
    n = 50
    # get all supervised train samples
    if hyperparams_dict["general"]["train_rule"][0] == "all":
        print("getting all train rules")
        data_path = os.path.join(
            os.path.expanduser("~/checkpoint/lgw/data"),
            hyperparams_dict["general"]["data_name"][0],
            "train",
        )
        rule_folders = [
            folder.split("/")[-1]
            for folder in os.listdir(data_path)
            if os.path.isdir(os.path.join(data_path, folder))
            and os.path.exists(os.path.join(data_path, folder, "config.json"))
        ]
        rule_folders.sort(key=natural_keys)
        if hyperparams_dict["general"]["train_mode"][0] == "supervised":
            # rule_folders = [r for r in rule_folders if r not in ['rule_48','rule_4','rule_14']]
            hyperparams_dict["general"]["train_rule"] = rule_folders
        elif hyperparams_dict["general"]["train_mode"][0] == "run_mult" and n > 0:
            hyperparams_dict["general"]["train_rule"] = [
                ",".join(random.sample(rule_folders, n))
            ]

    total_combinations = list(create_list_of_Hyperparams(hyperparams_dict))
    if num_trials > 0:
        print("choosing {} trials at random".format(num_trials))
        total_combinations = random.sample(total_combinations, num_trials)
    question = "Generating {} combinations, ok ? [y/n]".format(len(total_combinations))
    reply = str(input(question + " (y/n): ")).lower().strip()
    exp_ct = start_config_id
    if reply[0] == "y":
        with open(os.path.join(target_dir, "run.sh"), "w") as sfp:
            sfp.write("#!/bin/sh\n")
            for hp_idx, hyperparams in enumerate(total_combinations, start_config_id):
                new_config = deepcopy(base_config)
                current_str_id = base_config_file + tag + str(exp_ct)
                new_config["general"]["id"] = current_str_id
                new_config["model"]["save_dir"] = (
                    os.path.expanduser("~/checkpoint/lgw/ckpt/") + current_str_id
                )
                for hyperparam in hyperparams:
                    setInDict(new_config, hyperparam.key_list, hyperparam.value)
                new_config_file = os.path.join(target_dir, current_str_id + ".yaml")
                if new_config["general"]["train_mode"] in [
                    "seq_mult",
                    "seq_mult_comp",
                    "seq_mult_rep",
                ]:
                    new_config["general"]["train_rule"] = ",".join(rule_folders)
                # if new_config["general"]["train_mode"] == "supervised":
                #     if new_config["model"]["composition_fn_path"] == "codes.model.gat.edge_gat.GatedGatEncoder" and new_config["model"]["representation_fn_path"] == "codes.model.gat.sig_edge_gat.GatedNodeGatEncoder":
                #         continue
                # if new_config["model"]["use_composition_fn"] and new_config["model"]["freeze_composition_fn"]:
                #     continue
                # if "composition" in new_config["general"]["data_name"]:
                #     # skipping
                #     if new_config["model"]["use_composition_fn"] or new_config["model"]["freeze_composition_fn"]:
                #         continue
                # if "representation" in new_config["general"]["data_name"]:
                # load the correct model
                # if new_config["general"]["train_mode"] != "supervised":
                # new_config["model"]["load_dir"] = pretrained[new_config["model"]["representation_fn_path"]][new_config["model"]["composition_fn_path"]][new_config["general"]["data_name"]]
                # # with open(new_config_file, "w") as f:
                #     f.write(json.dumps(new_config, indent=4))
                yaml.dump(
                    new_config, open(new_config_file, "w"), default_flow_style=False
                )
                sfp.write(
                    "python "
                    + os.path.expanduser("~/mlp/lgw/codes/app/submitit_runner.py")
                    + " --config_id {}\n".format(current_str_id)
                )
                exp_ct += 1

        return hp_idx


def create_pt(start_config_id, end_config_id, prefix):
    path = os.path.dirname(os.path.realpath(__file__)).split("/codes")[0]
    target_dir = os.path.join(path, "config")
    template_command = [
        "cp config/{}.yaml {}.yaml",
        "pt run {}.yaml pt_logs/{}",
        "rm config/{}.yaml",
    ]

    current_str_id = "pt" + str(0)
    current_config = read_config_file(config_id=current_str_id)
    jobs = []

    for config_id in range(start_config_id, end_config_id + 1):
        jobs.append(deepcopy(current_config["jobs"][0]))
        jobs[-1]["name"] = str(config_id)
        jobs[-1]["command"][0] = "{} {}{}".format(
            current_config["jobs"][0]["command"][0].split(" ")[0], prefix, config_id
        )

    new_str_id = "pt" + str(1)
    new_config_file = target_dir + "/{}.yaml".format(new_str_id)
    with open(new_config_file, "w") as f:
        new_config = read_config_file(config_id=current_str_id)
        new_config["jobs"] = jobs

        f.write(json.dumps(new_config, indent=4))
        # print(config_id)
        print(template_command[1].format(new_str_id, new_str_id))


# start_config_id = 1
# prefix = "sines_"
# end_config_id = create_configs(start_config_id=start_config_id, prefix=prefix)

# create_pt(start_config_id=start_config_id,
#           end_config_id=end_config_id,
#           prefix=prefix)

if __name__ == "__main__":
    create_configs(
        base_config_file="supervised/supervised",
        tag="_logic_sup_valid_",
        num_trials=-1,
        start_config_id=0,
    )
