import os
import json
import torch
import whisperx
import librosa
import pandas as pd
import math

# ======== File path ========
translation_txt = ""
tgt_audio_dir = ""
timestamp_jsonl_path = ""

# ======== WhisperX setup ========
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisperx.load_model("medium", device=device, language="", compute_type="int8") # Change according to needs
align_model, metadata = whisperx.load_align_model(language_code="", device=device)

# ======== Load the translation text ========
with open(translation_txt, "r", encoding="utf-8") as f:
    tgt_sentences = [line.strip() for line in f if line.strip()]

# ======== Load the source language timestamp data（jsonl） ========
timestamp_dict = {}
with open(timestamp_jsonl_path, "r", encoding="utf-8") as f:
    for k, line in enumerate(f, start=1):
        obj = json.loads(line)

        # Take the word-level timestamp: prioritize the top-level words, otherwise aggregate from whisper_result.segments
        words = obj.get("words")
        if words is None:
            words = []
            wr = obj.get("whisper_result", {})
            for seg in wr.get("segments", []):
                seg_words = seg.get("words", [])
                if seg_words:
                    words.extend(seg_words)

        if not words:
            print(f"[Skip] Missing source language timestamp: index {k}")
            continue

        timestamp_dict[k] = words


# ---------- Tool: Map the "Xth src word" to time ----------
def time_at_token_index(src_end_times, x):
    n = len(src_end_times)
    if n == 0:
        raise ValueError("empty src_end_times")
    if x <= 1:
        return src_end_times[0]
    if x >= n:
        return src_end_times[-1]
    lo = math.floor(x)
    hi = lo + 1
    frac = x - lo
    t_lo = src_end_times[lo - 1]
    t_hi = src_end_times[hi - 1]
    return t_lo + frac * (t_hi - t_lo)

# ---------- Calculate AL (seconds) ----------
def compute_average_lagging_seconds(tgt_audio_path, src_words):
    if not os.path.exists(tgt_audio_path):
        return None

    audio = whisperx.load_audio(tgt_audio_path)
    duration = librosa.get_duration(path=tgt_audio_path)
    asr_result = model.transcribe(audio, batch_size=8)
    aligned = whisperx.align(
        asr_result["segments"],
        align_model, metadata,
        audio, device,
        return_char_alignments=False
    )
    tgt_units = aligned.get("word_segments", [])
    if not tgt_units:
        return None

    tgt_times = [max(0.0, u.get("start", 0.0)) for u in tgt_units]
    Y_len = len(tgt_times)

    src_end_times = [float(w["end"]) for w in src_words if "end" in w]
    X_len = len(src_end_times)
    if X_len == 0 or Y_len == 0:
        return None

    g = []
    for t in range(Y_len):
        tgt_time = tgt_times[t]
        if tgt_time is None:
            zh_time = duration
        cnt = 0
        for et in src_end_times:
            if et <= tgt_time:
                cnt += 1
            else:
                break
        g.append(cnt)

    gamma = Y_len / X_len
    tau = Y_len
    for t in range(1, Y_len + 1):
        if g[t-1] >= X_len:
            tau = t
            break

    s = 0.0
    for t in range(1, tau + 1):
        policy_idx = min(max(float(g[t-1]), 1.0), float(X_len))
        policy_time = time_at_token_index(src_end_times, policy_idx)

        diag_idx = min(max((t - 1) / gamma, 1.0), float(X_len))
        diag_time = time_at_token_index(src_end_times, diag_idx)

        s += (policy_time - diag_time)

    al_seconds = s / tau
    al_seconds = max(al_seconds, 0.0)
    
    return al_seconds

def compute_average_lagging_tokens_from_g(g, X_len, Y_len):
    gamma = Y_len / X_len
    tau = Y_len
    for t in range(1, Y_len + 1):
        if g[t-1] >= X_len:
            tau = t
            break
    s = 0.0
    for t in range(1, tau + 1):
        s += g[t-1] - (t - 1) / gamma
    return s / tau

# ======== Main loop ========
results = []
for i, tgt_sent in enumerate(tgt_sentences, start=1):
    tgt_audio_path = os.path.join(tgt_audio_dir, f"{i-1}.wav")
    src_words = timestamp_dict.get(i)

    if not src_words:
        print(f"[Skip] Missing source language timestamp: index {i}")
        continue
    if not os.path.exists(tgt_audio_path):
        print(f"[Skip] The tgt language audio does not exist: {tgt_audio_path}")
        continue

    al_sec = compute_average_lagging_seconds(tgt_sent, tgt_audio_path, src_words)
    if al_sec is not None:
        results.append({"index": i, "al_seconds": al_sec})
        print(f"Index {i}: AL_seconds = {al_sec:.3f} s")
    else:
        print(f"[Skip] AL_seconds calculation failed: index {i}")

# ======== Summarized results ========
if results:
    df = pd.DataFrame(results)
    print("\n Sentence-by-sentence Average per second Lagging:")
    print(df)
    print(f"\n Global Average Lagging (seconds): {df['al_seconds'].mean():.3f} s")
else:
    print("There are no valid results.")
