"""
Evaluation metrics for HMM-GLM models.

This module provides functions for calculating various evaluation metrics
for HMM-GLM models, including accuracy, AUC, Brier score, log loss,
delta log-likelihood, and state diversity.
"""

import numpy as np
import pandas as pd
from typing import Dict, List, Union, Optional, Tuple, Any
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    brier_score_loss,
    log_loss as sklearn_log_loss,
    confusion_matrix,
    precision_score,
    recall_score,
    f1_score
)
import logging

# Setup logging
logger = logging.getLogger(__name__)


def calculate_accuracy(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Calculate accuracy score.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted binary labels or probabilities (will be thresholded at 0.5)
        
    Returns:
    --------
    float
        Accuracy score
    """
    # Convert probabilities to binary predictions if needed
    if y_pred.ndim > 1 or np.any((y_pred > 0) & (y_pred < 1)):
        y_pred_binary = (y_pred > 0.5).astype(int)
    else:
        y_pred_binary = y_pred.astype(int)
    
    return accuracy_score(y_true, y_pred_binary)


def calculate_auc(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Calculate Area Under the ROC Curve (AUC).
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
        
    Returns:
    --------
    float
        AUC score
    """
    # Handle case where all samples belong to one class
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        logger.warning("AUC calculation requires at least two classes. Returning NaN.")
        return np.nan
    
    try:
        return roc_auc_score(y_true, y_pred)
    except Exception as e:
        logger.error(f"Error calculating AUC: {e}")
        return np.nan


def calculate_brier_score(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """
    Calculate Brier score (mean squared error of probabilistic predictions).
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
        
    Returns:
    --------
    float
        Brier score
    """
    try:
        return brier_score_loss(y_true, y_pred)
    except Exception as e:
        logger.error(f"Error calculating Brier score: {e}")
        return np.nan


def calculate_log_loss(y_true: np.ndarray, y_pred: np.ndarray, eps: float = 1e-15) -> float:
    """
    Calculate log loss (negative log-likelihood of predictions).
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
    eps : float, optional
        Small value to avoid log(0)
        
    Returns:
    --------
    float
        Log loss
    """
    # Handle case where all samples belong to one class
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        logger.warning("Log loss calculation requires at least two classes. Returning NaN.")
        return np.nan
    
    # Clip probabilities to avoid numerical issues
    y_pred_clipped = np.clip(y_pred, eps, 1 - eps)
    
    try:
        return sklearn_log_loss(y_true, y_pred_clipped)
    except Exception as e:
        logger.error(f"Error calculating log loss: {e}")
        return np.nan


def calculate_delta_loglikelihood(model: Any, X: np.ndarray, y: np.ndarray, 
                                 sequences: np.ndarray) -> float:
    """
    Calculate delta log-likelihood (improvement over baseline).
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
        
    Returns:
    --------
    float
        Delta log-likelihood
    """
    try:
        # Calculate log-likelihood of the model
        model_ll = model.score(X, y, sequences)
        
        # Calculate log-likelihood of a baseline model (e.g., constant prediction)
        baseline_prob = np.mean(y)
        baseline_ll = np.sum(y * np.log(baseline_prob + 1e-15) + 
                           (1 - y) * np.log(1 - baseline_prob + 1e-15))
        
        # Return the difference
        return model_ll - baseline_ll
    except Exception as e:
        logger.error(f"Error calculating delta log-likelihood: {e}")
        return np.nan


def calculate_state_diversity(model: Any, X: np.ndarray, 
                             sequences: np.ndarray) -> float:
    """
    Calculate state diversity index (entropy of state distributions).
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    sequences : numpy.ndarray
        Sequence IDs for each sample
        
    Returns:
    --------
    float
        State diversity index
    """
    try:
        # Get state probabilities
        state_probs = model.predict_state_probs(X, sequences)
        
        # Calculate entropy for each sample
        entropies = -np.sum(state_probs * np.log(state_probs + 1e-15), axis=1)
        
        # Return mean entropy
        return np.mean(entropies)
    except Exception as e:
        logger.error(f"Error calculating state diversity: {e}")
        return np.nan


def evaluate_model(model: Any, X: np.ndarray, y: np.ndarray, 
                  sequences: np.ndarray, 
                  metrics: Optional[List[str]] = None) -> Dict[str, float]:
    """
    Evaluate a HMM-GLM model using multiple metrics.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
    metrics : list of str, optional
        List of metrics to calculate. If None, calculate all metrics.
        
    Returns:
    --------
    dict
        Dictionary of metric names and values
    """
    # Default metrics
    if metrics is None:
        metrics = ['accuracy', 'auc', 'brier_score', 'log_loss', 
                  'delta_loglikelihood', 'state_diversity']
    
    # Initialize results dictionary
    results = {}
    
    # Get predictions
    try:
        y_pred_proba = model.predict_proba(X, sequences)
        y_pred = model.predict(X, sequences)
    except Exception as e:
        logger.error(f"Error getting predictions: {e}")
        return {metric: np.nan for metric in metrics}
    
    # Calculate requested metrics
    for metric in metrics:
        if metric == 'accuracy':
            results['accuracy'] = calculate_accuracy(y, y_pred)
        elif metric == 'auc':
            results['auc'] = calculate_auc(y, y_pred_proba)
        elif metric == 'brier_score':
            results['brier_score'] = calculate_brier_score(y, y_pred_proba)
        elif metric == 'log_loss':
            results['log_loss'] = calculate_log_loss(y, y_pred_proba)
        elif metric == 'delta_loglikelihood':
            results['delta_loglikelihood'] = calculate_delta_loglikelihood(model, X, y, sequences)
        elif metric == 'state_diversity':
            results['state_diversity'] = calculate_state_diversity(model, X, sequences)
        elif metric == 'precision':
            results['precision'] = precision_score(y, y_pred)
        elif metric == 'recall':
            results['recall'] = recall_score(y, y_pred)
        elif metric == 'f1':
            results['f1'] = f1_score(y, y_pred)
        else:
            logger.warning(f"Unknown metric: {metric}")
    
    return results


def calculate_per_state_metrics(model: Any, X: np.ndarray, y: np.ndarray, 
                               sequences: np.ndarray) -> pd.DataFrame:
    """
    Calculate metrics for each latent state.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
        
    Returns:
    --------
    pandas.DataFrame
        DataFrame with metrics for each state
    """
    try:
        # Get state assignments
        states = model.predict_states(X, sequences)
        
        # Get predictions
        y_pred_proba = model.predict_proba(X, sequences)
        y_pred = model.predict(X, sequences)
        
        # Initialize results
        n_states = model.hmm_component.n_states
        results = []
        
        # Calculate metrics for each state
        for state in range(n_states):
            # Get samples in this state
            mask = (states == state)
            
            # Skip if no samples in this state
            if not np.any(mask):
                continue
                
            # Get data for this state
            X_state = X[mask]
            y_state = y[mask]
            y_pred_state = y_pred[mask]
            y_pred_proba_state = y_pred_proba[mask]
            
            # Calculate metrics
            state_metrics = {
                'state': state,
                'count': np.sum(mask),
                'proportion': np.mean(mask),
                'positive_rate': np.mean(y_state),
                'accuracy': calculate_accuracy(y_state, y_pred_state)
            }
            
            # Add AUC if possible
            if len(np.unique(y_state)) > 1:
                state_metrics['auc'] = calculate_auc(y_state, y_pred_proba_state)
                state_metrics['brier_score'] = calculate_brier_score(y_state, y_pred_proba_state)
                state_metrics['log_loss'] = calculate_log_loss(y_state, y_pred_proba_state)
            else:
                state_metrics['auc'] = np.nan
                state_metrics['brier_score'] = np.nan
                state_metrics['log_loss'] = np.nan
            
            results.append(state_metrics)
        
        # Convert to DataFrame
        return pd.DataFrame(results)
    
    except Exception as e:
        logger.error(f"Error calculating per-state metrics: {e}")
        return pd.DataFrame()


def calculate_transition_metrics(model: Any, X: np.ndarray, 
                                sequences: np.ndarray) -> Dict[str, Any]:
    """
    Calculate metrics related to state transitions.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    sequences : numpy.ndarray
        Sequence IDs for each sample
        
    Returns:
    --------
    dict
        Dictionary with transition metrics
    """
    try:
        # Get state assignments
        states = model.predict_states(X, sequences)
        
        # Get unique sequence IDs
        unique_sequences = np.unique(sequences)
        
        # Initialize transition count matrix
        n_states = model.hmm_component.n_states
        transition_counts = np.zeros((n_states, n_states))
        
        # Count transitions
        for seq_id in unique_sequences:
            # Get states for this sequence
            seq_mask = (sequences == seq_id)
            seq_states = states[seq_mask]
            
            # Count transitions
            for i in range(len(seq_states) - 1):
                from_state = seq_states[i]
                to_state = seq_states[i + 1]
                transition_counts[from_state, to_state] += 1
        
        # Calculate transition probabilities
        row_sums = transition_counts.sum(axis=1, keepdims=True)
        transition_probs = np.zeros_like(transition_counts)
        mask = (row_sums > 0)
        transition_probs[mask] = transition_counts[mask] / row_sums[mask]
        
        # Calculate metrics
        metrics = {
            'transition_counts': transition_counts,
            'transition_probabilities': transition_probs,
            'self_transition_rate': np.mean(np.diag(transition_probs)),
            'state_change_rate': 1 - np.mean(np.diag(transition_probs)),
            'most_common_transition': np.unravel_index(
                np.argmax(transition_counts - np.diag(np.diag(transition_counts))), 
                transition_counts.shape
            )
        }
        
        return metrics
    
    except Exception as e:
        logger.error(f"Error calculating transition metrics: {e}")
        return {}


if __name__ == "__main__":
    # Example usage
    import sys
    from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
    
    # Generate synthetic data
    np.random.seed(42)
    n_samples = 1000
    n_features = 5
    n_states = 3
    
    # Create sequences
    n_sequences = 50
    sequence_length = n_samples // n_sequences
    sequences = np.repeat(np.arange(n_sequences), sequence_length)
    
    # Generate features and labels
    X = np.random.randn(n_samples, n_features)
    states = np.random.randint(0, n_states, n_samples)
    
    # Different probabilities for different states
    probs = np.zeros(n_samples)
    for state in range(n_states):
        mask = (states == state)
        beta = np.random.randn(n_features)
        logits = X[mask] @ beta
        probs[mask] = 1 / (1 + np.exp(-logits))
    
    y = (np.random.random(n_samples) < probs).astype(int)
    
    # Create and fit model
    hmm_comp = CategoricalHMMComponent(n_states=n_states, n_categories=2)
    glm_comp = LogisticGLMComponent()
    model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)
    
    model.fit(X, y, sequences=sequences)
    
    # Evaluate model
    metrics = evaluate_model(model, X, y, sequences)
    print("Overall metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    # Calculate per-state metrics
    state_metrics = calculate_per_state_metrics(model, X, y, sequences)
    print("\nPer-state metrics:")
    print(state_metrics)
    
    # Calculate transition metrics
    transition_metrics = calculate_transition_metrics(model, X, sequences)
    print("\nTransition metrics:")
    for metric, value in transition_metrics.items():
        if isinstance(value, np.ndarray):
            print(f"{metric}:\n{value}")
        else:
            print(f"{metric}: {value}")
