import os
from pathlib import Path
from collections import defaultdict
import numpy as np
import pandas as pd
from tqdm import tqdm
import partitura as pt
import parangonar as pa
import matplotlib.pyplot as plt
import seaborn as sns
import json
from scipy.stats import pearsonr, spearmanr
from scipy.stats import wasserstein_distance
from scipy.spatial.distance import jensenshannon
import miditoolkit
from src.utils.midi import midi_to_ids
from src.model.pianoformer import PianoT5GemmaConfig

# --- All functions from your original script can remain here, unchanged. ---
# I will copy them for completeness, but the core changes are in the plotting
# functions and the __main__ block.

# --- 文件路径管理函数 (无需改动) ---
def get_evaluate_list(gt_path, pred_path, not_in = []):
    """
    生成一个包含 Ground Truth 和 Prediction 文件路径对的列表。
    """
    out = []
    gt_path = Path(gt_path)
    pred_path = Path(pred_path)
    gt_files = sorted(list(gt_path.glob("*.mid")))
    for gt_file in gt_files:
        prefix = str(gt_file).split("-")[-2].split("/")[-1]
        number = str(gt_file).split("-")[-1].split(".")[0]
        if prefix in not_in:
            continue
        # For 'human' and 'human-test', the prediction is just another human performance
        if Path(pred_path).name in ["human", "human-test"]:
            pred_file = gt_file
        else:
            pred_file = pred_path / f"{prefix}.mid"

        if pred_file.exists():
            out.append({
                "gt": str(gt_file),
                "pred": str(pred_file)
            })
    return out

# --- 核心重构部分：对齐与缓存 ---
def align_mid_pt(evaluate_list, cache_path=None, overwrite_cache=False):
    if cache_path and os.path.exists(cache_path) and not overwrite_cache:
        print(f"Loading alignment from cache: {cache_path}")
        with open(cache_path, "r") as f:
            return json.load(f)

    print("Cache not found or overwrite forced. Running alignment...")
    aligned_results = []
    matcher = pa.TheGlueNoteMatcher()
    for item in tqdm(evaluate_list, desc="Aligning MIDI pairs"):
        try:
            gt_path = item["gt"]
            pred_path = item["pred"]
            gt_perf = pt.load_performance_midi(gt_path)
            pred_perf = pt.load_performance_midi(pred_path)
            gt_note_array = gt_perf.note_array()
            pred_note_array = pred_perf.note_array()
            alignment = matcher(gt_note_array, pred_note_array)
            note_id_pairs = []
            for align_item in alignment:
                if align_item["label"] == "match":
                    note_id_pairs.append(
                        (align_item["score_id"], align_item["performance_id"])
                    )
                elif align_item["label"] == "deletion":
                    note_id_pairs.append(
                        (align_item["score_id"], None)
                    )
            aligned_results.append({
                "gt": gt_path,
                "pred": pred_path,
                "note_id_pairs": note_id_pairs
            })
        except Exception as e:
            print(f"Error during alignment: {e}")

    if cache_path:
        Path(cache_path).parent.mkdir(parents=True, exist_ok=True)
        print(f"Saving alignment cache to: {cache_path}")
        with open(cache_path, "w") as f:
            json.dump(aligned_results, f, indent=4)
    return aligned_results

# --- 指标计算函数 (基于Partitura重构) - ALL UNCHANGED ---
def compute_vel_l1_pt(aligned_list):
    total_abs_error, match_count = 0, 0
    for item in tqdm(aligned_list, desc="Calculating L1 Loss", leave=False):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            gt_note_array, pred_note_array = gt_perf.note_array(), pred_perf.note_array()
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None and gt_id in gt_vel_map and pred_id in pred_vel_map:
                    total_abs_error += abs(gt_vel_map[gt_id] - pred_vel_map[pred_id])
                    match_count += 1
        except Exception as e:
            print(f"Error in L1 loss for GT: {item['gt']}. Error: {e}")
    return total_abs_error / match_count if match_count > 0 else float('nan')

def compute_vel_correlation_pt(aligned_list):
    gt_velocities, pred_velocities = [], []
    for item in tqdm(aligned_list, desc="Calculating Correlation", leave=False):
        try:
            gt_perf = pt.load_performance_midi(item["gt"])
            pred_perf = pt.load_performance_midi(item["pred"])
            gt_note_array, pred_note_array = gt_perf.note_array(), pred_perf.note_array()
            gt_vel_map = {note['id']: note['velocity'] for note in gt_note_array}
            pred_vel_map = {note['id']: note['velocity'] for note in pred_note_array}
            for gt_id, pred_id in item["note_id_pairs"]:
                if gt_id is not None and pred_id is not None and gt_id in gt_vel_map and pred_id in pred_vel_map:
                    gt_velocities.append(gt_vel_map[gt_id])
                    pred_velocities.append(pred_vel_map[pred_id])
        except Exception as e:
            print(f"Error in correlation for GT: {item['gt']}. Error: {e}")
    if len(gt_velocities) < 2: return {}
    pearson_corr, p_pearson = pearsonr(gt_velocities, pred_velocities)
    spearman_corr, p_spearman = spearmanr(gt_velocities, pred_velocities)
    return {
        "pearson": pearson_corr,
        "spearman": spearman_corr,
    }
# (Other metric functions like duration, IOI, etc., would go here, unchanged)


# --- NEW: Data Extraction Helper Function ---
def extract_all_distribution_data(evaluate_list, config):
    """
    Extracts all necessary data for distribution plots from a list of MIDI files.
    This avoids re-reading files for each plot.
    """
    data = defaultdict(list)
    print(f"Extracting distribution data...")
    for item in tqdm(evaluate_list, desc="Extracting Data", leave=False):
        try:
            # Partitura for vel, dur, ioi
            pred_perf = pt.load_performance_midi(item["pred"])
            pred_note_array = pred_perf.note_array()
            if pred_note_array.size > 0:
                data["velocity"].extend(pred_note_array['velocity'])
                data["duration_tick"].extend(pred_note_array['duration_tick'])
                if pred_note_array.size > 1:
                    data["ioi_tick"].extend(np.diff(pred_note_array['onset_tick']))

            # Miditoolkit for pedal patterns
            pred_midi = miditoolkit.MidiFile(item["pred"])
            pred_tokens = midi_to_ids(config, pred_midi)
            
            # Pedal pattern extraction logic
            for i in range(0, len(pred_tokens), 8):
                note_chunk = pred_tokens[i : i + 8]
                if len(note_chunk) < 8: continue
                pedal_tokens = note_chunk[4:]
                binary_values = [1 if (t - config.pedal_token_base) >= 64 else 0 for t in pedal_tokens if config.pedal_token_base <= t < config.pedal_token_base + 128]
                if len(binary_values) == 4:
                    pattern_decimal = sum(val * (2**(3-i)) for i, val in enumerate(binary_values))
                    data["pedal_pattern"].append(pattern_decimal)

        except Exception as e:
            print(f"Warning: Could not process file {item['pred']} for data extraction. Error: {e}")
    
    # Filter ranges
    data['duration_tick'] = [d for d in data['duration_tick'] if 0 <= d < 500]
    data['ioi_tick'] = [i for i in data['ioi_tick'] if 0 <= i < 200]
    
    return data


# --- REFACTORED AND NEW PLOTTING FUNCTIONS ---
def plot_combined_distribution(
    predictions_data: dict,
    ground_truth_data: list,
    models_config: list,
    attribute_name: str,
    xlabel: str,
    output_path: Path,
    xlim: tuple = None,
):
    """Generic plotting function for combined KDE plots."""
    print(f"Generating combined plot for {attribute_name}...")
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(10, 6))

    # Plot Ground Truth
    sns.kdeplot(ground_truth_data, color='black', linestyle='--', label='Ground Truth', linewidth=2, fill=True, alpha=0.1)

    # Plot each model's prediction
    for model in models_config:
        model_id = model['id']
        if model_id in predictions_data:
            sns.kdeplot(
                predictions_data[model_id],
                label=model['name'],
                color=model['color'],
                linewidth=2.5,
                alpha=0.7
            )

    plt.title(f'{attribute_name} Distribution Comparison', fontsize=16)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel('Density', fontsize=12)
    if xlim:
        plt.xlim(xlim)
    plt.legend(title='Method', fontsize=10)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()
    print(f"Saved plot to {output_path}")

def plot_combined_pedal_distribution(
    predictions_data: dict,
    ground_truth_data: list,
    models_config: list,
    output_path: Path,
):
    """Specific plotting function for combined pedal pattern bar plots."""
    print("Generating combined plot for Pedal Patterns...")
    
    # Prepare data in long-form for seaborn
    plot_df_list = []
    # Ground Truth
    for pattern in ground_truth_data:
        plot_df_list.append({'Pattern ID': pattern, 'Method': 'Ground Truth'})
    # Predictions
    for model in models_config:
        model_id = model['id']
        if model_id in predictions_data:
            for pattern in predictions_data[model_id]:
                plot_df_list.append({'Pattern ID': pattern, 'Method': model['name']})
    
    plot_df = pd.DataFrame(plot_df_list)
    
    # Define order and colors
    method_order = ['Ground Truth'] + [m['name'] for m in models_config]
    color_palette = {'Ground Truth': 'black'}
    for m in models_config:
        color_palette[m['name']] = m['color']

    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(14, 7))

    sns.histplot(
        data=plot_df,
        x='Pattern ID',
        hue='Method',
        hue_order=method_order,
        multiple='dodge',
        stat='probability', # Use probability for fair comparison
        common_norm=False, # Normalize each histogram independently
        shrink=0.8,
        palette=color_palette,
        bins=np.arange(17) - 0.5 # Center bins on integers
    )

    plt.title('Pedal Pattern Distribution Comparison', fontsize=16)
    plt.xlabel('Pedal Pattern ID', fontsize=12)
    plt.ylabel('Proportion', fontsize=12)
    plt.xticks(range(16), [f'{i}\n({i:04b})' for i in range(16)])
    plt.legend(title='Method', fontsize=10)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()
    print(f"Saved plot to {output_path}")


# --- 主程序入口 ---
if __name__ == "__main__":
    # --- 1. Configuration ---
    BASE_PATH = Path("data/midis/testset-norm")
    RESULTS_DIR = Path("results")
    CACHE_DIR = Path("temp")
    IMG_DIR = RESULTS_DIR / "imgs"
    
    # Create output directories
    RESULTS_DIR.mkdir(exist_ok=True)
    CACHE_DIR.mkdir(exist_ok=True)
    IMG_DIR.mkdir(exist_ok=True)

    # Define all models, their legend names, and colors
    MODELS_TO_EVALUATE = [
        {'id': 'human-test', 'name': 'Human', 'color': '#1f77b4'},
        {'id': 'score', 'name': 'Score', 'color': '#a6a6a6'},
        {'id': 'virtuosoNet', 'name': 'VirtuosoNet', 'color': '#1b9e77'},
        {'id': 'virtuosoNet-han', 'name': 'HAN', 'color': '#d390ff'},
        {'id': 'ai-pianist-165M', 'name': 'Ours', 'color': '#ff6f00'},
    ]
    
    gt_path = BASE_PATH / "human"
    config = PianoT5GemmaConfig()

    # --- 2. Data Extraction for Plotting ---
    all_distribution_data = defaultdict(dict)

    # First, extract data for the Ground Truth
    print("--- Extracting Ground Truth Data ---")
    # To get a list of all GT files, we can just glob the path
    gt_files_list = [{"pred": str(p)} for p in sorted(gt_path.glob("*.mid"))]
    ground_truth_distributions = extract_all_distribution_data(gt_files_list, config)

    # Then, extract data for each model to be evaluated
    print("\n--- Extracting Model Prediction Data ---")
    for model in tqdm(MODELS_TO_EVALUATE, desc="Processing Models"):
        model_id = model['id']
        pred_path = BASE_PATH / model_id
        
        # 'human-test' needs special handling as it refers to the same folder as GT
        if model_id == 'human-test':
            evaluate_list = [{"pred": str(p)} for p in sorted(gt_path.glob("*.mid"))]
        else:
            # Need a dummy gt_path to get a list of prediction files
            evaluate_list = get_evaluate_list(gt_path, pred_path, not_in=["12", "22"])

        # We only need the prediction paths for distribution extraction
        pred_files_list = [{"pred": item["pred"]} for item in evaluate_list]
        model_distributions = extract_all_distribution_data(pred_files_list, config)

        for key, value in model_distributions.items():
            all_distribution_data[key][model_id] = value

    # --- 3. Plotting ---
    print("\n--- Generating Combined Plots ---")
    
    # Velocity Plot
    plot_combined_distribution(
        predictions_data=all_distribution_data['velocity'],
        ground_truth_data=ground_truth_distributions['velocity'],
        models_config=MODELS_TO_EVALUATE,
        attribute_name='Velocity',
        xlabel='MIDI Velocity',
        output_path=IMG_DIR / 'combined_velocity_distribution.pdf',
        xlim=(0, 128)
    )

    # Duration Plot
    plot_combined_distribution(
        predictions_data=all_distribution_data['duration_tick'],
        ground_truth_data=ground_truth_distributions['duration_tick'],
        models_config=MODELS_TO_EVALUATE,
        attribute_name='Note Duration',
        xlabel='Duration (ticks)',
        output_path=IMG_DIR / 'combined_duration_distribution.pdf',
        xlim=(0, 500)
    )

    # IOI Plot
    plot_combined_distribution(
        predictions_data=all_distribution_data['ioi_tick'],
        ground_truth_data=ground_truth_distributions['ioi_tick'],
        models_config=MODELS_TO_EVALUATE,
        attribute_name='Inter-Onset Interval (IOI)',
        xlabel='IOI (ticks)',
        output_path=IMG_DIR / 'combined_ioi_distribution.pdf',
        xlim=(0, 200)
    )
    
    # Pedal Pattern Plot
    plot_combined_pedal_distribution(
        predictions_data=all_distribution_data['pedal_pattern'],
        ground_truth_data=ground_truth_distributions['pedal_pattern'],
        models_config=MODELS_TO_EVALUATE,
        output_path=IMG_DIR / 'combined_pedal_pattern_distribution.pdf'
    )

    # --- 4. Metric Calculation ---
    print("\n--- Calculating Performance Metrics ---")
    all_metrics = {}
    
    for model in MODELS_TO_EVALUATE:
        model_id = model['id']
        model_name = model['name']
        print(f"\n--- Evaluating: {model_name} ---")
        
        pred_path = BASE_PATH / model_id
        evaluate_list = get_evaluate_list(gt_path, pred_path, not_in=["12", "22"])
        
        if not evaluate_list:
            print(f"No files found for {model_name}. Skipping.")
            continue
            
        # Align files
        #cache_path = CACHE_DIR / f"alignment_cache_{model_id}.json"
        #aligned_list = align_mid_pt(evaluate_list, cache_path=cache_path, overwrite_cache=False)
        
        #if not aligned_list:
        #    print(f"Alignment failed for {model_name}. Skipping metrics.")
        #    continue
            
        # Compute metrics
        model_metrics = {}
        #model_metrics['Vel. L1 MAE'] = compute_vel_l1_pt(aligned_list)
        #vel_corr = compute_vel_correlation_pt(aligned_list)
        #model_metrics['Vel. Pearson'] = vel_corr.get('pearson')
        # ... Add other metric calculations here if needed ...
        
        all_metrics[model_name] = model_metrics

    # --- 5. Display Final Results ---
    print("\n\n--- FINAL METRIC RESULTS ---")
    if all_metrics:
        results_df = pd.DataFrame.from_dict(all_metrics, orient='index')
        print(results_df.to_string(float_format="%.4f"))
    else:
        print("No metrics were calculated.")