'''
Calculate objective metric for Moyu's model
'''

import os
import sys
dirof = os.path.dirname
sys.path.insert(0, dirof(dirof(dirof(__file__))))

from [anonymous] import MultiTrack, Bar  
from evaluations.piano_evaluator import song_level_pitch_wer
from utils_common.utils import ls, save_json
from tqdm import tqdm


def main():
    test_data_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/original/test'
    out_dir = '/data2/[anonymous]/[anonymous]_data/out_moyu/infer_test_set'
    save_fp = '/data2/[anonymous]/[anonymous]_data/out_moyu/infer_test_set_wer.json'
    song_names = ls(test_data_dir)
    all_wers = {}
    for song_name in tqdm(song_names):
        out_midi_fp = f'{out_dir}/{song_name}.mid'
        ref_midi_fp = f'{test_data_dir}/{song_name}/all_src.mid'
        if not os.path.exists(out_midi_fp):
            print(f'{out_midi_fp} does not exist. Skip...')
            continue
        pitch_wer = song_level_pitch_wer(out_midi_fp, ref_midi_fp)
        all_wers[song_name] = pitch_wer

    wer = sum(all_wers.values()) / len(all_wers)
    res = {
        'wer': wer,
        'all_wers': all_wers,
    }
    save_json(res, save_fp)


if __name__ == '__main__':
    main()