import pretty_midi
import numpy as np
from collections import Counter
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
dataset = 'emoTV' # pop909 or lmd_full
files_dir = f'./generated/midis/{dataset}'


def get_midi_files(file_path):
    mids = []
    for root, _, files in os.walk(file_path):
        for file in files:
            if file.endswith('.mid'):
                mids.append(os.path.join(root, file))
    return mids

def compute_dche(midis):
    total_entroy, total_len = 0, 0
    for midi_file in midis:
        midi_data = pretty_midi.PrettyMIDI(midi_file)
        all_notes = []

        # 获取所有音符信息
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                all_notes.append(note)

        # 计算音符时值类别
        note_durations = [note.end - note.start for note in all_notes]
        duration_classes = [int(d * 100) for d in note_durations]  # 转换为整数（比如 0.25s -> 25）

        # 计算时值类别的频率分布
        duration_counter = Counter(duration_classes)
        total_notes = len(duration_classes)
        
        # 计算熵
        entropy = 0
        for count in duration_counter.values():
            p = count / total_notes
            entropy -= p * np.log2(p)
        total_entroy += entropy
        total_len += 1
    return total_entroy / total_len


# 计算 Average Pitch Interval (API)
def compute_api(midi_file):
    total_api, total_len = 0, 0
    for midi_file in midis:
        midi_data = pretty_midi.PrettyMIDI(midi_file)
        all_notes = []

        # 获取所有音符信息
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                all_notes.append(note)

        # 计算音符之间的音高间隔
        pitch_intervals = []
        for i in range(1, len(all_notes)):
            pitch_intervals.append(abs(all_notes[i].pitch - all_notes[i-1].pitch))

        # 计算音高间隔的平均值
        avg_pitch_interval = np.mean(pitch_intervals) if pitch_intervals else 0
        total_api += avg_pitch_interval
        total_len += 1
    return total_api / total_len 

if __name__ == '__main__':
    midis = get_midi_files(files_dir)
    midis = midis[:2000]
    print(compute_dche(midis))
    print(compute_api(midis))