
"""
this is the metrics of chord progression accuracy
only need the conditional generated midi
"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
import os
from preprocess.tokenizer import RemiPlus
tokenizer = RemiPlus()

GEN_MIDI_FILES = "./generated/midis/pop909"

def get_midi_files(file_path):
    mids = []
    for root, _, files in os.walk(file_path):
        for file in files:
            if file.endswith('.mid'):
                mids.append(os.path.join(root, file))
    return mids


def get_accuracy(mids):
    up, down = 0, 0
    for mid_file in mids:
        remi_seq = tokenizer.midi2RemiPlus(mid_file)
        chords = tokenizer.getChords(remi_seq.events)
        cp = tokenizer.getChordProgression(chords)
        if cp is None:
            continue
        cp = set(cp)
        # todo: 优化
        cp_pre = mid_file.split('/')[-1].split('-')[1:-1]
        cp_pre = set(cp_pre)
        intersetion = cp.intersection(cp_pre)
        ratio = len(intersetion) / max(len(cp), len(cp_pre))
        if ratio > 0.5:
            print(mid_file)
            up += 1
        down += 1
    return up / down

if __name__ == '__main__':
    # 1. load mid files
    mids = get_midi_files(GEN_MIDI_FILES)
    res = get_accuracy(mids)
    print(res)

