"""
"""
import os
import sys
import pretty_midi
import numpy as np
from collections import Counter
import json
import muspy
import math
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
dataset = 'lmd_matched' # pop909 or lmd_full
files_dir = f'./data/midis/{dataset}'


def get_midi_files(file_path):
    mids = []
    for root, _, files in os.walk(file_path):
        for file in files:
            if file.endswith('.mid'):
                mids.append(os.path.join(root, file))
    return mids

'''
just ingore this
'''
def handscript():
    finnal_obj = {}
    midis = get_midi_files(files_dir)
    total_pce = 0
    for midi_file in midis:
        midi_data = pretty_midi.PrettyMIDI(midi_file)

        # 提取所有音符的音高（Pitch Class）
        pitch_classes = []
        for instrument in midi_data.instruments:
            for note in instrument.notes:
                pitch_class = note.pitch % 12  # 计算 Pitch Class，音高取模12
                pitch_classes.append(pitch_class)

        # 统计音高类别的出现次数
        pitch_class_counts = Counter(pitch_classes)

        # 计算每个音高类别的频率
        total_notes = len(pitch_classes)
        pitch_class_probabilities = {k: v / total_notes for k, v in pitch_class_counts.items()}

        # 计算熵（Pitch Class Entropy）
        entropy = -sum(p * np.log2(p) for p in pitch_class_probabilities.values() if p > 0)
        finnal_obj[midi_file] = {
            'pitch_class_probabilities': pitch_class_probabilities,
            'entropy': entropy
        }
        print(f'Pitch Class Entropy: {entropy}')
        total_pce += entropy
    finnal_pce = total_pce / len(midis)
    finnal_obj["entropy"] = finnal_pce
    print(f"avg pce is:{finnal_pce}")
    with open('./generated/metrics/lmd/pce.json', 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)


def picth_class_entropy(midis, generated_file=f'./generated/metrics/{dataset}/pce.json'):
    finnal_obj = {}
    avg_pce, low_pce, high_pce = 0, 4, 0
    number = len(midis)
    for midi_file in midis:
        try:
            music = muspy.read(midi_file)
        except Exception as e:
            print(f"Error reading {midi_file}: {e}")
            number -= 1
            continue
        pce = muspy.pitch_class_entropy(music)
        if math.isnan(pce):
            number -= 1
            continue
        avg_pce += pce
        finnal_obj[midi_file] = {
            'entropy': pce
        }
        low_pce, high_pce = min(low_pce, pce), max(high_pce, pce)
    avg_pce = avg_pce / number
    finnal_obj['entropy'] = avg_pce
    finnal_obj['low_pce'] = low_pce
    finnal_obj['high_pce'] = high_pce
    with open(generated_file, 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)
    print(f"metric [pitch class entropy] finished!, {dataset}'s avg entropy is {avg_pce}")


"""
音阶一致性ratio
"""
def scale_consistency(midis, generated_file=f'./generated/metrics/{dataset}/scale_consistency.json'):
    
    finnal_obj = {}
    avg_scale_consistency,low_scale_consistency, high_scale_consistency = 0, 1, 0
    number = len(midis)
    for midi_file in midis:
        try:
            music = muspy.read(midi_file)
        except Exception as e:
            print(f"Error reading {midi_file}: {e}")
            number -= 1
            continue
        scale_consistency = muspy.scale_consistency(music)
        if math.isnan(scale_consistency):
            number -= 1
            continue
        avg_scale_consistency += scale_consistency
        finnal_obj[midi_file] = {
            'scale_consistency': scale_consistency
        }
        low_scale_consistency, high_scale_consistency = min(low_scale_consistency, scale_consistency), max(high_scale_consistency, scale_consistency)
    avg_scale_consistency = avg_scale_consistency / number
    finnal_obj['avg_scale_consistency'] = avg_scale_consistency
    finnal_obj['low_scale_consistency'] = low_scale_consistency
    finnal_obj['high_scale_consistency'] = high_scale_consistency
    with open(generated_file, 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)
    print(f"metric [scale_consistency] finished!, {dataset}'s avg scale consistency is {avg_scale_consistency}")


"""
muspy.empty_beat_rate
"""


def empty_beat_rate(midis, generated_file = f'./generated/metrics/{dataset}/empty_beat_rate.json'):
    
    finnal_obj = {}
    avg_empty_beat_rate,low_empty_beat_rate, high_empty_beat_rate = 0, 1, 0
    number = len(midis)
    for midi_file in midis:
        try:
            music = muspy.read(midi_file)
        except Exception as e:
            print(f"Error reading {midi_file}: {e}")
            number -= 1
            continue
        if len(music.tracks) == 0:
            number -= 1
            continue
        empty_beat_rate = muspy.empty_beat_rate(music)
        if math.isnan(empty_beat_rate):
            number -= 1
            continue
        avg_empty_beat_rate += empty_beat_rate
        finnal_obj[midi_file] = {
            'empty_beat_rate': empty_beat_rate
        }
        low_empty_beat_rate, high_empty_beat_rate = min(low_empty_beat_rate, empty_beat_rate), max(high_empty_beat_rate, empty_beat_rate)
    avg_empty_beat_rate = avg_empty_beat_rate / number
    finnal_obj['avg_empty_beat_rate'] = avg_empty_beat_rate
    finnal_obj['low_empty_beat_rate'] = low_empty_beat_rate
    finnal_obj['high_empty_beat_rate'] = high_empty_beat_rate
    with open(generated_file, 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)
    print(f"metric [empty_beat_rate] finished!, {dataset}'s avg empty beat rate is {avg_empty_beat_rate}")


def groove_consistency(midis, generated_file = f'./generated/metrics/{dataset}/groove_consistency.json'):
    finnal_obj = {}
    avg_groove_consistency,low_groove_consistency, high_groove_consistency = 0, 1e5, 0
    number = len(midis)
    for midi_file in midis:
        try:
            music = muspy.read(midi_file)
        except Exception as e:
            print(f"Error reading {midi_file}: {e}")
            number -= 1
            continue
        if len(music.tracks) == 0:
            number -= 1
            continue
        groove_consistency = muspy.groove_consistency(music, 16)
        if math.isnan(groove_consistency):
            number -= 1
            continue
        avg_groove_consistency += groove_consistency
        finnal_obj[midi_file] = {
            'groove_consistency': groove_consistency
        }
        low_groove_consistency, high_groove_consistency = min(low_groove_consistency, groove_consistency), max(high_groove_consistency, groove_consistency)
    avg_groove_consistency = avg_groove_consistency / number
    finnal_obj['avg_groove_consistency'] = avg_groove_consistency
    finnal_obj['low_groove_consistency'] = low_groove_consistency
    finnal_obj['high_groove_consistency'] = groove_consistency
    with open(generated_file, 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)
    print(f"metric [groove_consistency] finished!, {dataset}'s avg groove consistency is {avg_groove_consistency}")


"""
Polyphony（复音数） 是衡量音乐中同时演奏的音符数量的一个重要指标，通常用来描述音乐的和声复杂性和结构。在音乐分析和生成中，复音数是理解和评估音乐文本、节奏和和声的一个重要参数。
"""
def polyphony(midis, generated_file = f'./generated/metrics/{dataset}/polyphony.json'):
    finnal_obj = {}
    avg_polyphony,low_polyphony, high_polyphony = 0, 0, 100
    number = len(midis)
    for midi_file in midis:
        try:
            music = muspy.read(midi_file)
        except Exception as e:
            print(f"Error reading {midi_file}: {e}")
            number -= 1
            continue
        if len(music.tracks) == 0:
            number -= 1
            continue
        polyphony = muspy.polyphony(music)
        if math.isnan(polyphony):
            number -= 1
            continue
        avg_polyphony += polyphony
        finnal_obj[midi_file] = {
            'polyphony': polyphony
        }
        low_polyphony, high_polyphony = min(low_polyphony, polyphony), max(high_polyphony, polyphony)
    avg_polyphony = avg_polyphony / number
    finnal_obj['avg_polyphony'] = avg_polyphony
    finnal_obj['low_polyphony'] = low_polyphony
    finnal_obj['high_polyphony'] = high_polyphony
    with open(generated_file, 'w') as json_file:
        json.dump(finnal_obj, json_file, indent=4)
    print(f"metric [polyphony] finished!, {dataset}'s avg polyphony is {avg_polyphony}")

if __name__ == '__main__':
    midis = get_midi_files(files_dir)
    picth_class_entropy(midis)
    scale_consistency(midis)
    empty_beat_rate(midis)
    groove_consistency(midis)
    polyphony(midis)
    
        
    