import os
from pathlib import Path
from collections import defaultdict
# from miditoolkit import MidiFile # 不再需要，除非你的normalize_midi函数依赖它
import numpy as np
import pandas as pd
from tqdm import tqdm
import partitura as pt
# from src.utils.midi import normalize_midi # 假设这一步已在预处理中完成
import parangonar as pa
from copy import deepcopy
import matplotlib.pyplot as plt
import seaborn as sns
import json
from scipy.stats import pearsonr, spearmanr
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon
import miditoolkit
from src.utils.midi import midi_to_ids
from src.model.pianoformer import PianoT5GemmaConfig

# --- 文件路径管理函数 (无需改动) ---
def get_evaluate_list(gt_path, pred_path, not_in = []):
    """
    生成一个包含 Ground Truth 和 Prediction 文件路径对的列表。
    """
    out = []
    gt_path = Path(gt_path)
    pred_path = Path(pred_path)
    gt_files = sorted(list(gt_path.glob("*.mid")))
    for gt_file in gt_files:
        prefix = str(gt_file).split("-")[-2].split("/")[-1]
        number = str(gt_file).split("-")[-1].split(".")[0]
        if prefix in not_in:
            continue
        out.append({
            "gt": os.path.join(gt_path, f"{prefix}-{number}.mid"), 
            "pred": os.path.join(pred_path, f"{prefix}.mid")
        })
    return out

# --- 核心重构部分：对齐与缓存 ---
def align_mid_pt(evaluate_list, cache_path=None, overwrite_cache=False):
    """
    使用 partitura 和 parangonar 对MIDI文件进行音符对齐。
    - 优雅地处理对齐，直接使用 NoteArray。
    - 缓存对齐结果，保存的是稳定的 note_id 对，而非索引。
    - 支持缓存读取和强制重写。
    """
    if cache_path and os.path.exists(cache_path) and not overwrite_cache:
        print(f"Loading alignment from cache: {cache_path}")
        with open(cache_path, "r") as f:
            return json.load(f)

    print("Cache not found or overwrite forced. Running alignment...")
    aligned_results = []
    matcher = pa.TheGlueNoteMatcher()
    #print(evaluate_list)
    for item in tqdm(evaluate_list, desc="Aligning MIDI pairs"):
        try:
            gt_path = item["gt"]
            pred_path = item["pred"]
            
            # 使用 partitura 加载标准化的MIDI
            #print(gt_path)
            gt_perf = pt.load_performance_midi(gt_path)
            pred_perf = pt.load_performance_midi(pred_path)
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()
            
            # 使用 parangonar 进行对齐
            alignment = matcher(gt_note_array, pred_note_array)
            
            # 提取对齐的 note_id 对
            # ('match', 'deletion', 'insertion')
            note_id_pairs = []
            for align_item in alignment:
                if align_item["label"] == "match":
                    note_id_pairs.append(
                        (align_item["score_id"], align_item["performance_id"])
                    )
                elif align_item["label"] == "deletion":
                    note_id_pairs.append(
                        (align_item["score_id"], None) # Pred 中没有对应的音符
                    )
            # 我们也可以处理 'insertion'，但根据你之前的逻辑，这里先忽略
            # else: # insertion
            #     note_id_pairs.append(
            #         (None, align_item["performance_id"])
            #     )

            aligned_results.append({
                "gt": gt_path,
                "pred": pred_path,
                "note_id_pairs": note_id_pairs
            })
            
        except Exception as e:
            print(f"Error: {e}")

    if cache_path:
        # 确保缓存目录存在
        Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
        print(f"Saving alignment cache to: {cache_path}")
        with open(cache_path, "w") as f:
            json.dump(aligned_results, f, indent=4)
            
    return aligned_results

# --- 指标计算函数 (基于Partitura重构) ---
def compute_vel_l1_pt(aligned_list):
    """
    基于对齐的note_id对，计算L1 Loss (MAE)。
    """
    total_abs_error = 0
    match_count = 0
    
    print("Computing Velocity L1 Loss (MAE)...")
    for item in tqdm(aligned_list, desc="Calculating L1 Loss"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 velocity 的快速查找字典
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}

            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        total_abs_error += abs(gt_vel_map[gt_id] - pred_vel_map[pred_id])
                        match_count += 1
                        
        except Exception as e:
            print(f"Error processing item for L1 loss. GT: {item['gt']}. Error: {e}")
            
    if match_count == 0:
        return float('nan')
        
    return total_abs_error / match_count

def compute_vel_correlation_pt(aligned_list):
    """
    基于对齐的note_id对，计算力度相关性。
    """
    gt_velocities = []
    pred_velocities = []

    print("Computing Velocity Correlation...")
    for item in tqdm(aligned_list, desc="Calculating Correlation"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 velocity 的快速查找字典
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        gt_velocities.append(gt_vel_map[gt_id])
                        pred_velocities.append(pred_vel_map[pred_id])

        except Exception as e:
            print(f"Error processing item for correlation. GT: {item['gt']}. Error: {e}")
    
    if len(gt_velocities) < 2:
        print("Warning: Not enough matched notes to compute correlation.")
        return {}

    pearson_corr, p_pearson = pearsonr(gt_velocities, pred_velocities)
    spearman_corr, p_spearman = spearmanr(gt_velocities, pred_velocities)

    return {
        "pearson": {"correlation": pearson_corr, "p_value": p_pearson},
        "spearman": {"correlation": spearman_corr, "p_value": p_spearman},
        "matched_note_count": len(gt_velocities)
    }

def z_score_normalize(sequence):
    """对一个序列进行Z-score归一化"""
    # 如果序列中所有值都相同，标准差为0，返回全零数组以避免除零错误
    if np.std(sequence) == 0:
        return np.zeros_like(sequence)
    return (sequence - np.mean(sequence)) / np.std(sequence)

def compute_velocity_contour_similarity(aligned_list, window_size=64):
    """
    计算对齐音符的力度轮廓相似度。

    Args:
        aligned_list (list): 包含GT/Pred路径和note_id_pairs的对齐列表。
        window_size (int): 滑动窗口的大小，推荐使用奇数 (e.g., 5, 7, 9)。

    Returns:
        float: 所有窗口的平均皮尔逊相关性得分。
    """
    all_window_correlations = []

    print(f"Computing Velocity Contour Similarity (window_size={window_size})...")
    for item in tqdm(aligned_list, desc="Processing Contour Similarity"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            
            # 首先，构建出对齐的力度序列
            gt_aligned_vels = []
            pred_aligned_vels = []
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_vel_map and pred_id in pred_vel_map:
                        gt_aligned_vels.append(gt_vel_map[gt_id])
                        pred_aligned_vels.append(pred_vel_map[pred_id])
            
            if len(gt_aligned_vels) < window_size:
                continue # 音符数量不足以形成一个窗口，跳过

            # 滑动窗口计算
            for i in range(len(gt_aligned_vels) - window_size + 1):
                gt_window = np.array(gt_aligned_vels[i : i + window_size])
                pred_window = np.array(pred_aligned_vels[i : i + window_size])
                
                # 对窗口内的数据进行Z-score归一化
                gt_norm = z_score_normalize(gt_window)
                pred_norm = z_score_normalize(pred_window)

                # 如果归一化后一个或两个序列的标准差为0（即窗口内力度值都一样），则无法计算相关性
                # 这种情况下的相似度可以认为是0或直接跳过，这里我们跳过
                if np.std(gt_norm) == 0 or np.std(pred_norm) == 0:
                    continue

                # 计算归一化后的皮尔逊相关性
                corr, _ = pearsonr(gt_norm, pred_norm)

                # pearsonr可能返回nan，需要处理
                if not np.isnan(corr):
                    all_window_correlations.append(corr)

        except Exception as e:
            print(f"Error processing item for contour similarity. GT: {item['gt']}. Error: {e}")

    if not all_window_correlations:
        return float('nan')
        
    return np.mean(all_window_correlations)

def compute_dur_l1_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符时长(秒)的L1 Loss (MAE)。
    """
    total_abs_error = 0
    match_count = 0
    
    print("Computing Duration (sec) L1 Loss (MAE)...")
    for item in tqdm(aligned_list, desc="Calculating Duration L1 Loss"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])

            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 duration_sec 的快速查找字典
            gt_dur_map = {note['id']: note['duration_tick'] for note in gt_note_array}
            pred_dur_map = {note['id']: note['duration_tick'] for note in pred_note_array}

            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_dur_map and pred_id in pred_dur_map:
                        total_abs_error += abs(gt_dur_map[gt_id] - pred_dur_map[pred_id])
                        #print(gt_dur_map[gt_id], pred_dur_map[pred_id])
                        match_count += 1
                        
        except Exception as e:
            print(f"Error processing item for Duration L1 loss. GT: {item['gt']}. Error: {e}")
            
    if match_count == 0:
        return float('nan')
        
    return total_abs_error / match_count

def compute_dur_correlation_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符时长(秒)的相关性。
    """
    gt_durations = []
    pred_durations = []

    print("Computing Duration (sec) Correlation...")
    for item in tqdm(aligned_list, desc="Calculating Duration Correlation"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 duration_sec 的快速查找字典
            gt_dur_map = {note['id']: note['duration_tick'] for note in gt_note_array}
            pred_dur_map = {note['id']: note['duration_tick'] for note in pred_note_array}
            
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None:
                    if gt_id in gt_dur_map and pred_id in pred_dur_map:
                        gt_durations.append(gt_dur_map[gt_id])
                        pred_durations.append(pred_dur_map[pred_id])

        except Exception as e:
            print(f"Error processing item for duration correlation. GT: {item['gt']}. Error: {e}")
    
    if len(gt_durations) < 2:
        print("Warning: Not enough matched notes to compute duration correlation.")
        return {}
        
    # 健壮性检查，防止因缺乏变化导致NaN
    if np.std(gt_durations) == 0 or np.std(pred_durations) == 0:
       print("Warning: One of the duration arrays has zero standard deviation. Correlation is undefined.")
       return {
           "pearson": {"correlation": 0.0, "p_value": 1.0},
           "spearman": {"correlation": 0.0, "p_value": 1.0},
           "matched_note_count": len(gt_durations),
           "note": "Correlation is undefined (zero variance), reported as 0."
       }

    pearson_corr, p_pearson = pearsonr(gt_durations, pred_durations)
    spearman_corr, p_spearman = spearmanr(gt_durations, pred_durations)

    return {
        "pearson": {"correlation": float(pearson_corr), "p_value": float(p_pearson)},
        "spearman": {"correlation": float(spearman_corr), "p_value": float(p_spearman)},
        "matched_note_count": len(gt_durations)
    }

def compute_ioi_metrics_pt(aligned_list):
    """
    基于对齐的note_id对，计算音符间起始时距 (IOI) 的 L1 Loss 和相关性。
    该函数一次性计算所有IOI指标以提高效率。
    """
    gt_iois_sec = []
    pred_iois_sec = []

    print("Computing Inter-Onset Interval (IOI) Metrics...")
    for item in tqdm(aligned_list, desc="Calculating IOI Metrics"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()

            # 创建从 note_id 到 onset_tick 的快速查找字典
            gt_onset_map = {note['id']: note['onset_tick'] for note in gt_note_array}
            pred_onset_map = {note['id']: note['onset_tick'] for note in pred_note_array}
            
            # 首先，只筛选出成功匹配的音符对
            matched_pairs = [
                (gt_id, pred_id) for gt_id, pred_id in item["note_id_pairs"] 
                if gt_id is not None and pred_id is not None
            ]

            # 遍历连续的匹配对来计算IOI
            for i in range(1, len(matched_pairs)):
                prev_gt_id, prev_pred_id = matched_pairs[i-1]
                curr_gt_id, curr_pred_id = matched_pairs[i]

                # 确保所有ID都存在于map中
                if all(k in gt_onset_map for k in (prev_gt_id, curr_gt_id)) and \
                   all(k in pred_onset_map for k in (prev_pred_id, curr_pred_id)):
                    
                    # 计算tick为单位的IOI
                    gt_ioi_tick = gt_onset_map[curr_gt_id] - gt_onset_map[prev_gt_id]
                    pred_ioi_tick = pred_onset_map[curr_pred_id] - pred_onset_map[prev_pred_id]
                    
                    # 只考虑正向时间的IOI（排除和弦内音符）
                    if gt_ioi_tick >= 0 and pred_ioi_tick >= 0:
                        gt_iois_sec.append(gt_ioi_tick / 1000)
                        pred_iois_sec.append(pred_ioi_tick / 1000)

        except Exception as e:
            print(f"Error processing item for IOI metrics. GT: {item['gt']}. Error: {e}")

    if len(gt_iois_sec) < 2:
        print("Warning: Not enough consecutive matched notes to compute IOI metrics.")
        return {}

    # --- 计算 L1 MAE ---
    gt_iois_arr = np.array(gt_iois_sec)
    pred_iois_arr = np.array(pred_iois_sec)
    l1_error_sec = np.mean(np.abs(gt_iois_arr - pred_iois_arr))
    
    # --- 计算相关性 ---
    correlation_results = {}
    if np.std(gt_iois_arr) > 0 and np.std(pred_iois_arr) > 0:
        pearson_corr, p_pearson = pearsonr(gt_iois_arr, pred_iois_arr)
        spearman_corr, p_spearman = spearmanr(gt_iois_arr, pred_iois_arr)
        correlation_results = {
            "pearson": {"correlation": pearson_corr, "p_value": p_pearson},
            "spearman": {"correlation": spearman_corr, "p_value": p_spearman},
        }
    else:
        correlation_results = {
            "note": "Correlation is undefined (zero variance), reported as 0."
        }

    return {
        'l1_error_ms': l1_error_sec * 1000, # 转换为毫秒，更易读
        'correlation': correlation_results,
        'matched_ioi_count': len(gt_iois_sec)
    }

# --- 可视化函数 (基于Partitura重构) ---
"""
def plot_velocity_distribution_pt(evaluate_list, method):
    method_name = method.split("/")[-1]
    gt_velocities = []
    pred_velocities = []

    print("Extracting velocity data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting Vel Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_velocities.extend(gt_perf.note_array()['velocity'])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_velocities.extend(pred_perf.note_array()['velocity'])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_velocities or not pred_velocities:
        print("No notes found to plot velocity distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_velocities, color="skyblue", label="Ground Truth (Human)", kde=True, bins=128, binrange=(0, 128))
    sns.histplot(pred_velocities, color="red", label=f"Prediction ({method_name})", kde=True, bins=128, binrange=(0, 128), alpha=0.6)
    plt.title("Velocity Distribution Comparison", fontsize=16)
    plt.xlabel("MIDI Velocity", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    # 确保图片保存目录存在
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/vel_{method_name}.png")
    plt.close()
"""
def plot_velocity_distribution_and_calculate_metrics_pt(evaluate_list, method):
    """
    使用 Partitura 绘制力度分布直方图，并计算分布相似性指标。
    使用直方图数据进行计算，而非KDE。
    """
    method_name = method.split("/")[-1]
    gt_velocities = []
    pred_velocities = []

    print(f"[{method_name}] Extracting velocity data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing Vel Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 0:
                gt_velocities.extend(gt_note_array['velocity'])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 0:
                pred_velocities.extend(pred_note_array['velocity'])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_velocities or not pred_velocities:
        print(f"Error: No valid velocity data found for {method_name}. Skipping plot and metrics.")
        return None, None, None

    # --- 1. 计算直方图 ---
    # 定义bins，确保两个直方图具有相同的bins，这是计算相似性的前提
    bins = np.arange(0, 129)  # 0 to 128, with 128 bins of width 1
    
    # 计算频数
    gt_hist, _ = np.histogram(gt_velocities, bins=bins)
    pred_hist, _ = np.histogram(pred_velocities, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    # 将频数转换为概率分布（归一化）
    # 添加一个极小值epsilon防止出现log(0)或除以0的错误
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    
    # 再次归一化确保和为1
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)

    # 计算JS散度 (值越小越好)
    # jensenshannon的输入是概率分布
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)

    # 计算Wasserstein距离 (值越小越好)
    # wasserstein_distance的输入是原始数据点或分布的值（这里用原始数据点更准确）
    w_distance = wasserstein_distance(gt_velocities, pred_velocities)
    
    # (可选) 计算直方图交叉（Histogram Intersection），也是一个不错的指标
    # 值越大越好, 范围[0, 1]
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Distribution Metrics for {method_name} ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Wasserstein Distance: {w_distance:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")
    
    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    
    # 我们画直方图，但为了美观，可以叠加一层薄薄的KDE曲线
    # 使用 histplot 可以直接画频数直方图
    sns.histplot(gt_velocities, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_velocities, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    
    # 如果想让图形更平滑，可以画线图，但本质是直方图
    # plt.plot((bins[:-1] + bins[1:]) / 2, gt_hist, color="skyblue", label="Ground Truth (Human)")
    # plt.plot((bins[:-1] + bins[1:]) / 2, pred_hist, color="red", label=f"Prediction ({method_name})")
    
    plt.title(f"Velocity Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("MIDI Velocity", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(0, 128)
    
    # 在图上标注计算出的指标，更有说服力
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Wasserstein Distance: {w_distance:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.05, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    # 确保图片保存目录存在
    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"vel_dist_{method_name}.png")
    plt.close()
    
    return js_divergence, w_distance, intersection

"""
def plot_duration_distribution_pt(evaluate_list, method, duration_range=(0, 2000.0)):
    method_name = method.split("/")[-1]
    gt_durations = []
    pred_durations = []
    
    print("Extracting duration data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting Dur Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_durs = gt_perf.note_array()['duration_tick']
            gt_durations.extend(gt_durs[(gt_durs >= duration_range[0]) & (gt_durs < duration_range[1])])
            
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_durs = pred_perf.note_array()['duration_tick']
            pred_durations.extend(pred_durs[(pred_durs >= duration_range[0]) & (pred_durs < duration_range[1])])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue
            
    if not gt_durations or not pred_durations:
        print("No notes found to plot duration distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_durations, color="skyblue", label="Ground Truth (Human)", kde=True, bins=200)
    sns.histplot(pred_durations, color="red", label=f"Prediction ({method_name})", kde=True, bins=200, alpha=0.6)
    plt.title("Note Duration Distribution Comparison", fontsize=16)
    plt.xlabel("Duration (seconds)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/dur_{method_name}.png")
    plt.close()
"""
def plot_duration_distribution_and_metrics_pt(evaluate_list, method, duration_range=(0, 500), num_bins=250):
    """
    使用 Partitura 绘制音符时值(tick)分布直方图，并计算分布相似性指标。
    """
    method_name = method.split("/")[-1]
    gt_durations = []
    pred_durations = []

    print(f"[{method_name}] Extracting duration data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing Dur Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 0:
                gt_durs_tick = gt_note_array['duration_tick']
                gt_durations.extend(gt_durs_tick)

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 0:
                pred_durs_tick = pred_note_array['duration_tick']
                pred_durations.extend(pred_durs_tick)
        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue
    
    # 将列表转换为numpy数组以便于过滤
    gt_durations = np.array(gt_durations)
    pred_durations = np.array(pred_durations)
    
    # 过滤掉超出范围的值
    gt_durations = gt_durations[(gt_durations >= duration_range[0]) & (gt_durations < duration_range[1])]
    pred_durations = pred_durations[(pred_durations >= duration_range[0]) & (pred_durations < duration_range[1])]

    if gt_durations.size == 0 or pred_durations.size == 0:
        print(f"Error: No valid duration data found in range for {method_name}. Skipping.")
        return None, None, None

    # --- 1. 计算直方图 ---
    bins = np.linspace(duration_range[0], duration_range[1], num_bins + 1)
    gt_hist, _ = np.histogram(gt_durations, bins=bins)
    pred_hist, _ = np.histogram(pred_durations, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)

    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    w_distance = wasserstein_distance(gt_durations, pred_durations)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Duration Distribution Metrics for {method_name} (Range: {duration_range}) ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Wasserstein Distance: {w_distance:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")

    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_durations, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_durations, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    plt.title(f"Note Duration Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("Duration (ticks)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(duration_range)

    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Wasserstein Distance: {w_distance:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.65, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"dur_dist_{method_name}.png")
    plt.close()

    return js_divergence, w_distance, intersection

"""
def plot_ioi_distribution_pt(evaluate_list, method, ioi_range=(0, 1000.0)):
    method_name = method.split("/")[-1]
    gt_iois = []
    pred_iois = []

    print("Extracting IOI data for distribution plot...")
    for item in tqdm(evaluate_list, desc=f"Plotting IOI Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_onsets = gt_perf.note_array()['onset_tick']
            # np.diff 计算相邻元素的差，因为NoteArray已排序，这正是IOI
            gt_ioi_vals = np.diff(gt_onsets)
            gt_iois.extend(gt_ioi_vals[(gt_ioi_vals >= ioi_range[0]) & (gt_ioi_vals < ioi_range[1])])

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_onsets = pred_perf.note_array()['onset_tick']
            pred_ioi_vals = np.diff(pred_onsets)
            pred_iois.extend(pred_ioi_vals[(pred_ioi_vals >= ioi_range[0]) & (pred_ioi_vals < ioi_range[1])])
        except Exception as e:
            print(f"Warning: Could not process file pair for plot. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_iois or not pred_iois:
        print("No notes found to plot IOI distribution.")
        return

    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    bins = 200
    sns.histplot(gt_iois, color="skyblue", label="Ground Truth (Human)", kde=True, bins=bins)
    sns.histplot(pred_iois, color="red", label=f"Prediction ({method_name})", kde=True, bins=bins, alpha=0.6)
    plt.title(f"Inter-Onset Interval (IOI) Distribution", fontsize=16)
    plt.xlabel("IOI (seconds)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    Path("results/imgs/").mkdir(parents=True, exist_ok=True)
    plt.savefig(f"results/imgs/ioi_{method_name}.png")
    plt.close()
"""

def plot_ioi_distribution_and_metrics_pt(evaluate_list, method, ioi_range=(0, 200), num_bins=200):
    """
    使用 Partitura 绘制 IOI (tick) 分布直方图，并计算分布相似性指标。
    """
    method_name = method.split("/")[-1]
    gt_iois = []
    pred_iois = []

    print(f"[{method_name}] Extracting IOI data for distribution analysis...")
    for item in tqdm(evaluate_list, desc=f"Analyzing IOI Dist for {method_name}"):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            gt_note_array = gt_perf.note_array()
            if gt_note_array.size > 1:
                # np.diff 计算已排序音符的onset差，即IOI
                gt_ioi_vals = np.diff(gt_note_array['onset_tick'])
                gt_iois.extend(gt_ioi_vals)

            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 1:
                pred_ioi_vals = np.diff(pred_note_array['onset_tick'])
                pred_iois.extend(pred_ioi_vals)
        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue

    gt_iois = np.array(gt_iois)
    pred_iois = np.array(pred_iois)
    
    gt_iois = gt_iois[(gt_iois >= ioi_range[0]) & (gt_iois < ioi_range[1])]
    pred_iois = pred_iois[(pred_iois >= ioi_range[0]) & (pred_iois < ioi_range[1])]

    if gt_iois.size == 0 or pred_iois.size == 0:
        print(f"Error: No valid IOI data found in range for {method_name}. Skipping.")
        return None, None, None

    # --- 1. 计算直方图 ---
    bins = np.linspace(ioi_range[0], ioi_range[1], num_bins + 1)
    gt_hist, _ = np.histogram(gt_iois, bins=bins)
    pred_hist, _ = np.histogram(pred_iois, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    gt_prob /= np.sum(gt_prob)
    pred_prob /= np.sum(pred_prob)
    
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    w_distance = wasserstein_distance(gt_iois, pred_iois)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- IOI Distribution Metrics for {method_name} (Range: {ioi_range}) ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Wasserstein Distance: {w_distance:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")
    
    # --- 3. 绘图 ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 6))
    sns.histplot(gt_iois, color="skyblue", label="Ground Truth (Human)", kde=False, bins=bins, alpha=0.7)
    sns.histplot(pred_iois, color="red", label=f"Prediction ({method_name})", kde=False, bins=bins, alpha=0.5)
    plt.title(f"Inter-Onset Interval (IOI) Distribution: {method_name}", fontsize=16)
    plt.xlabel("IOI (ticks)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.legend()
    plt.xlim(ioi_range)
    
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Wasserstein Distance: {w_distance:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.65, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"ioi_dist_{method_name}.png")
    plt.close()

    return js_divergence, w_distance, intersection

def plot_pedal_pattern_distribution(
    evaluate_list: list[dict], 
    method: str, 
    pedal_token_base: int = 5261, 
    pedal_binarize_threshold: int = 64
):
    """
    根据PPT中的思路，加载MIDI文件，提取并分析踏板模式，最终绘制分布直方图并计算指标。

    处理逻辑:
    1. 遍历评估列表中的每一对(ground_truth, prediction) MIDI文件。
    2. 使用 miditoolkit 加载 MIDI 对象。
    3. 调用 `midi_to_ids` 将 MIDI 转换为 token 序列。
    4. 将 token 序列按8个一组进行切分（代表一个音符事件）。
    5. 对每个音符事件，提取后4个踏板token。
    6. 将每个踏板token ID转换为0-127的踏板值 (value = token_id - pedal_token_base)。
    7. 以64为阈值，将踏板值二值化 (>=64为1，<64为0)。
    8. 将4个二值化踏板值合并为一个0-15的整数，代表一种踏板模式。
    9. 收集所有 ground_truth 和 prediction 的踏板模式值。
    10. 计算分布的JS散度、Wasserstein距离和直方图交叉度。
    11. 绘制两种分布的对比条形图。

    Args:
        evaluate_list (list[dict]): 评估数据列表，每个元素是 {"gt": "path/to/gt.mid", "pred": "path/to/pred.mid"}。
        method (str): 当前评估的方法名称，用于图表标题和文件名 (例如 "AI-Pianist")。
        pedal_token_base (int): 踏板token的起始ID。根据您的描述，这是5261。
        pedal_binarize_threshold (int): 踏板值二值化的阈值。根据您的描述，这是64。
        
    Returns:
        tuple: (js_divergence, w_distance, intersection)，三个评估指标。
    """
    gt_patterns = []
    pred_patterns = []
    method_name = Path(method).stem # 获取方法名，如 'AI-Pianist'
    config = PianoT5GemmaConfig()
    print(f"[{method_name}] Analyzing pedal patterns...")

    for item in tqdm(evaluate_list, desc=f"Analyzing Pedals for {method_name}"):
        try:
            # --- 处理 Ground Truth ---
            gt_midi = miditoolkit.MidiFile(item["gt"])
            gt_tokens = midi_to_ids(config, gt_midi)
            # --- 处理 Prediction ---
            pred_midi = miditoolkit.MidiFile(item["pred"])
            pred_tokens = midi_to_ids(config, pred_midi)

            # 提取踏板模式的辅助函数
            def extract_patterns(tokens):
                patterns = []
                for i in range(0, len(tokens), 8):
                    note_chunk = tokens[i : i + 8]
                    if len(note_chunk) < 8:
                        continue  # 跳过末尾不完整的chunk

                    pedal_tokens = note_chunk[4:]
                    binary_values = []
                    for token in pedal_tokens:
                        # 检查token是否在踏板范围内
                        if pedal_token_base <= token < pedal_token_base + 128:
                            pedal_value = token - pedal_token_base
                            binary_value = 1 if pedal_value >= pedal_binarize_threshold else 0
                            binary_values.append(binary_value)
                    
                    if len(binary_values) == 4:
                        # 将4位二进制数转换为一个0-15的整数
                        # 例: [1, 0, 1, 1] -> 1*8 + 0*4 + 1*2 + 1*1 = 11
                        pattern_decimal = (
                            binary_values[0] * 8 +
                            binary_values[1] * 4 +
                            binary_values[2] * 2 +
                            binary_values[3] * 1
                        )
                        patterns.append(pattern_decimal)
                return patterns

            gt_patterns.extend(extract_patterns(gt_tokens))
            pred_patterns.extend(extract_patterns(pred_tokens))

        except Exception as e:
            print(f"Warning: Could not process file pair. GT: {item['gt']}. Error: {e}")
            continue

    if not gt_patterns or not pred_patterns:
        print(f"Error: No valid pedal data found for {method_name}. Skipping analysis.")
        return None, None, None

    # --- 1. 计算直方图 (16个类别) ---
    bins = np.arange(17)  # Bins are [0, 1), [1, 2), ..., [15, 16)
    gt_hist, _ = np.histogram(gt_patterns, bins=bins)
    pred_hist, _ = np.histogram(pred_patterns, bins=bins)

    # --- 2. 计算分布相似性指标 ---
    epsilon = 1e-10
    # 归一化为概率分布
    gt_prob = (gt_hist / np.sum(gt_hist)) + epsilon
    pred_prob = (pred_hist / np.sum(pred_hist)) + epsilon
    
    js_divergence = jensenshannon(gt_prob, pred_prob, base=2)
    w_distance = wasserstein_distance(gt_patterns, pred_patterns)
    intersection = np.sum(np.minimum(gt_prob, pred_prob))

    print(f"--- Pedal Pattern Distribution Metrics for {method_name} ---")
    print(f"Jensen-Shannon Divergence: {js_divergence:.4f} (↓ lower is better)")
    print(f"Wasserstein Distance: {w_distance:.4f} (↓ lower is better)")
    print(f"Histogram Intersection: {intersection:.4f} (↑ higher is better)")

    # --- 3. 绘图 (使用条形图更适合离散类别) ---
    sns.set_style("whitegrid")
    plt.figure(figsize=(14, 7))
    
    x = np.arange(16)  # 16个类别
    width = 0.35  # 条形宽度

    # 绘制条形图
    plt.bar(x - width/2, gt_hist, width, label="Ground Truth (Human)", color="skyblue", alpha=0.8)
    plt.bar(x + width/2, pred_hist, width, label=f"Prediction ({method_name})", color="red", alpha=0.6)

    plt.title(f"Pedal Pattern Distribution Comparison: {method_name}", fontsize=16)
    plt.xlabel("Pedal Pattern ID (0-15)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.xticks(x, [f'{i}\n({i:04b})' for i in x]) # 显示ID和对应的4位二进制
    plt.legend()
    
    metrics_text = (f"JS Divergence: {js_divergence:.4f}\n"
                    f"Wasserstein Distance: {w_distance:.4f}\n"
                    f"Intersection: {intersection:.4f}")
    plt.text(0.75, 0.95, metrics_text, transform=plt.gca().transAxes, fontsize=10,
             verticalalignment='top', bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))

    output_dir = Path("results/imgs/")
    output_dir.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_dir / f"pedal_pattern_dist_{method_name}.png")
    plt.close()
    
    return js_divergence, w_distance, intersection



# --- 主程序入口 ---
if __name__ == "__main__":
    # --- 配置 ---
    # 假设你的标准化文件存放在这里
    BASE_PATH = "data/midis/testset-norm" 
    #MODEL_NAME = "virtuosoNet-han" #name HAN with color #d390ff
    #MODEL_NAME = "score" #name Score with color #a6a6a6
    #MODEL_NAME = "human" #name Human with color #1f77b4
    #MODEL_NAME = "virtuosoNet" #name Human with color #1b9e77
    MODEL_NAME = "ai-pianist-165M" #name Ours with #ff6f00
    
    gt_path = os.path.join(BASE_PATH, "human")
    pred_path = os.path.join(BASE_PATH, MODEL_NAME)
    
    # 为每个模型生成独立的缓存文件
    ALIGNMENT_CACHE_PATH = f"temp/alignment_cache_{MODEL_NAME}.json"

    # --- 流程 ---
    
    # 1. 获取文件列表
    # 使用所有文件进行评估，而不是切片
    evaluate_list = get_evaluate_list(gt_path, pred_path, not_in=["12", "22"])
    #new_evaluate_list = []
    #for i in evaluate_list:
    #    if "11.mid" in i["pred"]:
    #        new_evaluate_list.append(i)
    #evaluate_list = new_evaluate_list
    
    # 2. 对齐 (如果缓存不存在则运行，否则加载)
    # 使用新的 Partitura-based 对齐函数
    #aligned_list = align_mid_pt(evaluate_list, cache_path=ALIGNMENT_CACHE_PATH, overwrite_cache=False)
    #new_aligned_list = []
    #aligned_list = None
    #aligned_list = align_mid_pt(evaluate_list, cache_path=None, overwrite_cache=False)
    #for i in aligned_list:
    #    for j in evaluate_list:
    #        if i["gt"] == j["gt"]:
    #            new_aligned_list.append(i)
    #            break
    #aligned_list = new_aligned_list
    aligned_list = None
    if aligned_list:
        print("Alignment failed or produced no results. Exiting.")
    else:
        # 3. 计算指标
        # 使用新的 Partitura-based 计算函数
        #vel_l1_mae = compute_vel_l1_pt(aligned_list)
        #print(f"\nVelocity L1 MAE: {vel_l1_mae:.4f}")

        #correlation_results = compute_vel_correlation_pt(aligned_list)
        #print("\n--- Velocity Correlation Analysis ---")
        #print(json.dumps(correlation_results, indent=4))

        #dur_l1_mae = compute_dur_l1_pt(aligned_list)
        #print(f"\Duration L1 MAE: {dur_l1_mae:.4f}")

        #dur_correlation_results = compute_dur_correlation_pt(aligned_list)
        #print("\n--- Duration Correlation Analysis ---")
        #print(json.dumps(dur_correlation_results, indent=4))

        #ioi_results = compute_ioi_metrics_pt(aligned_list)
        #print(json.dumps(ioi_results, indent=4))
        #velocity_contour_similarity = compute_velocity_contour_similarity(aligned_list)        
        #print(velocity_contour_similarity)
        # 4. 绘制分布图
        # 注意：绘图函数使用的是原始文件列表 `evaluate_list`
        print("\nGenerating distribution plots...")
        js = []
        ins = []
        j1, _, i1 = plot_velocity_distribution_and_calculate_metrics_pt(evaluate_list, pred_path)
        j2, _, i2 = plot_duration_distribution_and_metrics_pt(evaluate_list, pred_path)
        j3, _, i3 = plot_ioi_distribution_and_metrics_pt(evaluate_list, pred_path)
        j4, _, i4 = plot_pedal_pattern_distribution(evaluate_list, pred_path)
        js.append(1 - j1)
        js.append(1 - j2)
        js.append(1 - j3)
        js.append(1 - j4)
        ins.append(i1)
        ins.append(i2)
        ins.append(i3)
        ins.append(i4)
        print(js)
        print((sum(js) / len(js) + sum(ins) / len(ins)) / 2)
        print("Plots saved to results/imgs/")
