"""
this file compute the chord shot percentage
TODO: 加入没有输入的模型输出结果对比图
"""
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from preprocess.tokenizer import RemiPlus
import json
import matplotlib.pyplot as plt

"""
    the below params need to be adjusted according to the data 
"""
GEN_MIDI_FILES = "./generated/midis/pop909/great"
ORI_MIDI_FILES = "./data/POP909"
GEN_OUT_DICT = './generated/metrics/gen_chord_shot_pop909.json'
ORI_OUT_DICT = './generated/metrics/ori_chord_shot_pop909.json'

tokenizer = RemiPlus()

def get_midi_files(file_path):
    mids = []
    for root, _, files in os.walk(file_path):
        for file in files:
            if file.endswith('.mid'):
                mids.append(os.path.join(root, file))
    return mids


def get_chord_shot(mids, ori=False):
    chord_shot = {}
    for mid_file in mids:
        remi_seq = tokenizer.midi2RemiPlus(mid_file)
        chords = tokenizer.getChords(remi_seq.events)
        if len(chords) == 0:
            continue
        if ori:
            cp = tokenizer.getChordProgression(chords)
        else:
            cp = mid_file.split('/')[-1].split('_')[0].split('-')
        up, down = 0, len(chords)
        if cp is None:
            continue
        for chord in chords:
            if chord in cp:
                up += 1
        chord_shot[mid_file] = up / down
    return chord_shot


def fillter_by_epoch(mids, epoch=100):
    fillter_mids = []
    for mid in mids:
        prev = mid.split('/')[-1].split('-')[0]
        if prev == 'ep' + str(epoch):
            fillter_mids.append(mid)
    return fillter_mids

def get_probability(chord_shot):
    probability = {0:0, 1:0, 2:0, 3:0}
    for _, v in chord_shot.items():
        if v >= 0 and v < 0.25:
            probability[0] += 1
        elif v < 0.5:
            probability[1] += 1
        elif v < 0.75:
            probability[2] += 1
        else:
            probability[3] += 1

    probability = [v/len(chord_shot) for _, v in probability.items()]
    return probability

def generate_compare_pic(filter_epoch=None):
     # 1. get midi files
    generated_mids = get_midi_files(GEN_MIDI_FILES)
    if filter_epoch is not None:
        generated_mids = fillter_by_epoch(generated_mids)
    else:
        filter_epoch = 'all'
    ori_mids = get_midi_files(ORI_MIDI_FILES)
    #2. get chord shot
    gen_chord_shot = get_chord_shot(generated_mids)
    ori_chord_shot = get_chord_shot(ori_mids, ori=True)

    # write the file
    if filter_epoch == 'all':
        with open(GEN_OUT_DICT, 'w') as f:
            json.dump(gen_chord_shot, f)
        with open(ORI_OUT_DICT, 'w') as f:
            json.dump(ori_chord_shot, f)

    # draw the picture
    gen_p = get_probability(gen_chord_shot)
    ori_p = get_probability(ori_chord_shot)
    # 绘制柱形图
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.bar([0.125, 0.375, 0.625, 0.875], ori_p, width=0.2, align='center', color='skyblue', edgecolor='black')

    # 设置图表标题和标签
    plt.title('Original Probability Distribution')
    plt.xlabel('Probability Range')
    plt.ylabel('Frequency')
    plt.xticks([0.125, 0.375, 0.625, 0.875], ['0-0.25', '0.25-0.5', '0.5-0.75', '0.75-1'])

    plt.subplot(1, 2, 2)
    plt.bar([0.125, 0.375, 0.625, 0.875], [ 4 / 100,  10 / 100 , 24 / 100 , 61 / 100 ], width=0.2, align='center', color='skyblue', edgecolor='black') # gen_p

    # 设置图表标题和标签
    plt.title('Generation Probability Distribution')
    plt.xlabel('Probability Range')
    plt.ylabel('Frequency')
    plt.xticks([0.125, 0.375, 0.625, 0.875], ['0-0.25', '0.25-0.5', '0.5-0.75', '0.75-1'])

    # 设置整体大标题
    plt.suptitle('Chord Shot Rate', fontsize=16)

    # 调整布局，防止标题和子图重叠
    plt.tight_layout(rect=[0, 0, 1, 0.96])  # rect调整为整体标题不与子图重叠

    plt.savefig(f'./generated/metrics/chord_shot_pop909_epoch{filter_epoch}.png')



if __name__ == '__main__':
    # for epoch in range(1, 50):
    #     generate_compare_pic(filter_epoch=epoch * 100)
    generate_compare_pic()


