"""
Model comparison tools for HMM-GLM evaluation.

This module provides functions for comparing different HMM-GLM models and
comparing HMM-GLM models with baseline models.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Union, Optional, Tuple, Any
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import (
    accuracy_score,
    roc_auc_score,
    brier_score_loss,
    log_loss,
    precision_score,
    recall_score,
    f1_score
)
from sklearn.linear_model import LogisticRegression
from scipy import stats
import logging

# Import local modules
from .metrics import evaluate_model

# Setup logging
logger = logging.getLogger(__name__)


def compare_models(models: List[Any], model_names: List[str],
                 X: np.ndarray, y: np.ndarray, sequences: np.ndarray,
                 metrics: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Compare multiple HMM-GLM models.
    
    Parameters:
    -----------
    models : list
        List of fitted HMM-GLM models
    model_names : list of str
        Names for each model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
    metrics : list of str, optional
        List of metrics to calculate
        
    Returns:
    --------
    pandas.DataFrame
        DataFrame with evaluation metrics for each model
    """
    if len(models) != len(model_names):
        raise ValueError("Number of models must match number of model names")
    
    # Default metrics
    if metrics is None:
        metrics = ['accuracy', 'auc', 'brier_score', 'log_loss', 
                  'delta_loglikelihood', 'state_diversity']
    
    # Initialize results
    results = []
    
    # Evaluate each model
    for i, (model, name) in enumerate(zip(models, model_names)):
        try:
            # Evaluate model
            model_metrics = evaluate_model(model, X, y, sequences, metrics)
            
            # Add model name
            model_metrics['model'] = name
            
            # Add to results
            results.append(model_metrics)
            
        except Exception as e:
            logger.error(f"Error evaluating model {name}: {e}")
            # Add empty results
            results.append({'model': name, **{metric: np.nan for metric in metrics}})
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results)
    
    # Reorder columns to put model first
    cols = ['model'] + [col for col in results_df.columns if col != 'model']
    results_df = results_df[cols]
    
    return results_df


def compare_with_baseline(model: Any, X: np.ndarray, y: np.ndarray, sequences: np.ndarray,
                         baseline_type: str = 'logistic',
                         metrics: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Compare HMM-GLM model with baseline model.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
    baseline_type : str, optional
        Type of baseline model ('logistic', 'constant')
    metrics : list of str, optional
        List of metrics to calculate
        
    Returns:
    --------
    pandas.DataFrame
        DataFrame with evaluation metrics for HMM-GLM and baseline models
    """
    # Default metrics
    if metrics is None:
        metrics = ['accuracy', 'auc', 'brier_score', 'log_loss']
    
    # Initialize results
    results = []
    
    # Evaluate HMM-GLM model
    try:
        hmm_glm_metrics = evaluate_model(model, X, y, sequences, metrics)
        hmm_glm_metrics['model'] = 'HMM-GLM'
        results.append(hmm_glm_metrics)
    except Exception as e:
        logger.error(f"Error evaluating HMM-GLM model: {e}")
        results.append({'model': 'HMM-GLM', **{metric: np.nan for metric in metrics}})
    
    # Create and evaluate baseline model
    if baseline_type == 'logistic':
        try:
            # Fit logistic regression
            baseline_model = LogisticRegression(max_iter=1000)
            baseline_model.fit(X, y)
            
            # Make predictions
            y_pred = baseline_model.predict(X)
            y_pred_proba = baseline_model.predict_proba(X)[:, 1]
            
            # Calculate metrics
            baseline_metrics = {}
            for metric in metrics:
                if metric == 'accuracy':
                    baseline_metrics['accuracy'] = accuracy_score(y, y_pred)
                elif metric == 'auc':
                    if len(np.unique(y)) > 1:
                        baseline_metrics['auc'] = roc_auc_score(y, y_pred_proba)
                    else:
                        baseline_metrics['auc'] = np.nan
                elif metric == 'brier_score':
                    baseline_metrics['brier_score'] = brier_score_loss(y, y_pred_proba)
                elif metric == 'log_loss':
                    if len(np.unique(y)) > 1:
                        baseline_metrics['log_loss'] = log_loss(y, y_pred_proba)
                    else:
                        baseline_metrics['log_loss'] = np.nan
                elif metric == 'precision':
                    baseline_metrics['precision'] = precision_score(y, y_pred)
                elif metric == 'recall':
                    baseline_metrics['recall'] = recall_score(y, y_pred)
                elif metric == 'f1':
                    baseline_metrics['f1'] = f1_score(y, y_pred)
            
            baseline_metrics['model'] = 'Logistic Regression'
            results.append(baseline_metrics)
            
        except Exception as e:
            logger.error(f"Error evaluating logistic regression baseline: {e}")
            results.append({'model': 'Logistic Regression', **{metric: np.nan for metric in metrics}})
    
    elif baseline_type == 'constant':
        try:
            # Constant prediction (most frequent class)
            most_frequent = np.argmax(np.bincount(y.astype(int)))
            y_pred = np.full_like(y, most_frequent)
            y_pred_proba = np.full_like(y, float(most_frequent), dtype=float)
            
            # Calculate metrics
            baseline_metrics = {}
            for metric in metrics:
                if metric == 'accuracy':
                    baseline_metrics['accuracy'] = accuracy_score(y, y_pred)
                elif metric == 'auc':
                    baseline_metrics['auc'] = 0.5  # Random guessing
                elif metric == 'brier_score':
                    baseline_metrics['brier_score'] = brier_score_loss(y, y_pred_proba)
                elif metric == 'log_loss':
                    # Avoid log(0) issues
                    eps = 1e-15
                    y_pred_proba_adj = np.clip(y_pred_proba, eps, 1 - eps)
                    baseline_metrics['log_loss'] = log_loss(y, y_pred_proba_adj)
                elif metric == 'precision':
                    baseline_metrics['precision'] = precision_score(y, y_pred)
                elif metric == 'recall':
                    baseline_metrics['recall'] = recall_score(y, y_pred)
                elif metric == 'f1':
                    baseline_metrics['f1'] = f1_score(y, y_pred)
            
            baseline_metrics['model'] = 'Constant Prediction'
            results.append(baseline_metrics)
            
        except Exception as e:
            logger.error(f"Error evaluating constant baseline: {e}")
            results.append({'model': 'Constant Prediction', **{metric: np.nan for metric in metrics}})
    
    else:
        raise ValueError(f"Unknown baseline type: {baseline_type}")
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results)
    
    # Reorder columns to put model first
    cols = ['model'] + [col for col in results_df.columns if col != 'model']
    results_df = results_df[cols]
    
    return results_df


def statistical_significance_test(model1: Any, model2: Any,
                                X: np.ndarray, y: np.ndarray, sequences: np.ndarray,
                                metric: str = 'accuracy',
                                n_bootstrap: int = 1000,
                                alpha: float = 0.05) -> Dict[str, Any]:
    """
    Perform statistical significance test between two models.
    
    Parameters:
    -----------
    model1 : HMMGLMModel
        First model
    model2 : Any
        Second model (can be HMMGLMModel or sklearn-compatible model)
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
    metric : str, optional
        Metric to compare
    n_bootstrap : int, optional
        Number of bootstrap samples
    alpha : float, optional
        Significance level
        
    Returns:
    --------
    dict
        Dictionary with test results
    """
    try:
        # Get predictions from both models
        if hasattr(model1, 'predict_proba') and metric in ['auc', 'brier_score', 'log_loss']:
            y_pred1 = model1.predict_proba(X, sequences=sequences)
        else:
            y_pred1 = model1.predict(X, sequences=sequences)
        
        if hasattr(model2, 'predict_proba') and metric in ['auc', 'brier_score', 'log_loss']:
            if hasattr(model2, 'predict_proba') and 'sequences' in model2.predict_proba.__code__.co_varnames:
                y_pred2 = model2.predict_proba(X, sequences=sequences)
            else:
                y_pred2 = model2.predict_proba(X)[:, 1]
        else:
            if hasattr(model2, 'predict') and 'sequences' in model2.predict.__code__.co_varnames:
                y_pred2 = model2.predict(X, sequences=sequences)
            else:
                y_pred2 = model2.predict(X)
        
        # Function to calculate metric
        def calc_metric(y_true, y_pred1, y_pred2, metric):
            if metric == 'accuracy':
                # Convert probabilities to binary predictions if needed
                if y_pred1.ndim > 1 or np.any((y_pred1 > 0) & (y_pred1 < 1)):
                    y_pred1_bin = (y_pred1 > 0.5).astype(int)
                else:
                    y_pred1_bin = y_pred1.astype(int)
                
                if y_pred2.ndim > 1 or np.any((y_pred2 > 0) & (y_pred2 < 1)):
                    y_pred2_bin = (y_pred2 > 0.5).astype(int)
                else:
                    y_pred2_bin = y_pred2.astype(int)
                
                return (accuracy_score(y_true, y_pred1_bin), 
                        accuracy_score(y_true, y_pred2_bin))
            
            elif metric == 'auc':
                return (roc_auc_score(y_true, y_pred1), 
                        roc_auc_score(y_true, y_pred2))
            
            elif metric == 'brier_score':
                return (brier_score_loss(y_true, y_pred1), 
                        brier_score_loss(y_true, y_pred2))
            
            elif metric == 'log_loss':
                # Clip probabilities to avoid numerical issues
                eps = 1e-15
                y_pred1_clip = np.clip(y_pred1, eps, 1 - eps)
                y_pred2_clip = np.clip(y_pred2, eps, 1 - eps)
                
                return (log_loss(y_true, y_pred1_clip), 
                        log_loss(y_true, y_pred2_clip))
            
            elif metric == 'precision':
                if y_pred1.ndim > 1 or np.any((y_pred1 > 0) & (y_pred1 < 1)):
                    y_pred1_bin = (y_pred1 > 0.5).astype(int)
                else:
                    y_pred1_bin = y_pred1.astype(int)
                
                if y_pred2.ndim > 1 or np.any((y_pred2 > 0) & (y_pred2 < 1)):
                    y_pred2_bin = (y_pred2 > 0.5).astype(int)
                else:
                    y_pred2_bin = y_pred2.astype(int)
                
                return (precision_score(y_true, y_pred1_bin), 
                        precision_score(y_true, y_pred2_bin))
            
            elif metric == 'recall':
                if y_pred1.ndim > 1 or np.any((y_pred1 > 0) & (y_pred1 < 1)):
                    y_pred1_bin = (y_pred1 > 0.5).astype(int)
                else:
                    y_pred1_bin = y_pred1.astype(int)
                
                if y_pred2.ndim > 1 or np.any((y_pred2 > 0) & (y_pred2 < 1)):
                    y_pred2_bin = (y_pred2 > 0.5).astype(int)
                else:
                    y_pred2_bin = y_pred2.astype(int)
                
                return (recall_score(y_true, y_pred1_bin), 
                        recall_score(y_true, y_pred2_bin))
            
            elif metric == 'f1':
                if y_pred1.ndim > 1 or np.any((y_pred1 > 0) & (y_pred1 < 1)):
                    y_pred1_bin = (y_pred1 > 0.5).astype(int)
                else:
                    y_pred1_bin = y_pred1.astype(int)
                
                if y_pred2.ndim > 1 or np.any((y_pred2 > 0) & (y_pred2 < 1)):
                    y_pred2_bin = (y_pred2 > 0.5).astype(int)
                else:
                    y_pred2_bin = y_pred2.astype(int)
                
                return (f1_score(y_true, y_pred1_bin), 
                        f1_score(y_true, y_pred2_bin))
            
            else:
                raise ValueError(f"Unknown metric: {metric}")
        
        # Calculate observed difference
        metric1, metric2 = calc_metric(y, y_pred1, y_pred2, metric)
        observed_diff = metric1 - metric2
        
        # Bootstrap sampling
        bootstrap_diffs = []
        n_samples = len(y)
        
        for _ in range(n_bootstrap):
            # Sample with replacement
            indices = np.random.choice(n_samples, n_samples, replace=True)
            y_boot = y[indices]
            
            # For HMM-GLM, we need to handle sequences
            unique_seqs = np.unique(sequences)
            seq_indices = np.random.choice(len(unique_seqs), len(unique_seqs), replace=True)
            selected_seqs = unique_seqs[seq_indices]
            
            # Create mask for selected sequences
            mask = np.isin(sequences, selected_seqs)
            
            # Get predictions for bootstrap sample
            if hasattr(model1, 'predict_proba') and metric in ['auc', 'brier_score', 'log_loss']:
                y_pred1_boot = model1.predict_proba(X[mask], sequences=sequences[mask])
            else:
                y_pred1_boot = model1.predict(X[mask], sequences=sequences[mask])
            
            if hasattr(model2, 'predict_proba') and metric in ['auc', 'brier_score', 'log_loss']:
                if hasattr(model2, 'predict_proba') and 'sequences' in model2.predict_proba.__code__.co_varnames:
                    y_pred2_boot = model2.predict_proba(X[mask], sequences=sequences[mask])
                else:
                    y_pred2_boot = model2.predict_proba(X[mask])[:, 1]
            else:
                if hasattr(model2, 'predict') and 'sequences' in model2.predict.__code__.co_varnames:
                    y_pred2_boot = model2.predict(X[mask], sequences=sequences[mask])
                else:
                    y_pred2_boot = model2.predict(X[mask])
            
            # Calculate metric difference for bootstrap sample
            try:
                metric1_boot, metric2_boot = calc_metric(y_boot, y_pred1_boot, y_pred2_boot, metric)
                bootstrap_diffs.append(metric1_boot - metric2_boot)
            except Exception as e:
                logger.warning(f"Error in bootstrap iteration: {e}")
                continue
        
        # Calculate p-value
        if metric in ['accuracy', 'auc', 'precision', 'recall', 'f1']:
            # Higher is better
            p_value = np.mean(np.array(bootstrap_diffs) <= 0)
        else:
            # Lower is better
            p_value = np.mean(np.array(bootstrap_diffs) >= 0)
        
        # Calculate confidence interval
        lower = np.percentile(bootstrap_diffs, alpha/2 * 100)
        upper = np.percentile(bootstrap_diffs, (1 - alpha/2) * 100)
        
        # Determine significance
        significant = p_value < alpha
        
        return {
            'metric': metric,
            'model1_value': metric1,
            'model2_value': metric2,
            'observed_diff': observed_diff,
            'p_value': p_value,
            'confidence_interval': (lower, upper),
            'significant': significant,
            'better_model': 'model1' if (significant and ((metric in ['accuracy', 'auc', 'precision', 'recall', 'f1'] and observed_diff > 0) or
                                                        (metric in ['brier_score', 'log_loss'] and observed_diff < 0))) else
                           'model2' if (significant and ((metric in ['accuracy', 'auc', 'precision', 'recall', 'f1'] and observed_diff < 0) or
                                                        (metric in ['brier_score', 'log_loss'] and observed_diff > 0))) else
                           'neither'
        }
    
    except Exception as e:
        logger.error(f"Error in statistical significance test: {e}")
        return {
            'metric': metric,
            'model1_value': np.nan,
            'model2_value': np.nan,
            'observed_diff': np.nan,
            'p_value': np.nan,
            'confidence_interval': (np.nan, np.nan),
            'significant': False,
            'better_model': 'error',
            'error': str(e)
        }


def cross_validation_comparison(models: List[Any], model_names: List[str],
                              X: np.ndarray, y: np.ndarray, sequences: np.ndarray,
                              n_splits: int = 5,
                              metrics: Optional[List[str]] = None,
                              stratify: bool = True,
                              random_state: int = 42) -> pd.DataFrame:
    """
    Compare models using cross-validation.
    
    Parameters:
    -----------
    models : list
        List of unfitted HMM-GLM models
    model_names : list of str
        Names for each model
    X : numpy.ndarray
        Feature matrix
    y : numpy.ndarray
        True labels
    sequences : numpy.ndarray
        Sequence IDs for each sample
    n_splits : int, optional
        Number of cross-validation splits
    metrics : list of str, optional
        List of metrics to calculate
    stratify : bool, optional
        Whether to use stratified cross-validation
    random_state : int, optional
        Random seed
        
    Returns:
    --------
    pandas.DataFrame
        DataFrame with cross-validation results
    """
    if len(models) != len(model_names):
        raise ValueError("Number of models must match number of model names")
    
    # Default metrics
    if metrics is None:
        metrics = ['accuracy', 'auc', 'brier_score', 'log_loss']
    
    # Create cross-validation splitter
    if stratify:
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    else:
        cv = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)
    
    # Get unique sequences
    unique_sequences = np.unique(sequences)
    
    # Create sequence-to-fold mapping
    np.random.seed(random_state)
    sequence_folds = np.random.randint(0, n_splits, size=len(unique_sequences))
    sequence_to_fold = dict(zip(unique_sequences, sequence_folds))
    
    # Initialize results
    all_results = []
    
    # Perform cross-validation
    for fold in range(n_splits):
        # Create train/test masks based on sequences
        train_seqs = [seq for seq, seq_fold in sequence_to_fold.items() if seq_fold != fold]
        test_seqs = [seq for seq, seq_fold in sequence_to_fold.items() if seq_fold == fold]
        
        train_mask = np.isin(sequences, train_seqs)
        test_mask = np.isin(sequences, test_seqs)
        
        # Get train/test data
        X_train, X_test = X[train_mask], X[test_mask]
        y_train, y_test = y[train_mask], y[test_mask]
        sequences_train, sequences_test = sequences[train_mask], sequences[test_mask]
        
        # Remap sequence IDs for test set
        seq_map = {seq: i for i, seq in enumerate(np.unique(sequences_test))}
        sequences_test_remapped = np.array([seq_map[seq] for seq in sequences_test])
        
        # Evaluate each model
        for i, (model, name) in enumerate(zip(models, model_names)):
            try:
                # Create a fresh copy of the model
                model_copy = model.__class__(**model.get_params())
                
                # Fit model
                model_copy.fit(X_train, y_train, sequences=sequences_train)
                
                # Evaluate model
                model_metrics = evaluate_model(model_copy, X_test, y_test, 
                                             sequences=sequences_test_remapped, 
                                             metrics=metrics)
                
                # Add model name and fold
                model_metrics['model'] = name
                model_metrics['fold'] = fold
                
                # Add to results
                all_results.append(model_metrics)
                
            except Exception as e:
                logger.error(f"Error evaluating model {name} on fold {fold}: {e}")
                # Add empty results
                all_results.append({
                    'model': name, 
                    'fold': fold, 
                    **{metric: np.nan for metric in metrics}
                })
    
    # Convert to DataFrame
    results_df = pd.DataFrame(all_results)
    
    # Calculate mean and std for each model and metric
    summary = []
    
    for name in model_names:
        model_results = results_df[results_df['model'] == name]
        
        for metric in metrics:
            mean_val = model_results[metric].mean()
            std_val = model_results[metric].std()
            
            summary.append({
                'model': name,
                'metric': metric,
                'mean': mean_val,
                'std': std_val,
                'min': model_results[metric].min(),
                'max': model_results[metric].max()
            })
    
    summary_df = pd.DataFrame(summary)
    
    return results_df, summary_df


if __name__ == "__main__":
    # Example usage
    import sys
    from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
    
    # Generate synthetic data
    np.random.seed(42)
    n_samples = 1000
    n_features = 5
    n_states = 3
    
    # Create sequences
    n_sequences = 50
    sequence_length = n_samples // n_sequences
    sequences = np.repeat(np.arange(n_sequences), sequence_length)
    
    # Generate features and labels
    X = np.random.randn(n_samples, n_features)
    states = np.random.randint(0, n_states, n_samples)
    
    # Different probabilities for different states
    probs = np.zeros(n_samples)
    for state in range(n_states):
        mask = (states == state)
        beta = np.random.randn(n_features)
        logits = X[mask] @ beta
        probs[mask] = 1 / (1 + np.exp(-logits))
    
    y = (np.random.random(n_samples) < probs).astype(int)
    
    # Create models
    hmm_glm_model = HMMGLMModel(
        hmm_component=CategoricalHMMComponent(n_states=n_states, n_categories=2),
        glm_component=LogisticGLMComponent()
    )
    
    hmm_glm_model.fit(X, y, sequences=sequences)
    
    # Compare with baseline
    baseline_comparison = compare_with_baseline(
        hmm_glm_model, X, y, sequences, 
        baseline_type='logistic'
    )
    
    print("Comparison with baseline:")
    print(baseline_comparison)
    
    # Statistical significance test
    baseline_model = LogisticRegression().fit(X, y)
    
    significance_test = statistical_significance_test(
        hmm_glm_model, baseline_model,
        X, y, sequences,
        metric='accuracy'
    )
    
    print("\nStatistical significance test:")
    for key, value in significance_test.items():
        print(f"{key}: {value}")
