import json
import os

from miditoolkit import MidiFile
import pandas as pd
from tqdm import tqdm

from src.utils.midi import align_score_and_performance, ids_to_midi
from src.model.piano_bert import PianoBertConfig

if __name__ == "__main__":
    config = PianoBertConfig()
    data_path = "data/midi/asap-dataset-master/"
    metadata = pd.read_csv(os.path.join(data_path, "metadata.csv"))

    data = []
    for i in tqdm(range(925, len(metadata))):
        score_midi_obj = MidiFile(os.path.join(data_path, metadata["midi_score"][i]))
        performance_midi_obj = MidiFile(os.path.join(data_path, metadata["midi_performance"][i]))
        try:
            x, label = align_score_and_performance(config, score_midi_obj, performance_midi_obj)
        except Exception as e:
            print(e)
            continue
        data.append({"x": x, "label": label})
        ids_to_midi(config, x).dump(f"data/midi/sft_test/scores/{i}.mid")
        ids_to_midi(config, label).dump(f"data/midi/sft_test/labels/{i}.mid")

        with open("data/processed/sft.jsonl", "a") as f:
            #for i in data:
            f.write(json.dumps(data[-1])+"\n")
        print(f"Sample {i} successfully write!")
