"""
Plotting functions for LapBoost visualizations.

This module provides functions for visualizing various aspects of the LapBoost models,
including graph structure, confidence distributions, learning curves, and decision boundaries.
"""

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Union, List, Dict, Any, Tuple
import networkx as nx
from scipy.sparse import csr_matrix
from sklearn.decomposition import PCA

from lapboost.core.graph import GraphConstructor


def plot_confidence_distribution(
    confidences: np.ndarray,
    true_labels: Optional[np.ndarray] = None,
    predicted_labels: Optional[np.ndarray] = None,
    title: str = "Confidence Distribution",
    bins: int = 20,
    figsize: Tuple[int, int] = (10, 6),
    save_path: Optional[str] = None
) -> plt.Figure:
    """
    Plot the distribution of confidence values for predictions.
    
    Parameters
    ----------
    confidences : np.ndarray
        Array of confidence scores
    true_labels : np.ndarray, optional
        True labels for coloring by correctness
    predicted_labels : np.ndarray, optional
        Predicted labels for determining correctness
    title : str, default="Confidence Distribution"
        Plot title
    bins : int, default=20
        Number of histogram bins
    figsize : tuple, default=(10, 6)
        Figure size
    save_path : str, optional
        Path to save the figure
        
    Returns
    -------
    matplotlib.figure.Figure
        The figure object
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    if true_labels is not None and predicted_labels is not None:
        # Color by correctness
        correct = true_labels == predicted_labels
        
        # Plot separate histograms
        sns.histplot(confidences[correct], bins=bins, alpha=0.7, 
                     label="Correct predictions", ax=ax)
        sns.histplot(confidences[~correct], bins=bins, alpha=0.7, 
                     label="Incorrect predictions", ax=ax)
        
        ax.legend()
    else:
        # Plot single histogram
        sns.histplot(confidences, bins=bins, ax=ax)
        
    ax.set_xlabel("Confidence")
    ax.set_ylabel("Count")
    ax.set_title(title)
    
    # Add vertical line for mean confidence
    mean_conf = np.mean(confidences)
    ax.axvline(mean_conf, color='red', linestyle='--', 
              label=f"Mean: {mean_conf:.3f}")
    ax.legend()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig


def plot_graph_structure(
    graph_constructor: GraphConstructor,
    features: np.ndarray,
    labels: Optional[np.ndarray] = None,
    n_samples: int = 200,
    dim_reduction: str = "pca",
    title: str = "k-NN Graph Structure",
    figsize: Tuple[int, int] = (10, 8),
    save_path: Optional[str] = None
) -> plt.Figure:
    """
    Visualize the k-NN graph structure used in LapBoost.
    
    Parameters
    ----------
    graph_constructor : GraphConstructor
        Fitted graph constructor object
    features : np.ndarray
        Input features used to build the graph
    labels : np.ndarray, optional
        Labels for coloring nodes
    n_samples : int, default=200
        Maximum number of samples to visualize
    dim_reduction : str, default="pca"
        Dimensionality reduction method ("pca" or "tsne")
    title : str, default="k-NN Graph Structure"
        Plot title
    figsize : tuple, default=(10, 8)
        Figure size
    save_path : str, optional
        Path to save the figure
        
    Returns
    -------
    matplotlib.figure.Figure
        The figure object
    """
    if not hasattr(graph_constructor, 'adjacency_matrix_'):
        raise ValueError("GraphConstructor must be fitted before visualization")
    
    # Subsample if needed
    n_total = features.shape[0]
    if n_total > n_samples:
        rng = np.random.RandomState(42)
        idx = rng.choice(n_total, n_samples, replace=False)
        features_sub = features[idx]
        adjacency_sub = graph_constructor.adjacency_matrix_[idx][:, idx]
        if labels is not None:
            labels_sub = labels[idx]
    else:
        features_sub = features
        adjacency_sub = graph_constructor.adjacency_matrix_
        labels_sub = labels
        
    # Apply dimensionality reduction for visualization
    if features_sub.shape[1] > 2:
        if dim_reduction == "pca":
            from sklearn.decomposition import PCA
            reducer = PCA(n_components=2, random_state=42)
        else:  # tsne
            from sklearn.manifold import TSNE
            reducer = TSNE(n_components=2, random_state=42)
            
        features_2d = reducer.fit_transform(features_sub)
    else:
        features_2d = features_sub
        
    # Create networkx graph
    G = nx.from_scipy_sparse_array(adjacency_sub)
    
    # Set node positions based on reduced features
    pos = {i: (features_2d[i, 0], features_2d[i, 1]) for i in range(features_2d.shape[0])}
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Draw nodes
    if labels_sub is not None:
        unique_labels = np.unique(labels_sub)
        colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
        
        for i, label in enumerate(unique_labels):
            mask = labels_sub == label
            nx.draw_networkx_nodes(
                G, pos, 
                nodelist=np.where(mask)[0],
                node_color=[colors[i]],
                node_size=50,
                alpha=0.8,
                label=f"Class {label}",
                ax=ax
            )
    else:
        nx.draw_networkx_nodes(
            G, pos, 
            node_size=50, 
            alpha=0.8,
            ax=ax
        )
        
    # Draw edges with transparency based on weight
    edge_list = list(G.edges(data=True))
    weights = [d['weight'] for u, v, d in edge_list]
    
    # Normalize weights for transparency
    if weights:
        max_weight = max(weights)
        min_weight = min(weights)
        weight_range = max_weight - min_weight
        
        if weight_range > 0:
            alphas = [(w - min_weight) / weight_range * 0.8 + 0.2 for w in weights]
        else:
            alphas = [0.5] * len(weights)
            
        # Draw edges with alpha based on weight
        for (u, v, d), alpha in zip(edge_list, alphas):
            nx.draw_networkx_edges(
                G, pos, 
                edgelist=[(u, v)], 
                width=1.0, 
                alpha=alpha,
                edge_color='gray',
                ax=ax
            )
    
    ax.set_title(title)
    plt.axis('off')
    plt.legend()
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig


def plot_learning_curves(
    history: List[Dict[str, float]],
    metric_keys: Optional[List[str]] = None,
    iterations: Optional[List[int]] = None,
    title: str = "LapBoost Learning Curves",
    figsize: Tuple[int, int] = (10, 6),
    save_path: Optional[str] = None
) -> plt.Figure:
    """
    Plot learning curves from iterative training.
    
    Parameters
    ----------
    history : list of dict
        Performance history from each iteration
    metric_keys : list of str, optional
        Specific metrics to plot
    iterations : list of int, optional
        Iteration numbers (x-axis)
    title : str, default="LapBoost Learning Curves"
        Plot title
    figsize : tuple, default=(10, 6)
        Figure size
    save_path : str, optional
        Path to save the figure
        
    Returns
    -------
    matplotlib.figure.Figure
        The figure object
    """
    if not history:
        raise ValueError("History cannot be empty")
        
    # Use all metrics if none specified
    if metric_keys is None:
        metric_keys = list(history[0].keys())
        
    # Generate iteration numbers if not provided
    if iterations is None:
        iterations = list(range(1, len(history) + 1))
        
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot each metric
    for metric in metric_keys:
        if metric not in history[0]:
            continue
            
        values = [h.get(metric, np.nan) for h in history]
        ax.plot(iterations, values, 'o-', label=metric)
        
    ax.set_xlabel("Iteration")
    ax.set_ylabel("Metric Value")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig


def plot_decision_boundary(
    model,
    X: np.ndarray,
    y: Optional[np.ndarray] = None,
    mesh_step_size: float = 0.02,
    title: str = "Decision Boundary",
    feature_indices: Optional[Tuple[int, int]] = None,
    figsize: Tuple[int, int] = (10, 8),
    save_path: Optional[str] = None
) -> plt.Figure:
    """
    Plot the decision boundary of a LapBoost classifier.
    
    Parameters
    ----------
    model : LapBoostClassifier
        Trained model
    X : np.ndarray
        Features
    y : np.ndarray, optional
        True labels
    mesh_step_size : float, default=0.02
        Step size for the mesh grid
    title : str, default="Decision Boundary"
        Plot title
    feature_indices : tuple, optional
        Indices of two features to use for plotting
    figsize : tuple, default=(10, 8)
        Figure size
    save_path : str, optional
        Path to save the figure
        
    Returns
    -------
    matplotlib.figure.Figure
        The figure object
    """
    # Get features for plotting (either specified or using PCA)
    if X.shape[1] > 2 and feature_indices is None:
        # Use PCA to reduce to 2D
        pca = PCA(n_components=2)
        X_2d = pca.fit_transform(X)
        feature_names = ["PCA Component 1", "PCA Component 2"]
    elif feature_indices is not None:
        # Use specified features
        X_2d = X[:, feature_indices]
        feature_names = [f"Feature {feature_indices[0]}", f"Feature {feature_indices[1]}"]
    else:
        # Use first two features
        X_2d = X[:, :2]
        feature_names = ["Feature 0", "Feature 1"]
        
    # Create mesh grid
    x_min, x_max = X_2d[:, 0].min() - 0.5, X_2d[:, 0].max() + 0.5
    y_min, y_max = X_2d[:, 1].min() - 0.5, X_2d[:, 1].max() + 0.5
    xx, yy = np.meshgrid(
        np.arange(x_min, x_max, mesh_step_size),
        np.arange(y_min, y_max, mesh_step_size)
    )
    
    # Predict class for each point in the mesh
    if X.shape[1] > 2 and feature_indices is None:
        # Need to project mesh points back to original space
        mesh_points = np.c_[xx.ravel(), yy.ravel()]
        try:
            mesh_points_original = pca.inverse_transform(mesh_points)
            Z = model.predict(mesh_points_original)
        except:
            # If inverse transform fails, skip decision boundary
            Z = None
    else:
        if feature_indices is not None:
            # Create full-dimensional mesh points with specified features
            mesh_points = np.zeros((xx.ravel().shape[0], X.shape[1]))
            mesh_points[:, feature_indices] = np.c_[xx.ravel(), yy.ravel()]
            Z = model.predict(mesh_points)
        else:
            # Using first two features
            if X.shape[1] > 2:
                # Need to fill in values for other features
                mesh_points = np.zeros((xx.ravel().shape[0], X.shape[1]))
                mesh_points[:, :2] = np.c_[xx.ravel(), yy.ravel()]
                # Fill other dimensions with median values from training data
                for i in range(2, X.shape[1]):
                    mesh_points[:, i] = np.median(X[:, i])
                Z = model.predict(mesh_points)
            else:
                # Only two features
                Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
                
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Plot decision boundary if available
    if Z is not None:
        Z = Z.reshape(xx.shape)
        ax.contourf(xx, yy, Z, alpha=0.3, cmap=plt.cm.coolwarm)
        ax.contour(xx, yy, Z, colors='black', linewidths=0.5, alpha=0.5)
    
    # Plot training points
    if y is not None:
        unique_classes = np.unique(y)
        for cls in unique_classes:
            mask = y == cls
            ax.scatter(
                X_2d[mask, 0], X_2d[mask, 1],
                alpha=0.8, 
                label=f"Class {cls}",
                edgecolor='k'
            )
    else:
        ax.scatter(X_2d[:, 0], X_2d[:, 1], alpha=0.8, edgecolor='k')
    
    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    ax.set_xlabel(feature_names[0])
    ax.set_ylabel(feature_names[1])
    ax.set_title(title)
    ax.legend()
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        
    return fig
