"""prepare for the pkl data
    Author: Joey.Zhu
    Email: joey8273@qq.com
    Date: 2024/12/19
    
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import tokenizer
import glob
import pickle
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import json


remi_tokenizer = tokenizer.RemiPlus()


# Check the paths for your own case
# the theme annotated midi files
MIDI_FILES = "./data/midis/lmd_matched"

# the tokens converted from theme annotated midi files
# MIDI_FILES_PKLs_DIR = "./generated/pop909_midi_pkls"

# the output training data 
OUTPUT_TRAINING_DATA_PKL = "./generated/data_pkl/train_lmd_matched_1024.pkl"
OUTPUT_VALIDATE_DATA_PKL = "./generated/data_pkl/val_lmd_matched_1024.pkl"

useLabel = True

all_mids = []
for root, _, files in os.walk(MIDI_FILES):
    for file in files:
        if file.endswith('.mid'):
            all_mids.append(os.path.join(root, file))

def read_json_keys(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        data = json.load(file)
        return data

label_data = read_json_keys("./data/features.json")


def process_midi_file(midi_file, tokenizer):
    """
    处理单个 MIDI 文件并返回对应的 token 数据或 None
    """
    output_pkl_fp = midi_file.replace(".mid", ".pkl")
    remi_seq = tokenizer.midi2RemiPlus(midi_file)
    if remi_seq is None:
        return None
    ret = tokenizer.preprocessRemiPlus(remi_seq)
    if useLabel:
        features = label_data[midi_file.split("/")[-1].split(".")[0]]
        ret["arousal"] = features["arousal"]
        ret["valence"] = features["valence"]  
        ret["danceability"] = features["danceability"]
        ret["energy"] = features["energy"]
        ret["instrumentalness"] = features["instrumentalness"]
        ret["liveness"] = features["liveness"]
    if ret is not None:
        with open(output_pkl_fp, 'wb') as f:
            pickle.dump(ret, f, protocol=pickle.HIGHEST_PROTOCOL)
        return output_pkl_fp
    return None

processed_files = []
with ProcessPoolExecutor() as executor:
    # 使用 partial 绑定 tokenizer 作为参数传入
    process_fn = partial(process_midi_file, tokenizer=remi_tokenizer)
    future_to_file = {executor.submit(process_fn, midi_file): midi_file for midi_file in all_mids}
    
    for future in as_completed(future_to_file):
        midi_file = future_to_file[future]
        try:
            result = future.result()
            if result:
                processed_files.append(result)
                print(f"Processed: {midi_file}")
            else:
                print(f"Skipped: {midi_file}")
        except Exception as e:
            print(f"Error processing {midi_file}: {e}")

print(f"Finished processing {len(processed_files)} MIDI files.")

# collect all .pkl files and generate .pkl file for training/testing
al_pkls = sorted(glob.glob(os.path.join(MIDI_FILES,'**', '*.pkl'), recursive=True))

total_data = []

for i,fn in enumerate(al_pkls):
    print(">>[{}/{}][train] Now processing {}".format(i+1,len(al_pkls),os.path.split(fn)[-1]))
    with open(fn,"rb") as f:
        data = pickle.load(f)
    src = data["src"]
    
    for i_tgt, tgt in enumerate(data["tgt_segments"]):
        total_data.append({"src":src,"tgt":tgt,
        "tgt_chord_msk":data["tgt_segments_chord_binary_msk"][i_tgt],
        "arousal":data["arousal"],
        "valence":data["valence"],
        "danceability":data["danceability"],
        "energy":data["energy"],
        "instrumentalness":data["instrumentalness"],
        "liveness":data["liveness"]})

# dump to pkl file
split = int(len(total_data) * 0.9)
pickle.dump(total_data[:split], open(OUTPUT_TRAINING_DATA_PKL, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

# dump to pkl file for validating
pickle.dump(total_data[split:], open(OUTPUT_VALIDATE_DATA_PKL, 'wb'), protocol=pickle.HIGHEST_PROTOCOL)

print(f"total data length:{len(total_data)}, training data length: {len(total_data[:split])}, validate data length:{len(total_data[split:])}")






