import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, List, Optional
import numpy as np
from .mixture_gen import Component
import scipy.stats as stats


def plot_mixture_data(
    X: np.ndarray,
    y: np.ndarray,
    components: List[Component],
    feature_idx: int = 0,
    figsize: Tuple[int, int] = (12, 6),
) -> None:
    """
    Plot the generated mixture data.

    Args:
        X: Feature array of shape (n_samples, n_features)
        y: Target array of shape (n_samples, 1)
        components: List of Component objects
        feature_idx: Index of feature to plot against y
        figsize: Figure size (width, height)
    """
    plt.figure(figsize=figsize)

    # Create subplot layout
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)

    # Scatter plot of X vs y
    ax1.scatter(X[:, feature_idx], y, alpha=0.5, s=5)
    ax1.set_xlabel(f"X{feature_idx}")
    ax1.set_ylabel("y")
    ax1.set_title("Scatter Plot: X vs y")

    # Add component rules visualization
    colors = plt.cm.rainbow(np.linspace(0, 1, len(components)))
    for comp, color in zip(components, colors):
        if feature_idx in comp.rules:
            lower, upper = comp.rules[feature_idx]
            ax1.axvline(x=lower, color=color, linestyle="--", alpha=0.5)
            ax1.axvline(x=upper, color=color, linestyle="--", alpha=0.5)
            # Add shaded region for component
            ax1.axvspan(lower, upper, alpha=0.1, color=color)

            # Add distribution mean if normal
            if comp.distribution == "normal":
                ax1.axhline(
                    y=comp.dist_params["loc"],
                    xmin=(lower - ax1.get_xlim()[0])
                    / (ax1.get_xlim()[1] - ax1.get_xlim()[0]),
                    xmax=(upper - ax1.get_xlim()[0])
                    / (ax1.get_xlim()[1] - ax1.get_xlim()[0]),
                    color=color,
                    linestyle="-",
                    alpha=0.5,
                )

    # Density plot of y
    sns.kdeplot(data=y.flatten(), ax=ax2)
    ax2.set_xlabel("y")
    ax2.set_title("Density Plot of y")

    # Add component distributions to density plot
    x_range = np.linspace(y.min(), y.max(), 100)
    for comp, color in zip(components, colors):
        if comp.distribution == "normal":
            density = comp.weight * stats.norm.pdf(
                x_range, comp.dist_params["loc"], comp.dist_params["scale"]
            )
            ax2.plot(
                x_range,
                density,
                color=color,
                linestyle="--",
                alpha=0.5,
                label=f'Component {comp.dist_params["loc"]:.1f} ± {comp.dist_params["scale"]:.1f}',
            )
    ax2.legend()

    plt.tight_layout()
    plt.show()


def plot_component_placement(
    components: List[Component],
    X: Optional[np.ndarray] = None,
    y: Optional[np.ndarray] = None,
    feature_dims: Tuple[int, int] = (0, 1),
    figsize: Tuple[int, int] = (10, 10),
    title: str = "Component Placement",
    show_rules: bool = True,
    show_centers: bool = True,
    alpha: float = 0.2,
    colormap: str = "rainbow",
    ax: Optional[plt.Axes] = None,
) -> Optional[plt.Axes]:
    """
    Visualize the placement of components in a 2D feature space.
    
    Args:
        components: List of Component objects to visualize
        X: Optional feature array of shape (n_samples, n_features)
        y: Optional target array of shape (n_samples,)
        feature_dims: Which two feature dimensions to plot (default: first two dimensions)
        figsize: Figure size (width, height) - only used if ax is None
        title: Plot title
        show_rules: Whether to show component rule boundaries
        show_centers: Whether to show component centers
        alpha: Transparency level for component regions
        colormap: Matplotlib colormap name
        ax: Optional matplotlib Axes to plot on
        
    Returns:
        The matplotlib Axes object if ax was provided, otherwise None
    """
    created_fig = False
    if ax is None:
        plt.figure(figsize=figsize)
        ax = plt.gca()
        created_fig = True
        
    # Extract the 2D dimensions we're plotting
    dim1, dim2 = feature_dims
    
    # Plot data points if provided
    if X is not None:
        if y is not None:
            # Color by class if y is provided
            scatter = ax.scatter(
                X[:, dim1], 
                X[:, dim2], 
                c=y, 
                alpha=0.6, 
                s=15, 
                cmap='viridis',
                edgecolors='w',
                linewidths=0.5
            )
            if created_fig:
                plt.colorbar(scatter, ax=ax, label='Target value')
        else:
            # Single color if no y
            ax.scatter(X[:, dim1], X[:, dim2], alpha=0.6, s=15, c='gray', edgecolors='w')
    
    # Create color map for components
    colors = plt.cm.get_cmap(colormap)(np.linspace(0, 1, len(components)))
    
    # Plot component regions and boundaries
    for i, (comp, color) in enumerate(zip(components, colors)):
        if dim1 in comp.rules and dim2 in comp.rules:
            # Get rule ranges for both dimensions
            x_min, x_max = comp.rules[dim1]
            y_min, y_max = comp.rules[dim2]
            
            if show_rules:
                # Plot rectangle for the component's 2D region
                rect = plt.Rectangle(
                    (x_min, y_min), 
                    x_max - x_min, 
                    y_max - y_min, 
                    fill=True, 
                    alpha=alpha,
                    color=color, 
                    label=f"Component {i+1}"
                )
                ax.add_patch(rect)
                
                # Add border lines
                ax.plot([x_min, x_max, x_max, x_min, x_min], 
                        [y_min, y_min, y_max, y_max, y_min], 
                        color=color, linewidth=1.5)
            
            if show_centers:
                # Calculate and plot center
                center_x = (x_min + x_max) / 2
                center_y = (y_min + y_max) / 2
                ax.plot(center_x, center_y, 'o', color=color, markersize=8)
                ax.text(center_x, center_y, str(i+1), 
                      ha='center', va='center', color='white', 
                      fontsize=8, fontweight='bold')
    
    # Set plot limits slightly larger than the feature space
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.1)
    
    # Add grid, labels, and title
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.set_xlabel(f"Feature {dim1}")
    ax.set_ylabel(f"Feature {dim2}")
    ax.set_title(title)
    
    if created_fig:
        # Only add legend if we created the figure
        if show_rules and len(components) <= 10:  # Only show legend if not too many components
            ax.legend(loc='upper right', bbox_to_anchor=(1.1, 1.1))
        plt.tight_layout()
        plt.show()
        return None
    
    return ax


def showcase_component_placement(
    n_features: int = 2,
    distributions: List[str] = ["normal"],
    n_components_list: List[int] = [5, 10],
    strategies: List[str] = ["random", "grid", "poisson_disk"],
    spacing_factors: List[float] = [1.2, 2.0],
    base_size: float = 0.5,
    random_seed: int = 42,
    figsize: Tuple[int, int] = (15, 12),
) -> None:
    """
    Showcase different parameters of the distributed component generator.
    
    This function creates a grid of plots showing how different parameters
    affect component placement in 2D feature space.
    
    Args:
        n_features: Number of features (should be 2 for visualization)
        distributions: List of distribution types
        n_components_list: List of different component counts to showcase
        strategies: List of placement strategies to compare
        spacing_factors: List of spacing factors to compare
        base_size: Base size for components
        random_seed: Random seed for reproducibility
        figsize: Figure size (width, height)
    """
    from ..mixture.mixture_gen_controlled import distributed_component_generator
    
    if n_features != 2:
        print("Warning: n_features should be 2 for visualization purposes")
        n_features = 2
    
    # Create subplot grid
    n_rows = len(spacing_factors)
    n_cols = len(strategies) * len(n_components_list)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
    
    # If only one row, wrap axes in a 2D array
    if n_rows == 1:
        axes = axes.reshape(1, -1)
    
    # Flatten axes for easy indexing if needed
    flat_axes = axes.flatten()
    
    # Iterate through all combinations
    plot_idx = 0
    for r, spacing in enumerate(spacing_factors):
        for s, strategy in enumerate(strategies):
            for c, n_comps in enumerate(n_components_list):
                col_idx = s * len(n_components_list) + c
                ax = axes[r, col_idx]
                
                # Generate components with current parameters
                components = distributed_component_generator(
                    n_components=n_comps,
                    n_features=n_features,
                    distributions=distributions,
                    base_size=base_size,
                    spacing_factor=spacing,
                    placement_strategy=strategy,
                    adaptive_sizing=True,
                    vary_size=False,
                    random_seed=random_seed + plot_idx,
                )
                
                # Plot the components
                title = f"{strategy.capitalize()}, {n_comps} comps\nSpacing={spacing}"
                plot_component_placement(
                    components=components,
                    feature_dims=(0, 1),
                    title=title,
                    alpha=0.3,
                    ax=ax,
                    show_rules=True,
                    show_centers=True,
                )
                
                # Increment plot index
                plot_idx += 1
    
    # Add a title for the entire figure
    plt.suptitle(f"Component Placement Comparison (base_size={base_size})", 
                 fontsize=16, y=0.99)
    
    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Adjust layout to make room for suptitle
    plt.show()


def plot_feature_distributions(
    X: np.ndarray, n_noise_features: int = 0, figsize: Optional[Tuple[int, int]] = None
) -> None:
    """
    Plot histograms of all features.

    Args:
        X: Feature array of shape (n_samples, n_features)
        n_noise_features: Number of noise features
        figsize: Figure size (width, height)
    """
    n_features = X.shape[1]
    n_informative = n_features - n_noise_features

    if figsize is None:
        figsize = (4 * min(4, n_features), 3 * ((n_features + 3) // 4))

    fig, axes = plt.subplots((n_features + 3) // 4, min(4, n_features), figsize=figsize)
    if n_features == 1:
        axes = np.array([axes])
    axes = axes.flatten()

    for i in range(n_features):
        sns.histplot(data=X[:, i], ax=axes[i])
        axes[i].set_title(
            f'{"Informative" if i < n_informative else "Noise"} Feature {i}'
        )

    # Remove empty subplots
    for i in range(n_features, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()
    plt.show()
