import os
import json
from pathlib import Path
from collections import defaultdict

# =========================
# 路径（按需改）
# =========================
# TRACES_JSONL = " TRAIN_JSON   = " OUT_JSON     = " = "   = "     = " =========================
# 第一阶段：trip_id 过滤
# =========================
def load_trip_ids(jsonl_path: str) -> set:
    trip_ids = set()
    bad_lines = 0
    with open(jsonl_path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                bad_lines += 1
                continue
            tid = obj.get("trip_id") or (obj.get("summary") or {}).get("trip_id")
            if tid is not None:
                trip_ids.add(tid)
    print(f"[1] trip_id unique count = {len(trip_ids)}, bad lines = {bad_lines}")
    return trip_ids


def get_trip_id(sample: dict):
    if isinstance(sample, dict):
        return sample.get("trip_id")
    return None


def load_train_items(path: str):
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if isinstance(data, list):
        return data, "list", None
    if isinstance(data, dict) and isinstance(data.get("data"), list):
        return data["data"], "dict_data", data

    raise ValueError(f"Unrecognized JSON structure in {path}")


def save_json(obj, path: str):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


# =========================
# 第二阶段： (trip_id, final_id_list) 过滤
# =========================
def canonical_final_id_list(final_id_list):
    if not isinstance(final_id_list, list):
        return tuple()
    return tuple(sorted({str(x) for x in final_id_list if x is not None}))


def make_sample_key(trip_id, final_id_list):
    return (trip_id, canonical_final_id_list(final_id_list))


def final_id_list_from_filter_sample(sample):
    chains = sample.get("applied_modification_chains")
    if isinstance(chains, dict) and chains:
        out = []
        for _, chain in chains.items():
            if isinstance(chain, list) and chain:
                out.append(chain[-1])
        return out

    fil = sample.get("final_id_list")
    if isinstance(fil, list):
        return fil

    return []


def final_id_list_from_trace_record(rec):
    summary = rec.get("summary") or {}
    fil = summary.get("final_id_list")
    if isinstance(fil, list):
        return fil

    fil2 = rec.get("final_id_list")
    if isinstance(fil2, list):
        return fil2

    return []


def iter_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        for ln, line in enumerate(f, 1):
            line = line.strip()
            if not line:
                continue
            try:
                yield json.loads(line)
            except Exception as e:
                print(f"[WARN] skip bad jsonl line {ln}: {e}")


# =========================
# main
# =========================
def main():
    # -------- 1) trip_id 过滤 --------
    trip_ids = load_trip_ids(TRACES_JSONL)

    items, mode, original_obj = load_train_items(TRAIN_JSON)
    print(f"[2] train total = {len(items)}")

    stage1 = []
    missing_tid = 0
    for s in items:
        tid = get_trip_id(s)
        if not tid:
            missing_tid += 1
            continue
        if tid in trip_ids:
            stage1.append(s)

    print(f"[2] after trip_id filter = {len(stage1)}, missing trip_id = {missing_tid}")

    # -------- 2) 收集 jsonl 中真实存在的 key --------
    jsonl_keys = set()
    for rec in iter_jsonl(TRACES_JSONL):
        tid = rec.get("trip_id") or (rec.get("summary") or {}).get("trip_id")
        if not tid:
            continue
        fil = final_id_list_from_trace_record(rec)
        jsonl_keys.add(make_sample_key(tid, fil))

    print(f"[3] jsonl unique (trip_id, final_id_list) = {len(jsonl_keys)}")

    # -------- 3) 二次过滤 --------
    kept = []
    dropped = []

    for s in stage1:
        tid = s.get("trip_id")
        fil = final_id_list_from_filter_sample(s)
        sk = make_sample_key(tid, fil)

        if sk in jsonl_keys:
            kept.append(s)
        else:
            dropped.append({
                "trip_id": tid,
                "final_id_list_canonical": list(sk[1]),
                "reason": "not_in_jsonl"
            })

    print(f"[4] kept = {len(kept)}, dropped = {len(dropped)}")

    # -------- 4) 写结果 --------
    if mode == "dict_data":
        out_obj = dict(original_obj)
        out_obj["data"] = kept
    else:
        out_obj = kept

    save_json(out_obj, OUT_JSON)

    print(f"[DONE] wrote:")
    print(f"  kept   -> {OUT_JSON}")


if __name__ == "__main__":
    main()

