import os
import sys
import json
from tqdm import tqdm
from pydub import AudioSegment
import torchaudio

# === CosyVoice repo path (edit as needed) ===
COSYVOICE_ROOT = "/path/to/CosyVoice"
sys.path.append(COSYVOICE_ROOT)
sys.path.append(os.path.join(COSYVOICE_ROOT, "third_party/Matcha-TTS"))

from cosyvoice.cli.cosyvoice import CosyVoice

# === File paths (edit as needed) ===
SEGMENTS_JSONL = "data/segment_timetable.jsonl"            # JSONL: each line is a list of {"text": str, "start": float}
MODEL_PATH     = "checkpoints/CosyVoice-300M-SFT"          # CosyVoice model dir
OUTPUT_DIR     = "outputs/sentences"                       # final <idx>.wav
TMP_DIR        = "outputs/tmp"                             # intermediate chunks

# === Synthesis settings ===
VOICE = "中文女"        # change to a valid CosyVoice speaker tag for your model
PREPEND_SILENCE_MS = 500   # leading silence for each sentence

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)

# === Initialize model ===
cosyvoice = CosyVoice(MODEL_PATH, load_jit=True, fp16=True)

# === Load segment timetable ===
with open(SEGMENTS_JSONL, "r", encoding="utf-8") as fin:
    all_segments = [json.loads(line) for line in fin if line.strip()]

def find_start_idx(output_dir: str, total: int) -> int:
    """Resume from the first index whose <idx>.wav does not exist."""
    existing = {
        int(fn[:-4]) for fn in os.listdir(output_dir)
        if fn.endswith(".wav") and fn[:-4].isdigit()
    }
    for i in range(total):
        if i not in existing:
            return i
    return total

n_total = len(all_segments)
start_idx = find_start_idx(OUTPUT_DIR, n_total)

if start_idx >= n_total:
    print("All items already generated.")
    sys.exit(0)

print(f"Start from first missing index: {start_idx}/{n_total-1}")

# === Main synthesis loop ===
for idx in tqdm(range(start_idx, n_total), desc="Synthesize"):
    segments = all_segments[idx]  # list of {"text": str, "start": float}
    final_path = os.path.join(OUTPUT_DIR, f"{idx}.wav")
    if os.path.exists(final_path):
        continue  # double-check skip

    tmp_sentence_dir = os.path.join(TMP_DIR, str(idx))
    os.makedirs(tmp_sentence_dir, exist_ok=True)

    # start with a short leading silence
    full_audio = AudioSegment.silent(duration=PREPEND_SILENCE_MS)

    for seg_id, seg in enumerate(segments):
        text = seg.get("text", "")
        seg_start = seg.get("start", None)  # seconds

        if not text:
            continue

        # TTS
        for i, out in enumerate(cosyvoice.inference_sft(text, VOICE, stream=False)):
            tmp_path = os.path.join(tmp_sentence_dir, f"{seg_id}_{i}.wav")
            torchaudio.save(tmp_path, out["tts_speech"], cosyvoice.sample_rate)

            seg_audio = AudioSegment.from_wav(tmp_path)

            # align by start time on the global timeline
            current_duration = full_audio.duration_seconds
            if seg_start is None:
                seg_start = current_duration  # append immediately if no timestamp

            delay_ms = max(0, int((seg_start - current_duration) * 1000))
            if delay_ms:
                full_audio += AudioSegment.silent(duration=delay_ms)
            full_audio += seg_audio

    # export combined audio for this sentence/index
    full_audio.export(final_path, format="wav")
