#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Encapsulates calculation methods for four evaluation metrics:
1. Pairwise Agreement
2. Pearson Correlation Coefficient
3. Spearman Correlation Coefficient
4. Human Evaluator Agreement (ICC)
"""

import numpy as np
from scipy.stats import pearsonr, spearmanr
import pandas as pd
import logging
from collections import defaultdict
import pingouin as pg # For ICC calculation

def get_model_pairs():
    """Get model pairs in a fixed order."""
    # Specified 6 fixed order pairs
    pairs = [
        ('openai', 'perplexity'),
        ('openai', 'grok'),
        ('openai', 'gemini'),
        ('gemini', 'perplexity'),
        ('gemini', 'grok'),
        ('perplexity', 'grok')
    ]
    return pairs

def get_preference(scores, model_a, model_b):
    """Determine preference based on scores.
    Returns 1 for preference for model_a, 0 for model_b.
    If scores are equal or model_a's score is less, model_b is preferred (returns 0).
    Missing scores default to 0.
    """
    if scores.get(model_a, 0) > scores.get(model_b, 0):
        return 1
    else: # Covers model_b_score >= model_a_score (including ties)
        return 0

def calculate_human_peer_agreement(human_data):
    """
    Calculate preference agreement among human evaluators.

    Args:
    - human_data: List of human rating data, each item should contain prompt_id and expert ratings.

    Returns:
    - Percentage of agreement among human evaluators.
    """
    model_pairs = get_model_pairs()
    agreement_count = 0
    total_comparisons = 0

    for entry in human_data:
        experts_overall_scores = []
        for i in range(1, 4):  # Max 3 experts
            expert_key = f'expert_{i}'
            # Check if expert data and overall_scores exist and are not empty
            if expert_key in entry and isinstance(entry[expert_key], dict) and entry[expert_key].get('overall_scores'):
                experts_overall_scores.append(entry[expert_key]['overall_scores'])
        
        if len(experts_overall_scores) < 2: # Need at least 2 experts for comparison
            continue

        # Compare each pair of experts
        for i in range(len(experts_overall_scores)):
            for j in range(i + 1, len(experts_overall_scores)):
                expert1_scores = experts_overall_scores[i]
                expert2_scores = experts_overall_scores[j]
                
                for model_a, model_b in model_pairs:
                    # Ensure both models are rated by both experts
                    if model_a in expert1_scores and model_b in expert1_scores and \
                       model_a in expert2_scores and model_b in expert2_scores:
                        
                        pref1 = get_preference(expert1_scores, model_a, model_b)
                        pref2 = get_preference(expert2_scores, model_a, model_b)
                        
                        if pref1 == pref2:
                            agreement_count += 1
                        total_comparisons += 1
                            
    if total_comparisons > 0:
        return (agreement_count / total_comparisons) * 100
    return 0.0

def calculate_icc(expert_scores_for_models):
    """
    Calculate ICC for human expert ratings.

    Args:
    - expert_scores_for_models: List of lists, where each inner list contains scores from one expert for multiple models.
                                Example: [[model1_exp1, model2_exp1], [model1_exp2, model2_exp2]]
    Returns:
    - float: ICC(2,1) value, or np.nan if calculation fails.
    """
    try:
        if len(expert_scores_for_models) < 2:  # Min 2 raters
            return np.nan
        
        num_raters = len(expert_scores_for_models)
        num_targets = len(expert_scores_for_models[0]) # Assuming all raters rated same number of targets

        if num_targets < 2: # Min 2 targets
            return np.nan

        # Prepare data for pingouin: long format
        df_data = []
        for rater_idx, scores_by_rater in enumerate(expert_scores_for_models):
            if len(scores_by_rater) != num_targets:
                logging.warning("Inconsistent number of targets rated by experts for ICC.")
                return np.nan # Inconsistent data
            for target_idx, score in enumerate(scores_by_rater):
                df_data.append({
                    'target_id': f'target_{target_idx}',
                    'rater_id': f'rater_{rater_idx}',
                    'rating': score
                })
        
        if not df_data:
            return np.nan
            
        df = pd.DataFrame(df_data)
        
        # Calculate ICC(2,1) using Pingouin
        # ICC2: Two-way random effects, absolute agreement, single rater/measurement.
        # Targets are items being rated, raters are judges.
        icc_results = pg.intraclass_corr(data=df, targets='target_id', raters='rater_id', ratings='rating')
        icc2_value = icc_results.set_index('Type').loc['ICC2', 'ICC']
        
        return icc2_value
    except Exception as e:
        logging.error(f"Error calculating ICC: {e}", exc_info=True)
        return np.nan

def calculate_pair_wise_agreement(human_data, model_data_dict):
    """
    Calculate pairwise preference agreement between a model/judge and human evaluators.

    Args:
    - human_data: List of human rating data.
    - model_data_dict: Dictionary of model rating data, keyed by prompt_id.

    Returns:
    - Pairwise agreement percentage.
    """
    model_pairs = get_model_pairs()
    agreement_count = 0
    total_comparisons = 0

    for human_entry in human_data:
        prompt_id = str(human_entry.get('prompt_id', ''))
        if not prompt_id or prompt_id not in model_data_dict:
            continue
        
        current_model_scores = model_data_dict[prompt_id].get('overall_scores', {})
        if not current_model_scores:
            continue

        experts_overall_scores_list = []
        for i in range(1, 4):
            expert_key = f'expert_{i}'
            # Check if expert data and overall_scores exist and are not empty
            if expert_key in human_entry and isinstance(human_entry[expert_key], dict) and human_entry[expert_key].get('overall_scores'):
                experts_overall_scores_list.append(human_entry[expert_key]['overall_scores'])
        
        if not experts_overall_scores_list:
            continue

        for model_a, model_b in model_pairs:
            if model_a not in current_model_scores or model_b not in current_model_scores:
                continue
            
            model_pref = get_preference(current_model_scores, model_a, model_b)
            # No need to check for model_pref == -1, as original get_preference returns 0 or 1

            for expert_scores in experts_overall_scores_list:
                if model_a in expert_scores and model_b in expert_scores:
                    expert_pref = get_preference(expert_scores, model_a, model_b)
                    # No need to check for expert_pref != -1
                    
                    if model_pref == expert_pref:
                        agreement_count += 1
                    total_comparisons += 1 # Incremented for every valid comparison
                        
    if total_comparisons > 0:
        return (agreement_count / total_comparisons) * 100
    return 0.0

def calculate_model_avg_correlation(model_data_dict, human_data, correlation_func):
    """Calculate correlation between model average scores and human average scores across all models."""
    judge_model_avg_scores = defaultdict(list)
    human_model_avg_scores = defaultdict(list)

    # Collect scores from model_data_dict
    for prompt_id, data_item in model_data_dict.items():
        overall_scores = data_item.get('overall_scores', {})
        for model_name, score in overall_scores.items():
            judge_model_avg_scores[model_name].append(score)

    # Collect scores from human_data
    for human_entry in human_data:
        for i in range(1, 4):
            expert_key = f'expert_{i}'
            if expert_key in human_entry and isinstance(human_entry[expert_key], dict) and human_entry[expert_key].get('overall_scores'):
                expert_scores = human_entry[expert_key]['overall_scores']
                for model_name, score in expert_scores.items():
                    human_model_avg_scores[model_name].append(score)
    
    # Calculate average scores for each model
    final_judge_scores = {model: np.mean(scores) for model, scores in judge_model_avg_scores.items() if scores}
    final_human_scores = {model: np.mean(scores) for model, scores in human_model_avg_scores.items() if scores}

    common_models = list(set(final_judge_scores.keys()) & set(final_human_scores.keys()))
    if len(common_models) < 2:
        return 0.0 # Not enough common models for correlation

    judge_scores_list = [final_judge_scores[model] for model in common_models]
    human_scores_list = [final_human_scores[model] for model in common_models]

    try:
        corr, _ = correlation_func(judge_scores_list, human_scores_list)
        return corr * 100 if not np.isnan(corr) else 0.0
    except Exception as e:
        logging.error(f"Error calculating model average correlation: {e}")
        return 0.0

def calculate_prompt_correlation_avg(model_data_dict, human_data, correlation_func):
    """Calculate the average of per-prompt correlations, filtering by human ICC >= 0."""
    prompt_correlations = []
    prompts_with_icc_too_low = 0
    prompts_processed_for_icc = 0

    for human_entry in human_data:
        prompt_id = str(human_entry.get('prompt_id', ''))
        if not prompt_id or prompt_id not in model_data_dict:
            continue

        prompts_processed_for_icc +=1
        current_model_scores = model_data_dict[prompt_id].get('overall_scores', {})
        if not current_model_scores:
            continue

        # Prepare data for ICC: matrix of expert_scores_for_models for this prompt
        # Rows: experts, Columns: models that all experts rated for this prompt
        expert_ratings_for_prompt = [] # List of dicts {model: score}
        for i in range(1,4):
            expert_key = f'expert_{i}'
            if expert_key in human_entry and isinstance(human_entry[expert_key], dict) and human_entry[expert_key].get('overall_scores'):
                expert_ratings_for_prompt.append(human_entry[expert_key]['overall_scores'])
        
        if len(expert_ratings_for_prompt) < 2: # Need at least 2 experts for ICC
            prompts_with_icc_too_low +=1
            continue

        # Find common models rated by all available experts for this prompt
        if not expert_ratings_for_prompt:
            prompts_with_icc_too_low +=1
            continue
            
        common_rated_models = set(expert_ratings_for_prompt[0].keys())
        for i in range(1, len(expert_ratings_for_prompt)):
            common_rated_models.intersection_update(expert_ratings_for_prompt[i].keys())
        
        common_rated_models = list(common_rated_models)
        if len(common_rated_models) < 2: # Need at least 2 models for ICC and correlation
            prompts_with_icc_too_low +=1
            continue

        icc_expert_data_matrix = []
        for expert_scores_dict in expert_ratings_for_prompt:
            icc_expert_data_matrix.append([expert_scores_dict[model] for model in common_rated_models])
        
        icc_value = calculate_icc(icc_expert_data_matrix)
        if np.isnan(icc_value) or icc_value < 0:
            prompts_with_icc_too_low += 1
            continue
        
        # ICC >= 0, proceed to calculate correlation for this prompt
        human_avg_scores_for_prompt = {model: np.mean([expert_scores[model] for expert_scores in expert_ratings_for_prompt if model in expert_scores]) 
                                       for model in common_rated_models}
        
        judge_scores_for_prompt_list = []
        human_scores_for_prompt_list = []
        
        valid_models_for_corr = list(set(current_model_scores.keys()) & set(human_avg_scores_for_prompt.keys()))
        if len(valid_models_for_corr) < 2:
            continue # Not enough models for correlation

        for model in valid_models_for_corr:
            judge_scores_for_prompt_list.append(current_model_scores[model])
            human_scores_for_prompt_list.append(human_avg_scores_for_prompt[model])
        
        try:
            corr, _ = correlation_func(judge_scores_for_prompt_list, human_scores_for_prompt_list)
            if not np.isnan(corr):
                prompt_correlations.append(corr)
        except Exception as e:
            logging.warning(f"Error calculating correlation for prompt {prompt_id}: {e}")

    if prompts_processed_for_icc > 0:
        logging.info(f"ICC Filter: Skipped {prompts_with_icc_too_low}/{prompts_processed_for_icc} prompts due to low/invalid ICC.")

    if prompt_correlations:
        return np.mean(prompt_correlations) * 100
    return 0.0

def evaluate_all_metrics(human_data, model_data_dict):
    """
    Calculate all evaluation metrics.

    Args:
    - human_data: List of human rating data.
    - model_data_dict: Dictionary of model rating data, keyed by prompt_id.

    Returns:
    - Dictionary containing all metric results.
    """
    metrics = {}
    
    metrics["Pairwise Agreement"] = calculate_pair_wise_agreement(human_data, model_data_dict)
    metrics["Model Mean Correlation"] = calculate_model_avg_correlation(model_data_dict, human_data, pearsonr)
    metrics["Average Pearson Coefficient"] = calculate_prompt_correlation_avg(model_data_dict, human_data, pearsonr)
    metrics["Average Spearman Coefficient"] = calculate_prompt_correlation_avg(model_data_dict, human_data, spearmanr)
    
    metrics_to_average = [
        metrics["Pairwise Agreement"],
        metrics["Model Mean Correlation"],
        metrics["Average Pearson Coefficient"],
        metrics["Average Spearman Coefficient"]
    ]
    valid_scores = [score for score in metrics_to_average if not np.isnan(score)]
    metrics["Overall Score"] = np.mean(valid_scores) if valid_scores else 0.0
    
    return metrics 