import json

# === Path settings (edit these three only) ===
ALIGNMENT_PATH = "/path/to/src_tgt_alignment_with_waits.jsonl"   # list of records; each has "zh_tokens" and "alignment"
SRC_WORD_TIMESTAMPS_PATH = "/path/to/src_word_timestamps.jsonl"  # list of records; each has "words" with word-level times
TGT_SEGMENT_TIMETABLE_PATH = "/path/to/tgt_segment_timetable.jsonl"  # output JSONL

# === Load data ===
with open(ALIGNMENT_PATH, "r", encoding="utf-8") as f:
    alignments = [json.loads(line) for line in f]

with open(SRC_WORD_TIMESTAMPS_PATH, "r", encoding="utf-8") as f:
    timestamps = [json.loads(line) for line in f]

assert len(alignments) == len(timestamps), "Mismatched data length!"

# === Build tgt->src alignment map ===
def build_tgt_to_src_map(alignment):
    """
    alignment: list of pairs (src_idx, tgt_idx).
    Returns: dict mapping tgt_idx -> [src_idx, ...]
    """
    tgt_to_src = {}
    for src_idx, tgt_idx in alignment:
        tgt_to_src.setdefault(tgt_idx, []).append(src_idx)
    return tgt_to_src

# === Post-processing merge function ===
def merge_early_or_invalid_segments(time_table):
    """
    Merge segments that have no valid time or non-increasing start times.
    Keeps behavior identical to the original: if current.start <= last.start,
    append text to the last segment; otherwise start a new segment.
    """
    merged = []
    for seg in time_table:
        if not merged:
            merged.append(seg)
            continue

        last = merged[-1]

        # Condition: missing/invalid time or time goes backward/non-increasing
        if seg["start"] <= last["start"]:
            last["text"] += seg["text"]
        else:
            merged.append(seg)
    return merged

# === Main processing logic ===
all_results = []

for align_info, time_info in zip(alignments, timestamps):
    # Note: the input schema still uses "zh_tokens"; we bind it to tgt_tokens for public, language-agnostic code.
    tgt_tokens = align_info["zh_tokens"]
    src_word_times = time_info["words"]
    tgt_to_src = build_tgt_to_src_map(align_info["alignment"])

    # 1) Split target tokens by <WAIT>
    segments = []
    buffer = []
    base_tgt_idx = 0  # offset of segment start (index in tgt tokens)

    for i, tok in enumerate(tgt_tokens):
        if tok == "<WAIT>":
            if buffer:
                segments.append((base_tgt_idx, buffer))
                buffer = []
            base_tgt_idx = i + 1
        else:
            buffer.append(tok)
    if buffer:
        segments.append((base_tgt_idx, buffer))

    # 2) For each target segment, get a start time from the earliest aligned source word
    tgt_segment_with_time = []

    for tgt_start_idx, tgt_words in segments:
        src_indices = tgt_to_src.get(tgt_start_idx, [])

        # Fallback: scan forward to find the nearest aligned target index
        if not src_indices:
            for j in range(tgt_start_idx + 1, len(tgt_tokens)):
                if tgt_tokens[j] == "<WAIT>":
                    continue
                src_indices = tgt_to_src.get(j, [])
                if src_indices:
                    break

        # Collect valid source word start times
        valid_starts = [
            src_word_times[i]["start"]
            for i in src_indices
            if 0 <= i < len(src_word_times)
        ]

        if valid_starts:
            start_time = min(valid_starts)
            tgt_segment_with_time.append({
                "text": ''.join(tgt_words),
                "start": round(start_time, 2)
            })
        else:
            # Temporarily set start=0.0; will be merged later if invalid/out-of-order
            tgt_segment_with_time.append({
                "text": ''.join(tgt_words),
                "start": 0.0
            })

    # 3) Merge segments with invalid or non-increasing start times
    merged_segments = merge_early_or_invalid_segments(tgt_segment_with_time)

    all_results.append(merged_segments)

# === Save output ===
with open(TGT_SEGMENT_TIMETABLE_PATH, "w", encoding="utf-8") as fout:
    for item in all_results:
        fout.write(json.dumps(item, ensure_ascii=False) + "\n")
