#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import json
import glob
import pandas as pd
import numpy as np
from collections import defaultdict
import sys

from evaluation_metrics import evaluate_all_metrics, calculate_human_peer_agreement

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
ABLATION_STUDY_DIR = os.path.dirname(SCRIPT_DIR)
COMBINED_RESULTS_DIR = os.path.join(ABLATION_STUDY_DIR, "combined_results")
HUMAN_DATA_FILE = os.path.join(ABLATION_STUDY_DIR, "data", "human_data.jsonl")

METRIC_DISPLAY_ORDER = [
    "Pairwise Agreement", 
    "Model Mean Correlation", 
    "Average Pearson Coefficient", 
    "Average Spearman Coefficient", 
    "Overall Score"
]

def load_jsonl(file_path):
    data = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    try:
                        data.append(json.loads(line))
                    except json.JSONDecodeError:
                        print(f"JSONDecodeError: Failed to parse line in {file_path}: {line[:70]}...")
    except Exception as e:
        print(f"Error: Failed to read file {file_path}: {str(e)}")
    return data

def convert_model_data_list_to_dict_by_prompt(model_data_list):
    """Converts a list of model data items to a dict keyed by prompt_id."""
    result_dict = {}
    for item in model_data_list:
        prompt_id = item.get("prompt_id") or str(item.get("prompt")) # Use prompt_id or prompt string
        if prompt_id:
            result_dict[prompt_id] = item
    return result_dict

def calculate_human_agreement_metrics(human_data_list):
    """Calculates agreement metrics specifically for human data."""
    # calculate_human_peer_agreement returns a single agreement rate percentage
    peer_agreement_rate = calculate_human_peer_agreement(human_data_list)
    return {
        "Pairwise Agreement": peer_agreement_rate,
        "Model Mean Correlation": np.nan, # Not applicable for human vs human
        "Average Pearson Coefficient": np.nan,
        "Average Spearman Coefficient": np.nan,
        "Overall Score": peer_agreement_rate # Using peer agreement as the representative score
    }

def get_setting_display_name(file_basename):
    """Generates a display name from the experiment result filename."""
    name_map = {
        "ablation_exp_v2.jsonl": "基准线 (Baseline)",
        "no_weights_ablation_exp_v2.jsonl": "无权重 (No Weights)",
        "no_criteria_weights_ablation_exp_v2.jsonl": "无标准权重 (No Criteria Weights)",
        "no_dim_weights_ablation_exp_v2.jsonl": "无维度权重 (No Dim Weights)",
        # Names from process_combined_results.py (directory names become filenames)
        "Baseline.jsonl": "基准线 (Baseline - run_ablation)",
        "No_Weights.jsonl": "无权重 (No Weights - run_ablation)",
        "Pointwise.jsonl": "点评分 (Pointwise - run_ablation)",
        "Static_Criteria_Merged.jsonl": "静态标准-合并 (Static Merged - run_ablation)",
        "Vanilla_Prompt.jsonl": "简单提示 (Vanilla Prompt - run_ablation)",
        "Pointwise_Static_Fallback.jsonl": "点评分-静态回退 (Pointwise Static Fallback - run_ablation)",
        "Pointwise_No_Weights.jsonl": "点评分-无权重 (Pointwise No Weights - run_ablation)",
        "Static_Criteria_Merged_No_Weights.jsonl": "静态标准-合并-无权重 (Static Merged No Weights - run_ablation)",
        "Pointwise_Static_Fallback_No_Weights.jsonl": "点评分-静态回退-无权重 (Pointwise Static No Weights - run_ablation)"
        # Add other specific mappings if needed for files from process_combined_results.py
    }
    return name_map.get(file_basename, file_basename.replace(".jsonl", ""))

def generate_metrics_table():
    print(f"Loading human data from: {HUMAN_DATA_FILE}")
    human_data_list = load_jsonl(HUMAN_DATA_FILE)
    if not human_data_list:
        print(f"Critical Error: Human data not found or empty at {HUMAN_DATA_FILE}. Cannot generate metrics table.")
        return
    print(f"Successfully loaded {len(human_data_list)} human data entries.")

    # Files to process are the outputs of process_ablation_baseline.py and process_combined_results.py
    # These are expected to be directly in COMBINED_RESULTS_DIR
    experiment_result_files = glob.glob(os.path.join(COMBINED_RESULTS_DIR, "*.jsonl"))
    
    # Filter out human_data.jsonl itself and any other non-experiment files if necessary
    experiment_result_files = [f for f in experiment_result_files if os.path.basename(f) != "human_data.jsonl" and not os.path.basename(f).startswith("process_errors")]

    if not experiment_result_files:
        print(f"No experiment result files (.jsonl) found in {COMBINED_RESULTS_DIR} to generate metrics from.")
        return
    
    print(f"Found {len(experiment_result_files)} experiment result files to process:")
    for f_path in experiment_result_files: print(f"  - {os.path.basename(f_path)}")

    all_metrics_results = {}

    # Calculate metrics for human evaluators (self-agreement)
    print("\nCalculating metrics for Human Evaluators...")
    human_agreement_calc = calculate_human_agreement_metrics(human_data_list)
    all_metrics_results["Human Evaluators"] = human_agreement_calc
    print(f"  Pairwise Agreement: {human_agreement_calc.get('Pairwise Agreement', float('nan')):.2f}%")

    # Calculate metrics for each experiment setting
    for res_file_path in experiment_result_files:
        setting_basename = os.path.basename(res_file_path)
        setting_display_name = get_setting_display_name(setting_basename)
        print(f"\nCalculating metrics for setting: {setting_display_name} (from file: {setting_basename})...")
        
        setting_data_list = load_jsonl(res_file_path)
        if not setting_data_list:
            print(f"  Warning: No data loaded from {res_file_path}. Skipping this setting.")
            continue
        print(f"  Loaded {len(setting_data_list)} entries for {setting_display_name}.")
        
        # evaluation_metrics expects model_data as a dict {prompt_id: data}
        setting_data_dict = convert_model_data_list_to_dict_by_prompt(setting_data_list)
        if not setting_data_dict:
            print(f"  Warning: Data for {setting_display_name} could not be converted to prompt-keyed dict. Skipping.")
            continue
            
        current_setting_metrics = evaluate_all_metrics(human_data_list, setting_data_dict)
        all_metrics_results[setting_display_name] = current_setting_metrics
        for metric_name in METRIC_DISPLAY_ORDER:
            print(f"  {metric_name}: {current_setting_metrics.get(metric_name, float('nan')):.2f}%")

    # Create DataFrame from the collected metrics
    metrics_df = pd.DataFrame(all_metrics_results).T # Transpose to have settings as rows
    
    # Ensure columns are in the desired order
    metrics_df = metrics_df[METRIC_DISPLAY_ORDER]
    metrics_df.index.name = "Evaluation Setting"

    # Generate Markdown table
    markdown_table = metrics_df.to_markdown(floatfmt=".2f") # Format floats to 2 decimal places

    table_title = "# Ablation Study Metrics Comparison\n"
    metrics_explanation = "\n## Metric Explanations:\n"
    metrics_explanation += "*   **Pairwise Agreement**: Degree of agreement between the model's (or within human evaluators') preference and the average preference of human evaluators. Higher is better.\n"
    metrics_explanation += "*   **Model Mean Correlation**: Pearson correlation between the average scores of each model and the corresponding human average scores. Higher is better.\n"
    metrics_explanation += "*   **Average Pearson Coefficient**: Pearson correlation calculated for each prompt between model scores and human average scores, then averaged. Higher is better (based on ICC filtered prompts).\n"
    metrics_explanation += "*   **Average Spearman Coefficient**: Spearman rank correlation calculated for each prompt between model scores and human average scores, then averaged. Higher is better (based on ICC filtered prompts).\n"
    metrics_explanation += "*   **Overall Score**: Arithmetic mean of the four core metrics above. Higher is better.\n"

    final_markdown_output = table_title + markdown_table + metrics_explanation
    
    output_md_file_path = os.path.join(COMBINED_RESULTS_DIR, "ablation_metrics_comparison_table.md")
    try:
        with open(output_md_file_path, 'w', encoding='utf-8') as f:
            f.write(final_markdown_output)
        print(f"\nSuccessfully saved metrics comparison table to: {output_md_file_path}")
    except IOError as e:
        print(f"Error saving metrics table: {e}")

    print("\nFinal Metrics Comparison Table:\n")
    print(markdown_table)

if __name__ == "__main__":
    generate_metrics_table() 