import json
from pathlib import Path

# === Paths (adjust as needed) ===
ALIGNMENT_PATH = "/path/to/alignment_with_wait_tokens/jsonl"
TIMESTAMP_PATH = "/path/to/timestamp_src_language.jsonl"
OUTPUT_PATH    = "/path/to/output.jsonl"

# === Parameters ===
WAIT_TOKEN = "<WAIT>"
MERGE_BACKWARD_EPS = 1e-3
PRINT_DEBUG_SAMPLES = 8
SYNTHETIC_TOKEN_SEC = 0.30   # fallback step length (sec/token)

# === I/O ===
def read_jsonl(path):
    out = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                out.append(json.loads(line))
    return out

def write_jsonl(path, rows):
    Path(path).parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w", encoding="utf-8") as fout:
        for row in rows:
            fout.write(json.dumps(row, ensure_ascii=False) + "\n")

# === Utils ===
def is_wait(tok, wait_token=WAIT_TOKEN):
    return isinstance(tok, str) and tok.strip().upper() == (wait_token or "").strip().upper()

def build_tgt_to_en_map(alignment):
    mapping = {}
    if not isinstance(alignment, (list, tuple)):
        return mapping
    for pair in alignment:
        if isinstance(pair, (list, tuple)) and len(pair) >= 2:
            en_idx, tgt_idx = pair[0], pair[1]
        elif isinstance(pair, dict) and "en_idx" in pair:
            if "tgt_idx" in pair:
                en_idx, tgt_idx = pair["en_idx"], pair["tgt_idx"]
            elif "zh_idx" in pair:
                en_idx, tgt_idx = pair["en_idx"], pair["zh_idx"]
            elif "de_idx" in pair:
                en_idx, tgt_idx = pair["en_idx"], pair["de_idx"]
            else:
                continue
        else:
            continue
        if isinstance(tgt_idx, int) and isinstance(en_idx, int):
            mapping.setdefault(tgt_idx, []).append(en_idx)
    return mapping

def _extract_segments_anywhere(time_info):
    if not isinstance(time_info, dict):
        return []
    for key in ("whisper_result", "result", "whisper", "asr", "stt"):
        obj = time_info.get(key)
        if isinstance(obj, dict) and isinstance(obj.get("segments"), list):
            return obj["segments"]
    if isinstance(time_info.get("segments"), list):
        return time_info["segments"]
    return []

def _try_total_duration(time_info):
    if not isinstance(time_info, dict):
        return None
    segs = _extract_segments_anywhere(time_info)
    if segs and segs[0].get("start") is not None and segs[-1].get("end") is not None:
        try:
            return float(segs[-1]["end"]) - float(segs[0]["start"])
        except Exception:
            pass
    for k in ("audio_duration", "duration", "total_duration", "final_duration"):
        if k in time_info:
            try:
                val = float(time_info[k])
                if val > 0:
                    return val
            except Exception:
                continue
    for k in ("meta", "info", "stats"):
        v = time_info.get(k)
        if isinstance(v, dict):
            for kk in ("audio_duration", "duration", "total_duration", "final_duration"):
                if kk in v:
                    try:
                        vv = float(v[kk])
                        if vv > 0:
                            return vv
                    except Exception:
                        continue
    return None

def flatten_word_times_flexible(time_info, fallback_en_tokens=None):
    """
    Build a flat list of word-level timestamps:
    1) Use word-level times if available.
    2) Otherwise, distribute tokens evenly across segment start/end.
    3) If no timing info at all, use synthetic step.
    """
    segs = _extract_segments_anywhere(time_info) or []
    flat = []

    # Case 1: word-level timestamps
    for seg in segs:
        words = seg.get("words")
        if isinstance(words, list) and words:
            for w in words:
                flat.append({
                    "word": w.get("word") or "",
                    "start": w.get("start", None),
                    "end": w.get("end", None),
                })
    if flat:
        return flat, False

    # Prepare fallback tokens
    toks = list(fallback_en_tokens or [])
    if not toks:
        joined = " ".join([str(s.get("text","")) for s in segs if isinstance(s.get("text"), str)]).strip()
        if joined:
            toks = joined.split()

    # Case 2: distribute across segment duration
    if segs and toks:
        segs_with_time = [s for s in segs if s.get("start") is not None and s.get("end") is not None]
        if segs_with_time:
            total_start = float(segs_with_time[0]["start"])
            total_end   = float(segs_with_time[-1]["end"])
            total_dur   = max(0.0, total_end - total_start)
            n = max(1, len(toks))
            avg = total_dur / n if n else 0.0
            cur = total_start
            for tok in toks:
                flat.append({"word": tok, "start": cur, "end": cur + avg})
                cur += avg
            return flat, False

    # Case 3: synthetic timeline
    if toks:
        total_dur = _try_total_duration(time_info)
        if total_dur is None:
            total_dur = SYNTHETIC_TOKEN_SEC * len(toks)
        n = max(1, len(toks))
        avg = (total_dur / n) if n else SYNTHETIC_TOKEN_SEC
        cur = 0.0
        for tok in toks:
            flat.append({"word": tok, "start": cur, "end": cur + avg})
            cur += avg
        return flat, True
    return [], True

def merge_early_or_invalid_segments_stable(time_table, eps=MERGE_BACKWARD_EPS):
    merged = []
    for seg in time_table:
        seg = dict(seg)
        if not merged:
            seg["_raw_start"] = seg.get("start")
            merged.append(seg)
            continue
        last = merged[-1]
        s_cur = seg.get("start")
        s_last = last.get("_raw_start")
        if s_cur is None or s_last is None or (s_cur < s_last - eps):
            last["text"] = last.get("text", "") + seg.get("text", "")
            keep = min(x for x in (s_cur, s_last) if x is not None)
            last["_raw_start"] = keep
        else:
            seg["_raw_start"] = s_cur
            merged.append(seg)
    for seg in merged:
        st = seg.get("_raw_start")
        seg["start"] = round(float(st), 2) if st is not None else None
        seg.pop("_raw_start", None)
    return merged

def detect_tokens_with_waits(align_obj):
    """
    Try to locate a token field containing WAIT tokens.
    """
    candidates = []
    for k, v in align_obj.items():
        if isinstance(v, list) and v and all(isinstance(x, str) for x in v):
            wait_cnt = sum(1 for x in v if is_wait(x))
            if wait_cnt > 0:
                candidates.append((k, v, wait_cnt))
    if candidates:
        candidates.sort(key=lambda x: x[2], reverse=True)
        return candidates[0]
    for k in ("tgt_tokens", "de_tokens", "zh_tokens", "target_tokens"):
        v = align_obj.get(k)
        if isinstance(v, list) and v and all(isinstance(x, str) for x in v):
            return k, v, sum(1 for x in v if is_wait(x))
    return None, [], 0

# === Main process ===
alignments = read_jsonl(ALIGNMENT_PATH)
timestamps = read_jsonl(TIMESTAMP_PATH)
assert len(alignments) == len(timestamps), f"Mismatched data length: {len(alignments)} vs {len(timestamps)}"

all_results = []
synthetic_used_any = False

for idx, (align_info, time_info) in enumerate(zip(alignments, timestamps)):
    en_tokens = align_info.get("en_tokens") or align_info.get("source_tokens") or []
    tgt_field, tgt_tokens, wait_count = detect_tokens_with_waits(align_info)
    if not tgt_tokens:
        tgt_tokens = align_info.get("tgt_tokens") or align_info.get("zh_tokens") or align_info.get("de_tokens") or align_info.get("target_tokens") or []
        wait_count = sum(1 for x in tgt_tokens if is_wait(x))

    tgt_to_en = build_tgt_to_en_map(align_info.get("alignment", []))

    en_word_times, used_synthetic = flatten_word_times_flexible(time_info, fallback_en_tokens=en_tokens)
    synthetic_used_any = synthetic_used_any or used_synthetic
    N = len(en_word_times)

    # Segment by WAIT
    segments, buffer, base_idx = [], [], 0
    for i, tok in enumerate(tgt_tokens):
        if is_wait(tok):
            if buffer:
                segments.append((base_idx, i, buffer))
                buffer = []
            base_idx = i + 1
        else:
            buffer.append(tok)
    if buffer:
        segments.append((base_idx, len(tgt_tokens), buffer))

    # Anchor each segment to earliest aligned English word
    segs_with_time = []
    for l, r, toks in segments:
        en_indices = []
        for j in range(l, r):
            if not is_wait(tgt_tokens[j]):
                en_indices.extend(tgt_to_en.get(j, []))
        if not en_indices:
            for j in range(r, len(tgt_tokens)):
                if not is_wait(tgt_tokens[j]) and tgt_to_en.get(j):
                    en_indices = tgt_to_en[j]; break
        if not en_indices:
            for j in range(l - 1, -1, -1):
                if not is_wait(tgt_tokens[j]) and tgt_to_en.get(j):
                    en_indices = tgt_to_en[j]; break
        valid_starts = [float(en_word_times[ei]["start"]) for ei in en_indices if 0 <= ei < N and en_word_times[ei].get("start") is not None]
        start_raw = min(valid_starts) if valid_starts else None
        segs_with_time.append({"text": "".join(toks), "start": start_raw})

    merged_segments = merge_early_or_invalid_segments_stable(segs_with_time, eps=MERGE_BACKWARD_EPS)
    all_results.append(merged_segments)

    # Debug output
    if idx < PRINT_DEBUG_SAMPLES:
        al = align_info.get("alignment", [])
        max_en_idx, total_en_links, valid_en_links = -1, 0, 0
        for pair in al:
            if isinstance(pair, (list, tuple)) and len(pair) >= 2 and isinstance(pair[0], int):
                total_en_links += 1
                max_en_idx = max(max_en_idx, pair[0])
                if 0 <= pair[0] < N:
                    valid_en_links += 1
            elif isinstance(pair, dict) and isinstance(pair.get("en_idx"), int):
                total_en_links += 1
                max_en_idx = max(max_en_idx, pair["en_idx"])
                if 0 <= pair["en_idx"] < N:
                    valid_en_links += 1
        print(f"[{idx}] tokens_field={tgt_field or '(fallback)'} wait_count={wait_count}")
        print(f"[{idx}] N words: {N} | max en_idx in alignment: {max_en_idx} | synthetic={used_synthetic}")
        print(f"[{idx}] valid en_idx ratio: {valid_en_links}/{total_en_links}")
        if merged_segments:
            for k, seg in enumerate(merged_segments[:3]):
                print(f"   - #{k} start={seg['start']} text={seg['text'][:60]}{'...' if len(seg['text'])>60 else ''}")
        print("-" * 60)

# === Save results ===
write_jsonl(OUTPUT_PATH, all_results)
print(f"Saved -> {OUTPUT_PATH}")
print(f"Total items: {len(all_results)}")
if synthetic_used_any:
    print(f"NOTE: Some items used a SYNTHETIC timeline. Step = {SYNTHETIC_TOKEN_SEC:.2f}s/token.")
