import os
import json
from collections import defaultdict

# =========================
# 分类映射（你给的）
# =========================
evaluation_functions = {
    "Basic Feasibility": {
        "JSON Structural Validity": [
            "validate_trip_plan_json"
        ],
        "POI Validity": [
            "evaluate_id_existence",
            "evaluate_city_consistency"
        ],
        "Information Completeness": [
            "evaluate_meta"
        ]
    },
    "Planning Soundness": {
        "Temporal Reasonableness": [
            "evaluate_daily_timing",
            "evaluate_activity_duration",
            "evaluate_opening_hours",
            "evaluate_intercity_transportation",
            "evaluate_local_transportation"
        ],
        "Spatial Logic": [
            "evaluate_restaurant_location"
        ],
        "Experience Diversity": [
            "evaluate_attraction_restaurant_diversity"
        ],
        "Product Consistency": [
            "evaluate_product_requirements"
        ]
    }
}

# =========================
# 路径
# =========================
base_dir = "interact/traces/test_easy_single_turn/LongCat-Flash-Chat"
base_dir = "interact/traces/test_easy_multi_turn/LongCat-Flash-Chat"
testset_path = "interact/test/test_easy.json"


# =========================
# 工具函数
# =========================
def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def safe_div(a, b):
    return a / b if b else 0.0


def get_last_round_dict(rounds):
    """rounds 取最后一个 dict（从后往前找第一个 dict）"""
    if not isinstance(rounds, list):
        return None
    for x in reversed(rounds):
        if isinstance(x, dict):
            return x
    return None


def build_func_maps(evaluation_functions):
    """构建 func -> (major, sub) 的映射，以及各层级的 func 列表"""
    func2group = {}
    major2funcs = defaultdict(list)
    sub2funcs = defaultdict(list)  # key=(major, sub)

    for major, submap in evaluation_functions.items():
        for sub, funcs in submap.items():
            for fn in funcs:
                func2group[fn] = (major, sub)
                major2funcs[major].append(fn)
                sub2funcs[(major, sub)].append(fn)

    all_funcs = list(func2group.keys())
    return func2group, all_funcs, major2funcs, sub2funcs


def get_expected_rubric_keys(sample):
    """
    统计单条样本里出现了哪些 rubric_key：每个 key 记 1 次
    口径：applied_modification_chains 中 rubric 的数量（不管链多长只算一个）
    优先：applied_modification_chains 的 keys
    fallback：rubric_results 的 keys
    """
    keys = set()

    chains = sample.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        keys.update(chains.keys())
        return keys

    rr = sample.get("rubric_results")
    if isinstance(rr, dict) and rr:
        keys.update(rr.keys())
        return keys

    return keys


# =========================
# 1) 读取 testset，构建 rubric 分母（按 rubric_key）
# =========================
testset = load_json(testset_path)

if isinstance(testset, dict):
    if isinstance(testset.get("data"), list):
        testset_list = testset["data"]
    else:
        raise ValueError(f"Unrecognized testset format in {testset_path}")
elif isinstance(testset, list):
    testset_list = testset
else:
    raise ValueError(f"Unrecognized testset format in {testset_path}")

k = len(testset_list)

test_trip_ids = set()

# trip_id -> expected rubric key set（macro strict & per-trip 统计用）
expected_keys_per_trip = {}

# rubric_key -> total occurrences in test set（每条样本出现一次算 1）
rubric_expected_counter = defaultdict(int)

for sample in testset_list:
    trip_id = sample.get("trip_id")
    if trip_id:
        test_trip_ids.add(trip_id)

    keys = get_expected_rubric_keys(sample)
    if trip_id:
        expected_keys_per_trip[trip_id] = keys

    for rk in keys:
        rubric_expected_counter[rk] += 1

total_expected_rubrics = sum(rubric_expected_counter.values())

print(f"[INFO] testset k = {k}")
print(f"[INFO] total expected rubric keys (per-trip each key counted once) = {total_expected_rubrics}")


# =========================
# 2) general 统计准备
# =========================
func2group, all_funcs, major2funcs, sub2funcs = build_func_maps(evaluation_functions)

func_true_counts = defaultdict(int)  # func -> #ok_true (across samples, divide by k)
general_all_ok_count = 0             # 全部 general funcs 全对 的样本数
validate_trip_plan_ok_count = 0      # validate_trip_plan_json 通过样本数

# macro：组内全对样本数 / k
major_all_ok_counts = defaultdict(int)   # major -> #samples group-all-ok
sub_all_ok_counts = defaultdict(int)     # (major, sub) -> #samples group-all-ok


# =========================
# 3) rubric 统计准备（按 rubric_key；每个 trip 每个 key 只算 1 条）
# =========================
# macro：trip-level strict：expected 的每个 key 必须 seen>0 且 false==0
rubric_all_correct_trip_count = 0
general_and_rubric_all_correct = 0

# key-level 统计（每条 rubric 的正确率需要两套：含 missing / 不含 missing）
rubric_key_pass_counts = defaultdict(int)  # rk -> #trips passed (seen>0 and false==0)
rubric_key_seen_counts = defaultdict(int)  # rk -> #trips where rk appeared in eval (seen>0)

# 汇总 micro 两套：
# - consider missing: total_pass / total_expected
# - ignore missing:   total_pass / total_seen
total_key_pass = 0
total_key_seen = 0

# 诊断（全局）
rubric_key_missing_total = 0  # expected key 在 eval 完全没出现（seen==0）
rubric_key_failed_total = 0   # expected key 出现但有 false（false>0）


# =========================
# 4) 遍历 trace 文件（只取 rounds 最后一个 dict 的 eval_result）
# =========================
seen_trip_ids = set()
used_samples = 0
skipped_not_in_testset = 0
skipped_no_eval = 0

for root, _, files in os.walk(base_dir):
    for filename in files:
        # print(filename)
        if not filename.endswith(".json"):
            continue

        file_path = os.path.join(root, filename)
        try:
            data = load_json(file_path)
        except Exception as e:
            print(f"[跳过] 读取失败 {file_path}: {e}")
            continue

        trip_id = data.get("trip_id")

        # 去重 + 对齐 testset
        if trip_id:
            if trip_id in seen_trip_ids:
                continue
            seen_trip_ids.add(trip_id)
            if test_trip_ids and (trip_id not in test_trip_ids):
                skipped_not_in_testset += 1
                continue

        rounds = data.get("rounds", [])
        last_round = get_last_round_dict(rounds)
        if not last_round:
            skipped_no_eval += 1
            continue

        eval_result = last_round.get("eval_result", {})
        if not isinstance(eval_result, dict):
            skipped_no_eval += 1
            continue

        used_samples += 1

        # -------- general_results：统计每个 function 的 ok
        general_results = eval_result.get("general_results", [])
        func_ok_map = {}  # func -> bool
        if isinstance(general_results, list):
            for g in general_results:
                if not isinstance(g, dict):
                    continue
                name = g.get("name")
                ok = g.get("ok")
                if name:
                    func_ok_map[name] = (ok is True)

        # all_funcs：缺失当作 False
        all_general_ok = True
        for fn in all_funcs:
            ok = func_ok_map.get(fn, False)
            if ok:
                func_true_counts[fn] += 1
            else:
                all_general_ok = False

        if func_ok_map.get("validate_trip_plan_json", False):
            validate_trip_plan_ok_count += 1

        if all_general_ok:
            general_all_ok_count += 1

        # macro：组内全对才算该样本通过
        for major, funcs in major2funcs.items():
            if all(func_ok_map.get(fn, False) for fn in funcs):
                major_all_ok_counts[major] += 1

        for (major, sub), funcs in sub2funcs.items():
            if all(func_ok_map.get(fn, False) for fn in funcs):
                sub_all_ok_counts[(major, sub)] += 1

        # -------- rubric：按 key 聚合判定（每个 trip 每个 key 只算 1 条）
        id_true_results = eval_result.get("id_true_results", [])
        id_false_results = eval_result.get("id_false_results", [])

        # rk_stats: rk -> {"seen": int, "false": int}
        rk_stats = defaultdict(lambda: {"seen": 0, "false": 0})

        # true_items：ok!=True 也当失败（防御式）
        for item in id_true_results:
            if not isinstance(item, dict):
                continue
            rk = item.get("rubric_key")
            if not rk:
                continue
            rk_stats[rk]["seen"] += 1
            if item.get("ok") is not True:
                rk_stats[rk]["false"] += 1

        # false_items：全部当失败
        for item in id_false_results:
            if not isinstance(item, dict):
                continue
            rk = item.get("rubric_key")
            if not rk:
                continue
            rk_stats[rk]["seen"] += 1
            rk_stats[rk]["false"] += 1

        expected_keys = expected_keys_per_trip.get(trip_id) if trip_id else None
        # print(expected_keys)
        rubric_all_correct = True

        if expected_keys is not None:
            for rk in expected_keys:
                st = rk_stats.get(rk, {"seen": 0, "false": 0})
                # print(rk, st)

                if st["seen"] > 0:
                    rubric_key_seen_counts[rk] += 1
                    total_key_seen += 1

                if st["seen"] <= 0:
                    # missing：在“考虑 missing”口径下视为失败
                    rubric_all_correct = False
                    rubric_key_missing_total += 1
                    continue

                if st["false"] > 0:
                    rubric_all_correct = False
                    rubric_key_failed_total += 1
                    continue

                # pass：seen>0 且 false==0
                rubric_key_pass_counts[rk] += 1
                total_key_pass += 1
        else:
            rubric_all_correct = False

        if rubric_all_correct:
            rubric_all_correct_trip_count += 1

        if all_general_ok and rubric_all_correct:
            general_and_rubric_all_correct += 1
        # break

# =========================
# 5) 输出：general micro/macro（按层级）
# =========================
print("\n====================")
print("GENERAL 统计（分母固定 k）")
print("====================\n")

print(f"[INFO] used_samples_from_traces = {used_samples}")
print(f"[INFO] skipped_not_in_testset = {skipped_not_in_testset}")
print(f"[INFO] skipped_no_eval = {skipped_no_eval}")
print(f"[INFO] NOTE: 所有比例分母都用 testset k={k}（不是 used_samples）。\n")

print(f"① validate_trip_plan_json 通过率: {validate_trip_plan_ok_count} / {k} = {safe_div(validate_trip_plan_ok_count, k):.2%}")
print(f"② general 全部函数全对比例:       {general_all_ok_count} / {k} = {safe_div(general_all_ok_count, k):.2%}\n")

print("---- 单个 general function 通过率（true_count/k）----")
for fn in all_funcs:
    cnt = func_true_counts.get(fn, 0)
    print(f"{fn:35s} {cnt:4d} / {k} = {safe_div(cnt, k):.2%}")
print()


def report_group(group_name, funcs, macro_count):
    """
    micro: sum(true of funcs) / (k * #funcs)
    macro: group-all-ok samples / k
    """
    n = len(funcs)
    if n == 0:
        return
    true_sum = sum(func_true_counts.get(f, 0) for f in funcs)
    micro = safe_div(true_sum, k * n)
    macro = safe_div(macro_count, k)
    print(f"{group_name}")
    print(f"  micro: {true_sum} / ({k}*{n}) = {micro:.2%}")
    print(f"  macro: {macro_count} / {k} = {macro:.2%}")


print("---- 大类 Basic Feasibility / Planning Soundness micro&macro ----")
for major, funcs in major2funcs.items():
    report_group(f"[{major}]", funcs, major_all_ok_counts.get(major, 0))
print()

print("---- 子类 micro&macro ----")
for (major, sub), funcs in sub2funcs.items():
    report_group(f"[{major} :: {sub}]", funcs, sub_all_ok_counts.get((major, sub), 0))
print()


# =========================
# 6) 输出：rubric 统计（两套正确率：不含 missing / 含 missing）
# =========================
print("\n====================")
print("RUBRIC 统计（按 rubric_key；每个 trip 每个 key 只算 1 条 rubric）")
print("====================\n")

print(f"[INFO] total expected rubric keys = {total_expected_rubrics}")

# macro：trip strict（missing 也会导致不通过）
print(f"[INFO] rubric macro（trip strict：该 trip 所有 expected keys seen>0 且无 false） = "
      f"{rubric_all_correct_trip_count} / {k} = {safe_div(rubric_all_correct_trip_count, k):.2%}")
print(f"[INFO] general+rubric strict(all-correct) = "
      f"{general_and_rubric_all_correct} / {k} = {safe_div(general_and_rubric_all_correct, k):.2%}\n")

# micro 两套：
rubric_micro_with_missing = safe_div(total_key_pass, total_expected_rubrics)  # missing 计入分母（等价于 missing 视为 fail）
rubric_micro_ignore_missing = safe_div(total_key_pass, total_key_seen)        # 只在 seen 的 key 上算准确率

print(f"rubric micro（考虑 missing）: {total_key_pass} / {total_expected_rubrics} = {rubric_micro_with_missing:.2%}")
print(f"rubric micro（不考虑 missing）: {total_key_pass} / {total_key_seen} = {rubric_micro_ignore_missing:.2%}\n")

# per-rubric（两套 acc）
print("---- 每条 rubric（按 rubric_key）正确率：不考虑 missing vs 考虑 missing ----")
print(f"{'rubric_key':35s} "
      f"{'pass':>6s} "
      f"{'seen':>6s} "
      f"{'expected':>9s} "
      f"{'acc_no_missing':>15s} "
      f"{'acc_with_missing':>17s} "
      f"{'missing':>8s}")

for rk, expected_cnt in sorted(rubric_expected_counter.items(), key=lambda x: x[1], reverse=True):
    pass_cnt = rubric_key_pass_counts.get(rk, 0)
    seen_cnt = rubric_key_seen_counts.get(rk, 0)
    missing_cnt = expected_cnt - seen_cnt

    acc_no_missing = safe_div(pass_cnt, seen_cnt)          # 不考虑 missing：在 seen 的样本上算
    acc_with_missing = safe_div(pass_cnt, expected_cnt)    # 考虑 missing：missing 视为 fail

    print(f"{rk:35s} "
          f"{pass_cnt:6d} "
          f"{seen_cnt:6d} "
          f"{expected_cnt:9d} "
          f"{acc_no_missing:15.2%} "
          f"{acc_with_missing:17.2%} "
          f"{missing_cnt:8d}")

# 全局 missing/failed 诊断（按 expected key 计）
print("\n---- [DIAG] 期望 key 维度的缺失/失败统计（按 expected key 计数）----")
print(f"[DIAG] expected-key missing(total) = {rubric_key_missing_total}  "
      f"({safe_div(rubric_key_missing_total, total_expected_rubrics):.2%} of expected)")
print(f"[DIAG] expected-key failed(total)  = {rubric_key_failed_total}  "
      f"({safe_div(rubric_key_failed_total, total_expected_rubrics):.2%} of expected)")
