'''
Calculate metrics for CA v2's output
'''

import os
import sys
dirof = os.path.dirname
sys.path.insert(0, dirof(dirof(dirof(dirof(os.path.abspath(__file__))))))

from utils_common.utils import *
from [anonymous] import MultiTrack, Bar
from m2m.evaluate import Metric
from tqdm import tqdm


def main():
    # check_lost_files()
    calculate_metrics()


def procedures():
    check_lost_files()
    calculate_metrics()


def check_lost_files():
    in_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/test_normalized'
    out_dir = '/data2/[anonymous]/[anonymous]_data/results_cav2'
    in_files = ls(in_dir)
    out_files = ls(out_dir)

    out_files = [f"Track0{fn.split('t')[1]}" for fn in out_files]
    out_files = set(out_files)

    for fn in in_files:
        if fn not in out_files:
            print(fn)

def calculate_metrics():
    in_dir = '/data2/[anonymous]/Datasets/slakh2100_flac_redux/test_normalized'
    out_dir = '/data2/[anonymous]/[anonymous]_data/results_cav2'
    song_fns = ls(in_dir)
    metric = Metric()
    
    note_f1s = []
    for song_fn in tqdm(song_fns):
        track_name = song_fn.split('.')[0].split('k')[1][1:]
        song_fp_in = jpath(in_dir, song_fn)
        song_fn_out = f'midi_export{track_name}.mid'
        song_fp_out = jpath(out_dir, song_fn_out)

        mt_ref = MultiTrack.from_midi(song_fp_in)
        mt_out = MultiTrack.from_midi(song_fp_out)

        

        for bar_ref, bar_out in zip(mt_ref, mt_out):
            ref_seq = bar_ref.get_drum_content_seq()
            out_seq = bar_out.get_drum_content_seq()

            if len(ref_seq) == 1:
                continue 

            note_f1 = metric.calculate_note_f1_q16(
                out_seq=out_seq,
                tgt_seq=ref_seq,
            )
            note_f1s.append(note_f1)
        
        print(f'Average note F1: {np.mean(note_f1s)}')
        
    print(f'Average note F1: {np.mean(note_f1s)}')

if __name__ == '__main__':
    main()