import os
import json
from collections import defaultdict
from datetime import datetime

# =========================
# 分类映射（你给的）
# =========================
evaluation_functions = {
    "Basic Feasibility": {
        "JSON Structural Validity": [
            "validate_trip_plan_json"
        ],
        "POI Validity": [
            "evaluate_id_existence",
            "evaluate_city_consistency"
        ],
        "Information Completeness": [
            "evaluate_meta"
        ]
    },
    "Planning Soundness": {
        "Temporal Reasonableness": [
            "evaluate_daily_timing",
            "evaluate_activity_duration",
            "evaluate_opening_hours",
            "evaluate_intercity_transportation",
            "evaluate_local_transportation"
        ],
        "Spatial Logic": [
            "evaluate_restaurant_location"
        ],
        "Experience Diversity": [
            "evaluate_attraction_restaurant_diversity"
        ],
        "Product Consistency": [
            "evaluate_product_requirements"
        ]
    }
}

# =========================
# 路径
# =========================
# traces：jsonl（每行一个 json）
trace_path = " 目标文件 / testset：你贴的那种 list[dict] 结构
testset_path = " = " 目标文件 / testset：你贴的那种 list[dict] 结构
testset_path = " trace_path = " # 目标文件 / testset：你贴的那种 list[dict] 结构
# testset_path = " =========================
# 工具函数
# =========================
def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def safe_div(a, b):
    return a / b if b else 0.0


def parse_time(s):
    if not s:
        return None
    try:
        return datetime.fromisoformat(s)
    except Exception:
        return None


def get_last_round_dict(rounds):
    """rounds 取最后一个 dict（从后往前找第一个 dict）"""
    if not isinstance(rounds, list):
        return None
    for x in reversed(rounds):
        if isinstance(x, dict):
            return x
    return None


def build_func_maps(evaluation_functions):
    """构建 func -> (major, sub) 的映射，以及各层级的 func 列表"""
    func2group = {}
    major2funcs = defaultdict(list)
    sub2funcs = defaultdict(list)  # key=(major, sub)

    for major, submap in evaluation_functions.items():
        for sub, funcs in submap.items():
            for fn in funcs:
                func2group[fn] = (major, sub)
                major2funcs[major].append(fn)
                sub2funcs[(major, sub)].append(fn)

    all_funcs = list(func2group.keys())
    return func2group, all_funcs, major2funcs, sub2funcs


def get_expected_rubric_keys(sample):
    """
    统计单条样本里出现了哪些 rubric_key：每个 key 记 1 次
    优先：applied_modification_chains 的 keys
    fallback：rubric_results 的 keys
    """
    keys = set()

    chains = sample.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        keys.update(chains.keys())
        return keys

    rr = sample.get("rubric_results")
    if isinstance(rr, dict) and rr:
        keys.update(rr.keys())
        return keys

    return keys


def final_id_list_from_test_sample(sample):
    """
    目标文件中：final_id_list = 每个 rubric_key 的 modification chain 的最后一个 id
    sample["applied_modification_chains"][rk] 是 list，取最后一个
    """
    chains = sample.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        last_ids = []
        for rk, chain in chains.items():
            if isinstance(chain, list) and chain:
                last_ids.append(chain[-1])
        return last_ids

    # 兜底：如果目标文件直接带 final_id_list
    fil = sample.get("final_id_list")
    if isinstance(fil, list):
        return fil

    return []


def final_id_list_from_trace_record(data):
    """
    trace 中：通常在 data["summary"]["final_id_list"]
    """
    summary = data.get("summary") or {}
    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil

    # 兜底：少数数据可能在顶层
    fil2 = data.get("final_id_list")
    if isinstance(fil2, list):
        return fil2

    return []


def canonical_final_id_list(final_id_list):
    """
    为了避免顺序不一致：去重 + 排序
    """
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))


def make_sample_key(trip_id, final_id_list):
    return (trip_id, canonical_final_id_list(final_id_list))


def get_rounds_from_trace(data):
    """
    你的 trace 样例：rounds 在 data["summary"]["rounds"]
    也兼容少数把 rounds 放在顶层的情况
    """
    if isinstance(data.get("rounds"), list):
        return data["rounds"]
    summary = data.get("summary") or {}
    if isinstance(summary.get("rounds"), list):
        return summary["rounds"]
    return []


def iter_trace_records(path_or_dir):
    """
    逐条 yield trace dict：
    - 目录：递归读取所有 .json
    - 文件：
        - .jsonl：按行 json.loads
        - .json：json.load
    """
    if os.path.isdir(path_or_dir):
        for root, _, files in os.walk(path_or_dir):
            for filename in files:
                if not filename.endswith(".json"):
                    continue
                file_path = os.path.join(root, filename)
                try:
                    yield load_json(file_path)
                except Exception as e:
                    print(f"[跳过] 读取失败 {file_path}: {e}")
        return

    if path_or_dir.endswith(".jsonl"):
        with open(path_or_dir, "r", encoding="utf-8") as f:
            for ln, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    yield json.loads(line)
                except Exception as e:
                    print(f"[跳过] 读取失败 {path_or_dir} line={ln}: {e}")
        return

    if path_or_dir.endswith(".json"):
        try:
            yield load_json(path_or_dir)
        except Exception as e:
            print(f"[跳过] 读取失败 {path_or_dir}: {e}")
        return

    raise ValueError(f"Unsupported path: {path_or_dir}")


def record_score(data, last_round):
    """
    用于同一个 sample_key 多条 trace 记录时挑“最好”的那条：
    1) 有 eval_result 优先
    2) round_idx 更大优先
    3) end_time 更晚优先
    """
    has_eval = 1 if isinstance(last_round, dict) and isinstance(last_round.get("eval_result"), dict) else 0
    ridx = last_round.get("round_idx", 0) if isinstance(last_round, dict) else 0

    endt = None
    if isinstance(last_round, dict):
        endt = parse_time(last_round.get("end_time"))
    if endt is None:
        endt = parse_time(data.get("end_time"))
    if endt is None:
        endt = parse_time((data.get("summary") or {}).get("end_time"))

    endt_ts = endt.timestamp() if endt else 0
    return (has_eval, ridx, endt_ts)


# =========================
# 1) 读取 testset：建立 sample_key -> expected_keys，并统计 rubric 分母
# =========================
testset = load_json(testset_path)

if isinstance(testset, dict):
    if isinstance(testset.get("data"), list):
        testset_list = testset["data"]
    else:
        raise ValueError(f"Unrecognized testset format in {testset_path}")
elif isinstance(testset, list):
    testset_list = testset
else:
    raise ValueError(f"Unrecognized testset format in {testset_path}")

k = len(testset_list)  # 仍按目标文件条数做分母（和你之前一致）

expected_keys_per_sample = {}           # sample_key -> set(expected rubric_keys)
test_sample_keys = set()               # 所有 sample_key（用于对齐 traces）
rubric_expected_counter = defaultdict(int)

dup_keys_in_testset = 0

for sample in testset_list:
    trip_id = sample.get("trip_id")
    fil = final_id_list_from_test_sample(sample)
    sk = make_sample_key(trip_id, fil)

    if sk in test_sample_keys:
        dup_keys_in_testset += 1
    test_sample_keys.add(sk)

    keys = get_expected_rubric_keys(sample)
    expected_keys_per_sample[sk] = keys

    for rk in keys:
        rubric_expected_counter[rk] += 1

total_expected_rubrics = sum(rubric_expected_counter.values())

print(f"[INFO] testset k = {k}")
print(f"[INFO] unique sample_key in testset = {len(test_sample_keys)}")
if dup_keys_in_testset:
    print(f"[WARN] duplicated sample_key in testset = {dup_keys_in_testset}  (后写会覆盖前写的 expected_keys_per_sample)")
print(f"[INFO] total expected rubric keys (per-sample each key counted once) = {total_expected_rubrics}")


# =========================
# 2) general 统计准备
# =========================
func2group, all_funcs, major2funcs, sub2funcs = build_func_maps(evaluation_functions)

func_true_counts = defaultdict(int)  # func -> #ok_true (across samples, divide by k)
general_all_ok_count = 0
validate_trip_plan_ok_count = 0

major_all_ok_counts = defaultdict(int)   # major -> #samples group-all-ok
sub_all_ok_counts = defaultdict(int)     # (major, sub) -> #samples group-all-ok


# =========================
# 3) rubric 统计准备
# =========================
rubric_all_correct_trip_count = 0
general_and_rubric_all_correct = 0

rubric_key_pass_counts = defaultdict(int)  # rk -> #samples passed (seen>0 and false==0)
rubric_key_seen_counts = defaultdict(int)  # rk -> #samples where rk appeared in eval (seen>0)

total_key_pass = 0
total_key_seen = 0

rubric_key_missing_total = 0
rubric_key_failed_total = 0


# =========================
# 4) 先遍历 traces：按 sample_key 聚合，解决 trip_id 重复
# =========================
best_record_by_key = {}   # sample_key -> (score, data, last_round)
seen_key_any = set()      # sample_key 在 traces 出现过（不论有没有 eval）

skipped_not_in_testset = 0
trace_records_total = 0
trace_records_in_testset = 0

for data in iter_trace_records(trace_path):
    trace_records_total += 1

    trip_id = data.get("trip_id") or (data.get("summary") or {}).get("trip_id")
    fil = final_id_list_from_trace_record(data)
    sk = make_sample_key(trip_id, fil)

    if sk not in test_sample_keys:
        skipped_not_in_testset += 1
        continue

    trace_records_in_testset += 1
    seen_key_any.add(sk)

    rounds = get_rounds_from_trace(data)
    last_round = get_last_round_dict(rounds)
    sc = record_score(data, last_round)

    if sk not in best_record_by_key or sc > best_record_by_key[sk][0]:
        best_record_by_key[sk] = (sc, data, last_round)


# =========================
# 5) 遍历 testset 的 sample_key 做统计（分母固定 k）
# =========================
used_samples = 0              # 从 traces 里找到“有 eval”的 sample_key 数
skipped_no_eval = 0           # sample_key 在 traces 出现，但最终没拿到 eval
missing_in_traces = 0         # sample_key 根本没在 traces 出现过

for sk in test_sample_keys:
    if sk not in seen_key_any:
        missing_in_traces += 1
        continue

    sc_data = best_record_by_key.get(sk)
    if not sc_data:
        skipped_no_eval += 1
        continue

    _, data, last_round = sc_data
    if not isinstance(last_round, dict) or not isinstance(last_round.get("eval_result"), dict):
        skipped_no_eval += 1
        continue

    used_samples += 1
    eval_result = last_round["eval_result"]

    # -------- general_results
    general_results = eval_result.get("general_results", [])
    func_ok_map = {}  # func -> bool
    if isinstance(general_results, list):
        for g in general_results:
            if not isinstance(g, dict):
                continue
            name = g.get("name")
            ok = g.get("ok")
            if name:
                func_ok_map[name] = (ok is True)

    all_general_ok = True
    for fn in all_funcs:
        ok = func_ok_map.get(fn, False)
        if ok:
            func_true_counts[fn] += 1
        else:
            all_general_ok = False

    if func_ok_map.get("validate_trip_plan_json", False):
        validate_trip_plan_ok_count += 1
    if all_general_ok:
        general_all_ok_count += 1

    # macro：组内全对
    for major, funcs in major2funcs.items():
        if all(func_ok_map.get(fn, False) for fn in funcs):
            major_all_ok_counts[major] += 1
    for (major, sub), funcs in sub2funcs.items():
        if all(func_ok_map.get(fn, False) for fn in funcs):
            sub_all_ok_counts[(major, sub)] += 1

    # -------- rubric：按 key 聚合判定（每个 sample_key 每个 rubric_key 只算 1 条）
    id_true_results = eval_result.get("id_true_results", [])
    id_false_results = eval_result.get("id_false_results", [])

    rk_stats = defaultdict(lambda: {"seen": 0, "false": 0})

    for item in id_true_results:
        if not isinstance(item, dict):
            continue
        rk = item.get("rubric_key")
        if not rk:
            continue
        rk_stats[rk]["seen"] += 1
        if item.get("ok") is not True:
            rk_stats[rk]["false"] += 1

    for item in id_false_results:
        if not isinstance(item, dict):
            continue
        rk = item.get("rubric_key")
        if not rk:
            continue
        rk_stats[rk]["seen"] += 1
        rk_stats[rk]["false"] += 1

    expected_keys = expected_keys_per_sample.get(sk)
    rubric_all_correct = True

    if expected_keys is not None:
        for rk in expected_keys:
            st = rk_stats.get(rk, {"seen": 0, "false": 0})

            if st["seen"] > 0:
                rubric_key_seen_counts[rk] += 1
                total_key_seen += 1

            if st["seen"] <= 0:
                rubric_all_correct = False
                rubric_key_missing_total += 1
                continue

            if st["false"] > 0:
                rubric_all_correct = False
                rubric_key_failed_total += 1
                continue

            rubric_key_pass_counts[rk] += 1
            total_key_pass += 1
    else:
        rubric_all_correct = False

    if rubric_all_correct:
        rubric_all_correct_trip_count += 1
    if all_general_ok and rubric_all_correct:
        general_and_rubric_all_correct += 1


# =========================
# 6) 输出：general
# =========================
print("\n====================")
print("GENERAL 统计（分母固定 k）")
print("====================\n")

print(f"[INFO] trace_records_total = {trace_records_total}")
print(f"[INFO] trace_records_in_testset = {trace_records_in_testset}")
print(f"[INFO] used_samples_from_traces(sample_key with eval) = {used_samples}")
print(f"[INFO] missing_in_traces(sample_key never appeared) = {missing_in_traces}")
print(f"[INFO] skipped_not_in_testset(trace records) = {skipped_not_in_testset}")
print(f"[INFO] skipped_no_eval(sample_key appeared but no eval) = {skipped_no_eval}")
print(f"[INFO] NOTE: 所有比例分母都用 testset k={k}（不是 used_samples）。\n")

print(f"① validate_trip_plan_json 通过率: {validate_trip_plan_ok_count} / {k} = {safe_div(validate_trip_plan_ok_count, k):.2%}")
print(f"② general 全部函数全对比例:       {general_all_ok_count} / {k} = {safe_div(general_all_ok_count, k):.2%}\n")

print("---- 单个 general function 通过率（true_count/k）----")
for fn in all_funcs:
    cnt = func_true_counts.get(fn, 0)
    print(f"{fn:35s} {cnt:4d} / {k} = {safe_div(cnt, k):.2%}")
print()


def report_group(group_name, funcs, macro_count):
    """
    micro: sum(true of funcs) / (k * #funcs)
    macro: group-all-ok samples / k
    """
    n = len(funcs)
    if n == 0:
        return
    true_sum = sum(func_true_counts.get(f, 0) for f in funcs)
    micro = safe_div(true_sum, k * n)
    macro = safe_div(macro_count, k)
    print(f"{group_name}")
    print(f"  micro: {true_sum} / ({k}*{n}) = {micro:.2%}")
    print(f"  macro: {macro_count} / {k} = {macro:.2%}")


print("---- 大类 Basic Feasibility / Planning Soundness micro&macro ----")
for major, funcs in major2funcs.items():
    report_group(f"[{major}]", funcs, major_all_ok_counts.get(major, 0))
print()

print("---- 子类 micro&macro ----")
for (major, sub), funcs in sub2funcs.items():
    report_group(f"[{major} :: {sub}]", funcs, sub_all_ok_counts.get((major, sub), 0))
print()


# =========================
# 7) 输出：rubric
# =========================
print("\n====================")
print("RUBRIC 统计（按 rubric_key；每个 sample_key 每个 key 只算 1 条 rubric）")
print("====================\n")

print(f"[INFO] total expected rubric keys = {total_expected_rubrics}")

print(f"[INFO] rubric macro（sample strict：该样本所有 expected keys seen>0 且无 false） = "
      f"{rubric_all_correct_trip_count} / {k} = {safe_div(rubric_all_correct_trip_count, k):.2%}")
print(f"[INFO] general+rubric strict(all-correct) = "
      f"{general_and_rubric_all_correct} / {k} = {safe_div(general_and_rubric_all_correct, k):.2%}\n")

rubric_micro_with_missing = safe_div(total_key_pass, total_expected_rubrics)
rubric_micro_ignore_missing = safe_div(total_key_pass, total_key_seen)

print(f"rubric micro（考虑 missing）: {total_key_pass} / {total_expected_rubrics} = {rubric_micro_with_missing:.2%}")
print(f"rubric micro（不考虑 missing）: {total_key_pass} / {total_key_seen} = {rubric_micro_ignore_missing:.2%}\n")

print("---- 每条 rubric（按 rubric_key）正确率：不考虑 missing vs 考虑 missing ----")
print(f"{'rubric_key':35s} "
      f"{'pass':>6s} "
      f"{'seen':>6s} "
      f"{'expected':>9s} "
      f"{'acc_no_missing':>15s} "
      f"{'acc_with_missing':>17s} "
      f"{'missing':>8s}")

for rk, expected_cnt in sorted(rubric_expected_counter.items(), key=lambda x: x[1], reverse=True):
    pass_cnt = rubric_key_pass_counts.get(rk, 0)
    seen_cnt = rubric_key_seen_counts.get(rk, 0)
    missing_cnt = expected_cnt - seen_cnt

    acc_no_missing = safe_div(pass_cnt, seen_cnt)
    acc_with_missing = safe_div(pass_cnt, expected_cnt)

    print(f"{rk:35s} "
          f"{pass_cnt:6d} "
          f"{seen_cnt:6d} "
          f"{expected_cnt:9d} "
          f"{acc_no_missing:15.2%} "
          f"{acc_with_missing:17.2%} "
          f"{missing_cnt:8d}")

print("\n---- [DIAG] 期望 key 维度的缺失/失败统计（按 expected key 计数）----")
print(f"[DIAG] expected-key missing(total) = {rubric_key_missing_total}  "
      f"({safe_div(rubric_key_missing_total, total_expected_rubrics):.2%} of expected)")
print(f"[DIAG] expected-key failed(total)  = {rubric_key_failed_total}  "
      f"({safe_div(rubric_key_failed_total, total_expected_rubrics):.2%} of expected)")
