import copy
import json
import os

import wandb
from collections import OrderedDict

import yaml

api = wandb.Api()

name_project = "daoyuan/pFL-bench"

filters_each_line_main_table = OrderedDict(
    # {dataset_name: filter}
    [
        # ("all",
        # None,
        # ),
        # ("FEMNIST-all",
        #  {"$and":
        #      [
        #          {"config.data.type": "femnist"},
        #      ]
        #  }
        #  ),
        ("FEMNIST-s02",
         {"$and":
             [
                 {"config.data.type": "femnist"},
                 {"config.federate.sample_client_rate": 0.2},
                 {"state": "finished"},
             ]
         }
         ),
        # ("cifar10-alpha05",
        #  {"$and":
        #      [
        #          {"config.data.type": "CIFAR10@torchvision"},
        #          {"config.data.splitter_args": [{"alpha": 0.5}]},
        #      ]
        #  }
        #  ),
        ("sst2",
         {"$and":
             [
                 {"config.data.type": "sst2@huggingface_datasets"},
             ]
         }
         ),
        ("pubmed",
         {"$and":
             [
                 {"config.data.type": "pubmed"},
             ]
         }
         ),
    ]
)

filters_each_line_all_cifar10 = OrderedDict(
    # {dataset_name: filter}
    [
        ("cifar10-alpha5",
         {"$and":
             [
                 {"config.data.type": "CIFAR10@torchvision"},
                 {"config.data.splitter_args": [{"alpha": 5}]},
             ]
         }
         ),
        ("cifar10-alpha05",
         {"$and":
             [
                 {"config.data.type": "CIFAR10@torchvision"},
                 {"config.data.splitter_args": [{"alpha": 0.5}]},
             ]
         }
         ),
        ("cifar10-alpha01",
         {"$and":
             [
                 {"config.data.type": "CIFAR10@torchvision"},
                 {"config.data.splitter_args": [{"alpha": 0.1}]},
             ]
         }
         ),
    ]
)

filters_each_line_femnist_all_s = OrderedDict(
    # {dataset_name: filter}
    [
        ("FEMNIST-s02",
         {"$and":
             [
                 {"config.data.type": "femnist"},
                 {"config.federate.sample_client_rate": 0.2},
                 {"state": "finished"},
             ]
         }
         ),
        ("FEMNIST-s01",
         {"$and":
             [
                 {"config.data.type": "femnist"},
                 {"config.federate.sample_client_rate": 0.1},
                 {"state": "finished"},
             ]
         }
         ),
        ("FEMNIST-s005",
         {"$and":
             [
                 {"config.data.type": "femnist"},
                 {"config.federate.sample_client_rate": 0.05},
                 {"state": "finished"},
             ]
         }
         ),

    ]
)

filters_each_line_all_graph = OrderedDict(
    # {dataset_name: filter}
    [
        ("pubmed",
         {"$and":
             [
                 {"config.data.type": "pubmed"},
             ]
         }
         ),
        ("cora",
         {"$and":
             [
                 {"config.data.type": "cora"},
             ]
         }
         ),
        ("citeseer",
         {"$and":
             [
                 {"config.data.type": "citeseer"},
             ]
         }
         ),
    ]
)

filters_each_line_all_nlp = OrderedDict(
    # {dataset_name: filter}
    [
        ("cola",
         {"$and":
             [
                 {"config.data.type": "cola@huggingface_datasets"},
             ]
         }
         ),
        ("sst2",
         {"$and":
             [
                 {"config.data.type": "sst2@huggingface_datasets"},
             ]
         }
         ),
    ]
)


sweep_name_2_id = dict()
column_names_generalization = [
    "best_client_summarized_weighted_avg/test_acc",
    "best_unseen_client_summarized_weighted_avg_unseen/test_acc",
    "participation_gap"
]
column_names_fair = [
    "best_client_summarized_avg/test_acc",
    "best_client_summarized_fairness/test_acc_std",
    "best_client_summarized_fairness/test_acc_bottom_decile"
]
column_names_efficiency = [
    "sys_avg/total_flops",
    "sys_avg/total_upload_bytes",
    "sys_avg/total_download_bytes",
    "sys_avg/global_convergence_round",
    # "sys_avg/local_convergence_round"
]
sorted_keys = OrderedDict(
    [("global-train", "Global Train"),
     ("isolated-train", "Isolated"),
     ("fedavg", "FedAvg"),
     ("fedavg-ft", "FedAvg-FT"),
     ("fedopt", "FedOpt"),
     ("fedopt-ft", "FedOpt-FT"),
     ("pfedme", "pFedMe"),
     ("ft-pfedme", "pFedMe-FT"),
     ("fedbn", "FedBN"),
     ("fedbn-ft", "FedBN-FT"),
     ("fedbn-fedopt", "FedBN-FedOPT"),
     ("fedbn-fedopt-ft", "FedBN-FedOPT-FT"),
     ("ditto", "Ditto"),
     ("ditto-ft", "Ditto-FT"),
     ("ditto-fedbn", "Ditto-FedBN"),
     ("ditto-fedbn-ft", "Ditto-FedBN-FT"),
     ("ditto-fedbn-fedopt", "Ditto-FedBN-FedOpt"),
     ("ditto-fedbn-fedopt-ft", "Ditto-FedBN-FedOpt-FT"),
     ("fedem", "FedEM"),
     ("fedem-ft", "FedEM-FT"),
     ("fedbn-fedem", "FedEM-FedBN"),
     ("fedbn-fedem-ft", "FedEM-FedBN-FT"),
     ("fedbn-fedem-fedopt", "FedEM-FedBN-FedOPT"),
     ("fedbn-fedem-fedopt-ft", "FedEM-FedBN-FedOPT-FT"),
     ]
)
expected_keys = set(list(sorted_keys.keys()))

def bytes_to_unit_size(size_bytes):
    import math
    if size_bytes == 0:
        return "0"
    size_name = ("", "K", "M", "G", "T", "P", "E", "Z", "Y")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return f"{s}{size_name[i]}"

def get_sweep_filter_by(filter_name, filters_each_line_table):
    filter = filters_each_line_table[filter_name]
    filtered_runs = api.runs(name_project, filters=filter)
    filtered_sweep_ids = set()
    check_run_cnt = 0
    # may hang on
    for run in filtered_runs:
        if run.sweep is not None:
            filtered_sweep_ids.add(run.sweep.id)
        check_run_cnt += 1
    print(f"check_run_cnt is {check_run_cnt}")
    return list(filtered_sweep_ids)


def get_runs_filter_by(filter_name, filters_each_line_table):
    filter = filters_each_line_table[filter_name]
    filtered_runs = api.runs(name_project, filters=filter)
    return filtered_runs



order = '-' + 'summary_metrics.best_client_summarized_weighted_avg/val_acc'


def print_table_datasets_list(filters_each_line_table):
    res_of_each_line_generalization = OrderedDict()
    res_of_each_line_fair = OrderedDict()
    res_of_each_line_efficiency = OrderedDict()
    res_of_each_line_commu_acc_trade = OrderedDict()
    res_of_each_line_conver_acc_trade = OrderedDict()
    res_of_all_sweeps = OrderedDict()
    for data_name in filters_each_line_table:
        unseen_keys = copy.copy(expected_keys)
        print(f"======= processing dataset {data_name}")
        runs_ids = get_sweep_filter_by(data_name, filters_each_line_table)
        for best_run in runs_ids:
            res_all_generalization = []
            res_all_fair = []
            res_all_efficiency = []
            if best_run.state != "finished":
                print(f"==================Waring: the best_run with id={best_run} has state {best_run.state}. "
                     )

            # TODO download best cfg, then send to all, then generate the python string,
            # TODO: filters_add dataset name + seed
            def remove_a_key(d, remove_key):
                if isinstance(d, dict):
                    for key in list(d.keys()):
                        if key == remove_key:
                            del d[key]
                        else:
                            remove_a_key(d[key], remove_key)

            remove_a_key(best_run_cfg, "cfg_check_funcs")
            best_run_cfg = best_run.config
            run_header = best_run_cfg.expname_tag
            run_header = run_header.split("_")[0]



            # for generalization results
            if "isolated" in run_header.lower() or "global" in run_header.lower():
                try:
                    res = best_run.summary[column_names_generalization[0]]
                    res_all_generalization.append(res)
                except KeyError:
                    print(
                        f"KeyError with key={column_names_generalization[0]}, sweep_id={sweep_id}, sweep_name={run_header}, best_run_id={best_run.id}")
                    wrong_sweep = True
                if wrong_sweep:
                    continue
                res_all_generalization.append("-")  # un-seen
                res_all_generalization.append("-")  # Gap
            else:
                for column_name in column_names_generalization[0:2]:
                    try:
                        res = best_run.summary[column_name]
                        res_all_generalization.append(res)
                    except KeyError:
                        print(f"KeyError with key={column_name}, sweep_id={sweep_id}, sweep_name={run_header}, best_run_id={best_run.id}")
                        wrong_sweep = True
                if wrong_sweep:
                    continue
                res_all_generalization.append(res_all_generalization[-1] - res_all_generalization[-2])
            # -============== for fairness results ======
            for column_name in column_names_fair:
                if "global" in run_header:
                    res_all_fair.append("-")
                    res_all_fair.append("-")
                    res_all_fair.append("-")
                else:
                    try:
                        res = best_run.summary[column_name]
                        res_all_fair.append(res)
                    except KeyError:
                        print(f"KeyError with key={column_name}, sweep_id={sweep_id}, sweep_name={run_header}, best_run_id={best_run.id}")
                        res_all_fair.append("-")
                        wrong_sweep = True

            # -============== for efficiency results ======
            for column_name in column_names_efficiency:
                try:
                    res = best_run.summary[column_name]
                    contain_unit = False
                    for size_unit in ["K", "M", "G", "T", "P", "E", "Z", "Y"]:
                        if size_unit in str(res):
                            contain_unit = True
                    if not contain_unit:
                        res = bytes_to_unit_size(float(res))

                    res_all_efficiency.append(res)
                except KeyError:
                    print(f"KeyError with key={column_name}, sweep_id={sweep_id}, sweep_name={run_header}, best_run_id={best_run.id}")
                    wrong_sweep = True
                    res_all_efficiency.append("-")


            old_run_header = run_header
            if best_run_cfg["trainer"]["finetune"]["before_eval"] is True and "ft" not in run_header:
                run_header = run_header + ",ft"
            elif best_run_cfg["fedopt"]["use"] is True and "fedopt" not in run_header:
                run_header = run_header + ",fedopt"
            if old_run_header != run_header:
                print(f"processed {old_run_header} to new run header {run_header}")

            if run_header not in res_of_all_sweeps:
                res_of_all_sweeps[run_header] = res_all_generalization
                sweep_name_2_id[run_header] = sweep_id
            else:
                print(f"processed duplicated sweep with name {run_header}, plz check it with id {sweep_id}. "
                      f"The first appeared sweep has id {sweep_name_2_id[run_header]}")

                while run_header + "_dup" in res_of_all_sweeps:
                    run_header = run_header + "_dup"
                run_header = run_header + "dup"
                print(f"processed to new run header {run_header}")
                res_of_all_sweeps[run_header] = res_all_generalization

            run_header = run_header.replace("-", ",")
            run_header = run_header.replace("+", ",")
            split_res = run_header.split(",")
            filter_split_res = []
            for sub in split_res:
                if "femnist" in sub or "cifar" in sub or "cora" in sub or "cola" in sub or "pubmed" in sub or "citeseer" in sub or "sst2" in sub \
                        or "s02" in sub or "s005" in sub or "s01" in sub \
                        or "alpha5" in sub or "alpha0.5" in sub or "alpha0.1" in sub:
                    pass
                else:
                    filter_split_res.append(sub)
            method_header = "-".join(sorted(filter_split_res))
            if method_header in unseen_keys:
                unseen_keys.remove(method_header)

            # save config
            parent_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")
            best_cfg_dir = os.path.join(parent_dir, "yaml_best_rums")
            os.makedirs(best_cfg_dir, exist_ok=True)
            yaml_f_name = f"best_{sorted_keys[method_header]}_on_{data_name}.yaml"
            with open(os.path.join(best_cfg_dir, yaml_f_name), 'w') as yml_f:
                yaml.dump(best_run_cfg, yml_f, allow_unicode=True)

            if method_header not in res_of_each_line_generalization:
                res_of_each_line_generalization[method_header] = res_all_generalization
                res_of_each_line_fair[method_header] = res_all_fair
                res_of_each_line_efficiency[method_header] = res_all_efficiency
            else:
                res_of_each_line_generalization[method_header].extend(res_all_generalization)
                res_of_each_line_fair[method_header].extend(res_all_fair)
                res_of_each_line_efficiency[method_header].extend(res_all_efficiency)

        for missing_header in unseen_keys:
            print(f"the header is missing {missing_header} in dataset {data_name}")
            if missing_header not in res_of_each_line_generalization:
                res_of_each_line_generalization[missing_header] = ["-"] * 3
                res_of_each_line_fair[missing_header] = ["-"] * 3
                res_of_each_line_efficiency[missing_header] = ["-"] * 4
            else:
                res_of_each_line_generalization[missing_header].extend(["-"] * 3)
                res_of_each_line_fair[missing_header].extend(["-"] * 3)
                res_of_each_line_efficiency[missing_header].extend(["-"] * 4)

    print("\n=============res_of_each_line [Generalization]===============" + ",".join(
        list(filters_each_line_table.keys())))
    # Acc, Unseen-ACC, Delta
    for key in sorted_keys:
        res_to_print = ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_generalization[key]]
        res_to_print = [sorted_keys[key]] + res_to_print
        print(",".join(res_to_print))
    print("\n=============res_of_each_line [Fairness]===============" + ",".join(list(filters_each_line_table.keys())))
    for key in sorted_keys:
        res_to_print = ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key]]
        res_to_print = [sorted_keys[key]] + res_to_print
        print(",".join(res_to_print))
    print("\n=============res_of_each_line [All Efficiency]===============" + ",".join(
        list(filters_each_line_table.keys())))
    # FLOPS, UPLOAD, DOWNLOAD
    for key in sorted_keys:
        res_to_print = [str(v) for v in res_of_each_line_efficiency[key]]
        res_to_print = [sorted_keys[key]] + res_to_print
        print(",".join(res_to_print))
    print("\n=============res_of_each_line [flops, communication, acc]===============" + ",".join(
        list(filters_each_line_table.keys())))
    for key in sorted_keys:
        res_of_each_line_commu_acc_trade[key] = []
        dataset_num = 2 if "cola" in list(filters_each_line_table.keys()) else 3
        for i in range(dataset_num):
            res_of_each_line_commu_acc_trade[key].extend(
                [str(res_of_each_line_efficiency[key][i * 4])] + \
                [str(res_of_each_line_efficiency[key][i * 4 + 1])] + \
                ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key][i * 3:i * 3 + 1]]
            )

        res_to_print = [str(v) for v in res_of_each_line_commu_acc_trade[key]]
        res_to_print = [sorted_keys[key]] + res_to_print
        print(",".join(res_to_print))
    print("\n=============res_of_each_line [converge_round, acc]===============" + ",".join(
        list(filters_each_line_table.keys())))
    for key in sorted_keys:
        res_of_each_line_conver_acc_trade[key] = []
        dataset_num = 2 if "cola" in list(filters_each_line_table.keys()) else 3
        for i in range(dataset_num):
            res_of_each_line_conver_acc_trade[key].extend(
                [str(res_of_each_line_efficiency[key][i * 4 + 3])] + \
                # [str(res_of_each_line_efficiency[key][i * 4 + 4])] + \
                ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_each_line_fair[key][i * 3:i * 3 + 1]]
            )

        res_to_print = [str(v) for v in res_of_each_line_conver_acc_trade[key]]
        res_to_print = [sorted_keys[key]] + res_to_print
        print(",".join(res_to_print))
    # print("\n=============res_of_all_sweeps [Generalization]===============")
    # for key in sorted(res_of_all_sweeps.keys()):
    #     res_to_print = ["{:.2f}".format(v * 100) if v != "-" else v for v in res_of_all_sweeps[key]]
    #     res_to_print = [key] + res_to_print
    #     print(",".join(res_to_print))
    #

def generate_repeat_scripts(best_cfg_path, seed_sets=None):
    file_cnt = 0
    if seed_sets is None:
        seed_sets = [2, 3]
    from os import listdir
    from os.path import isfile, join
    onlyfiles = [f for f in listdir(best_cfg_path) if isfile(join(best_cfg_path, f))]
    # TODO: repeat runs with new project: pfl-bench-best-repeat;  cfg.expname_tag = yaml_name + seed
    for file_name in onlyfiles:
        exp_name = file_name
        exp_name = exp_name.replace(".yaml", "")
        method, data = exp_name.split("_")
        for seed in seed_sets:
            print(f"python federatedscope/main.py --cfg scripts/personalization_exp_scripts/pfl_bench/yaml_best_runs/{file_name} seed {seed} expname_tag {exp_name}_seed{seed} wandb.name_project pfl-bench-best-repeat")
            file_cnt += 1
            if file_cnt % 10 == 0:
                print(f"Seed={seed}, totally generated {file_cnt} run scripts\n\n")

    print(f"Seed={seed_sets}, totally generated {file_cnt} run scripts")
    print(f"=============================== END ===============================")


def generate_res_table():
    print_table_datasets_list(filters_each_line_main_table)
    print_table_datasets_list(filters_each_line_femnist_all_s)
    print_table_datasets_list(filters_each_line_all_cifar10)
    print_table_datasets_list(filters_each_line_all_nlp)
    print_table_datasets_list(filters_each_line_all_graph)


seed_sets = [2, 3]
for seed in seed_sets:
    generate_repeat_scripts("/mnt/daoyuanchen.cdy/FederatedScope/scripts/personalization_exp_scripts/pfl_bench/yaml_best_runs", seed_sets=[seed])