# Usage: python evaluation/svs.py \
#   --ref_audio_jsonl data/m4singer/test/audio.jsonl \
#   --gen_audio_dir "/path/to/experiments/inference/singing_voice_synthesis" \
#     --xp_name "double"
# where `gen_dir` can be generated by `generate_postprocess/make_audio_jsonl.py`

import argparse
from collections import defaultdict
from multiprocessing import Pool
from functools import partial
from math import log2

import numpy as np
import pysptk
import soundfile as sf
import librosa
# from fastdtw import fastdtw
from scipy import spatial
from tqdm import tqdm
import pyworld as pw
import json

from utils.general import read_jsonl_to_mapping, audio_dir_to_mapping

N_FFT = 1024
N_SHIFT = 256
F0_MIN = 40
F0_MAX = 800
MCEP_DIM = None
MCEP_ALPHA = None


def _get_best_mcep_params(fs: int) -> tuple[int, float]:
    if fs == 16000:
        return 23, 0.42
    elif fs == 22050:
        return 34, 0.45
    elif fs == 24000:
        return 34, 0.46
    elif fs == 44100:
        return 39, 0.53
    elif fs == 48000:
        return 39, 0.55
    else:
        raise ValueError(f"Not found the setting for {fs}.")


def sptk_extract(
    x: np.ndarray,
    fs: int,
    n_fft: int = 512,
    n_shift: int = 256,
    mcep_dim: int = 25,
    mcep_alpha: float = 0.41,
    is_padding: bool = False,
) -> np.ndarray:
    """Extract SPTK-based mel-cepstrum.

    Args:
        x (ndarray): 1D waveform array.
        fs (int): Sampling rate
        n_fft (int): FFT length in point (default=512).
        n_shift (int): Shift length in point (default=256).
        mcep_dim (int): Dimension of mel-cepstrum (default=25).
        mcep_alpha (float): All pass filter coefficient (default=0.41).
        is_padding (bool): Whether to pad the end of signal (default=False).

    Returns:
        ndarray: Mel-cepstrum with the size (N, n_fft).

    """
    # perform padding
    if is_padding:
        n_pad = n_fft - (len(x) - n_fft) % n_shift
        x = np.pad(x, (0, n_pad), "reflect")

    # get number of frames
    n_frame = (len(x) - n_fft) // n_shift + 1

    # get window function
    win = pysptk.sptk.hamming(n_fft)

    # check mcep and alpha
    if mcep_dim is None or mcep_alpha is None:
        mcep_dim, mcep_alpha = _get_best_mcep_params(fs)

    # calculate spectrogram
    mcep = [
        pysptk.mcep(
            x[n_shift * i:n_shift * i + n_fft] * win,
            mcep_dim,
            mcep_alpha,
            eps=1e-6,
            etype=1,
        ) for i in range(n_frame)
    ]

    return np.stack(mcep)


def world_extract(
    x: np.ndarray,
    fs: int,
    f0min: int = 40,
    f0max: int = 800,
    n_fft: int = 512,
    n_shift: int = 256,
    mcep_dim: int = 25,
    mcep_alpha: float = 0.41,
) -> np.ndarray:
    """Extract World-based acoustic features.

    Args:
        x (ndarray): 1D waveform array.
        fs (int): Minimum f0 value (default=40).
        f0 (int): Maximum f0 value (default=800).
        n_shift (int): Shift length in point (default=256).
        n_fft (int): FFT length in point (default=512).
        n_shift (int): Shift length in point (default=256).
        mcep_dim (int): Dimension of mel-cepstrum (default=25).
        mcep_alpha (float): All pass filter coefficient (default=0.41).

    Returns:
        ndarray: Mel-cepstrum with the size (N, n_fft).
        ndarray: F0 sequence (N,).

    """
    # extract features
    x = x.astype(np.float64)
    f0, time_axis = pw.harvest(
        x,
        fs,
        f0_floor=f0min,
        f0_ceil=f0max,
        frame_period=n_shift / fs * 1000,
    )
    sp = pw.cheaptrick(x, f0, time_axis, fs, fft_size=n_fft)
    if mcep_dim is None or mcep_alpha is None:
        mcep_dim, mcep_alpha = _get_best_mcep_params(fs)
    mcep = pysptk.sp2mc(sp, mcep_dim, mcep_alpha)

    return mcep, f0


def _Hz2Semitone(freq):
    """_Hz2Semitone."""
    A4 = 440
    C0 = A4 * pow(2, -4.75)
    name = ["C", "C#", "D", "D#", "E", "F", "F#", "G", "G#", "A", "A#", "B"]

    if freq == 0:
        return "Sil"  # silence
    else:
        h = round(12 * log2(freq / C0))
        octave = h // 12
        n = h % 12
        return name[n] + "_" + str(octave)


def compute_metrics(entry: tuple[str, str, str], args):
    audio_id, ref_audio, gen_audio = entry

    # Load wav files as int16
    gen_x, gen_fs = sf.read(gen_audio, dtype="int16")
    gt_x, gt_fs = sf.read(ref_audio, dtype="int16")

    fs = gen_fs
    if gen_fs != gt_fs:
        # Ensure the reference is resampled to the generated file's rate
        gt_x = librosa.resample(
            gt_x.astype(np.float32), orig_sr=gt_fs, target_sr=gen_fs
        )

    ##########################################################################
    # Calculate MCD
    ##########################################################################
    # Extract MCEP features
    gen_mcep = sptk_extract(
        x=gen_x,
        fs=fs,
        n_fft=N_FFT,
        n_shift=N_SHIFT,
        mcep_dim=MCEP_DIM,
        mcep_alpha=MCEP_ALPHA,
    )
    gt_mcep = sptk_extract(
        x=gt_x,
        fs=fs,
        n_fft=N_FFT,
        n_shift=N_SHIFT,
        mcep_dim=MCEP_DIM,
        mcep_alpha=MCEP_ALPHA,
    )

    # Dynamic Time Warping (DTW)
    # _, path = fastdtw(gen_mcep, gt_mcep, dist=spatial.distance.euclidean)
    _, path = librosa.sequence.dtw(
        X=gen_mcep.T, Y=gt_mcep.T, metric='euclidean'
    )
    twf = np.array(path).T
    gen_mcep_dtw = gen_mcep[twf[0]]
    gt_mcep_dtw = gt_mcep[twf[1]]

    # Mel-Cepstral Distortion (MCD)
    diff2sum = np.sum((gen_mcep_dtw - gt_mcep_dtw)**2, axis=1)
    mcd = np.mean(10.0 / np.log(10.0) * np.sqrt(2 * diff2sum))

    ##########################################################################
    # Calculate F0
    ##########################################################################
    gen_mcep, gen_f0 = world_extract(
        x=gen_x,
        fs=fs,
        f0min=F0_MIN,
        f0max=F0_MAX,
        n_fft=N_FFT,
        n_shift=N_SHIFT,
        mcep_dim=MCEP_DIM,
        mcep_alpha=MCEP_ALPHA,
    )
    gt_mcep, gt_f0 = world_extract(
        x=gt_x,
        fs=fs,
        f0min=F0_MIN,
        f0max=F0_MAX,
        n_fft=N_FFT,
        n_shift=N_SHIFT,
        mcep_dim=MCEP_DIM,
        mcep_alpha=MCEP_ALPHA,
    )

    # Dynamic Time Warping (DTW)
    _, path = librosa.sequence.dtw(
        X=gen_mcep.T, Y=gt_mcep.T, metric='euclidean'
    )
    twf = np.array(path).T
    gen_f0_dtw = gen_f0[twf[0]]
    gt_f0_dtw = gt_f0[twf[1]]

    # Get voiced part
    nonzero_idxs = np.where((gen_f0_dtw != 0) & (gt_f0_dtw != 0))[0]
    gen_f0_dtw_voiced = np.log(gen_f0_dtw[nonzero_idxs])
    gt_f0_dtw_voiced = np.log(gt_f0_dtw[nonzero_idxs])

    # log F0 RMSE
    log_f0_rmse = np.sqrt(np.mean((gen_f0_dtw_voiced - gt_f0_dtw_voiced)**2))

    ##########################################################################
    # Calculate Semitone Accuracy
    ##########################################################################
    gt_semitone = np.array([_Hz2Semitone(_f0) for _f0 in gt_f0_dtw])
    gen_semitone = np.array([_Hz2Semitone(_f0) for _f0 in gen_f0_dtw])
    semitone_acc = float((gt_semitone == gen_semitone).sum()
                        ) / len(gt_semitone)

    return audio_id, {
        "mcd": mcd,
        "f0": log_f0_rmse,
        "semitone": semitone_acc,
    }


def evaluate(args):
    """Calculate MCD, F0, and Semitone Accuracy."""
    ref_aid_to_audios = read_jsonl_to_mapping(
        args.ref_audio_jsonl, "audio_id", "audio"
    )

    if args.gen_audio_jsonl is not None:
        gen_aid_to_audios = read_jsonl_to_mapping(
            args.gen_audio_jsonl, "audio_id", "audio"
        )
    elif args.gen_audio_dir is not None:
        gen_aid_to_audios = audio_dir_to_mapping(args.gen_audio_dir, 'svs')

    assert ref_aid_to_audios.keys() == gen_aid_to_audios.keys(
    ), "Reference and generated audio IDs do not match"

    audio_ids = list(ref_aid_to_audios.keys())
    results = defaultdict(dict)
    entries = [(aid, ref_aid_to_audios[aid], gen_aid_to_audios[aid])
               for aid in audio_ids]

    with Pool(processes=args.num_workers) as pool:
        worker = partial(compute_metrics, args=args)
        for audio_id, metrics in tqdm(
            pool.imap(worker, entries),
            total=len(entries),
            desc="Computing metrics"
        ):
            for metric, value in metrics.items():
                results[metric][audio_id] = value

    if args.output_file == '':
        output_path = './evaluation/result/' + '_'.join([
            "svs_results", args.xp_name
        ]) + '.jsonl'
    else:
        output_path = args.output_file

    audio_id2metric = {audio_id: {} for audio_id in audio_ids}
    for metric, values in results.items():
        for audio_id, value in values.items():
            audio_id2metric[audio_id][metric] = value

    with open(output_path, "w") as writer:
        for audio_id, metrics in audio_id2metric.items():
            line = metrics
            line['audio_id'] = audio_id
            json.dump(line, writer, ensure_ascii=False)
            writer.write('\n')
        for metric, values in results.items():
            if metric == "semitone":
                print_msg = f"{metric}: {np.mean(list(values.values())):.3%}"
            else:
                print_msg = f"{metric}: {np.mean(list(values.values())):.5f}"
            print(print_msg)
            print(print_msg, file=writer)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ref_audio_jsonl",
        "-r",
        type=str,
        required=True,
        help="path to reference audio jsonl file"
    )
    parser.add_argument(
        "--gen_audio_dir",
        "-gd",
        type=str,
        help="path to generated audio directory"
    )
    parser.add_argument(
        "--gen_audio_jsonl",
        "-gj",
        type=str,
        help="path to generated audio jsonl file"
    )
    parser.add_argument(
        "--output_file",
        "-o",
        type=str,
        default='',
        help="path to output file"
    )
    parser.add_argument(
        "--num_workers",
        "-c",
        default=16,
        type=int,
        help="number of workers for parallel processing"
    )
    parser.add_argument(
        '--xp_name', type=str, default='', help='experiment name'
    )
    args = parser.parse_args()

    evaluate(args)
