"""
WhisperX: transcribe & align each audio to word-level timestamps,
then merge per-file JSONs into a single JSONL sorted by 'sent_(N)'.

Usage:
  1) Set AUDIO_DIR, OUTPUT_DIR, and OUTPUT_JSONL below.
  2) Run: python script.py

Notes:
  - Audio files are expected to be named like 'sent_1.wav', 'sent_2.wav', ...
  - Per-file JSONs will be 'sent_1.json', 'sent_2.json', ... in OUTPUT_DIR.
  - The final JSONL will be sorted by the numeric N from 'sent_(N).json'.
"""

import re
import json
from pathlib import Path

import whisperx

AUDIO_DIR = Path("/path/to/audio_dir")                # e.g., /Users/you/.../gold_eval
OUTPUT_DIR = Path("/path/to/output_json_dir")         # e.g., /Users/you/.../whisperx_output_eval
OUTPUT_JSONL = Path("/path/to/merged.jsonl")          # e.g., /Users/you/.../whisperx_merged.jsonl

MODEL_NAME = "large-v2"
DEVICE = "cpu"
COMPUTE_TYPE = "int8"
LANGUAGE = "en"
BATCH_SIZE = 8

# File pattern for merging (expects 'sent_(N).json')
MERGE_PATTERN = re.compile(r"^sent_(\d+)\.json$", flags=re.IGNORECASE)


def list_audio_files(folder: Path):
    """Return .wav files in alphabetical order (case-insensitive)."""
    return sorted([p for p in folder.iterdir() if p.is_file() and p.suffix.lower() == ".wav"])


def transcribe_and_align_all():
    """Transcribe and align each .wav file, saving word-level timestamps as JSON."""
    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Load models once (keeps behavior close but more efficient than reloading per file).
    print(f"[info] Loading ASR model: {MODEL_NAME} on {DEVICE} ({COMPUTE_TYPE})")
    asr_model = whisperx.load_model(MODEL_NAME, device=DEVICE, compute_type=COMPUTE_TYPE)

    print(f"[info] Loading alignment model for language_code='{LANGUAGE}' on {DEVICE}")
    align_model, metadata = whisperx.load_align_model(language_code=LANGUAGE, device=DEVICE)

    audio_files = list_audio_files(AUDIO_DIR)
    if not audio_files:
        print(f"[warn] No .wav files found in: {AUDIO_DIR}")
        return

    for wav_path in audio_files:
        out_json = OUTPUT_DIR / (wav_path.stem + ".json")
        print(f"[run] Transcribe: {wav_path.name}")
        result = asr_model.transcribe(str(wav_path), batch_size=BATCH_SIZE, language=LANGUAGE)

        print(f"[run] Align: {wav_path.name}")
        aligned = whisperx.align(
            result.get("segments", []),
            align_model,
            metadata,
            str(wav_path),
            device=DEVICE,
        )

        word_segments = aligned.get("word_segments", [])
        out_json.write_text(json.dumps(word_segments, ensure_ascii=False, indent=2), encoding="utf-8")
        print(f"[ok] Saved word-level timestamps -> {out_json}")


def merge_jsons_to_jsonl():
    """
    Merge JSON files in OUTPUT_DIR into OUTPUT_JSONL,
    sorted by the integer N captured from 'sent_(N).json'.
    Each JSONL line structure:
      {
        "sent_id": N,
        "source": "sent_N",
        "word_segments": [...]
      }
    """
    candidates = []
    for p in OUTPUT_DIR.iterdir():
        if p.is_file():
            m = MERGE_PATTERN.match(p.name)
            if m:
                try:
                    idx = int(m.group(1))
                    candidates.append((idx, p))
                except ValueError:
                    pass

    if not candidates:
        print(f"[warn] No JSON files matching 'sent_(N).json' found in: {OUTPUT_DIR}")
        return

    # Sort by numeric ID
    candidates.sort(key=lambda x: x[0])

    # Gentle sanity checks
    first_id = candidates[0][0]
    if first_id != 1:
        print(f"[warn] First ID is {first_id}, expected 1. Proceeding anyway.")
    # Detect gaps
    gaps = []
    for i in range(len(candidates) - 1):
        if candidates[i + 1][0] != candidates[i][0] + 1:
            gaps.append((candidates[i][0], candidates[i + 1][0]))
    if gaps:
        print(f"[warn] Detected non-contiguous IDs: {gaps}. Proceeding with sorted merge.")

    # Write JSONL
    written = 0
    with OUTPUT_JSONL.open("w", encoding="utf-8") as fw:
        for idx, path in candidates:
            try:
                data = json.loads(path.read_text(encoding="utf-8"))
            except Exception as e:
                print(f"[warn] Skipping unreadable JSON: {path} ({e})")
                continue

            record = {
                "sent_id": idx,
                "source": path.stem,
                "word_segments": data,
            }
            fw.write(json.dumps(record, ensure_ascii=False) + "\n")
            written += 1

    print(f"[ok] Merged {written} files -> {OUTPUT_JSONL}")


def main():
    # Step 1: Transcribe & align (word-level timestamps)
    transcribe_and_align_all()

    # Step 2: Merge per-file JSONs into a single JSONL in 'sent_(N)' order
    merge_jsons_to_jsonl()


if __name__ == "__main__":
    main()
