"""
Statistical functions for grading comparison and inter-rater reliability.

Provides metrics to assess consistency between multiple graders:
- Raw agreement percentage
- Correlation coefficients (for progress grades)
"""

from typing import Dict, Optional
from scipy import stats


def calculate_agreement(session1_gradings, session2_gradings) -> Dict:
    """
    Calculate agreement metrics between two grading sessions.

    Args:
        session1_gradings: QuerySet or list of ModelGrading objects from first session
        session2_gradings: QuerySet or list of ModelGrading objects from second session

    Returns:
        Dictionary containing:
        - category_agreements: Dict of agreement % per category
        - overall_agreement: Overall agreement percentage
        - progress_correlation: Correlation coefficient for progress grades
        - n_compared: Number of answers compared
    """
    # Create mappings by model_answer_id
    s1_map = {g.model_answer_id: g for g in session1_gradings}
    s2_map = {g.model_answer_id: g for g in session2_gradings}

    # Find common answers
    common_ids = set(s1_map.keys()) & set(s2_map.keys())

    if not common_ids:
        return {
            'category_agreements': {},
            'overall_agreement': 0,
            'progress_correlation': None,
            'n_compared': 0,
            'error': 'No common answers between sessions'
        }

    # Categories to compare (8 binary categories)
    binary_categories = [
        'error_incorrect_logic',
        'error_hallucinated',
        'error_calculation',
        'error_conceptual',
        'achievement_understanding',
        'achievement_correct_result',
        'achievement_insight',
        'achievement_usefulness',
    ]

    # Calculate agreement for each category
    category_agreements = {}
    total_agreements = 0
    total_comparisons = 0

    for category in binary_categories:
        agreements = 0
        valid_comparisons = 0

        for answer_id in common_ids:
            g1 = s1_map[answer_id]
            g2 = s2_map[answer_id]

            # Get the state for this category (it's a ForeignKey to GradingState)
            state1 = getattr(g1, category)
            state2 = getattr(g2, category)

            # Only compare if both graders provided a rating (not None)
            if state1 is not None and state2 is not None:
                valid_comparisons += 1
                # Compare state codes
                if state1.state_code == state2.state_code:
                    agreements += 1

        if valid_comparisons > 0:
            agreement_pct = (agreements / valid_comparisons) * 100
            category_agreements[category] = {
                'agreement_pct': round(agreement_pct, 1),
                'agreements': agreements,
                'total': valid_comparisons
            }
            total_agreements += agreements
            total_comparisons += valid_comparisons
        else:
            category_agreements[category] = {
                'agreement_pct': None,
                'agreements': 0,
                'total': 0
            }

    # Overall agreement across all categories
    overall_agreement = 0
    if total_comparisons > 0:
        overall_agreement = round((total_agreements / total_comparisons) * 100, 1)

    # Calculate correlation for progress grades
    progress_correlation = _calculate_progress_correlation(s1_map, s2_map, common_ids)

    # Calculate progress grade agreement metrics
    progress_agreement = _calculate_progress_grade_agreement(s1_map, s2_map, common_ids)

    return {
        'category_agreements': category_agreements,
        'overall_agreement': overall_agreement,
        'total_agreements': total_agreements,
        'total_comparisons': total_comparisons,
        'progress_correlation': progress_correlation,
        'progress_agreement': progress_agreement,
        'n_compared': len(common_ids)
    }


def _calculate_progress_grade_agreement(s1_map, s2_map, common_ids) -> Dict:
    """
    Calculate agreement metrics specifically for progress grades (0-3 scale).

    Returns:
        Dictionary containing:
        - exact_agreement_pct: Percentage of exact matches
        - within_1_agreement_pct: Percentage within 1 point
        - mean_absolute_difference: Average difference between grades
        - n_compared: Number of answers with both grades
    """
    exact_matches = 0
    within_1_matches = 0
    total_difference = 0
    n_compared = 0

    for answer_id in common_ids:
        g1 = s1_map[answer_id]
        g2 = s2_map[answer_id]

        # Only compare if both have progress grades
        if g1.progress_grade is not None and g2.progress_grade is not None:
            n_compared += 1
            difference = abs(g1.progress_grade - g2.progress_grade)
            total_difference += difference

            if difference == 0:
                exact_matches += 1
                within_1_matches += 1  # Exact match is also within 1
            elif difference == 1:
                within_1_matches += 1

    if n_compared == 0:
        return {
            'exact_agreement_pct': None,
            'within_1_agreement_pct': None,
            'mean_absolute_difference': None,
            'n_compared': 0
        }

    return {
        'exact_agreement_pct': round((exact_matches / n_compared) * 100, 1),
        'within_1_agreement_pct': round((within_1_matches / n_compared) * 100, 1),
        'mean_absolute_difference': round(total_difference / n_compared, 2),
        'n_compared': n_compared
    }


def _calculate_progress_correlation(s1_map, s2_map, common_ids) -> Optional[float]:
    """
    Calculate Spearman rank correlation coefficient for progress grades.

    Progress grades are on a 0-3 ordinal scale (none, minimal, substantial, near-complete).
    Spearman correlation is more appropriate than Pearson for ordinal data as it
    doesn't assume a linear relationship and is based on rank order.
    """
    grades1 = []
    grades2 = []

    for answer_id in common_ids:
        g1 = s1_map[answer_id]
        g2 = s2_map[answer_id]

        # Only include if both have progress grades
        if g1.progress_grade is not None and g2.progress_grade is not None:
            grades1.append(g1.progress_grade)
            grades2.append(g2.progress_grade)

    if len(grades1) < 2:
        return None

    # Calculate Spearman correlation using scipy
    correlation, p_value = stats.spearmanr(grades1, grades2)

    # Handle NaN (e.g., when all values are the same)
    if correlation != correlation:  # NaN check
        return None

    return round(correlation, 3)
