import json, yaml
import argparse

with open("exp_params/abb.json") as f:
    abb = json.load(f)


def format_params(params):
    sparams = str(params)
    for key in params.keys():
        sparams = sparams.replace(f"'{key}':", f"{key}:")
    return sparams


def gen_script(
    exp_name,
    exp_detail,
    class_sampler_opt,
    miner_opt,
    optimizer_opt,
    test_opt,
    batch128=False,
):
    with open(f"scripts/format{'' if not batch128 else '_batch128'}.txt") as file:
        script = file.read()

    params = format_params(params=exp_detail["params"])
    script = script.format(
        exp_name=exp_name,
        loss_name=exp_detail["name"],
        loss_params=params,
        mpcs_opt=class_sampler_opt,
        miner_opt=miner_opt,
        optimizer_opt=optimizer_opt,
        test_opt=test_opt,
    )
    return script


def get_exp_details(loss, abb):
    base = abb["losses"][loss]["base"]
    prefix = abb["losses"][loss].get("prefix", "")
    suffix = abb["losses"][loss].get("suffix", "")
    sub = abb["losses"][loss].get("sub", "")
    exp_name = prefix + base + suffix + sub

    with open(f"exp_params/{base}.json") as file:
        exp_detail = json.load(file)

    exp_detail["name"] = prefix + exp_detail["name"] + suffix
    if "Proxy" in prefix or "proxy" in prefix or exp_detail["type"] == "classification":
        exp_detail["params"]["embedding_size"] = 128
        exp_detail["type"] = "classification"
    elif "MeanField" in prefix or "meanField" in prefix:
        exp_detail["params"]["embedding_size"] = 128
        exp_detail["type"] = "classification"
        exp_detail["params"]["mf_reg~BAYESIAN~"] = [0, 1]
    if "WithMFReg" in suffix:
        exp_detail["params"]["mf_reg~BAYESIAN~"] = [0, 1]
    elif "WithRegFace" in suffix:
        exp_detail["params"]["rf_reg~BAYESIAN~"] = [0, 1]

    miner_opt = (
        "--mining_funcs {tuple_miner: {MultiSimilarityMiner: {epsilon: 0.1}}}"
        if "WithMSMiner" in sub
        else ""
    )

    if "l1" in sub:
        exp_detail["params"]["mf_power"] = 1
    elif "l2" in sub:
        exp_detail["params"]["mf_power"] = 2

    loss_type = exp_detail["type"]
    if loss_type == "classification":
        optimizer_opt = (
            "--config_optimizers [default,with_metric_loss_optimizer_bayes_opt]"
        )
        class_sampler_opt = "1"
    elif loss_type == "pairwise":
        optimizer_opt = "--config_optimizers [default]"
        class_sampler_opt = "4"
    else:
        raise ValueError("Invalid loss type")

    if exp_detail["normalize"] is True:
        test_opt = "--tester~APPLY~2 {dataloader_num_workers: 8}"
    elif exp_detail["normalize"] is False:
        test_opt = "--tester~APPLY~2 {dataloader_num_workers: 8, normalize_embeddings: False} --ensemble~APPLY~2 {normalize_embeddings: False}"
    else:
        raise ValueError("Invalid normalize option")
    return exp_name, exp_detail, class_sampler_opt, miner_opt, optimizer_opt, test_opt


def generate(loss="all", batch128=False):
    if loss == "all":
        for loss in abb["losses"].keys():
            if loss != "all":
                generate(loss, batch128)
    else:
        (
            exp_name,
            exp_detail,
            class_sampler_opt,
            miner_opt,
            optimizer_opt,
            test_opt,
        ) = get_exp_details(loss=loss, abb=abb)

        script = gen_script(
            exp_name=exp_name,
            exp_detail=exp_detail,
            class_sampler_opt=class_sampler_opt,
            miner_opt=miner_opt,
            optimizer_opt=optimizer_opt,
            test_opt=test_opt,
            batch128=batch128,
        )
        with open(
            f"scripts/{loss}{'' if not batch128 else '_batch128'}.sh", "w"
        ) as file:
            file.write(script)


def _generate(eval):
    with open(f"exp_params/{eval}.yaml") as f:
        config = yaml.safe_load(f)

    if eval == "eval":
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
                for loss in config["losses"]:
                    if config["general"].get("seed", None) is None:
                        config["general"]["seed"] = [1]
                    for seed in config["general"]["seed"]:
                        config["losses"][loss]["seed"] = seed
                        loss_params = " ".join(
                            [f"--{k} {v}" for k, v in config["losses"][loss].items()]
                        )
                        scripts.append(
                            f.format(
                                dataset=dataset,
                                embedding_size=config["general"]["embedding_size"],
                                batch_size=config["general"]["batch_size"],
                                loss=loss,
                                params=loss_params,
                                lr=config["datasets"][dataset]["lr"],
                                weight_decay=config["datasets"][dataset][
                                    "weight_decay"
                                ],
                                lr_decay_step=config["datasets"][dataset][
                                    "lr_decay_step"
                                ],
                                lr_decay_gamma=config["datasets"][dataset][
                                    "lr_decay_gamma"
                                ],
                                num_workers=config["general"]["num_workers"],
                                patience=config["general"]["patience"],
                                remark=config["general"]["remark"],
                            )
                        )

            with open(f"scripts/{eval}_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))
        with open(f"scripts/{eval}_all.sh", "w") as f:
            f.write(
                " \n".join(
                    [
                        f"sh scripts/{eval}_{dataset.lower()}.sh"
                        for dataset in config["datasets"].keys()
                    ]
                )
            )
    elif "batch" in eval:
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
            for batch_size in config["general"]["batch_size"]:
                for loss in config["losses"]:
                    if config["general"].get("seed", None) is None:
                        config["general"]["seed"] = [1]
                    for seed in config["general"]["seed"]:
                        config["losses"][loss]["seed"] = seed
                        loss_params = " ".join(
                            [f"--{k} {v}" for k, v in config["losses"][loss].items()]
                        )
                        scripts.append(
                            f.format(
                                dataset=dataset,
                                embedding_size=config["general"]["embedding_size"],
                                batch_size=batch_size,
                                loss=loss,
                                params=loss_params,
                                lr=config["datasets"][dataset]["lr"],
                                weight_decay=config["datasets"][dataset][
                                    "weight_decay"
                                ],
                                lr_decay_step=config["datasets"][dataset][
                                    "lr_decay_step"
                                ],
                                lr_decay_gamma=config["datasets"][dataset][
                                    "lr_decay_gamma"
                                ],
                                num_workers=config["general"]["num_workers"],
                                patience=config["general"]["patience"],
                                remark=config["general"]["remark"],
                            )
                        )

            with open(f"scripts/batch_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))
        with open(f"scripts/{eval}.sh", "w") as f:
            f.write(
                " \n".join(
                    [
                        f"sh scripts/batch_{dataset.lower()}.sh"
                        for dataset in config["datasets"].keys()
                    ]
                )
            )

    elif eval == "dim":
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
            for embedding_size in config["general"]["embedding_size"]:
                for loss in config["losses"]:
                    if config["general"].get("seed", None) is None:
                        config["general"]["seed"] = [1]
                    for seed in config["general"]["seed"]:
                        config["losses"][loss]["seed"] = seed
                        loss_params = " ".join(
                            [f"--{k} {v}" for k, v in config["losses"][loss].items()]
                        )
                        scripts.append(
                            f.format(
                                dataset=dataset,
                                embedding_size=embedding_size,
                                batch_size=config["general"]["batch_size"],
                                loss=loss,
                                params=loss_params,
                                lr=config["datasets"][dataset]["lr"],
                                weight_decay=config["datasets"][dataset][
                                    "weight_decay"
                                ],
                                lr_decay_step=config["datasets"][dataset][
                                    "lr_decay_step"
                                ],
                                lr_decay_gamma=config["datasets"][dataset][
                                    "lr_decay_gamma"
                                ],
                                num_workers=config["general"]["num_workers"],
                                patience=config["general"]["patience"],
                                remark=config["general"]["remark"],
                            )
                        )

            with open(f"scripts/{eval}_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))
        with open(f"scripts/{eval}_all.sh", "w") as f:
            f.write(
                " \n".join(
                    [
                        f"sh scripts/{eval}_{dataset.lower()}.sh"
                        for dataset in config["datasets"].keys()
                    ]
                )
            )

    elif eval == "reg":
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
            for reg in config["general"]["reg"]:
                for loss in config["losses"]:
                    if config["general"].get("seed", None) is None:
                        config["general"]["seed"] = [1]
                    for seed in config["general"]["seed"]:
                        config["losses"][loss]["reg"] = reg
                        config["losses"][loss]["seed"] = seed
                        loss_params = " ".join(
                            [f"--{k} {v}" for k, v in config["losses"][loss].items()]
                        )
                        scripts.append(
                            f.format(
                                dataset=dataset,
                                embedding_size=config["general"]["embedding_size"],
                                batch_size=config["general"]["batch_size"],
                                loss=loss,
                                params=loss_params,
                                lr=config["datasets"][dataset]["lr"],
                                weight_decay=config["datasets"][dataset][
                                    "weight_decay"
                                ],
                                lr_decay_step=config["datasets"][dataset][
                                    "lr_decay_step"
                                ],
                                lr_decay_gamma=config["datasets"][dataset][
                                    "lr_decay_gamma"
                                ],
                                num_workers=config["general"]["num_workers"],
                                patience=config["general"]["patience"],
                                remark=config["general"]["remark"],
                            )
                        )

            with open(f"scripts/{eval}_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))

    elif eval == "mfcwms":
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
            for beta in config["general"]["beta"]:
                for mrg in config["general"]["mrg"]:
                    for loss in config["losses"]:
                        if config["general"].get("seed", None) is None:
                            config["general"]["seed"] = [1]
                        for seed in config["general"]["seed"]:
                            config["losses"][loss]["beta"] = beta
                            config["losses"][loss]["mrg"] = mrg
                            config["losses"][loss]["seed"] = seed
                            loss_params = " ".join(
                                [
                                    f"--{k} {v}"
                                    for k, v in config["losses"][loss].items()
                                ]
                            )
                            scripts.append(
                                f.format(
                                    dataset=dataset,
                                    embedding_size=config["general"]["embedding_size"],
                                    batch_size=config["general"]["batch_size"],
                                    loss=loss,
                                    params=loss_params,
                                    lr=config["datasets"][dataset]["lr"],
                                    weight_decay=config["datasets"][dataset][
                                        "weight_decay"
                                    ],
                                    lr_decay_step=config["datasets"][dataset][
                                        "lr_decay_step"
                                    ],
                                    lr_decay_gamma=config["datasets"][dataset][
                                        "lr_decay_gamma"
                                    ],
                                    num_workers=config["general"]["num_workers"],
                                    patience=config["general"]["patience"],
                                    remark=config["general"]["remark"],
                                )
                            )

            with open(f"scripts/{eval}_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))
        with open(f"scripts/{eval}_all.sh", "w") as f:
            f.write(
                " \n".join(
                    [
                        f"sh scripts/{eval}_{dataset.lower()}.sh"
                        for dataset in config["datasets"].keys()
                    ]
                )
            )

    elif eval == "mfcont":
        for dataset in config["datasets"].keys():
            scripts = []
            with open("scripts/eval_format.txt") as f:
                f = f.read()
            for pos_mrg in config["general"]["pos_mrg"]:
                for neg_mrg in config["general"]["neg_mrg"]:
                    for loss in config["losses"]:
                        if config["general"].get("seed", None) is None:
                            config["general"]["seed"] = [1]
                        for seed in config["general"]["seed"]:
                            config["losses"][loss]["pos_mrg"] = pos_mrg
                            config["losses"][loss]["neg_mrg"] = neg_mrg
                            config["losses"][loss]["seed"] = seed
                            loss_params = " ".join(
                                [
                                    f"--{k} {v}"
                                    for k, v in config["losses"][loss].items()
                                ]
                            )
                            scripts.append(
                                f.format(
                                    dataset=dataset,
                                    embedding_size=config["general"]["embedding_size"],
                                    batch_size=config["general"]["batch_size"],
                                    loss=loss,
                                    params=loss_params,
                                    lr=config["datasets"][dataset]["lr"],
                                    weight_decay=config["datasets"][dataset][
                                        "weight_decay"
                                    ],
                                    lr_decay_step=config["datasets"][dataset][
                                        "lr_decay_step"
                                    ],
                                    lr_decay_gamma=config["datasets"][dataset][
                                        "lr_decay_gamma"
                                    ],
                                    num_workers=config["general"]["num_workers"],
                                    patience=config["general"]["patience"],
                                    remark=config["general"]["remark"],
                                )
                            )

            with open(f"scripts/{eval}_{dataset.lower()}.sh", "w") as f:
                f.write(" \n".join(scripts))

        with open(f"scripts/{eval}_all.sh", "w") as f:
            f.write(
                " \n".join(
                    [
                        f"sh scripts/{eval}_{dataset.lower()}.sh"
                        for dataset in config["datasets"].keys()
                    ]
                )
            )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--loss", type=str, default="all")
    parser.add_argument("--batch128", action="store_true")
    parser.add_argument("--eval", type=str, default=None)
    args = parser.parse_args()
    if args.eval is None:
        generate(
            loss=args.loss, batch128=args.batch128
        )  ## generate scripts for powerful benchmarker
    else:
        _generate(args.eval)
