# python evaluation/tts.py  \
#   --audio_dir '/path/to/experiments/voiceflow_infer' \
#   --xp_name voiceflow

# python evaluation/tts.py  \
#     --audio_dir '/path/to/experiments/voiceflow_infer' \
#     --libritts_txt_dir '/path/to/data/tts/LibriTTS' \
#     --xp_name voiceflow

from pathlib import Path
import os
import json
import jiwer
import argparse
import string
import numpy as np

from tqdm import tqdm
from whisper_normalizer.english import EnglishTextNormalizer
import torch.nn.functional as F
import torchaudio
import nemo.collections.asr as nemo_asr

from utils.general import read_jsonl_to_mapping

english_normalizer = EnglishTextNormalizer()

# asr model: https://huggingface.co/nvidia/stt_en_conformer_transducer_xlarge


# spkear model: https://huggingface.co/nvidia/speakerverification_en_titanet_large
def get_audio_duration(filepath):
    try:
        metadata = torchaudio.info(filepath)
        num_frames = metadata.num_frames
        sample_rate = metadata.sample_rate
        duration = num_frames / sample_rate
        return duration
    except Exception as e:
        print(f"Skipping file {filepath}, reason: {e}")
        return 0.0


def load_asr_model(model_name, lang='en', ckpt_dir=""):
    if model_name == "whisper":  # requires numpy==2.2
        if lang == "zh":
            from funasr import AutoModel
            model = AutoModel(
                model=os.path.join(ckpt_dir, "paraformer-zh"),
                disable_update=True,
            )  # following seed-tts setting
        elif lang == "en":
            from faster_whisper import WhisperModel
            model_size = "large-v3" if ckpt_dir == "" else ckpt_dir
            model = WhisperModel(
                model_size, device="cuda", compute_type="float16"
            )
    elif model_name == "nemo":  # requires numpy<2.0
        import nemo.collections.asr as nemo_asr
        model = nemo_asr.models.EncDecRNNTBPEModel.from_pretrained(
            "nvidia/stt_en_conformer_transducer_xlarge"
        )

    return model


def get_all_audio_files(directory, exts={".wav"}):
    audio_files = []
    for root, _, files in os.walk(directory):
        for f in files:
            if os.path.splitext(f)[-1].lower() in exts:
                audio_files.append(os.path.join(root, f))
    return audio_files


def get_libritts_text_mapping(txt_root):
    """
    Traverse all normalized.txt files under LibriTTS test-clean and build a mapping.
    Example: '5683_32865_000001_000000' -> 'reference text'
    """
    mapping = {}
    for root, _, files in os.walk(txt_root):
        for f in files:
            if f.endswith(".normalized.txt"):
                utt_id = f.replace(".normalized.txt", "")
                path = os.path.join(root, f)
                with open(path, "r") as t:
                    content = t.read().strip()
                    mapping[utt_id] = content
    return mapping


def get_reference_audio_mapping(txt_root):
    """
    Locate reference audio files in LibriTTS test-clean.
    Assumes .wav reference files exist in the same directory as normalized.txt.
    Returns: {'5683_32865_000001_000000': '/path/to/ref.wav'}
    """
    mapping = {}
    for root, _, files in os.walk(txt_root):
        for f in files:
            if f.endswith(".wav"):
                utt_id = f.replace(".wav", "")
                ref_wav = os.path.join(root, f)
                if os.path.exists(ref_wav):
                    mapping[utt_id] = ref_wav
    return mapping


def extract_utt_id(filepath):
    """Extract utterance ID like '5683_32865_000001_000000' from file path"""
    filename = os.path.basename(filepath)
    utt_id = os.path.splitext(filename)[0]
    while utt_id[0] == '0':
        utt_id = utt_id[1:]
    return utt_id


def evaluate_tts(
    audio_dir_or_jsonl: str, libritts_txt_dir: str, ref_transcript_path: str,
    ref_audio_path: str, output_path: str, model_name: str, xp_name: str
):
    print("Loading ASR model...")
    asr_model = load_asr_model(model_name)

    print("Loading speaker embedding model...")
    speaker_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
        "nvidia/speakerverification_en_titanet_large"
    )

    print("Building reference transcript map...")
    if not Path(ref_transcript_path).exists():
        ref_map = get_libritts_text_mapping(libritts_txt_dir)
        with open(ref_transcript_path, 'w') as f:
            json.dump(ref_map, f, indent=2)
    else:
        with open(ref_transcript_path, 'r') as f:
            ref_map = json.load(f)

    print("Building reference audio map...")
    if not Path(ref_audio_path).exists():
        ref_audio_map = get_reference_audio_mapping(libritts_txt_dir)
        with open(ref_audio_path, 'w') as f:
            json.dump(ref_audio_map, f, indent=2)
    else:
        with open(ref_audio_path, 'r') as f:
            ref_audio_map = json.load(f)

    if Path(audio_dir_or_jsonl).is_dir():
        audio_files = get_all_audio_files(audio_dir_or_jsonl)
    elif audio_dir_or_jsonl.endswith(".jsonl"):
        aid_to_audio = read_jsonl_to_mapping(
            audio_dir_or_jsonl, "audio_id", "audio"
        )
        audio_files = aid_to_audio.values()

    print(f"Found {len(audio_files)} audio files")

    total_word_errors = 0.0
    total_words = 0
    similarities = []
    results = []

    for wav_path in tqdm(audio_files, desc="Evaluating audio files"):
        utt_id = extract_utt_id(wav_path)
        if utt_id not in ref_map:
            print(f"Skipping {wav_path}, reference text not found")
            continue
        if utt_id not in ref_audio_map:
            print(f"Skipping {wav_path}, reference audio not found")
            continue
        wav_dur = get_audio_duration(wav_path)
        if wav_dur < 0.5:
            print(f"Skipping {wav_path}, duration less than 0.5s")
            continue

        reference = ref_map[utt_id]
        try:
            if model_name == 'whisper':
                segments, _ = asr_model.transcribe(
                    wav_path, beam_size=5, language="en"
                )
                pred_text = ""
                for segment in segments:
                    pred_text += " " + segment.text
            elif model_name == 'nemo':
                pred_text = asr_model.transcribe([wav_path],
                                                 verbose=False)[0].text.strip()

            for x in string.punctuation:
                pred_text = pred_text.replace(x, "")
                reference = reference.replace(x, "")

            pred_text = pred_text.replace("  ", " ")
            reference = reference.replace("  ", " ")
            pred_text = english_normalizer(pred_text)
            reference = english_normalizer(reference)

            reference_words = reference.lower().split()
            ref_len = len(reference_words)
            wer = jiwer.wer(reference.lower(), pred_text.lower())

            total_word_errors += wer * ref_len
            total_words += ref_len

            pred_emb = speaker_model.get_embedding(wav_path)
            ref_emb = speaker_model.get_embedding(ref_audio_map[utt_id])
            sim = F.cosine_similarity(pred_emb, ref_emb)[0].item()
            similarities.append(sim)

            line = {
                "utt_id": utt_id,
                "audio": wav_path,
                "ref_audio": ref_map[utt_id],
                "reference": reference,
                "prediction": pred_text,
                "WER": wer,
                "SIM": float(sim),
            }
            results.append(line)

        except Exception as e:
            print(f"Skipping file {wav_path}, reason: {e}")

    avg_wer = total_word_errors / total_words if total_words > 0 else 0.0
    avg_sim = np.mean(similarities) if similarities else 0.0

    if output_path == '':
        output_path = Path(
            './evaluation/result'
        ) / f'tts_results_{model_name}_{xp_name}.jsonl'
    else:
        output_path = Path(output_path)

    output_path.parent.mkdir(exist_ok=True, parents=True)
    with open(output_path, 'w') as f:
        for r in results:
            json.dump(r, f)
            f.write('\n')
        json.dump({
            "average_wer": avg_wer,
            "average_cosine_similarity": avg_sim
        }, f)
        f.write('\n')

    print(
        f"\n✅ Evaluation done: {len(results)} samples, average WER (weighted): {avg_wer}"
    )
    print(
        f"\n✅ Evaluation done: {len(similarities)} samples, average Cosine similarity: {avg_sim}"
    )
    print(f"Results saved to {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--audio_dir_or_jsonl',
        type=str,
        required=True,
        help=
        'Directory or JSONL filecontaining TTS-generated audio files (recursive search)'
    )
    parser.add_argument(
        '--libritts_txt_dir',
        type=str,
        default='/path/to/data/tts/LibriTTS/test-clean',
        help='Directory for LibriTTS test-clean files'
    )
    parser.add_argument(
        '--ref_transcript_path',
        type=str,
        default='./data/libritts/voiceflow_test/ref_transcript.json',
        help='Path to reference transcript JSON file'
    )
    parser.add_argument(
        '--ref_audio_path',
        type=str,
        default='./data/libritts/voiceflow_test/ref_audio.json',
        help='Path to reference audio JSON file'
    )
    parser.add_argument(
        '--output_path',
        type=str,
        default='',
        help='Output path for evaluation results'
    )
    parser.add_argument(
        '--model_name',
        type=str,
        default='nemo',
        help='Name of ASR model to use'
    )
    parser.add_argument(
        '--xp_name', type=str, default='', help='Experiment name'
    )

    args = parser.parse_args()

    evaluate_tts(
        args.audio_dir_or_jsonl, args.libritts_txt_dir,
        args.ref_transcript_path, args.ref_audio_path, args.output_path,
        args.model_name, args.xp_name
    )
