import os
import re
import json
from tqdm import tqdm
from pydub import AudioSegment
import torchaudio
import torch
import numpy as np

# === File paths ===
SEGMENT_JSONL = "/path/to/timetable.jsonl"
OUTPUT_DIR    = "/path/to/tts_output_fullsentence"
TMP_DIR       = "/path/to/tts_output_segments"

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)

# === Initialize Coqui TTS (German) ===
from TTS.api import TTS
# MODEL_NAME = "facebook/mms-tts-deu"
MODEL_NAME = "tts_models/de/thorsten/tacotron2-DDC"
tts = TTS(MODEL_NAME, progress_bar=False, gpu=False)
SAMPLE_RATE = getattr(getattr(tts, "synthesizer", None), "output_sample_rate", 22050)

# === Load timetable ===
with open(SEGMENT_JSONL, "r", encoding="utf-8") as fin:
    all_segments = [json.loads(line) for line in fin]

# === Resume from breakpoint: Locate the first ungenerated index ===
def find_start_idx(output_dir: str, total: int) -> int:
    existing = {
        int(fn[:-4]) for fn in os.listdir(output_dir)
        if fn.endswith(".wav") and fn[:-4].isdigit()
    }
    for i in range(total):
        if i not in existing:
            return i
    return total

start_idx = find_start_idx(OUTPUT_DIR, len(all_segments))
if start_idx >= len(all_segments):
    print("All generated. No need to continue.")
else:
    print(f"starts from the first ungenerated index:{start_idx}/{len(all_segments)-1}")

# === Text Normalization and "Short Segment Merging" ===
MIN_ALPHA_CHARS = 5  # Segments less than this "letter number" will be merged with subsequent segments
GERMAN_ALPHA_RE = re.compile(r"[A-Za-zÄÖÜäöüß]")

def normalize_de_text(s: str) -> str:
    if not isinstance(s, str):
        return ""
    s = s.strip()
    s = re.sub(r"\s+", " ", s)        
    s = re.sub(r"-$", "", s)            
    return s

def alpha_count(s: str) -> int:
    return len(GERMAN_ALPHA_RE.findall(s))

def make_tts_blocks(seg_list):
    blocks = []
    buf_text = ""
    buf_start = None

    def flush():
        nonlocal buf_text, buf_start, blocks
        if buf_text.strip():
            blocks.append({"text": buf_text.strip(), "start": buf_start})
        buf_text, buf_start = "", None

    for seg in seg_list:
        raw_text = seg.get("text", "")
        start = float(seg.get("start", 0.0) or 0.0)
        text = normalize_de_text(raw_text)

        # Pure punctuation or empty text: Skip the voice directly, but the timeline is still aligned by start (mute by inserting start in the subsequent paragraph)
        if alpha_count(text) == 0:
            continue

        if not buf_text:
            buf_text = text
            buf_start = start
        else:
            buf_text = (buf_text + " " + text).strip()

        # If the length after merging reaches the threshold, output a block immediately
        if alpha_count(buf_text) >= MIN_ALPHA_CHARS:
            flush()

    flush()
    return blocks

# === Synthesis ===
def synth_to_wav(text: str, wav_path: str, sample_rate: int = SAMPLE_RATE):
    """
    Synthesize text using Coqui TTS; Make a "re-cleanup and fallback" retry for the extremely short sentences.
    """
    def _speak(t: str):
        w = tts.tts(t)
        if not isinstance(w, np.ndarray):
            w = np.array(w, dtype=np.float32)
        w = w.astype(np.float32)
        ten = torch.from_numpy(w).unsqueeze(0)  # [1, T]
        torchaudio.save(wav_path, ten, sample_rate)

    try:
        _speak(text)
    except RuntimeError as e:
        # Fallback for Kernel size errors: Clean up again to retain only alphanumeric letters and necessary Spaces
        if "Kernel size" in str(e):
            safe = re.sub(r"[^0-9A-Za-zÄÖÜäöüß ,.;:!?()\-\']", " ", text)
            safe = re.sub(r"\s+", " ", safe).strip().strip("-")
            if alpha_count(safe) >= MIN_ALPHA_CHARS:
                _speak(safe)
            else:
                padded = (safe + " ja").strip()
                _speak(padded)
        else:
            raise

# === Main ===
for idx in tqdm(range(start_idx, len(all_segments))):
    seg_list = all_segments[idx]
    final_path = os.path.join(OUTPUT_DIR, f"{idx}.wav")
    if os.path.exists(final_path):
        continue

    tmp_sentence_dir = os.path.join(TMP_DIR, str(idx))
    os.makedirs(tmp_sentence_dir, exist_ok=True)

    # First perform "short segment merging" to avoid errors in the convolution kernel
    blocks = make_tts_blocks(seg_list)

    # Add a 0.5-second mute at the beginning of the sentence
    full_audio = AudioSegment.silent(duration=500)

    for b_id, block in enumerate(blocks):
        text = block["text"]
        seg_start = float(block["start"] or 0.0)

        if not text.strip():
            continue

        tmp_path = os.path.join(tmp_sentence_dir, f"{b_id}.wav")
        synth_to_wav(text, tmp_path, SAMPLE_RATE)

        seg_audio = AudioSegment.from_wav(tmp_path)

        current_duration = full_audio.duration_seconds
        delay_ms = max(0, int(round((seg_start - current_duration) * 1000)))
        if delay_ms > 0:
            full_audio += AudioSegment.silent(duration=delay_ms)

        full_audio += seg_audio

    full_audio.export(final_path, format="wav")
