# Usage: python evaluation/sr.py \
#   --ref_audio_jsonl /path/to/data/moisesdb/test/audio.jsonl \
#   --gen_audio_jsonl /path/to/sr_files.jsonl \
#   --output_file /path/to/evaluation/result/sr.jsonl \
#   -c 16
# where `gen_audio_jsonl` can be generated by `generate_postprocess/make_audio_jsonl.py`

# python evaluation/sr.py  \
#   --audio_dir /path/to/experiments/dummy/inference/audio_super_resolution' \
#   --xp_name dummy
import json
from pathlib import Path
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial

from tqdm import tqdm
from ssr_eval.metrics import AudioMetrics
import librosa
import numpy as np

from utils.general import read_jsonl_to_mapping, audio_dir_to_mapping


# Define worker function to be executed in parallel
def process_audio(audio_id, base_audio_dir, ref_dict, gen_dict, evaluator):
    if audio_id not in ref_dict:
        print(
            f"[WARN] audio_id {audio_id} not found in reference dict, skipping..."
        )
        return None

    ref_path = base_audio_dir / ref_dict[audio_id]
    gen_path = gen_dict[audio_id]

    try:
        # Load reference and generated audio
        ref, _ = librosa.core.load(ref_path, sr=24000)
        gen, _ = librosa.core.load(gen_path, sr=24000)

        # Ensure generated audio matches the reference length
        if len(gen) < len(ref):
            gen = np.pad(gen, (0, len(ref) - len(gen)), "constant")
        else:
            gen = gen[:len(ref)]

        # Compute evaluation metrics
        metrics = evaluator.evaluation(gen, ref, gen_path)
        return metrics["lsd"]
    except Exception as e:
        print(f"[ERROR] Failed on {audio_id}: {e}")
        return None


def evaluate(args):
    ref_dict = read_jsonl_to_mapping(args.ref_audio_jsonl, "audio_id", "audio")
    if args.gen_audio_jsonl is not None:
        gen_dict = read_jsonl_to_mapping(
            args.gen_audio_jsonl, "audio_id", "audio"
        )
    elif args.gen_audio_dir is not None:
        gen_dict = audio_dir_to_mapping(
            args.gen_audio_dir, "audio_id", "audio"
        )

    base_audio_dir = args.ref_audio_base_dir
    if base_audio_dir is not None:
        base_audio_dir = Path(base_audio_dir)

    evaluator = AudioMetrics(rate=24000)  # or use your  args.sr，

    total_lsd = 0.0
    count = 0

    worker = partial(
        process_audio,
        base_audio_dir=base_audio_dir,
        ref_dict=ref_dict,
        gen_dict=gen_dict,
        evaluator=evaluator
    )

    with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
        futures = {
            executor.submit(worker, audio_id): audio_id
            for audio_id in gen_dict.keys()
        }

        for future in tqdm(as_completed(futures), total=len(futures)):
            result = future.result()
            if result is not None:
                total_lsd += result
                count += 1

        if count == 0:
            print("not run eval")
            return

    avg_lsd = total_lsd / count
    print(f"\n eval  {count}  audio samples")
    print(f" mean LSD: {avg_lsd:.6f}")

    # 输出 JSON 文件
    output_path = Path(args.output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump({"avg_lsd": avg_lsd, "num_samples": count}, f, indent=2)
        f.write("\n")
    print(f"result save to : {output_path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ref_audio_jsonl",
        "-r",
        type=str,
        required=True,
        help="path to reference audio jsonl file"
    )
    parser.add_argument(
        "--ref_audio_base_dir",
        "-rb",
        type=str,
        help="path to reference audio base directory"
    )
    parser.add_argument(
        "--gen_audio_dir",
        "-gd",
        type=str,
        help="path to generated audio directory"
    )
    parser.add_argument(
        "--gen_audio_jsonl",
        "-gj",
        type=str,
        help="path to generated audio jsonl file"
    )

    parser.add_argument(
        "--output_file",
        "-o",
        type=str,
        required=True,
        help="path to output file"
    )
    parser.add_argument(
        "--num_workers",
        "-c",
        default=4,
        type=int,
        help="number of workers for parallel processing (not used yet)"
    )

    args = parser.parse_args()
    evaluate(args)
