import os
from pathlib import Path
from collections import defaultdict
# from miditoolkit import MidiFile # 不再需要，除非你的normalize_midi函数依赖它
import numpy as np
import pandas as pd
from tqdm import tqdm
import partitura as pt
# from src.utils.midi import normalize_midi # 假设这一步已在预处理中完成
import parangonar as pa
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import json
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import jensenshannon
import miditoolkit
from src.utils.midi import midi_to_ids
from src.model.pianoformer import PianoT5GemmaConfig

# --- 文件路径管理函数 (无需改动) ---
def get_evaluate_list(gt_path, pred_path, not_in = []):
    """
    生成一个包含 Ground Truth 和 Prediction 文件路径对的列表。
    """
    out = []
    gt_path = Path(gt_path)
    pred_path = Path(pred_path)
    gt_files = sorted(list(gt_path.glob("*.mid")))
    for gt_file in gt_files:
        prefix = str(gt_file).split("-")[-2].split("/")[-1]
        number = str(gt_file).split("-")[-1].split(".")[0]
        if prefix in not_in:
            continue
        out.append({
            "gt": os.path.join(gt_path, f"{prefix}-{number}.mid"), 
            "pred": os.path.join(pred_path, f"{prefix}.mid")
        })
    return out

# --- 核心重构部分：对齐与缓存 ---
def align_mid_pt(evaluate_list, cache_path=None, overwrite_cache=False):
    """
    使用 partitura 和 parangonar 对MIDI文件进行音符对齐。
    - 优雅地处理对齐，直接使用 NoteArray。
    - 缓存对齐结果，保存的是稳定的 note_id 对，而非索引。
    - 支持缓存读取和强制重写。
    """
    if cache_path and os.path.exists(cache_path) and not overwrite_cache:
        print(f"Loading alignment from cache: {cache_path}")
        with open(cache_path, "r") as f:
            return json.load(f)

    print("Cache not found or overwrite forced. Running alignment...")
    aligned_results = []
    matcher = pa.TheGlueNoteMatcher()
    #print(evaluate_list)
    for item in tqdm(evaluate_list, desc="Aligning MIDI pairs"):
        try:
            gt_path = item["gt"]
            pred_path = item["pred"]
            
            # 使用 partitura 加载标准化的MIDI
            #print(gt_path)
            gt_perf = pt.load_performance_midi(gt_path)
            pred_perf = pt.load_performance_midi(pred_path)
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()
            
            # 使用 parangonar 进行对齐
            alignment = matcher(gt_note_array, pred_note_array)
            
            # 提取对齐的 note_id 对
            # ('match', 'deletion', 'insertion')
            note_id_pairs = []
            for align_item in alignment:
                if align_item["label"] == "match":
                    note_id_pairs.append(
                        (align_item["score_id"], align_item["performance_id"])
                    )
                elif align_item["label"] == "deletion":
                    note_id_pairs.append(
                        (align_item["score_id"], None) # Pred 中没有对应的音符
                    )
            # 我们也可以处理 'insertion'，但根据你之前的逻辑，这里先忽略
            # else: # insertion
            #     note_id_pairs.append(
            #         (None, align_item["performance_id"])
            #     )

            aligned_results.append({
                "gt": gt_path,
                "pred": pred_path,
                "note_id_pairs": note_id_pairs
            })
            
        except Exception as e:
            print(f"Error: {e}")

    if cache_path:
        # 确保缓存目录存在
        Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
        print(f"Saving alignment cache to: {cache_path}")
        with open(cache_path, "w") as f:
            json.dump(aligned_results, f, indent=4)
            
    return aligned_results

# --- 指标计算函数 (基于Partitura重构) ---
def compute_vel_l1_pt(aligned_list):
    """
    基于对齐的note_id对，计算L1 Loss (MAE)。
    """
    total_abs_error = 0
    match_count = 0
    
    print("Computing Velocity L1 Loss (MAE)...")
    for item in tqdm(aligned_list, desc="Calculating L1 Loss"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 velocity 的快速查找字典
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}

            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        total_abs_error += abs(gt_vel_map[gt_id] - pred_vel_map[pred_id])
                        match_count += 1
                        
        except Exception as e:
            print(f"Error processing item for L1 loss. GT: {item['gt']}. Error: {e}")
            
    if match_count == 0:
        return float('nan')
        
    return total_abs_error / match_count

def compute_vel_correlation_pt(aligned_list):
    """
    基于对齐的note_id对，计算力度相关性。
    """
    gt_velocities = []
    pred_velocities = []

    print("Computing Velocity Correlation...")
    for item in tqdm(aligned_list, desc="Calculating Correlation"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 velocity 的快速查找字典
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        gt_velocities.append(gt_vel_map[gt_id])
                        pred_velocities.append(pred_vel_map[pred_id])

        except Exception as e:
            print(f"Error processing item for correlation. GT: {item['gt']}. Error: {e}")
    
    if len(gt_velocities) < 2:
        print("Warning: Not enough matched notes to compute correlation.")
        return {}

    pearson_corr, p_pearson = pearsonr(gt_velocities, pred_velocities)
    spearman_corr, p_spearman = spearmanr(gt_velocities, pred_velocities)

    return {
        "pearson": {"correlation": pearson_corr, "p_value": p_pearson},
        "spearman": {"correlation": spearman_corr, "p_value": p_spearman},
        "matched_note_count": len(gt_velocities)
    }

def z_score_normalize(sequence):
    """对一个序列进行Z-score归一化"""
    # 如果序列中所有值都相同，标准差为0，返回全零数组以避免除零错误
    if np.std(sequence) == 0:
        return np.zeros_like(sequence)
    return (sequence - np.mean(sequence)) / np.std(sequence)

def compute_velocity_contour_similarity(aligned_list, window_size=64):
    """
    计算对齐音符的力度轮廓相似度。

    Args:
        aligned_list (list): 包含GT/Pred路径和note_id_pairs的对齐列表。
        window_size (int): 滑动窗口的大小，推荐使用奇数 (e.g., 5, 7, 9)。

    Returns:
        float: 所有窗口的平均皮尔逊相关性得分。
    """
    all_window_correlations = []

    print(f"Computing Velocity Contour Similarity (window_size={window_size})...")
    for item in tqdm(aligned_list, desc="Processing Contour Similarity"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            
            # 首先，构建出对齐的力度序列
            gt_aligned_vels = []
            pred_aligned_vels = []
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        gt_aligned_vels.append(gt_vel_map[gt_id])
                        pred_aligned_vels.append(pred_vel_map[pred_id])
            
            if len(gt_aligned_vels) < window_size:
                continue # 音符数量不足以形成一个窗口，跳过

            # 滑动窗口计算
            for i in range(len(gt_aligned_vels) - window_size + 1):
                gt_window = np.array(gt_aligned_vels[i : i + window_size])
                pred_window = np.array(pred_aligned_vels[i : i + window_size])
                
                # 对窗口内的数据进行Z-score归一化
                gt_norm = z_score_normalize(gt_window)
                pred_norm = z_score_normalize(pred_window)

                # 如果归一化后一个或两个序列的标准差为0（即窗口内力度值都一样），则无法计算相关性
                # 这种情况下的相似度可以认为是0或直接跳过，这里我们跳过
                if np.std(gt_norm) == 0 or np.std(pred_norm) == 0:
                    continue

                # 计算归一化后的皮尔逊相关性
                corr, _ = pearsonr(gt_norm, pred_norm)

                # pearsonr可能返回nan，需要处理
                if not np.isnan(corr):
                    all_window_correlations.append(corr)

        except Exception as e:
            print(f"Error processing item for contour similarity. GT: {item['gt']}. Error: {e}")

    if not all_window_correlations:
        return float('nan')
        
    return np.mean(all_window_correlations)

def compute_dur_l1_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符时长(秒)的L1 Loss (MAE)。
    """
    total_abs_error = 0
    match_count = 0
    
    print("Computing Duration (sec) L1 Loss (MAE)...")
    for item in tqdm(aligned_list, desc="Calculating Duration L1 Loss"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])

            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 duration_sec 的快速查找字典
            gt_dur_map = {note['id']: note['duration_tick'] for note in gt_note_array}
            pred_dur_map = {note['id']: note['duration_tick'] for note in pred_note_array}

            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_dur_map and pred_id in pred_dur_map:
                        total_abs_error += abs(gt_dur_map[gt_id] - pred_dur_map[pred_id])
                        #print(gt_dur_map[gt_id], pred_dur_map[pred_id])
                        match_count += 1
                        
        except Exception as e:
            print(f"Error processing item for Duration L1 loss. GT: {item['gt']}. Error: {e}")
            
    if match_count == 0:
        return float('nan')
        
    return total_abs_error / match_count

def compute_dur_correlation_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符时长(秒)的相关性。
    """
    gt_durations = []
    pred_durations = []

    print("Computing Duration (sec) Correlation...")
    for item in tqdm(aligned_list, desc="Calculating Duration Correlation"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 duration_sec 的快速查找字典
            gt_dur_map = {note['id']: note['duration_tick'] for note in gt_note_array}
            pred_dur_map = {note['id']: note['duration_tick'] for note in pred_note_array}
            
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_dur_map and pred_id in pred_dur_map:
                        gt_durations.append(gt_dur_map[gt_id])
                        pred_durations.append(pred_dur_map[pred_id])

        except Exception as e:
            print(f"Error processing item for duration correlation. GT: {item['gt']}. Error: {e}")
    
    if len(gt_durations) < 2:
        print("Warning: Not enough matched notes to compute duration correlation.")
        return {}
        
    # 健壮性检查，防止因缺乏变化导致NaN
    if np.std(gt_durations) == 0 or np.std(pred_durations) == 0:
       print("Warning: One of the duration arrays has zero standard deviation. Correlation is undefined.")
       return {
           "pearson": {"correlation": 0.0, "p_value": 1.0},
           "spearman": {"correlation": 0.0, "p_value": 1.0},
           "matched_note_count": len(gt_durations),
           "note": "Correlation is undefined (zero variance), reported as 0."
       }

    pearson_corr, p_pearson = pearsonr(gt_durations, pred_durations)
    spearman_corr, p_spearman = spearmanr(gt_durations, pred_durations)

    return {
        "pearson": {"correlation": float(pearson_corr), "p_value": float(p_pearson)},
        "spearman": {"correlation": float(spearman_corr), "p_value": float(p_spearman)},
        "matched_note_count": len(gt_durations)
    }

def compute_ioi_metrics_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符间起始时距 (IOI) 的 L1 Loss 和相关性。
    该函数一次性计算所有IOI指标以提高效率。
    """
    gt_iois_sec = []
    pred_iois_sec = []

    print("Computing Inter-Onset Interval (IOI) Metrics...")
    for item in tqdm(aligned_list, desc="Calculating IOI Metrics"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 onset_tick 的快速查找字典
            gt_onset_map = {note['id']: note['onset_tick'] for note in gt_note_array}
            pred_onset_map = {note['id']: note['onset_tick'] for note in pred_note_array}
            
            # 首先，只筛选出成功匹配的音符对
            matched_pairs = [
                (gt_id, pred_id) for gt_id, pred_id in item["note_id_pairs"] 
                if gt_id is not None and pred_id is not None
            ]

            # 遍历连续的匹配对来计算IOI
            for i in range(1, len(matched_pairs)):
                prev_gt_id, prev_pred_id = matched_pairs[i-1]
                curr_gt_id, curr_pred_id = matched_pairs[i]

                # 确保所有ID都存在于map中
                if all(k in gt_onset_map for k in (prev_gt_id, curr_gt_id)) and \
                   all(k in pred_onset_map for k in (prev_pred_id, curr_pred_id)):
                    
                    # 计算tick为单位的IOI
                    gt_ioi_tick = gt_onset_map[curr_gt_id] - gt_onset_map[prev_gt_id]
                    pred_ioi_tick = pred_onset_map[curr_pred_id] - pred_onset_map[prev_pred_id]
                    
                    # 只考虑正向时间的IOI（排除和弦内音符）
                    if gt_ioi_tick >= 0 and pred_ioi_tick >= 0:
                        gt_iois_sec.append(gt_ioi_tick / 1000)
                        pred_iois_sec.append(pred_ioi_tick / 1000)

        except Exception as e:
            print(f"Error processing item for IOI metrics. GT: {item['gt']}. Error: {e}")

    if len(gt_iois_sec) < 2:
        print("Warning: Not enough consecutive matched notes to compute IOI metrics.")
        return {}

    # --- 计算 L1 MAE ---
    gt_iois_arr = np.array(gt_iois_sec)
    pred_iois_arr = np.array(pred_iois_sec)
    l1_error_sec = np.mean(np.abs(gt_iois_arr - pred_iois_arr))
    
    # --- 计算相关性 ---
    correlation_results = {}
    if np.std(gt_iois_arr) > 0 and np.std(pred_iois_arr) > 0:
        pearson_corr, p_pearson = pearsonr(gt_iois_arr, pred_iois_arr)
        spearman_corr, p_spearman = spearmanr(gt_iois_arr, pred_iois_arr)
        correlation_results = {
            "pearson": {"correlation": pearson_corr, "p_value": p_pearson},
            "spearman": {"correlation": spearman_corr, "p_value": p_spearman},
        }
    else:
        correlation_results = {
            "note": "Correlation is undefined (zero variance), reported as 0."
        }

    return {
        'l1_error_ms': l1_error_sec * 1000, # 转换为毫秒，更易读
        'correlation': correlation_results,
        'matched_ioi_count': len(gt_iois_sec)
    }

# --- 可视化函数 (基于Partitura重构) ---
"""
def plot_velocity_distribution_pt(evaluate_list, method):
    method_name = method.split("/")[-1]
    gt_velocities = []
    pred_velocities = []

    print("Extracting velocity data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting Vel Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_velocities.extend(gt_perf.note_array()['velocity'])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_velocities.extend(pred_perf.note_array()['velocity'])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_velocities or not pred_velocities:
        print("No notes found to plot velocity distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_velocities, color="skyblue", label="Ground Truth (Human)", kde=True, bins=128, binrange=(0, 128))
    sns.histplot(pred_velocities, color="red", label=f"Prediction ({method_name})", kde=True, bins=128, binrange=(0, 128), alpha=0.6)
    plt.title("Velocity Distribution Comparison", fontsize=16)
    plt.xlabel("MIDI Velocity", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    # 确保图片保存目录存在
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/vel_{method_name}.png")
    plt.close()
"""
def plot_velocity_distribution_and_calculate_metrics_pt(evaluate_list, method):
    """
    使用 Partitura 绘制力度分布直方图，并计算分布相似性指标。
    使用直方图数据进行计算，而非KDE。
    """
    method_name = method.split("/")[-1]
    gt_velocities = []
    pred_velocities = []

    print(f"[{method_name}] Extracting velocity data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing Vel Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 0:
                gt_velocities.extend(gt_note_array['velocity'])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 0:
                pred_velocities.extend(pred_note_array['velocity'])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_velocities or not pred_velocities:
        print(f"Error: No valid velocity data found for {method_name}. Skipping plot and metrics.")
        return None, None, None

    # --- 1. 计算直方图 ---
    # 定义bins，确保两个直方图具有相同的bins，这是计算相似性的前提
    bins = np.arange(0, 129)  # 0 to 128, with 128 bins of width 1
    
    # 计算频数
    gt_hist, _ = np.histogram(gt_velocities, bins=bins)
    pred_hist, _ = np.histogram(pred_velocities, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    # 将频数转换为概率分布（归一化）
    # 添加一个极小值epsilon防止出现log(0)或除以0的错误
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    
    # 再次归一化确保和为1
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)

    # 计算JS散度 (值越小越好)
    # jensenshannon的输入是概率分布
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)

    
    # (可选) 计算直方图交叉（Histogram Intersection），也是一个不错的指标
    # 值越大越好, 范围[0, 1]
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Distribution Metrics for {method_name} ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")
    
    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    
    # 我们画直方图，但为了美观，可以叠加一层薄薄的KDE曲线
    # 使用 histplot 可以直接画频数直方图
    sns.histplot(gt_velocities, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_velocities, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    
    # 如果想让图形更平滑，可以画线图，但本质是直方图
    # plt.plot((bins[:-1] + bins[1:]) / 2, gt_hist, color="skyblue", label="Ground Truth (Human)")
    # plt.plot((bins[:-1] + bins[1:]) / 2, pred_hist, color="red", label=f"Prediction ({method_name})")
    
    plt.title(f"Velocity Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("MIDI Velocity", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(0, 128)
    
    # 在图上标注计算出的指标，更有说服力
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.05, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    # 确保图片保存目录存在
    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"vel_dist_{method_name}.png")
    plt.close()
    
    return js_divergence, intersection

"""
def plot_duration_distribution_pt(evaluate_list, method, duration_range=(0, 2000.0)):
    method_name = method.split("/")[-1]
    gt_durations = []
    pred_durations = []
    
    print("Extracting duration data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting Dur Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_durs = gt_perf.note_array()['duration_tick']
            gt_durations.extend(gt_durs[(gt_durs >= duration_range[0]) & (gt_durs < duration_range[1])])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_durs = pred_perf.note_array()['duration_tick']
            pred_durations.extend(pred_durs[(pred_durs >= duration_range[0]) & (pred_durs < duration_range[1])])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue
            
    if not gt_durations or not pred_durations:
        print("No notes found to plot duration distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_durations, color="skyblue", label="Ground Truth (Human)", kde=True, bins=200)
    sns.histplot(pred_durations, color="red", label=f"Prediction ({method_name})", kde=True, bins=200, alpha=0.6)
    plt.title("Note Duration Distribution Comparison", fontsize=16)
    plt.xlabel("Duration (seconds)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/dur_{method_name}.png")
    plt.close()
"""
def plot_duration_distribution_and_metrics_pt(evaluate_list, method, duration_range=(0, 500), num_bins=250):
    """
    使用 Partitura 绘制音符时值(tick)分布直方图，并计算分布相似性指标。
    """
    method_name = method.split("/")[-1]
    gt_durations = []
    pred_durations = []

    print(f"[{method_name}] Extracting duration data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing Dur Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 0:
                gt_durs_tick = gt_note_array['duration_tick']
                gt_durations.extend(gt_durs_tick)

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 0:
                pred_durs_tick = pred_note_array['duration_tick']
                pred_durations.extend(pred_durs_tick)
        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue
    
    # 将列表转换为numpy数组以便于过滤
    gt_durations = np.array(gt_durations)
    pred_durations = np.array(pred_durations)
    
    # 过滤掉超出范围的值
    gt_durations = gt_durations[(gt_durations >= duration_range[0]) & (gt_durations < duration_range[1])]
    pred_durations = pred_durations[(pred_durations >= duration_range[0]) & (pred_durations < duration_range[1])]

    if gt_durations.size == 0 or pred_durations.size == 0:
        print(f"Error: No valid duration data found in range for {method_name}. Skipping.")
        return None, None, None

    # --- 1. 计算直方图 ---
    bins = np.linspace(duration_range[0], duration_range[1], num_bins + 1)
    gt_hist, _ = np.histogram(gt_durations, bins=bins)
    pred_hist, _ = np.histogram(pred_durations, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)

    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Duration Distribution Metrics for {method_name} (Range: {duration_range}) ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")

    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_durations, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_durations, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    plt.title(f"Note Duration Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("Duration (ticks)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(duration_range)

    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.65, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"dur_dist_{method_name}.png")
    plt.close()

    return js_divergence, intersection

"""
def plot_ioi_distribution_pt(evaluate_list, method, ioi_range=(0, 1000.0)):
    method_name = method.split("/")[-1]
    gt_iois = []
    pred_iois = []

    print("Extracting IOI data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting IOI Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_onsets = gt_perf.note_array()['onset_tick']
            # np.diff 计算相邻元素的差，因为NoteArray已排序，这正是IOI
            gt_ioi_vals = np.diff(gt_onsets)
            gt_iois.extend(gt_ioi_vals[(gt_ioi_vals >= ioi_range[0]) & (gt_ioi_vals < ioi_range[1])])

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_onsets = pred_perf.note_array()['onset_tick']
            pred_ioi_vals = np.diff(pred_onsets)
            pred_iois.extend(pred_ioi_vals[(pred_ioi_vals >= ioi_range[0]) & (pred_ioi_vals < ioi_range[1])])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_iois or not pred_iois:
        print("No notes found to plot IOI distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    bins = 200
    sns.histplot(gt_iois, color="skyblue", label="Ground Truth (Human)", kde=True, bins=bins)
    sns.histplot(pred_iois, color="red", label=f"Prediction ({method_name})", kde=True, bins=bins, alpha=0.6)
    plt.title(f"Inter-Onset Interval (IOI) Distribution", fontsize=16)
    plt.xlabel("IOI (seconds)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/ioi_{method_name}.png")
    plt.close()
"""

def plot_ioi_distribution_and_metrics_pt(evaluate_list, method, ioi_range=(0, 200), num_bins=200):
    """
    使用 Partitura 绘制 IOI (tick) 分布直方图，并计算分布相似性指标。
    """
    method_name = method.split("/")[-1]
    gt_iois = []
    pred_iois = []

    print(f"[{method_name}] Extracting IOI data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing IOI Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 1:
                # np.diff 计算已排序音符的onset差，即IOI
                gt_ioi_vals = np.diff(gt_note_array['onset_tick'])
                gt_iois.extend(gt_ioi_vals)

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 1:
                pred_ioi_vals = np.diff(pred_note_array['onset_tick'])
                pred_iois.extend(pred_ioi_vals)
        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue

    gt_iois = np.array(gt_iois)
    pred_iois = np.array(pred_iois)
    
    gt_iois = gt_iois[(gt_iois >= ioi_range[0]) & (gt_iois < ioi_range[1])]
    pred_iois = pred_iois[(pred_iois >= ioi_range[0]) & (pred_iois < ioi_range[1])]

    if gt_iois.size == 0 or pred_iois.size == 0:
        print(f"Error: No valid IOI data found in range for {method_name}. Skipping.")
        return None, None, None

    # --- 1. 计算直方图 ---
    bins = np.linspace(ioi_range[0], ioi_range[1], num_bins + 1)
    gt_hist, _ = np.histogram(gt_iois, bins=bins)
    pred_hist, _ = np.histogram(pred_iois, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)
    
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- IOI Distribution Metrics for {method_name} (Range: {ioi_range}) ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")
    
    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_iois, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_iois, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    plt.title(f"Inter-Onset Interval (IOI) Distribution: {method_name}", fontsize=16)
    plt.xlabel("IOI (ticks)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(ioi_range)
    
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.65, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"ioi_dist_{method_name}.png")
    plt.close()

    return js_divergence, intersection

def plot_pedal_pattern_distribution(
    evaluate_list: list[dict], 
    method: str, 
    pedal_token_base: int = 5261, 
    pedal_binarize_threshold: int = 64
):
    gt_patterns = []
    pred_patterns = []
    method_name = Path(method).stem # 获取方法名，如 'AI-Pianist'
    config = PianoT5GemmaConfig()
    print(f"[{method_name}] Analyzing pedal patterns...")

    for item in tqdm(evaluate_list, desc=f"Analyzing Pedals for {method_name}"):
        try:
            # --- 处理 Ground Truth ---
            gt_midi = miditoolkit.MidiFile(item["gt"])
            gt_tokens = midi_to_ids(config, gt_midi)
            # --- 处理 Prediction ---
            pred_midi = miditoolkit.MidiFile(item["pred"])
            pred_tokens = midi_to_ids(config, pred_midi)

            # 提取踏板模式的辅助函数
            def extract_patterns(tokens):
                patterns = []
                for i in range(0, len(tokens), 8):
                    note_chunk = tokens[i : i + 8]
                    if len(note_chunk) < 8:
                        continue  # 跳过末尾不完整的chunk

                    pedal_tokens = note_chunk[4:]
                    binary_values = []
                    for token in pedal_tokens:
                        # 检查token是否在踏板范围内
                        if pedal_token_base <= token < pedal_token_base + 128:
                            pedal_value = token - pedal_token_base
                            binary_value = 1 if pedal_value >= pedal_binarize_threshold else 0
                            binary_values.append(binary_value)
                    
                    if len(binary_values) == 4:
                        # 将4位二进制数转换为一个0-15的整数
                        # 例: [1, 0, 1, 1] -> 1*8 + 0*4 + 1*2 + 1*1 = 11
                        pattern_decimal = (
                            binary_values[0] * 8 +
                            binary_values[1] * 4 +
                            binary_values[2] * 2 +
                            binary_values[3] * 1
                        )
                        patterns.append(pattern_decimal)
                return patterns

            gt_patterns.extend(extract_patterns(gt_tokens))
            pred_patterns.extend(extract_patterns(pred_tokens))

        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_patterns or not pred_patterns:
        print(f"Error: No valid pedal data found for {method_name}. Skipping analysis.")
        return None, None, None

    # --- 1. 计算直方图 (16个类别) ---
    bins = np.arange(17)  # Bins are [0, 1), [1, 2), ..., [15, 16)
    gt_hist, _ = np.histogram(gt_patterns, bins=bins)
    pred_hist, _ = np.histogram(pred_patterns, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    # 归一化为概率分布
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Pedal Pattern Distribution Metrics for {method_name} ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")

    # --- 3. 绘图 (使用条形图更适合离散类别) ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(14, 7))
    
    x = np.arange(16)  # 16个类别
    width = 0.35  # 条形宽度

    # 绘制条形图
    plt.bar(x - width/2, gt_hist, width, label="Ground Truth (Human)", color="skyblue", alpha=0.8)
    plt.bar(x + width/2, pred_hist, width, label=f"Prediction ({method_name})", color="red", alpha=0.6)

    plt.title(f"Pedal Pattern Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("Pedal Pattern ID (0-15)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.xticks(x, [f'{i}\n({i:04b})' for i in x]) # 显示ID和对应的4位二进制
    plt.legend()
    
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.75, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"pedal_pattern_dist_{method_name}.png")
    plt.close()
    
    return js_divergence, intersection



# --- 主程序入口 ---
if __name__ == "__main__":
    # --- 配置 ---
    # 假设你的标准化文件存放在这里
    BASE_PATH = "data/midis/testset-norm" 
    #MODEL_NAME = "virtuosoNet-han" #name HAN with color #d390ff
    #MODEL_NAME = "score" #name Score with color #a6a6a6
    #MODEL_NAME = "human" #name Human with color #1f77b4
    #MODEL_NAME = "virtuosoNet" #name Human with color #1b9e77
    MODEL_NAME = "ai-pianist-165M" #name Ours with #ff6f00
    
    gt_path = os.path.join(BASE_PATH, "human")
    pred_path = os.path.join(BASE_PATH, MODEL_NAME)
    
    # 为每个模型生成独立的缓存文件
    ALIGNMENT_CACHE_PATH = f"temp/alignment_cache_{MODEL_NAME}.json"

    # --- 流程 ---
    
    # 1. 获取文件列表
    # 使用所有文件进行评估，而不是切片
    evaluate_list = get_evaluate_list(gt_path, pred_path, not_in=["12", "22"])
    #new_evaluate_list = []
    #for i in evaluate_list:
    #    if "11.mid" in i["pred"]:
    #        new_evaluate_list.append(i)
    #evaluate_list = new_evaluate_list
    
    # 2. 对齐 (如果缓存不存在则运行，否则加载)
    # 使用新的 Partitura-based 对齐函数
    #aligned_list = align_mid_pt(evaluate_list, cache_path=ALIGNMENT_CACHE_PATH, overwrite_cache=False)
    #new_aligned_list = []
    #aligned_list = None
    #aligned_list = align_mid_pt(evaluate_list, cache_path=None, overwrite_cache=False)
    #for i in aligned_list:
    #    for j in evaluate_list:
    #        if i["gt"] == j["gt"]:
    #            new_aligned_list.append(i)
    #            break
    #aligned_list = new_aligned_list
    aligned_list = None
    if aligned_list:
        print("Alignment failed or produced no results. Exiting.")
    else:
        # 3. 计算指标
        # 使用新的 Partitura-based 计算函数
        #vel_l1_mae = compute_vel_l1_pt(aligned_list)
        #print(f"\nVelocity L1 MAE: {vel_l1_mae:.4f}")

        #correlation_results = compute_vel_correlation_pt(aligned_list)
        #print("\n--- Velocity Correlation Analysis ---")
        #print(json.dumps(correlation_results, indent=4))

        #dur_l1_mae = compute_dur_l1_pt(aligned_list)
        #print(f"\Duration L1 MAE: {dur_l1_mae:.4f}")

        #dur_correlation_results = compute_dur_correlation_pt(aligned_list)
        #print("\n--- Duration Correlation Analysis ---")
        #print(json.dumps(dur_correlation_results, indent=4))

        #ioi_results = compute_ioi_metrics_pt(aligned_list)
        #print(json.dumps(ioi_results, indent=4))
        #velocity_contour_similarity = compute_velocity_contour_similarity(aligned_list)        
        #print(velocity_contour_similarity)
        # 4. 绘制分布图
        # 注意：绘图函数使用的是原始文件列表 `evaluate_list`
        print("\nGenerating distribution plots...")
        js = []
        ins = []
        j1, i1 = plot_velocity_distribution_and_calculate_metrics_pt(evaluate_list, pred_path)
        j2, i2 = plot_duration_distribution_and_metrics_pt(evaluate_list, pred_path)
        j3, i3 = plot_ioi_distribution_and_metrics_pt(evaluate_list, pred_path)
        j4, i4 = plot_pedal_pattern_distribution(evaluate_list, pred_path)
        js.append(1 - j1)
        js.append(1 - j2)
        js.append(1 - j3)
        js.append(1 - j4)
        ins.append(i1)
        ins.append(i2)
        ins.append(i3)
        ins.append(i4)
        print(js)
        print((sum(js) / len(js) + sum(ins) / len(ins)) / 2)
        print("Plots saved to results/imgs/")
