import json
import os
import random

from miditoolkit import MidiFile
import pandas as pd
from tqdm import tqdm

from src.utils.midi import midi_to_ids
from src.model.piano_bert import PianoBertConfig

if __name__ == "__main__":

    files = []
    for file_path, dir_names, file_names in os.walk("data/midi/giant-piano/midis"):
        for file_name in file_names:
            if not file_name.endswith(".mid"):
                continue
            files.append(os.path.join(file_path, file_name))

    random.shuffle(files)

    config = PianoBertConfig()
    output = []
    cnt = 0
    for file_name in tqdm(files):
        try:
            midi_obj = MidiFile(file_name)
            ids = midi_to_ids(config, midi_obj)
            output.append({"input_ids": ids, "source": file_name})
            if len(output) >= 1000:
                with open(f"data/processed/pretrain_post/{cnt}.jsonl", "w") as f:
                    for j in output:
                        f.write(json.dumps(j)+"\n")
                cnt += 1
                output = []
        except:
            pass
    with open(f"data/processed/pretrain_post/{cnt}.jsonl", "w") as f:
        for j in output:
            f.write(json.dumps(j)+"\n")
        cnt += 1
        output = []
    