#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : gen_read_me.py
# Author : Anonymous1
# Email  : anonymous1@anon
#
# Distributed under terms of the MIT license.

"""
Helper script to run and read results from synthetic graph theory datasets

When generating cmds:
Usage: python3 scripts/gen_read_me.py -dn {DATASETS} -ds {GRAPH_GENERATORS} \
    -md {MODELS} -ly {LAYERS} -mh {MAX_HEIGHTS} -nl {N_LAYERS} -se {SEED} \
    -pr {PARALLEL_RUNS} -gids {GPU_IDS} -du {DUMP_DIR}
Then: `bash scripts/bash/{GENERATED_BASH_FILE}`

When reading results:
Usage: python3 scripts/gen_read_me.py -get -du {DUMP_DIR} \
    -dn {DATASETS} -ds {GRAPH_GENERATORS} -re {REGEX}
The regex can be used to filter out runs
"""

from collections import defaultdict
import os
import os.path as osp
import argparse
import glob
import json
import re
import numpy as np

from megraph.datasets.datasets.synthetic_graph_tasks import AVAILABLE_GT_DATASETS
from megraph.datasets.utils.graph_generators import GRAPH_GENERATOR_NAMES
from sweeping.utils import (
    get_localtime_str,
    get_summary,
    get_stats,
    get_stats_str,
    f2str_with_proper_length,
    f2str,
)

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dump-dir", "-du", type=str, default="dumps/me_datasets", help="The dump dir"
)
parser.add_argument(
    "--out-dir",
    "-o",
    type=str,
    default=osp.join("scripts", "bash"),
    help="Smaller is better",
)
parser.add_argument(
    "--dataset",
    "-dn",
    type=str,
    nargs="+",
    choices=AVAILABLE_GT_DATASETS,
    default=AVAILABLE_GT_DATASETS,
    help="The datasets",
)
parser.add_argument(
    "--subname",
    "-ds",
    type=str,
    nargs="+",
    choices=GRAPH_GENERATOR_NAMES,
    default=GRAPH_GENERATOR_NAMES,
    help="The subnames of datasets",
)
parser.add_argument(
    "--model",
    "-md",
    type=str,
    nargs="+",
    default=["megraph"],
    help="The models",
)
parser.add_argument(
    "--layer",
    "-ly",
    type=str,
    nargs="+",
    default=["gfn"],
    help="The layers",
)
parser.add_argument(
    "--max-height",
    "-mh",
    type=int,
    nargs="+",
    default=[1, 5],
    help="The max height",
)
parser.add_argument(
    "--n-layers",
    "-nl",
    type=int,
    nargs="+",
    default=[1, 5],
    help="The num layers",
)
parser.add_argument(
    "--seed",
    "-se",
    type=int,
    default=None,
    help="The seed",
)
parser.add_argument(
    "--extra",
    "-ex",
    type=str,
    default=None,
    help="The extra cmd",
)
parser.add_argument(
    "--parallel-runs", "-pr", type=int, default=8, help="The number of parallel runs"
)
parser.add_argument(
    "--gpu-ids", "-gids", type=int, nargs="+", default=None, help="Available GPU ids"
)
parser.add_argument(
    "--get-results", "-get", action="store_true", help="Read results from runs"
)
parser.add_argument(
    "--filter",
    "-fi",
    type=str,
    choices=["reg", "pred"],
    default=None,
    help="filter task type",
)
parser.add_argument(
    "--verbose", "-v", action="store_true", help="show details for each run"
)
parser.add_argument(
    "--regex",
    "-re",
    type=str,
    default=None,
    help="The regex to filter dirs",
)
parser.add_argument(
    "--params-fname",
    "-pf",
    type=str,
    default="params.json",
    help="The params file of the run",
)
parser.add_argument(
    "--summary-fname",
    "-sf",
    type=str,
    default="progress.csv",
    help="The name of the summary file",
)
args = parser.parse_args()


def need_run(dataset, method):
    name = dataset[3:]  # "me_"
    if name.startswith("reach"):
        return False  # method == "er"
    if name.startswith("big"):
        return False  # TODO: enable edge colors
        # return method in ["cycle", "pseudotree", "geo"]  # "er", "ba"
    if name == "sssp_nreg" and method == "ba":
        return False
    if name.startswith("sssp") or name.startswith("diameter"):
        if method in ["mix", "er", "caveman", "star"]:
            return False
    if name == "diameter_npred":
        return False
        # extras = ["tree", "caterpillar", "lobster"]
        # return method not in (["grid", "line", "ladder", "star", "cycle"] + extras)
    if name.startswith("ecc"):
        return method not in ["star", "caveman"]
    if name.startswith("cntc"):
        return False  # TODO add color
    return True


def gen_cmd(main_file, dataset, scale, method, model, layer, h, n):
    cmd = (
        f"python3 {main_file} -dname {dataset} -dsub {scale}_{method} "
        f"-du {args.dump_dir} -md {model} -ly {layer} -mh {h} -nl {n}"
    )
    if args.seed is not None:
        cmd += f" -se {args.seed}"
    if args.extra is not None:
        cmd += f" {args.extra}"
    return cmd


def gen_cmds(dataset, method):
    main_file = "main.py"
    scale = "small"  # Now use small only
    cmds = []
    if not need_run(dataset, method):
        return cmds
    print(f"dataset: {dataset}, subname: {scale}_{method}")
    for h in args.max_height:
        for n in args.n_layers:
            for model in args.model:
                for layer in args.layer:
                    cmds.append(
                        gen_cmd(main_file, dataset, scale, method, model, layer, h, n)
                    )
    return cmds


def gen_all_cmds():
    cmds = []
    for dataset in args.dataset:
        for subname in args.subname:
            cmds.extend(gen_cmds(dataset, subname))
    print(f"The total runs {len(cmds)}")
    out_file = osp.join(args.out_dir, f"graph_theory_{get_localtime_str()}.sh")
    with open(out_file, "w") as f:
        for i, cmd in enumerate(cmds):
            if args.gpu_ids is not None:
                gpu_id = args.gpu_ids[i % len(args.gpu_ids)]
                cmd += f" -gid {gpu_id}"
            if (i + 1) % args.parallel_runs == 0 or i + 1 == len(cmds):
                f.write(cmd + "\n")
            else:
                f.write(cmd + " &\n")
        print(f"The bash file is generated as {out_file}")


def get_results(dump_dir, is_reg):
    dirname = osp.basename(dump_dir)
    summary_file = osp.join(dump_dir, args.summary_fname)
    params_file = osp.join(dump_dir, args.params_fname)
    with open(params_file, "r") as f:
        params = json.load(f)
    cmd = params.get("raw_cmdline", None)
    if args.verbose:
        print("-" * 20 + f" Dir: {dirname} " + "-" * 20)
        if cmd:
            print(cmd)
    runs, summary = get_summary(summary_file, smaller_better=is_reg, complete_runs=True)
    best_epoch_id, best_train_ress, best_val_ress, best_test_ress, avg_time = summary
    if not is_reg:  # to percentage
        best_test_ress = [res * 100.0 for res in best_test_ress]
    stats = get_stats(best_test_ress)
    info = f" (Completed Runs: {len(best_epoch_id)}, Avg Epoch time: {avg_time:.4f} s)"
    if args.verbose:
        if is_reg:
            stats_str = get_stats_str(stats, proper_length=5)
        else:
            stats_str = get_stats_str(stats, precision=2)
        print(stats_str + info)
    # if len(best_epoch_id) < 5:
    #     raise ValueError()
    return params, stats


def list_str_to_list(st):
    if st == "None":
        return []
    a = st.strip("[]").split(",")
    return [x.strip("' ") for x in a if len(x)]


def get_name_from_params(params, verbose=False):
    """Custom function to get results"""
    # NOTE: The True and False are store as string in json.
    n = f"n{params.get('n_layers', '1')}"
    h = f"h{params.get('max_height', '1')}"
    keys = [n, h]

    def check(name):
        if name not in params and verbose:
            print("Warning: {name} not found in params")

    def add_key(abbv, name, default):
        check(name)
        v = params.get(name, default)
        if v != default:
            keys.append(f"{abbv}_{v}")
        return v

    def add_bool(abbv, name, default):
        check(name)
        v = params.get(name, default)
        if v == "True":
            keys.append(abbv)
        return v

    def add_list(abbv, name, default):
        check(name)
        v = params.get(name, default)
        v = list_str_to_list(v)
        if len(v):
            keys.append(abbv)
            for x in v:
                keys.append(x)
        return v

    lr = params.get("lr", "None")
    if lr == "0.001":
        keys.append("lr0001")
    add_bool("embed", "use_input_embedding", "False")
    add_key("pae", "pool_aggr_edge", "none")
    add_key("xua", "x_update_att", "None")
    xum = add_key("xum", "cross_update_method", "conv")
    add_key("nsc", "num_shared", "1")  # old param
    add_key("nsc", "num_shared_convs", "1")
    add_key("nsp", "num_shared_pools", "1")
    add_key("nr", "num_recurrent", "None")
    add_key("ft", "fully_thresh", "None")
    add_key("pnr", "pool_node_ratio", "None")
    add_key("pdr", "pool_degree_ratio", "None")
    add_key("csl", "cluster_size_limit", "None")
    if xum == "pool":
        add_bool("ucs", "unpool_with_cluster_score", "True")
        add_bool("pcs", "pool_with_cluster_score", "True")
    # Old PE (to remove)
    pe = add_key("pe", "pe_type", "none")
    use_pe = pe != "none"
    # New PEs
    if not use_pe:
        pes = add_list("pe", "pe_types", "[]")
        if len(pes):
            use_pe = True
    if use_pe:
        add_key("pdim", "pe_dim", "1")
        add_key("prep", "pe_rep", "1")
    add_key("pfu", "pool_feature_using", "node")
    add_key("drop", "dropout", "0.0")
    add_key("edrop", "edgedrop", "0.0")
    add_key("ew", "edge_weights_method", "const")
    add_list("nh", "num_heads", "None")

    # add_list("gp", "global_pool_methods", '["mean"]')
    add_bool("gs", "use_global_pool_scales", "False")
    add_bool("unet", "unet_like", "False")
    add_key("xb", "cross_beta", "0.5")
    add_key("bb", "branch_beta", "0.5")
    tie_breaker = str(np.random.randint(1000))
    tie_breaker = params.get("local_time", tie_breaker)
    return "_".join(keys), tie_breaker


def read(dataset, method):
    results = {}
    if not need_run(dataset, method):
        return {}
    scale = "small"  # Now use small only
    is_reg = dataset.endswith("reg")
    if (args.filter == "reg" and not is_reg) or (args.filter == "pred" and is_reg):
        return {}
    dump_dir = osp.join(args.dump_dir, f"{dataset}_{scale}_{method}")
    # print(dump_dir)
    dirs = []
    for d in glob.glob(osp.join(dump_dir, "*")):
        dirname = osp.basename(d)
        if args.regex is None or re.match(args.regex, dirname):
            # print(f"matched: {dirname}")
            dirs.append(d)
    if len(dirs) > 0:
        print("-" * 40 + f" {dataset}_{scale}_{method} " + "-" * 40)
        all_stats = []
        for d in dirs:
            try:
                params, stats = get_results(d, is_reg)
                key, tie_breaker = get_name_from_params(params)
                all_stats.append((key, tie_breaker, stats))
                results[key] = stats["mean"]
            except:
                pass
        print("Begin listing")
        all_stats = sorted(all_stats)
        for name, _, stats in all_stats:
            if is_reg:
                mean = f2str_with_proper_length(stats["mean"], length=5)
                std = f2str_with_proper_length(stats["std"], length=5)
            else:
                mean = f2str(stats["mean"], precision=2)
                std = f2str(stats["std"], precision=2)
            print(f"{name}\t{mean} | {std}")
        print("")
    return results


def read_results():
    results = defaultdict(list)
    for dataset in args.dataset:
        for subname in args.subname:
            res = read(dataset, subname)
            for k, v in res.items():
                results[k].append(v)
    print("Final averaged results:")
    for k in sorted(results.keys()):
        stats = get_stats(results[k])
        mean = f2str_with_proper_length(stats["mean"], length=6)
        std = f2str_with_proper_length(stats["std"], length=6)
        print(f"{k}\t{mean} | {std}")


def main():
    if args.get_results:
        read_results()
    else:
        gen_all_cmds()


if __name__ == "__main__":
    main()
