import os
import json
from collections import defaultdict
from datetime import datetime

# =========================
# 路径（按需改）
# =========================
trace_path = " = " =========================
# 工具函数
# =========================
def normalize_phrase(s: str) -> str:
    if not isinstance(s, str):
        return ""
    s = s.strip()
    # 去掉末尾连续的英文句号
    while s.endswith("."):
        s = s[:-1].rstrip()
    return s


def load_json(path):
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)

def parse_time(s):
    if not s:
        return None
    try:
        return datetime.fromisoformat(s)
    except Exception:
        return None

def get_last_round_dict(rounds):
    """rounds 取最后一个 dict（从后往前找第一个 dict）"""
    if not isinstance(rounds, list):
        return None
    for x in reversed(rounds):
        if isinstance(x, dict):
            return x
    return None

def final_id_list_from_test_sample(sample):
    """
    目标文件中：final_id_list = 每个 rubric_key 的 modification chain 的最后一个 id
    sample["applied_modification_chains"][rk] 是 list，取最后一个
    """
    chains = sample.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        last_ids = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                last_ids.append(chain[-1])
        return last_ids

    # 兜底：如果目标文件直接带 final_id_list
    fil = sample.get("final_id_list")
    if isinstance(fil, list):
        return fil

    return []

def final_id_list_from_trace_record(data):
    """trace 中：通常在 data['summary']['final_id_list']"""
    summary = data.get("summary") or {}
    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil
    fil2 = data.get("final_id_list")
    if isinstance(fil2, list):
        return fil2
    return []

def canonical_final_id_list(final_id_list):
    """去重 + 排序，避免顺序不一致"""
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))

def make_sample_key(trip_id, final_id_list):
    return (trip_id, canonical_final_id_list(final_id_list))

def get_rounds_from_trace(data):
    """rounds 兼容 data['summary']['rounds'] 或顶层 data['rounds']"""
    if isinstance(data.get("rounds"), list):
        return data["rounds"]
    summary = data.get("summary") or {}
    if isinstance(summary.get("rounds"), list):
        return summary["rounds"]
    return []

def iter_trace_records(path_or_dir):
    """
    逐条 yield trace dict：
    - 目录：递归读取所有 .json
    - 文件：
        - .jsonl：按行 json.loads
        - .json：json.load
    """
    if os.path.isdir(path_or_dir):
        for root, _, files in os.walk(path_or_dir):
            for filename in files:
                if not filename.endswith(".json"):
                    continue
                file_path = os.path.join(root, filename)
                try:
                    yield load_json(file_path)
                except Exception as e:
                    print(f"[跳过] 读取失败 {file_path}: {e}")
        return

    if path_or_dir.endswith(".jsonl"):
        with open(path_or_dir, "r", encoding="utf-8") as f:
            for ln, line in enumerate(f, start=1):
                line = line.strip()
                if not line:
                    continue
                try:
                    yield json.loads(line)
                except Exception as e:
                    print(f"[跳过] 读取失败 {path_or_dir} line={ln}: {e}")
        return

    if path_or_dir.endswith(".json"):
        try:
            yield load_json(path_or_dir)
        except Exception as e:
            print(f"[跳过] 读取失败 {path_or_dir}: {e}")
        return

    raise ValueError(f"Unsupported path: {path_or_dir}")

def record_score(data, last_round):
    """
    用于同一个 sample_key 多条 trace 记录时挑“最好”的那条：
    1) 有 eval_result 优先
    2) round_idx 更大优先
    3) end_time 更晚优先
    """
    has_eval = 1 if isinstance(last_round, dict) and isinstance(last_round.get("eval_result"), dict) else 0
    ridx = last_round.get("round_idx", 0) if isinstance(last_round, dict) else 0

    endt = None
    if isinstance(last_round, dict):
        endt = parse_time(last_round.get("end_time"))
    if endt is None:
        endt = parse_time(data.get("end_time"))
    if endt is None:
        endt = parse_time((data.get("summary") or {}).get("end_time"))

    endt_ts = endt.timestamp() if endt else 0
    return (has_eval, ridx, endt_ts)

def safe_div(a, b):
    return a / b if b else 0.0

# =========================
# 1) 读取 testset：建立 sample_key 集合（用于对齐 traces）
# =========================
testset = load_json(testset_path)

if isinstance(testset, dict):
    if isinstance(testset.get("data"), list):
        testset_list = testset["data"]
    else:
        raise ValueError(f"Unrecognized testset format in {testset_path}")
elif isinstance(testset, list):
    testset_list = testset
else:
    raise ValueError(f"Unrecognized testset format in {testset_path}")

k = len(testset_list)  # 分母仍按目标文件条数

test_sample_keys = set()
dup_keys_in_testset = 0
for sample in testset_list:
    trip_id = sample.get("trip_id")
    fil = final_id_list_from_test_sample(sample)
    sk = make_sample_key(trip_id, fil)
    if sk in test_sample_keys:
        dup_keys_in_testset += 1
    test_sample_keys.add(sk)

print(f"[INFO] testset k = {k}")
print(f"[INFO] unique sample_key in testset = {len(test_sample_keys)}")
if dup_keys_in_testset:
    print(f"[WARN] duplicated sample_key in testset = {dup_keys_in_testset}")

# =========================
# 2) 遍历 traces：按 sample_key 聚合，解决 trip_id 重复，挑 best record
# =========================
best_record_by_key = {}  # sample_key -> (score, data, last_round)
seen_key_any = set()     # sample_key 在 traces 出现过（不论有没有 eval）

trace_records_total = 0
trace_records_in_testset = 0
skipped_not_in_testset = 0

for data in iter_trace_records(trace_path):
    trace_records_total += 1

    trip_id = data.get("trip_id") or (data.get("summary") or {}).get("trip_id")
    fil = final_id_list_from_trace_record(data)
    sk = make_sample_key(trip_id, fil)

    if sk not in test_sample_keys:
        skipped_not_in_testset += 1
        continue

    trace_records_in_testset += 1
    seen_key_any.add(sk)

    rounds = get_rounds_from_trace(data)
    last_round = get_last_round_dict(rounds)
    sc = record_score(data, last_round)

    if sk not in best_record_by_key or sc > best_record_by_key[sk][0]:
        best_record_by_key[sk] = (sc, data, last_round)

print(f"[INFO] trace_records_total = {trace_records_total}")
print(f"[INFO] trace_records_in_testset = {trace_records_in_testset}")
print(f"[INFO] skipped_not_in_testset(trace records) = {skipped_not_in_testset}")

# =========================
# 3) 统计 evaluate_local_transportation 的 msg（按 ; 切分）
# =========================
local_msg_counter = defaultdict(int)          # phrase -> occurrences (允许同一 msg 重复计数)
local_msg_total = 0                           # total phrase occurrences
local_msg_sample_counter = defaultdict(int)   # phrase -> #samples where phrase appears (同一样本去重)
local_msg_samples = 0                         # #samples that have a non-empty msg for this function

used_samples = 0
skipped_no_eval = 0
missing_in_traces = 0

for sk in test_sample_keys:
    if sk not in seen_key_any:
        missing_in_traces += 1
        continue

    sc_data = best_record_by_key.get(sk)
    if not sc_data:
        skipped_no_eval += 1
        continue

    _, data, last_round = sc_data
    if not isinstance(last_round, dict) or not isinstance(last_round.get("eval_result"), dict):
        skipped_no_eval += 1
        continue

    used_samples += 1
    eval_result = last_round["eval_result"]

    general_results = eval_result.get("general_results", [])
    if not isinstance(general_results, list):
        continue

    for g in general_results:
        if not isinstance(g, dict):
            continue
        if g.get("name") != "evaluate_local_transportation":
            continue

        msg = g.get("msg") or ""
        parts = [normalize_phrase(p) for p in msg.split(";")]
        parts = [p for p in parts if p]

        if not parts:
            continue

        local_msg_samples += 1

        # 口径A：按出现次数（允许重复）
        for p in parts:
            local_msg_counter[p] += 1
            local_msg_total += 1

        # 口径B：按样本覆盖率（同一样本去重）
        for p in set(parts):
            local_msg_sample_counter[p] += 1

print(f"[INFO] used_samples(sample_key with eval) = {used_samples}")
print(f"[INFO] missing_in_traces(sample_key never appeared) = {missing_in_traces}")
print(f"[INFO] skipped_no_eval(sample_key appeared but no eval) = {skipped_no_eval}")

# =========================
# 4) 输出比例
# =========================
print("\n====================")
print("evaluate_local_transportation 的 msg 说法占比（按 ; 切分）")
print("====================\n")

if local_msg_total == 0:
    print("[INFO] 没有收集到任何 evaluate_local_transportation 的非空 msg")
else:
    print(f"[INFO] 有 msg 的样本数: {local_msg_samples} / used_samples={used_samples} / k={k}")
    print(f"[INFO] phrase 总出现次数(允许重复计数): {local_msg_total}\n")

    print("---- 口径A：按出现次数占比（允许同一 msg 重复计数）----")
    for phrase, cnt in sorted(local_msg_counter.items(), key=lambda x: x[1], reverse=True):
        print(f"{phrase}  |  {cnt:6d}  |  {safe_div(cnt, local_msg_total):.2%}")

    print("\n---- 口径B：按样本覆盖率（同一样本去重）----")
    for phrase, scnt in sorted(local_msg_sample_counter.items(), key=lambda x: x[1], reverse=True):
        print(f"{phrase}  |  {scnt:6d}  |  {safe_div(scnt, local_msg_samples):.2%}")
