"""
Metrics utilities for LapBoost package.

This module contains functions for evaluating model performance and
computing confidence metrics for pseudo-labeling.
"""

import numpy as np
from typing import Dict, List, Optional, Tuple, Any, Union
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    mean_squared_error, r2_score, mean_absolute_error,
    roc_auc_score, confusion_matrix
)


def compute_metrics(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    y_proba: Optional[np.ndarray] = None,
    task: str = 'classification',
    multiclass: bool = False,
    sample_weight: Optional[np.ndarray] = None
) -> Dict[str, float]:
    """
    Compute evaluation metrics for model predictions.
    
    Parameters
    ----------
    y_true : np.ndarray
        True target values
    y_pred : np.ndarray
        Predicted target values
    y_proba : np.ndarray, optional
        Predicted probabilities (for classification)
    task : str, default='classification'
        Task type ('classification' or 'regression')
    multiclass : bool, default=False
        Whether the task is multiclass classification
    sample_weight : np.ndarray, optional
        Sample weights
        
    Returns
    -------
    dict
        Dictionary of evaluation metrics
    """
    metrics = {}
    
    if task == 'classification':
        # Classification metrics
        metrics['accuracy'] = accuracy_score(y_true, y_pred, sample_weight=sample_weight)
        
        if multiclass:
            # Multi-class classification
            metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', 
                                                       sample_weight=sample_weight)
            metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro',
                                                 sample_weight=sample_weight)
            metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro',
                                         sample_weight=sample_weight)
            
            # Per-class metrics
            cm = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
            metrics['confusion_matrix'] = cm
            
            # Class distribution
            class_counts = np.bincount(y_true.astype(int))
            metrics['class_distribution'] = class_counts
            
            if y_proba is not None:
                try:
                    # Multi-class ROC AUC (one-vs-rest)
                    metrics['roc_auc_ovr'] = roc_auc_score(
                        y_true, y_proba, multi_class='ovr', sample_weight=sample_weight
                    )
                except ValueError:
                    # May fail for some edge cases
                    pass
        else:
            # Binary classification
            metrics['precision'] = precision_score(y_true, y_pred, sample_weight=sample_weight)
            metrics['recall'] = recall_score(y_true, y_pred, sample_weight=sample_weight)
            metrics['f1'] = f1_score(y_true, y_pred, sample_weight=sample_weight)
            
            if y_proba is not None:
                if y_proba.ndim > 1 and y_proba.shape[1] > 1:
                    # Get positive class probability
                    y_proba_pos = y_proba[:, 1]
                else:
                    y_proba_pos = y_proba
                    
                try:
                    metrics['roc_auc'] = roc_auc_score(
                        y_true, y_proba_pos, sample_weight=sample_weight
                    )
                except ValueError:
                    # May fail for some edge cases
                    pass
    else:
        # Regression metrics
        metrics['mse'] = mean_squared_error(y_true, y_pred, sample_weight=sample_weight)
        metrics['rmse'] = np.sqrt(metrics['mse'])
        metrics['mae'] = mean_absolute_error(y_true, y_pred, sample_weight=sample_weight)
        metrics['r2'] = r2_score(y_true, y_pred, sample_weight=sample_weight)
        
    return metrics


def confidence_metrics(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    confidence: np.ndarray,
    n_bins: int = 10,
    task: str = 'classification'
) -> Dict[str, Any]:
    """
    Compute metrics for assessing pseudo-label confidence reliability.
    
    Parameters
    ----------
    y_true : np.ndarray
        True target values
    y_pred : np.ndarray
        Predicted target values
    confidence : np.ndarray
        Confidence scores for predictions
    n_bins : int, default=10
        Number of confidence bins for calibration analysis
    task : str, default='classification'
        Task type ('classification' or 'regression')
        
    Returns
    -------
    dict
        Dictionary of confidence evaluation metrics
    """
    metrics = {}
    
    # Sort by confidence
    sorted_indices = np.argsort(confidence)
    y_true_sorted = y_true[sorted_indices]
    y_pred_sorted = y_pred[sorted_indices]
    conf_sorted = confidence[sorted_indices]
    
    # Compute overall correlation between confidence and correctness
    if task == 'classification':
        # For classification, correctness is binary
        correctness = (y_true_sorted == y_pred_sorted).astype(float)
    else:
        # For regression, use negative absolute error as correctness
        correctness = -np.abs(y_true_sorted - y_pred_sorted)
        
    metrics['confidence_correctness_correlation'] = np.corrcoef(conf_sorted, correctness)[0, 1]
    
    # Bin analysis
    metrics['bin_edges'] = []
    metrics['bin_confidences'] = []
    metrics['bin_accuracies'] = []
    
    # Create confidence bins
    bin_edges = np.linspace(0.0, 1.0, n_bins + 1)
    
    for i in range(n_bins):
        low, high = bin_edges[i], bin_edges[i + 1]
        in_bin = (conf_sorted >= low) & (conf_sorted < high)
        
        if np.sum(in_bin) > 0:
            bin_conf = np.mean(conf_sorted[in_bin])
            
            if task == 'classification':
                bin_acc = np.mean(correctness[in_bin])
            else:
                # For regression, use normalized RMSE (lower is better)
                bin_mse = np.mean((y_true_sorted[in_bin] - y_pred_sorted[in_bin])**2)
                bin_acc = 1.0 / (1.0 + np.sqrt(bin_mse))  # Transform to [0,1], higher is better
                
            metrics['bin_edges'].append((low, high))
            metrics['bin_confidences'].append(bin_conf)
            metrics['bin_accuracies'].append(bin_acc)
    
    # Expected Calibration Error (ECE)
    # Lower is better - measures the difference between confidence and accuracy
    bin_sizes = np.array([np.sum((conf_sorted >= bin_edges[i]) & 
                                (conf_sorted < bin_edges[i+1])) 
                        for i in range(n_bins)])
    bin_sizes = bin_sizes / np.sum(bin_sizes)  # Normalize
    
    if len(metrics['bin_confidences']) > 0 and len(metrics['bin_accuracies']) > 0:
        ece = np.sum(bin_sizes[:len(metrics['bin_confidences'])] * 
                   np.abs(np.array(metrics['bin_confidences']) - 
                         np.array(metrics['bin_accuracies'])))
        metrics['expected_calibration_error'] = ece
    
    return metrics
