"""
Visualization tools for HMM-GLM model evaluation.

This module provides functions for visualizing HMM-GLM model evaluation results,
including confusion matrices, ROC curves, state transitions, and feature importance.
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Union, Optional, Tuple, Any
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    precision_recall_curve,
    auc,
    calibration_curve
)
import logging
from matplotlib.colors import LinearSegmentedColormap

# Setup logging
logger = logging.getLogger(__name__)


def plot_confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, 
                         normalize: bool = True, 
                         title: str = 'Confusion Matrix',
                         cmap: str = 'Blues',
                         ax: Optional[plt.Axes] = None,
                         figsize: Tuple[int, int] = (8, 6)) -> plt.Axes:
    """
    Plot confusion matrix.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted binary labels
    normalize : bool, optional
        Whether to normalize the confusion matrix
    title : str, optional
        Plot title
    cmap : str, optional
        Colormap name
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Convert probabilities to binary predictions if needed
    if y_pred.ndim > 1 or np.any((y_pred > 0) & (y_pred < 1)):
        y_pred_binary = (y_pred > 0.5).astype(int)
    else:
        y_pred_binary = y_pred.astype(int)
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred_binary)
    
    # Normalize if requested
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        cm = np.nan_to_num(cm)  # Replace NaN with 0
    
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Plot confusion matrix
    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
    ax.figure.colorbar(im, ax=ax)
    
    # Set labels
    classes = ['Negative', 'Positive']
    ax.set(xticks=np.arange(cm.shape[1]),
           yticks=np.arange(cm.shape[0]),
           xticklabels=classes,
           yticklabels=classes,
           title=title,
           ylabel='True Label',
           xlabel='Predicted Label')
    
    # Rotate tick labels and set alignment
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    
    # Loop over data dimensions and create text annotations
    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            ax.text(j, i, format(cm[i, j], fmt),
                   ha="center", va="center",
                   color="white" if cm[i, j] > thresh else "black")
    
    ax.set_ylim(len(classes) - 0.5, -0.5)  # Reverse y-axis
    plt.tight_layout()
    
    return ax


def plot_roc_curve(y_true: np.ndarray, y_pred: np.ndarray,
                  title: str = 'ROC Curve',
                  ax: Optional[plt.Axes] = None,
                  figsize: Tuple[int, int] = (8, 6)) -> plt.Axes:
    """
    Plot ROC curve.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
    title : str, optional
        Plot title
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Handle case where all samples belong to one class
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        logger.warning("ROC curve requires at least two classes.")
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "Insufficient class diversity for ROC curve",
               ha='center', va='center', fontsize=12)
        ax.set_title(title)
        return ax
    
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate ROC curve
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    roc_auc = auc(fpr, tpr)
    
    # Plot ROC curve
    ax.plot(fpr, tpr, lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    ax.plot([0, 1], [0, 1], 'k--', lw=2)
    
    # Set labels and title
    ax.set(xlim=[0.0, 1.0],
           ylim=[0.0, 1.05],
           xlabel='False Positive Rate',
           ylabel='True Positive Rate',
           title=title)
    
    ax.legend(loc="lower right")
    plt.tight_layout()
    
    return ax


def plot_precision_recall_curve(y_true: np.ndarray, y_pred: np.ndarray,
                               title: str = 'Precision-Recall Curve',
                               ax: Optional[plt.Axes] = None,
                               figsize: Tuple[int, int] = (8, 6)) -> plt.Axes:
    """
    Plot precision-recall curve.
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
    title : str, optional
        Plot title
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Handle case where all samples belong to one class
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        logger.warning("Precision-recall curve requires at least two classes.")
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "Insufficient class diversity for precision-recall curve",
               ha='center', va='center', fontsize=12)
        ax.set_title(title)
        return ax
    
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate precision-recall curve
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    pr_auc = auc(recall, precision)
    
    # Plot precision-recall curve
    ax.plot(recall, precision, lw=2, label=f'PR curve (AUC = {pr_auc:.3f})')
    
    # Set labels and title
    ax.set(xlim=[0.0, 1.0],
           ylim=[0.0, 1.05],
           xlabel='Recall',
           ylabel='Precision',
           title=title)
    
    # Add baseline
    baseline = np.mean(y_true)
    ax.axhline(y=baseline, color='r', linestyle='--', 
              label=f'Baseline ({baseline:.3f})')
    
    ax.legend(loc="lower left")
    plt.tight_layout()
    
    return ax


def plot_calibration_curve(y_true: np.ndarray, y_pred: np.ndarray,
                          n_bins: int = 10,
                          title: str = 'Calibration Curve',
                          ax: Optional[plt.Axes] = None,
                          figsize: Tuple[int, int] = (8, 6)) -> plt.Axes:
    """
    Plot calibration curve (reliability diagram).
    
    Parameters:
    -----------
    y_true : numpy.ndarray
        True binary labels
    y_pred : numpy.ndarray
        Predicted probabilities
    n_bins : int, optional
        Number of bins for calibration curve
    title : str, optional
        Plot title
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Handle case where all samples belong to one class
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        logger.warning("Calibration curve requires at least two classes.")
        if ax is None:
            fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, "Insufficient class diversity for calibration curve",
               ha='center', va='center', fontsize=12)
        ax.set_title(title)
        return ax
    
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate calibration curve
    prob_true, prob_pred = calibration_curve(y_true, y_pred, n_bins=n_bins)
    
    # Plot calibration curve
    ax.plot(prob_pred, prob_true, 's-', label='Calibration curve')
    
    # Plot perfect calibration
    ax.plot([0, 1], [0, 1], 'k--', label='Perfect calibration')
    
    # Set labels and title
    ax.set(xlim=[0.0, 1.0],
           ylim=[0.0, 1.0],
           xlabel='Mean predicted probability',
           ylabel='Fraction of positives',
           title=title)
    
    ax.legend(loc="lower right")
    plt.tight_layout()
    
    return ax


def plot_state_transitions(transition_matrix: np.ndarray,
                          title: str = 'State Transition Probabilities',
                          ax: Optional[plt.Axes] = None,
                          figsize: Tuple[int, int] = (8, 6),
                          cmap: str = 'viridis',
                          annot: bool = True) -> plt.Axes:
    """
    Plot state transition matrix.
    
    Parameters:
    -----------
    transition_matrix : numpy.ndarray
        State transition probability matrix
    title : str, optional
        Plot title
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
    cmap : str, optional
        Colormap name
    annot : bool, optional
        Whether to annotate cells
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Plot transition matrix as heatmap
    n_states = transition_matrix.shape[0]
    sns.heatmap(transition_matrix, annot=annot, cmap=cmap, 
               vmin=0, vmax=1, ax=ax, fmt='.2f',
               xticklabels=range(n_states),
               yticklabels=range(n_states))
    
    # Set labels and title
    ax.set(xlabel='To State',
           ylabel='From State',
           title=title)
    
    plt.tight_layout()
    
    return ax


def plot_state_distributions(state_probs: np.ndarray,
                            title: str = 'State Probability Distributions',
                            ax: Optional[plt.Axes] = None,
                            figsize: Tuple[int, int] = (10, 6)) -> plt.Axes:
    """
    Plot state probability distributions.
    
    Parameters:
    -----------
    state_probs : numpy.ndarray
        State probabilities for each sample [n_samples, n_states]
    title : str, optional
        Plot title
    ax : matplotlib.axes.Axes, optional
        Axes to plot on
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.axes.Axes
        Axes with the plot
    """
    # Create figure if needed
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    
    # Plot state probability distributions
    n_states = state_probs.shape[1]
    
    # Create DataFrame for easier plotting
    df = pd.DataFrame(state_probs, columns=[f'State {i}' for i in range(n_states)])
    df_melted = df.melt(var_name='State', value_name='Probability')
    
    # Plot violin plot
    sns.violinplot(x='State', y='Probability', data=df_melted, ax=ax)
    
    # Set labels and title
    ax.set(ylim=[0, 1],
           title=title)
    
    plt.tight_layout()
    
    return ax


def plot_feature_importance(model: Any, feature_names: Optional[List[str]] = None,
                           title: str = 'Feature Importance by State',
                           figsize: Tuple[int, int] = (12, 8)) -> plt.Figure:
    """
    Plot feature importance for each state.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    feature_names : list of str, optional
        Feature names
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.figure.Figure
        Figure with the plot
    """
    try:
        # Get GLM coefficients for each state
        n_states = model.hmm_component.n_states
        coeffs = []
        
        for state in range(n_states):
            if state in model.glm_component.models:
                state_coef = model.glm_component.models[state].coef_
                if state_coef.ndim > 1:
                    state_coef = state_coef[0]  # Get first class for multiclass
                coeffs.append(state_coef)
            else:
                # If no model for this state, use zeros
                n_features = len(feature_names) if feature_names else 1
                coeffs.append(np.zeros(n_features))
        
        # Convert to array
        coeffs = np.array(coeffs)
        n_features = coeffs.shape[1]
        
        # Create feature names if not provided
        if feature_names is None:
            feature_names = [f'Feature {i}' for i in range(n_features)]
        
        # Create figure
        fig, axes = plt.subplots(n_states, 1, figsize=figsize, sharex=True)
        if n_states == 1:
            axes = [axes]
        
        # Plot coefficients for each state
        for state in range(n_states):
            # Sort coefficients by absolute value
            sorted_idx = np.argsort(np.abs(coeffs[state]))
            
            # Plot horizontal bar chart
            ax = axes[state]
            colors = ['r' if c < 0 else 'b' for c in coeffs[state][sorted_idx]]
            ax.barh(np.array(feature_names)[sorted_idx], coeffs[state][sorted_idx], color=colors)
            
            # Set labels
            ax.set(ylabel='Feature',
                  title=f'State {state}')
            
            # Add zero line
            ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
        
        # Set overall title
        fig.suptitle(title, fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust for suptitle
        
        return fig
    
    except Exception as e:
        logger.error(f"Error plotting feature importance: {e}")
        # Return empty figure
        fig, ax = plt.subplots(figsize=figsize)
        ax.text(0.5, 0.5, f"Error plotting feature importance: {e}",
               ha='center', va='center', fontsize=12)
        return fig


def plot_state_sequences(states: np.ndarray, sequences: np.ndarray,
                        n_sequences: int = 5,
                        title: str = 'State Sequences',
                        figsize: Tuple[int, int] = (12, 8)) -> plt.Figure:
    """
    Plot state sequences for a subset of sequences.
    
    Parameters:
    -----------
    states : numpy.ndarray
        State assignments for each sample
    sequences : numpy.ndarray
        Sequence IDs for each sample
    n_sequences : int, optional
        Number of sequences to plot
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
        
    Returns:
    --------
    matplotlib.figure.Figure
        Figure with the plot
    """
    # Get unique sequence IDs
    unique_sequences = np.unique(sequences)
    
    # Select a subset of sequences
    if len(unique_sequences) > n_sequences:
        # Try to select evenly spaced sequences
        indices = np.linspace(0, len(unique_sequences) - 1, n_sequences, dtype=int)
        selected_sequences = unique_sequences[indices]
    else:
        selected_sequences = unique_sequences
    
    # Create figure
    fig, axes = plt.subplots(len(selected_sequences), 1, figsize=figsize, sharex=True)
    if len(selected_sequences) == 1:
        axes = [axes]
    
    # Plot state sequences
    for i, seq_id in enumerate(selected_sequences):
        # Get states for this sequence
        seq_mask = (sequences == seq_id)
        seq_states = states[seq_mask]
        
        # Plot states
        ax = axes[i]
        ax.step(range(len(seq_states)), seq_states, where='post')
        
        # Set labels
        ax.set(ylabel=f'Sequence {seq_id}')
        
        # Set y-ticks to integers
        n_states = len(np.unique(states))
        ax.set_yticks(range(n_states))
    
    # Set overall labels
    axes[-1].set(xlabel='Position in Sequence')
    fig.suptitle(title, fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust for suptitle
    
    return fig


if __name__ == "__main__":
    # Example usage
    import sys
    from src.core.hmm_glm import CategoricalHMMComponent, LogisticGLMComponent, HMMGLMModel
    
    # Generate synthetic data
    np.random.seed(42)
    n_samples = 1000
    n_features = 5
    n_states = 3
    
    # Create sequences
    n_sequences = 50
    sequence_length = n_samples // n_sequences
    sequences = np.repeat(np.arange(n_sequences), sequence_length)
    
    # Generate features and labels
    X = np.random.randn(n_samples, n_features)
    states = np.random.randint(0, n_states, n_samples)
    
    # Different probabilities for different states
    probs = np.zeros(n_samples)
    for state in range(n_states):
        mask = (states == state)
        beta = np.random.randn(n_features)
        logits = X[mask] @ beta
        probs[mask] = 1 / (1 + np.exp(-logits))
    
    y = (np.random.random(n_samples) < probs).astype(int)
    
    # Create and fit model
    hmm_comp = CategoricalHMMComponent(n_states=n_states, n_categories=2)
    glm_comp = LogisticGLMComponent()
    model = HMMGLMModel(hmm_component=hmm_comp, glm_component=glm_comp)
    
    model.fit(X, y, sequences=sequences)
    
    # Make predictions
    y_pred = model.predict_proba(X, sequences=sequences)
    states_pred = model.predict_states(X, sequences=sequences)
    
    # Create feature names
    feature_names = [f'Feature {i}' for i in range(n_features)]
    
    # Create plots
    plt.figure(figsize=(15, 12))
    
    # Plot confusion matrix
    plt.subplot(2, 2, 1)
    plot_confusion_matrix(y, y_pred)
    
    # Plot ROC curve
    plt.subplot(2, 2, 2)
    plot_roc_curve(y, y_pred)
    
    # Plot state transitions
    plt.subplot(2, 2, 3)
    plot_state_transitions(model.hmm_component.model.transmat_)
    
    # Plot feature importance
    plt.subplot(2, 2, 4)
    plot_feature_importance(model, feature_names=feature_names)
    
    plt.tight_layout()
    plt.show()
