"""
Review Metrics Module

Contains all metrics for evaluating review mode performance:
1. BP Raw Error - Average absolute error between BP raw scores and true qualities
2. BP F1 Loss - F1 score loss when using 50% acceptance rate
3. KL Divergence - Divergence between review distribution and true distribution
"""

import numpy as np


def calculate_calibration_error(final_decisions, true_qualities):
    """
    Metric 1: Calculate calibration error between BP probabilities and true qualities

    Uses sigmoid transformation to convert raw BP scores to probabilities, then
    compares with true quality probabilities for better calibration measurement.

    Args:
        final_decisions: Raw BP scores (continuous values, preserving information gain)
        true_qualities: True paper qualities (±1)

    Returns:
        dict: Contains calibration_error and related statistics
    """
    if final_decisions is None:
        return {'calibration_error': None, 'error': 'Final decisions not available'}

    true_qualities = np.array(true_qualities, dtype=float)
    bp_raw_scores = np.array(final_decisions, dtype=float)

    # Convert BP scores to probabilities using sigmoid
    def sigmoid(x):
        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))  # Clip to prevent overflow

    bp_probabilities = sigmoid(bp_raw_scores)

    # Convert true qualities (±1) to probabilities (0,1)
    true_probabilities = (true_qualities + 1) / 2

    # Calculate calibration error
    calibration_errors = np.abs(bp_probabilities - true_probabilities)
    average_calibration_error = np.mean(calibration_errors)

    return {
        'calibration_error': average_calibration_error,
        'individual_errors': calibration_errors,
        'max_error': np.max(calibration_errors),
        'min_error': np.min(calibration_errors),
        'raw_scores_preserved': True,
        'calibration_based': True
    }


def calculate_bp_raw_error(final_decisions, true_qualities):
    """
    Backward compatibility wrapper for calculate_calibration_error

    Returns results with 'bp_raw_error' key for compatibility with existing code
    """
    result = calculate_calibration_error(final_decisions, true_qualities)
    if result.get('calibration_error') is not None:
        result['bp_raw_error'] = result['calibration_error']
    return result


def calculate_bp_decision_f1_loss(final_decisions, true_qualities, acceptance_rate=0.5):
    """
    Metric 2: Calculate F1 score loss when using specified acceptance rate

    Args:
        final_decisions: Raw BP scores (continuous values)
        true_qualities: True paper qualities (±1)
        acceptance_rate: Target acceptance rate (default 0.5)

    Returns:
        dict: Contains f1_loss, f1_score, and acceptance_rate
    """
    if final_decisions is None:
        return {'f1_loss': None, 'f1_score': None, 'acceptance_rate': None,
                'error': 'Final decisions not available'}

    true_qualities = np.array(true_qualities, dtype=float)
    bp_raw_scores = np.array(final_decisions, dtype=float)

    # Sort papers by BP scores (descending order)
    sorted_indices = np.argsort(-bp_raw_scores)
    num_accept = int(len(bp_raw_scores) * acceptance_rate)

    # Create BP decisions based on acceptance rate
    bp_decisions = np.ones(len(bp_raw_scores)) * (-1)  # Initialize all as reject
    bp_decisions[sorted_indices[:num_accept]] = 1  # Top papers as accept

    # Calculate F1 score
    tp = np.sum((bp_decisions == 1) & (true_qualities == 1))
    fp = np.sum((bp_decisions == 1) & (true_qualities == -1))
    fn = np.sum((bp_decisions == -1) & (true_qualities == 1))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    # F1 loss is 1 - F1 score
    f1_loss = 1 - f1_score
    actual_acceptance_rate = np.sum(bp_decisions == 1) / len(bp_decisions)

    return {
        'f1_loss': f1_loss,
        'f1_score': f1_score,
        'acceptance_rate': actual_acceptance_rate,
        'precision': precision,
        'recall': recall,
        'confusion_matrix': {'tp': tp, 'fp': fp, 'fn': fn}
    }


def calculate_kl_divergence(final_decisions, true_qualities):
    """
    Metric 3: Calculate KL divergence between review distribution and true distribution
    Models the review process as a classification problem with +1 (accept) and -1 (reject)

    Args:
        final_decisions: Raw BP scores (continuous values)
        true_qualities: True paper qualities (±1)

    Returns:
        dict: Contains kl_divergence and distribution information
    """
    if final_decisions is None:
        return {'kl_divergence': None, 'error': 'Final decisions not available'}

    # Get true qualities and predicted qualities
    true_qualities = np.array(true_qualities, dtype=float)
    predicted_qualities = np.sign(final_decisions)

    # Calculate true distribution (frequency-based probability estimation)
    true_accept_count = np.sum(true_qualities == 1)
    true_reject_count = np.sum(true_qualities == -1)
    total_papers = len(true_qualities)

    # True distribution: P(+1), P(-1)
    p_true_accept = true_accept_count / total_papers
    p_true_reject = true_reject_count / total_papers

    # Calculate review distribution (frequency-based probability estimation)
    review_accept_count = np.sum(predicted_qualities == 1)
    review_reject_count = np.sum(predicted_qualities == -1)

    # Review distribution: Q(+1), Q(-1)
    p_review_accept = review_accept_count / total_papers
    p_review_reject = review_reject_count / total_papers

    # Handle edge cases (add small epsilon to avoid log(0))
    epsilon = 1e-10
    p_true_accept = max(p_true_accept, epsilon)
    p_true_reject = max(p_true_reject, epsilon)
    p_review_accept = max(p_review_accept, epsilon)
    p_review_reject = max(p_review_reject, epsilon)

    # Calculate KL divergence: KL(P||Q) = sum(P(x) * log(P(x)/Q(x)))
    # P = true distribution, Q = review distribution
    kl_divergence = (p_true_accept * np.log(p_true_accept / p_review_accept) +
                    p_true_reject * np.log(p_true_reject / p_review_reject))

    return {
        'kl_divergence': kl_divergence,
        'true_distribution': {'accept': p_true_accept, 'reject': p_true_reject},
        'review_distribution': {'accept': p_review_accept, 'reject': p_review_reject},
        'true_counts': {'accept': true_accept_count, 'reject': true_reject_count},
        'review_counts': {'accept': review_accept_count, 'reject': review_reject_count}
    }


def calculate_js_divergence(final_decisions, true_qualities):
    """
    Metric 4: Calculate Jensen-Shannon divergence between review distribution and true distribution
    JS divergence is more stable and bounded [0,1] compared to KL divergence

    Args:
        final_decisions: Raw BP scores (continuous values)
        true_qualities: True paper qualities (±1)

    Returns:
        dict: Contains js_divergence and distribution information
    """
    if final_decisions is None:
        return {'js_divergence': None, 'error': 'Final decisions not available'}

    # Get true qualities and predicted qualities
    true_qualities = np.array(true_qualities, dtype=float)
    predicted_qualities = np.sign(final_decisions)

    # Calculate true distribution (frequency-based probability estimation)
    true_accept_count = np.sum(true_qualities == 1)
    true_reject_count = np.sum(true_qualities == -1)
    total_papers = len(true_qualities)

    # True distribution: P(+1), P(-1)
    p_true_accept = true_accept_count / total_papers
    p_true_reject = true_reject_count / total_papers

    # Calculate review distribution (frequency-based probability estimation)
    review_accept_count = np.sum(predicted_qualities == 1)
    review_reject_count = np.sum(predicted_qualities == -1)

    # Review distribution: Q(+1), Q(-1)
    p_review_accept = review_accept_count / total_papers
    p_review_reject = review_reject_count / total_papers

    # Handle edge cases (add small epsilon to avoid log(0))
    epsilon = 1e-10
    p_true = np.array([p_true_accept, p_true_reject]) + epsilon
    p_review = np.array([p_review_accept, p_review_reject]) + epsilon

    # Renormalize after adding epsilon
    p_true = p_true / np.sum(p_true)
    p_review = p_review / np.sum(p_review)

    # Calculate middle point M = 0.5 * (P + Q)
    m = 0.5 * (p_true + p_review)

    # Calculate JS divergence: JS(P,Q) = 0.5 * KL(P||M) + 0.5 * KL(Q||M)
    js_divergence = 0.5 * np.sum(p_true * np.log(p_true / m)) + \
                   0.5 * np.sum(p_review * np.log(p_review / m))

    return {
        'js_divergence': js_divergence,
        'true_distribution': {'accept': p_true[0], 'reject': p_true[1]},
        'review_distribution': {'accept': p_review[0], 'reject': p_review[1]},
        'middle_distribution': {'accept': m[0], 'reject': m[1]},
        'true_counts': {'accept': true_accept_count, 'reject': true_reject_count},
        'review_counts': {'accept': review_accept_count, 'reject': review_reject_count}
    }


def calculate_all_metrics(final_decisions, true_qualities, acceptance_rate=0.5):
    """
    Calculate all four metrics at once

    Args:
        final_decisions: Raw BP scores (continuous values)
        true_qualities: True paper qualities (±1)
        acceptance_rate: Target acceptance rate for F1 loss calculation

    Returns:
        dict: Contains all four metrics
    """
    calibration_error_result = calculate_calibration_error(final_decisions, true_qualities)
    bp_f1_loss_result = calculate_bp_decision_f1_loss(final_decisions, true_qualities, acceptance_rate)
    kl_divergence_result = calculate_kl_divergence(final_decisions, true_qualities)
    js_divergence_result = calculate_js_divergence(final_decisions, true_qualities)

    return {
        'calibration_error': calibration_error_result,
        'bp_raw_error': calibration_error_result,  # Backward compatibility
        'bp_f1_loss': bp_f1_loss_result,
        'kl_divergence': kl_divergence_result,
        'js_divergence': js_divergence_result
    }