"""
Visualization functions for HMM-GLM models.

This module provides functions for visualizing HMM-GLM model components and results.
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def plot_state_transition_matrix(model, title=None, cmap='YlGnBu', figsize=(8, 6)):
    """
    Plot the state transition matrix.
    
    Parameters:
    -----------
    model : HMMGLMModel or HMMComponent
        Fitted model with transition matrix
    title : str, optional
        Plot title
    cmap : str, optional
        Colormap for the heatmap
    figsize : tuple, optional
        Figure size
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    ax : matplotlib.axes.Axes
        Axes object
    """
    # Extract transition matrix
    if hasattr(model, 'hmm_component'):
        # HMMGLMModel
        if hasattr(model.hmm_component, 'transmat_'):
            transmat = model.hmm_component.transmat_
        else:
            # Context-aware HMM, use alpha as approximation
            transmat = np.exp(model.hmm_component.alpha_)
            # Normalize rows
            row_sums = np.sum(transmat, axis=1)
            transmat = transmat / row_sums[:, np.newaxis]
    else:
        # HMMComponent
        if hasattr(model, 'transmat_'):
            transmat = model.transmat_
        else:
            # Context-aware HMM, use alpha as approximation
            transmat = np.exp(model.alpha_)
            # Normalize rows
            row_sums = np.sum(transmat, axis=1)
            transmat = transmat / row_sums[:, np.newaxis]
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot heatmap
    sns.heatmap(transmat, annot=True, fmt='.2f', cmap=cmap, ax=ax,
               cbar=True, square=True, vmin=0, vmax=1)
    
    # Set labels
    ax.set_title(title or 'State Transition Matrix')
    ax.set_xlabel('To State')
    ax.set_ylabel('From State')
    
    # Set tick labels
    n_states = transmat.shape[0]
    state_labels = [f'State {i+1}' for i in range(n_states)]
    ax.set_xticklabels(state_labels)
    ax.set_yticklabels(state_labels)
    
    plt.tight_layout()
    
    return fig, ax

def plot_emission_probabilities(model, title=None, figsize=(10, 6)):
    """
    Plot emission probabilities for each state.
    
    Parameters:
    -----------
    model : HMMGLMModel or HMMComponent
        Fitted model with emission probabilities
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    ax : matplotlib.axes.Axes
        Axes object
    """
    # Extract emission probabilities
    if hasattr(model, 'hmm_component'):
        # HMMGLMModel
        if hasattr(model.hmm_component, 'emissionprob_'):
            emissionprob = model.hmm_component.emissionprob_
            n_states = emissionprob.shape[0]
            # For binary emissions, extract success probabilities
            success_probs = emissionprob[:, 1]
        else:
            # For Gaussian emissions, we can't easily visualize
            # Instead, use GLM component to get success probabilities for each state
            n_states = len(model.glm_component.models)
            success_probs = np.zeros(n_states)
            for i in range(n_states):
                if i in model.glm_component.models:
                    # Create dummy features (mean values)
                    X_dummy = np.zeros((1, model.glm_component.models[i].coef_.shape[1]))
                    # Create dummy state assignment
                    dummy_states = np.array([i])
                    # Predict probability
                    success_probs[i] = model.glm_component.predict_proba(X_dummy, dummy_states)[0]
    else:
        # HMMComponent
        if hasattr(model, 'emissionprob_'):
            emissionprob = model.emissionprob_
            n_states = emissionprob.shape[0]
            # For binary emissions, extract success probabilities
            success_probs = emissionprob[:, 1]
        else:
            # For Gaussian emissions, we can't easily visualize
            return None, None
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot bar chart
    x = np.arange(n_states)
    ax.bar(x, success_probs, color='skyblue')
    
    # Add value labels
    for i, v in enumerate(success_probs):
        ax.text(i, v + 0.02, f'{v:.2f}', ha='center')
    
    # Set labels
    ax.set_title(title or 'Emission Probabilities (Success Rate) by State')
    ax.set_xlabel('State')
    ax.set_ylabel('Success Probability')
    ax.set_xticks(x)
    ax.set_xticklabels([f'State {i+1}' for i in range(n_states)])
    ax.set_ylim(0, 1)
    
    plt.tight_layout()
    
    return fig, ax

def plot_state_sequence(model, X_hmm, X_glm, y, contexts=None, window=100, 
                      title=None, figsize=(12, 6)):
    """
    Plot a segment of the state sequence.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X_hmm : ndarray, shape (n_samples, n_hmm_features)
        Feature matrix for the HMM component
    X_glm : ndarray, shape (n_samples, n_glm_features)
        Feature matrix for the GLM component
    y : ndarray, shape (n_samples,)
        True labels
    contexts : ndarray, shape (n_samples, n_contexts), optional
        Context variables for context-aware transitions
    window : int, optional
        Number of samples to plot
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    ax : matplotlib.axes.Axes
        Axes object
    """
    # Predict states
    _, states = model.predict(X_hmm, X_glm, contexts)
    
    # Limit to window size
    if len(states) > window:
        start_idx = np.random.randint(0, len(states) - window)
        end_idx = start_idx + window
        states = states[start_idx:end_idx]
        y = y[start_idx:end_idx]
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot states
    ax.plot(states, 'b-', label='State')
    
    # Plot outcomes
    ax.scatter(np.arange(len(y)), y, color='r', alpha=0.5, label='Outcome')
    
    # Set labels
    ax.set_title(title or 'State Sequence and Outcomes')
    ax.set_xlabel('Time')
    ax.set_ylabel('State / Outcome')
    ax.legend()
    
    # Set y-ticks
    n_states = model.hmm_component.n_states
    ax.set_yticks(np.arange(n_states))
    ax.set_yticklabels([f'State {i+1}' for i in range(n_states)])
    
    plt.tight_layout()
    
    return fig, ax

def plot_delta_log_likelihood_distribution(delta_ll_values, title=None, figsize=(10, 6)):
    """
    Plot the distribution of delta log-likelihood values.
    
    Parameters:
    -----------
    delta_ll_values : ndarray, shape (n_samples,)
        Delta log-likelihood values
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    ax : matplotlib.axes.Axes
        Axes object
    """
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot histogram
    ax.hist(delta_ll_values, bins=30, color='skyblue', edgecolor='black')
    
    # Add vertical line at 0
    ax.axvline(x=0, color='red', linestyle='--', label='No Improvement')
    
    # Add vertical line at mean
    mean_delta_ll = np.mean(delta_ll_values)
    ax.axvline(x=mean_delta_ll, color='green', linestyle='-', 
             label=f'Mean: {mean_delta_ll:.3f}')
    
    # Calculate percentage of positive values
    pct_positive = np.mean(delta_ll_values > 0) * 100
    
    # Set labels
    ax.set_title(title or f'Delta Log-Likelihood Distribution ({pct_positive:.1f}% Positive)')
    ax.set_xlabel('Delta Log-Likelihood')
    ax.set_ylabel('Count')
    ax.legend()
    
    plt.tight_layout()
    
    return fig, ax

def plot_state_goal_probabilities(model, X_hmm, X_glm, y, contexts=None, 
                                title=None, figsize=(10, 6)):
    """
    Plot goal probabilities by state.
    
    Parameters:
    -----------
    model : HMMGLMModel
        Fitted HMM-GLM model
    X_hmm : ndarray, shape (n_samples, n_hmm_features)
        Feature matrix for the HMM component
    X_glm : ndarray, shape (n_samples, n_glm_features)
        Feature matrix for the GLM component
    y : ndarray, shape (n_samples,)
        True labels
    contexts : ndarray, shape (n_samples, n_contexts), optional
        Context variables for context-aware transitions
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    
    Returns:
    --------
    fig : matplotlib.figure.Figure
        Figure object
    ax : matplotlib.axes.Axes
        Axes object
    """
    # Predict states
    _, states = model.predict(X_hmm, X_glm, contexts)
    
    # Calculate goal probabilities by state
    n_states = model.hmm_component.n_states
    goal_probs = np.zeros(n_states)
    goal_counts = np.zeros(n_states)
    
    for i in range(n_states):
        mask = (states == i)
        if np.sum(mask) > 0:
            goal_probs[i] = np.mean(y[mask])
            goal_counts[i] = np.sum(mask)
    
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot bar chart
    x = np.arange(n_states)
    bars = ax.bar(x, goal_probs, color='skyblue')
    
    # Add value labels
    for i, v in enumerate(goal_probs):
        ax.text(i, v + 0.02, f'{v:.3f}', ha='center')
    
    # Add count labels
    for i, (bar, count) in enumerate(zip(bars, goal_counts)):
        ax.text(bar.get_x() + bar.get_width()/2., 0.01, 
              f'n={int(count)}', ha='center', va='bottom',
              color='black', rotation=90)
    
    # Set labels
    ax.set_title(title or 'Goal Probability by State')
    ax.set_xlabel('State')
    ax.set_ylabel('Goal Probability')
    ax.set_xticks(x)
    ax.set_xticklabels([f'State {i+1}' for i in range(n_states)])
    ax.set_ylim(0, max(goal_probs) * 1.2)
    
    plt.tight_layout()
    
    return fig, ax


