"""
Evaluation metrics for HMM-GLM models.

This module provides functions for evaluating HMM-GLM models and comparing them
with baseline models.
"""

import numpy as np
from sklearn.metrics import (
    roc_auc_score,
    accuracy_score,
    log_loss,
    brier_score_loss
)

def evaluate_hmm_glm_model(model, X_hmm_test, X_glm_test, y_test, contexts_test=None):
    """
    Evaluate an HMM-GLM model on test data.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X_hmm_test : ndarray, shape (n_samples, n_hmm_features)
        Test feature matrix for the HMM component
    X_glm_test : ndarray, shape (n_samples, n_glm_features)
        Test feature matrix for the GLM component
    y_test : ndarray, shape (n_samples,)
        True test labels
    contexts_test : ndarray, shape (n_samples, n_contexts), optional
        Test context variables for context-aware transitions
    
    Returns:
    --------
    metrics : dict
        Dictionary of evaluation metrics
    """
    # Predict class labels
    y_pred, states = model.predict(X_hmm_test, X_glm_test, contexts_test)
    
    # Predict probabilities
    y_proba = model.predict_proba(X_hmm_test, X_glm_test, contexts_test)
    
    # Calculate metrics
    metrics = {}
    
    # Accuracy
    metrics['accuracy'] = accuracy_score(y_test, y_pred)
    
    # AUC (handle single-class case)
    if len(np.unique(y_test)) > 1:
        metrics['auc'] = roc_auc_score(y_test, y_proba)
    else:
        # If only one class is present, AUC is undefined
        # Set to 0.5 (random classifier performance)
        metrics['auc'] = 0.5
    
    # Brier score
    metrics['brier_score'] = brier_score_loss(y_test, y_proba)
    
    # Log loss (handle single-class case)
    if len(np.unique(y_test)) > 1:
        try:
            metrics['log_loss'] = log_loss(y_test, y_proba)
        except ValueError:
            # Handle edge cases where probabilities are 0 or 1
            # Clip probabilities to avoid log(0)
            y_proba_clipped = np.clip(y_proba, 1e-15, 1 - 1e-15)
            metrics['log_loss'] = log_loss(y_test, y_proba_clipped)
    else:
        # If only one class is present, calculate log loss manually
        y_value = y_test[0]  # The single class value
        if y_value == 1:
            # All samples are positive, log loss is -log(p)
            metrics['log_loss'] = -np.mean(np.log(np.clip(y_proba, 1e-15, 1.0)))
        else:
            # All samples are negative, log loss is -log(1-p)
            metrics['log_loss'] = -np.mean(np.log(np.clip(1 - y_proba, 1e-15, 1.0)))
    
    # State diversity (entropy of state distribution)
    state_counts = np.bincount(states, minlength=model.hmm_component.n_states)
    state_probs = state_counts / len(states)
    # Remove zero probabilities to avoid log(0)
    state_probs = state_probs[state_probs > 0]
    entropy = -np.sum(state_probs * np.log(state_probs))
    # Normalize by log(n_states) to get a value between 0 and 1
    metrics['state_diversity'] = entropy / np.log(model.hmm_component.n_states) if model.hmm_component.n_states > 1 else 0.0
    
    return metrics

def calculate_delta_log_likelihood(model, baseline_model, X_hmm_test, X_glm_test, y_test, contexts_test=None):
    """
    Calculate delta log-likelihood between HMM-GLM model and baseline model.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    baseline_model : object
        Fitted baseline model with predict_proba method
    X_hmm_test : ndarray, shape (n_samples, n_hmm_features)
        Test feature matrix for the HMM component
    X_glm_test : ndarray, shape (n_samples, n_glm_features)
        Test feature matrix for the GLM component
    y_test : ndarray, shape (n_samples,)
        True test labels
    contexts_test : ndarray, shape (n_samples, n_contexts), optional
        Test context variables for context-aware transitions
    
    Returns:
    --------
    delta_ll : float
        Delta log-likelihood
    """
    # Predict probabilities
    y_proba_hmm_glm = model.predict_proba(X_hmm_test, X_glm_test, contexts_test)
    y_proba_baseline = baseline_model.predict_proba(X_glm_test)[:, 1]
    
    # Clip probabilities to avoid log(0)
    y_proba_hmm_glm = np.clip(y_proba_hmm_glm, 1e-15, 1 - 1e-15)
    y_proba_baseline = np.clip(y_proba_baseline, 1e-15, 1 - 1e-15)
    
    # Calculate log-likelihoods
    ll_hmm_glm = np.sum(y_test * np.log(y_proba_hmm_glm) + (1 - y_test) * np.log(1 - y_proba_hmm_glm))
    ll_baseline = np.sum(y_test * np.log(y_proba_baseline) + (1 - y_test) * np.log(1 - y_proba_baseline))
    
    # Calculate delta log-likelihood
    delta_ll = (ll_hmm_glm - ll_baseline) / len(y_test)
    
    return delta_ll

def compare_with_baseline(model, X_hmm_test, X_glm_test, y_test, contexts_test=None):
    """
    Compare HMM-GLM model with baseline models.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X_hmm_test : ndarray, shape (n_samples, n_hmm_features)
        Test feature matrix for the HMM component
    X_glm_test : ndarray, shape (n_samples, n_glm_features)
        Test feature matrix for the GLM component
    y_test : ndarray, shape (n_samples,)
        True test labels
    contexts_test : ndarray, shape (n_samples, n_contexts), optional
        Test context variables for context-aware transitions
    
    Returns:
    --------
    results : dict
        Dictionary of evaluation results for all models
    """
    from sklearn.linear_model import LogisticRegression
    from sklearn.dummy import DummyClassifier
    
    results = {}
    
    # Evaluate HMM-GLM model
    print("Evaluating HMM-GLM model...")
    results['hmm_glm'] = evaluate_hmm_glm_model(model, X_hmm_test, X_glm_test, y_test, contexts_test)
    
    # Train and evaluate logistic regression baseline
    print("Training logistic regression model...")
    logistic = LogisticRegression(random_state=42)
    logistic.fit(X_glm_test, y_test)
    
    # Predict with logistic regression
    y_pred_logistic = logistic.predict(X_glm_test)
    y_proba_logistic = logistic.predict_proba(X_glm_test)[:, 1]
    
    # Calculate metrics for logistic regression
    results['logistic'] = {}
    
    # Accuracy
    results['logistic']['accuracy'] = accuracy_score(y_test, y_pred_logistic)
    
    # AUC (handle single-class case)
    if len(np.unique(y_test)) > 1:
        results['logistic']['auc'] = roc_auc_score(y_test, y_proba_logistic)
    else:
        results['logistic']['auc'] = 0.5
    
    # Brier score
    results['logistic']['brier_score'] = brier_score_loss(y_test, y_proba_logistic)
    
    # Log loss (handle single-class case)
    if len(np.unique(y_test)) > 1:
        try:
            results['logistic']['log_loss'] = log_loss(y_test, y_proba_logistic)
        except ValueError:
            y_proba_clipped = np.clip(y_proba_logistic, 1e-15, 1 - 1e-15)
            results['logistic']['log_loss'] = log_loss(y_test, y_proba_clipped)
    else:
        y_value = y_test[0]
        if y_value == 1:
            results['logistic']['log_loss'] = -np.mean(np.log(np.clip(y_proba_logistic, 1e-15, 1.0)))
        else:
            results['logistic']['log_loss'] = -np.mean(np.log(np.clip(1 - y_proba_logistic, 1e-15, 1.0)))
    
    # State diversity (always 0 for logistic regression)
    results['logistic']['state_diversity'] = 0.0
    
    # Train and evaluate Bernoulli baseline (predicts the mean of y)
    print("Creating Bernoulli model...")
    bernoulli = DummyClassifier(strategy="prior", random_state=42)
    bernoulli.fit(X_glm_test, y_test)
    
    # Predict with Bernoulli
    y_pred_bernoulli = bernoulli.predict(X_glm_test)
    y_proba_bernoulli = bernoulli.predict_proba(X_glm_test)[:, 1]
    
    # Calculate metrics for Bernoulli
    results['bernoulli'] = {}
    
    # Accuracy
    results['bernoulli']['accuracy'] = accuracy_score(y_test, y_pred_bernoulli)
    
    # AUC (always 0.5 for Bernoulli)
    results['bernoulli']['auc'] = 0.5
    
    # Brier score
    results['bernoulli']['brier_score'] = brier_score_loss(y_test, y_proba_bernoulli)
    
    # Log loss
    try:
        results['bernoulli']['log_loss'] = log_loss(y_test, y_proba_bernoulli)
    except ValueError:
        y_proba_clipped = np.clip(y_proba_bernoulli, 1e-15, 1 - 1e-15)
        results['bernoulli']['log_loss'] = log_loss(y_test, y_proba_clipped)
    
    # State diversity (always 0 for Bernoulli)
    results['bernoulli']['state_diversity'] = 0.0
    
    # Calculate delta log-likelihoods
    print("Calculating delta log-likelihood...")
    results['delta_ll'] = {}
    results['delta_ll']['vs_logistic'] = calculate_delta_log_likelihood(
        model, logistic, X_hmm_test, X_glm_test, y_test, contexts_test)
    results['delta_ll']['vs_bernoulli'] = calculate_delta_log_likelihood(
        model, bernoulli, X_hmm_test, X_glm_test, y_test, contexts_test)
    
    return results


