import os
import pandas as pd
import pprint
import math
import json


def read_json(path):
    args = {}
    with open(path, "r") as f:
        args = json.load(f)
    return args

def get_dataset(root):
    if "/url/" in root:
        return "url"
    elif "/ember/" in root:
        return "ember"
    exit(f"unknown dataset: {root}")

def get_bf_info(bf_model_folder_path):
    info_path = os.path.join(bf_model_folder_path, "info.json")
    with open(info_path, "r") as f:
        info = json.load(f)
    return info

def get_ml_info(bf_model_folder_path):
    info_path = os.path.join(bf_model_folder_path, "info.json")
    with open(info_path, "r") as f:
        info = json.load(f)
    if "trained_xgboost_folder_path" not in info:
        return {}
    trained_xgboost_folder_path = info["trained_xgboost_folder_path"]
    trained_xgboost_info_path = os.path.join(trained_xgboost_folder_path, "info.json")
    with open(trained_xgboost_info_path, "r") as f:
        info = json.load(f)
    return info

def get_bf_config(bf_model_folder_path):
    config_path = os.path.join(bf_model_folder_path, "config.json")
    if not os.path.exists(config_path):
        return {}
    with open(config_path, "r") as f:
        config = json.load(f)
    c = 1 / math.log(2)
    n = config["n"]
    def calc_bloom_filter_mem_sum(pr, f):
        if f >= 1.0:
            return 0.0
        if f <= 0.0:
            f = 1e-9
        bits_per_item = max(1, int(-math.log(f) / math.log(2) / math.log(2) + 0.5))
        return (n * pr * bits_per_item) / 8 / 1024
    if "g_b" in config:
        bloom_filter_bs = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_b"], config["b"])]
        bloom_filter_ts = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_t"], config["t"])]
        bloom_filter_fs = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_f"], config["f"])]
        # print("bloom_filter_bs: ", bloom_filter_bs)
        # print("bloom_filter_ts: ", bloom_filter_ts)
        # print("bloom_filter_fs: ", bloom_filter_fs)
        bloom_filter_b_mem_sum = sum(bloom_filter_bs)
        bloom_filter_t_mem_sum = sum(bloom_filter_ts)
        bloom_filter_f_mem_sum = sum(bloom_filter_fs)
        bloom_filter_mem_sum = bloom_filter_b_mem_sum + bloom_filter_t_mem_sum + bloom_filter_f_mem_sum
        config["bloom_filter_bs"] = bloom_filter_bs
        config["bloom_filter_ts"] = bloom_filter_ts
        config["bloom_filter_fs"] = bloom_filter_fs
        config["bloom_filter_b_mem_sum"] = bloom_filter_b_mem_sum
        config["bloom_filter_t_mem_sum"] = bloom_filter_t_mem_sum
        config["bloom_filter_f_mem_sum"] = bloom_filter_f_mem_sum
        config["bloom_filter_mem_sum"] = bloom_filter_mem_sum
    elif "g_b_l" in config:
        bloom_filter_b_ls = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_b_l"], config["b_l"])]
        bloom_filter_b_rs = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_b_r"], config["b_r"])]
        bloom_filter_ts = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_t"], config["t"])]
        bloom_filter_fs = [calc_bloom_filter_mem_sum(pr, f) for pr, f in zip(config["g_f"], config["f"])]
        # print("bloom_filter_b_ls: ", bloom_filter_b_ls)
        # print("bloom_filter_b_rs: ", bloom_filter_b_rs)
        # print("bloom_filter_ts: ", bloom_filter_ts)
        # print("bloom_filter_fs: ", bloom_filter_fs)
        bloom_filter_b_l_mem_sum = sum(bloom_filter_b_ls)
        bloom_filter_b_r_mem_sum = sum(bloom_filter_b_rs)
        bloom_filter_t_mem_sum = sum(bloom_filter_ts)
        bloom_filter_f_mem_sum = sum(bloom_filter_fs)
        bloom_filter_mem_sum = bloom_filter_b_l_mem_sum + bloom_filter_b_r_mem_sum + bloom_filter_t_mem_sum + bloom_filter_f_mem_sum
        config["bloom_filter_b_ls"] = bloom_filter_b_ls
        config["bloom_filter_b_rs"] = bloom_filter_b_rs
        config["bloom_filter_ts"] = bloom_filter_ts
        config["bloom_filter_fs"] = bloom_filter_fs
        config["bloom_filter_b_l_mem_sum"] = bloom_filter_b_l_mem_sum
        config["bloom_filter_b_r_mem_sum"] = bloom_filter_b_r_mem_sum
        config["bloom_filter_t_mem_sum"] = bloom_filter_t_mem_sum
        config["bloom_filter_f_mem_sum"] = bloom_filter_f_mem_sum
        config["bloom_filter_mem_sum"] = bloom_filter_mem_sum
    return config

def info_merge(info1, info2, prefix=""):
    info = info1.copy()
    for k, v in info2.items():
        k = prefix + k
        if k in info:
            exit(f"key {k} is already in info1")
        info[k] = v
    return info

def load_all_result(result_dir = "results"):
    all_results = []
    all_cols_set = set()

    for root, dirs, files in os.walk(result_dir):
        if "result.json" not in files:
            continue
        try:
            results = read_json(os.path.join(root, "result.json"))
            results["dataset"] = get_dataset(root)

            bf_info = get_bf_info(results["bf_model_folder_path"])
            ml_info = get_ml_info(results["bf_model_folder_path"])
            bf_config = get_bf_config(results["bf_model_folder_path"])

            results = info_merge(results, bf_info)
            results = info_merge(results, ml_info, prefix="trained_xgboost_")
            results = info_merge(results, bf_config)

            all_results.append(results)
            all_cols_set |= set(results.keys())
        except Exception as e:
            print(f"ERROR during loading {root}")
            print(e)
            continue

    df = pd.DataFrame(all_results)

    return df