import numpy as np
import matplotlib.pyplot as plt
import os
import itertools
import math


def setup_plot_environment(debug_dir):
    """Set up the plotting environment with consistent styling."""
    plt.style.use('seaborn-v0_8-whitegrid')
    
    # Create debug directory if it doesn't exist
    # debug_dir = "examples/plots/debug"
    os.makedirs(debug_dir, exist_ok=True)
    
    return debug_dir


def get_colors(n):
    """Generate n distinct colors for plotting.
    
    Args:
        n (int): Number of colors needed
        
    Returns:
        list: List of RGB color tuples
    """
    if n <= 10:
        # Use tab10 colormap for small number of regressors
        cmap = plt.cm.get_cmap('tab10')
        return [cmap(i) for i in range(n)]
    else:
        # For larger numbers, create evenly spaced colors in HSV space
        return [plt.cm.hsv(i/n) for i in range(n)]


def get_contribution_colors(n):
    """Generate n colors for stacked contribution plot.
    
    Args:
        n (int): Number of colors needed
        
    Returns:
        list: List of RGB color tuples
    """
    if n <= 6:
        # For smaller number of regressors, use dark muted colors
        contribution_colors = [
            '#1A5276',  # Dark blue
            '#641E16',  # Dark red
            '#186A3B',  # Dark green
            '#7D3C98',  # Dark purple
            '#2E4053',  # Dark slate
            '#784212'   # Dark brown
        ][:n]
    else:
        # For larger sets, use a dark palette
        dark_cmap = plt.cm.get_cmap('Dark2')
        contribution_colors = [dark_cmap(i % 8) for i in range(n)]
    
    return contribution_colors


def plot_predictions(G, y, coef, g_pred, debug_dir):
    """Plot individual model predictions and aggregated result.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        coef (np.ndarray): Model coefficients, shape (n_models,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    num_regressors = G.shape[1]
    colors = get_colors(num_regressors)
    markers = itertools.cycle(['o', 's', '^', 'D', 'v', '<', '>', 'p', '*', 'h'])
    line_styles = itertools.cycle(['-', '--', '-.', ':'])
    
    plt.figure(figsize=(12, 6), facecolor='white')
    x = np.arange(len(y))

    # Plot ground truth
    plt.plot(x, y, 'r-', linewidth=2.5, label='Ground Truth (y)')
    plt.plot(x, y, 'ro', markersize=2)

    # Plot each regressor with its own color and style
    for i in range(num_regressors):
        color = colors[i]
        marker = next(markers)
        line_style = next(line_styles)
        plt.plot(x, G[:, i], line_style, color=color, linewidth=2, 
                 label=f'Reg{i+1} (coef={coef[i]:.3f})')
        plt.plot(x, G[:, i], marker, color=color, markersize=2)

    # Plot aggregated prediction
    plt.plot(x, g_pred, 'k-', linewidth=2.5, label='Aggregated (g_pred)')
    plt.plot(x, g_pred, 'ko', markersize=2)

    plt.xlabel('Sample Index', fontsize=12)
    plt.ylabel('Value', fontsize=12)
    plt.title('Model Predictions vs Ground Truth', fontsize=14, pad=10)
    plt.xticks(x, fontsize=10)
    plt.yticks(fontsize=10)
    plt.legend(frameon=True, facecolor='white', edgecolor='lightgray', fontsize=10)
    plt.tight_layout()
    plt.savefig(f"{debug_dir}/aggregator_predictions.png", dpi=300, bbox_inches='tight')
    plt.close()


def plot_residuals(G, y, g_pred, debug_dir):
    """Plot residuals between ground truth and aggregated predictions.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    plt.figure(figsize=(12, 6), facecolor='white')
    x = np.arange(len(y))
    residuals = y - g_pred

    # Plot residuals as line
    plt.plot(x, residuals, 'purple', linewidth=2.5, label='Residuals (y - g_pred)')
    plt.plot(x, residuals, 'o', color='purple', markersize=2)
    plt.axhline(y=0, color='k', linestyle='-', alpha=0.3, linewidth=1.5)

    # Plot RMSE lines
    rmse = np.sqrt(np.mean(residuals**2))
    plt.axhline(y=rmse, color='r', linestyle='--', alpha=0.7, linewidth=1.5, 
                label=f'RMSE: {rmse:.3f}')
    plt.axhline(y=-rmse, color='r', linestyle='--', alpha=0.7, linewidth=1.5)

    plt.xlabel('Sample Index', fontsize=12)
    plt.ylabel('Residual Value', fontsize=12)
    plt.title('Prediction Residuals', fontsize=14, pad=10)
    plt.xticks(x, fontsize=10)
    plt.yticks(fontsize=10)
    plt.legend(frameon=True, facecolor='white', edgecolor='lightgray', fontsize=10)
    plt.tight_layout()
    plt.savefig(f"{debug_dir}/aggregator_residuals.png", dpi=300, bbox_inches='tight')
    plt.close()


def plot_contributions(G, y, coef, g_pred, debug_dir):
    """Plot stacked contribution visualization.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        coef (np.ndarray): Model coefficients, shape (n_models,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    num_regressors = G.shape[1]
    x = np.arange(len(y))
    
    plt.figure(figsize=(12, 6), facecolor='white')

    # Create stacked area chart for any number of regressors
    contributions = np.zeros((len(y), num_regressors))
    for i in range(num_regressors):
        contributions[:, i] = G[:, i] * coef[i]

    # Get colors for contributions
    contribution_colors = get_contribution_colors(num_regressors)

    # Cumulative sum for stacking
    cumulative = np.zeros(len(y))
    for i in range(num_regressors):
        plt.fill_between(x, cumulative, cumulative + contributions[:, i], 
                        color=contribution_colors[i], alpha=0.5,  
                        label=f'Reg{i+1} Contribution ({coef[i]:.3f})')
        
        # Add a thin line at the boundaries for better distinction
        plt.plot(x, cumulative, color=contribution_colors[i], linewidth=1.5, alpha=0.8)
        cumulative += contributions[:, i]
        plt.plot(x, cumulative, color=contribution_colors[i], linewidth=1.5, alpha=0.8)

    # Add ground truth and aggregated prediction lines
    plt.plot(x, y, 'r-', linewidth=2.5, label='Ground Truth (y)')
    plt.plot(x, y, 'ro', markersize=2)
    plt.plot(x, g_pred, 'k-', linewidth=2.5, label='Aggregated (g_pred)')
    plt.plot(x, g_pred, 'ko', markersize=2)

    plt.xlabel('Sample Index', fontsize=12)
    plt.ylabel('Value', fontsize=12)
    plt.title('Model Contributions', fontsize=14, pad=10)
    plt.xticks(x, fontsize=10)
    plt.yticks(fontsize=10)

    # Add text annotation for coefficients
    coef_text = ", ".join([f"{i+1}: {c:.3f}" for i, c in enumerate(coef)])
    plt.annotate(f'Coefficients: [{coef_text}]',
                xy=(0.5, 0.01), xycoords='axes fraction',
                ha='center', va='bottom',
                bbox=dict(boxstyle="round,pad=0.5", fc="white", ec="lightgray", alpha=0.8))

    plt.legend(frameon=True, facecolor='white', edgecolor='lightgray', fontsize=10)
    plt.tight_layout()
    plt.savefig(f"{debug_dir}/aggregator_contributions.png", dpi=300, bbox_inches='tight')
    plt.close()


def plot_individual_predictions_grid(G, y, coef, g_pred, debug_dir):
    """Plot grid of individual predictions with ground truth and aggregated.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        coef (np.ndarray): Model coefficients, shape (n_models,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    num_regressors = G.shape[1]
    x = np.arange(len(y))
    colors = get_colors(num_regressors)
    
    # Determine optimal grid layout based on number of regressors
    total_plots = num_regressors  # Just the regressors - no separate aggregated plot

    # Determine grid dimensions (try to make it as square as possible)
    if total_plots <= 3:
        n_rows, n_cols = 1, total_plots
    elif total_plots == 4:
        n_rows, n_cols = 2, 2
    elif total_plots <= 6:
        n_rows, n_cols = 2, 3
    elif total_plots <= 9:
        n_rows, n_cols = 3, 3
    elif total_plots <= 12:
        n_rows, n_cols = 3, 4
    else:
        # For larger numbers, make a more rectangular grid
        n_cols = 4  # Maximum 4 columns
        n_rows = math.ceil(total_plots / n_cols)

    # Create figure with appropriate size
    fig_width = 5 * n_cols
    fig_height = 4 * n_rows
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), 
                             facecolor='white', constrained_layout=True)
    if n_rows * n_cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]

    # Plot individual predictions with ground truth and aggregated
    for i in range(num_regressors):
        ax = axes[i]
        
        # Plot regressor prediction
        ax.plot(x, G[:, i], color=colors[i], linewidth=2.5, 
                 label=f'Reg{i+1} (coef={coef[i]:.3f})')
        ax.plot(x, G[:, i], 'o', color=colors[i], markersize=2)
        
        # Also plot ground truth and aggregated for comparison
        ax.plot(x, y, 'r-', linewidth=2, label='Ground Truth', alpha=0.8)
        ax.plot(x, y, 'ro', markersize=2, alpha=0.6)
        ax.plot(x, g_pred, 'k--', linewidth=2, label='Aggregated', alpha=0.7)
        
        # Calculate individual RMSE
        rmse_i = np.sqrt(np.mean((G[:, i] - y)**2))
        
        ax.set_title(f'Regressor {i+1} (RMSE: {rmse_i:.3f})', fontsize=12, pad=8)
        ax.set_ylabel('Value', fontsize=11)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend(loc='best', fontsize=10, frameon=True, 
                  facecolor='white', edgecolor='lightgray')
        ax.tick_params(axis='both', which='major', labelsize=10)
        ax.set_xticks(x)

    # Hide unused subplots if needed
    for j in range(num_regressors, len(axes)):
        axes[j].set_visible(False)

    # Add a common x-label at the bottom
    fig.supxlabel('Sample Index', fontsize=14, y=-0.05)

    # Add suptitle with more space
    fig.suptitle('Individual Predictions with Ground Truth and Aggregated', 
                 fontsize=16, y=1.05)

    # Save the figure
    plt.savefig(f"{debug_dir}/individual_predictions.png", dpi=300, 
                bbox_inches='tight', facecolor='white')
    plt.close()


def plot_residuals_grid(G, y, g_pred, debug_dir):
    """Plot grid of individual residuals and aggregated residuals.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    num_regressors = G.shape[1]
    x = np.arange(len(y))
    colors = get_colors(num_regressors)
    
    # Determine optimal grid layout based on number of regressors + 1 (for aggregated)
    total_plots = num_regressors + 1

    # Determine grid dimensions (try to make it as square as possible)
    if total_plots <= 3:
        n_rows, n_cols = 1, total_plots
    elif total_plots == 4:
        n_rows, n_cols = 2, 2
    elif total_plots <= 6:
        n_rows, n_cols = 2, 3
    elif total_plots <= 9:
        n_rows, n_cols = 3, 3
    elif total_plots <= 12:
        n_rows, n_cols = 3, 4
    else:
        # For larger numbers, make a more rectangular grid
        n_cols = 4  # Maximum 4 columns
        n_rows = math.ceil(total_plots / n_cols)

    # Create figure with appropriate size
    fig_width = 5 * n_cols
    fig_height = 4 * n_rows
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(fig_width, fig_height), 
                             facecolor='white', constrained_layout=True)
    if n_rows * n_cols > 1:
        axes = axes.flatten()
    else:
        axes = [axes]

    # Calculate residuals
    indiv_residuals = G - y.reshape(-1, 1)
    indiv_rmse = np.sqrt(np.mean(indiv_residuals**2, axis=0))
    residuals = y - g_pred
    rmse = np.sqrt(np.mean(residuals**2))

    # Plot individual residuals
    for i in range(num_regressors):
        ax = axes[i]
        ax.plot(x, indiv_residuals[:, i], color=colors[i], linewidth=2.5, 
                label=f'Reg{i+1} Residuals')
        ax.plot(x, indiv_residuals[:, i], 'o', color=colors[i], markersize=2)
        ax.axhline(y=0, color='k', linestyle='-', alpha=0.3, linewidth=1.5)
        ax.axhline(y=indiv_rmse[i], color=colors[i], linestyle='--', alpha=0.7, linewidth=1.5,
                   label=f'RMSE: {indiv_rmse[i]:.3f}')
        ax.axhline(y=-indiv_rmse[i], color=colors[i], linestyle='--', alpha=0.7, linewidth=1.5)
        ax.set_title(f'Regressor {i+1}', fontsize=12, pad=8)
        ax.set_ylabel('Residual', fontsize=11)
        ax.grid(True, linestyle='--', alpha=0.7)
        ax.legend(loc='best', fontsize=10, frameon=True, 
                  facecolor='white', edgecolor='lightgray')
        ax.tick_params(axis='both', which='major', labelsize=10)
        ax.set_xticks(x)

    # Plot aggregated residuals (always last)
    ax = axes[num_regressors]
    ax.plot(x, residuals, 'purple', linewidth=2.5, label='Aggregated')
    ax.plot(x, residuals, 'o', color='purple', markersize=2)
    ax.axhline(y=0, color='k', linestyle='-', alpha=0.3, linewidth=1.5)
    ax.axhline(y=rmse, color='r', linestyle='--', alpha=0.7, linewidth=1.5,
              label=f'RMSE: {rmse:.3f}')
    ax.axhline(y=-rmse, color='r', linestyle='--', alpha=0.7, linewidth=1.5)
    ax.set_title('Aggregated Prediction', fontsize=12, pad=8)
    ax.set_ylabel('Residual', fontsize=11)
    ax.grid(True, linestyle='--', alpha=0.7)
    ax.legend(loc='best', fontsize=10, frameon=True, 
              facecolor='white', edgecolor='lightgray')
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.set_xticks(x)

    # Hide unused subplots if needed
    for j in range(total_plots, len(axes)):
        axes[j].set_visible(False)

    # Add a common x-label at the bottom
    fig.supxlabel('Sample Index', fontsize=14, y=0.02)

    # Add suptitle with more space
    fig.suptitle('Individual vs Aggregated Residuals Comparison', 
                 fontsize=16, y=1.02)

    # Save the figure
    plt.savefig(f"{debug_dir}/aggregator_residual_comparison.png", dpi=300, 
                bbox_inches='tight', facecolor='white')
    plt.close()


def plot_rmse_and_coef_comparison(G, y, coef, g_pred, debug_dir):
    """Plot RMSE comparison bar chart with model coefficients.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        coef (np.ndarray): Model coefficients, shape (n_models,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        None
    """
    num_regressors = G.shape[1]
    colors = get_colors(num_regressors)
    
    # Create a figure with two subplots (2 rows, 1 column)
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), facecolor='white', 
                                    gridspec_kw={'height_ratios': [1, 1]})
    
    # Calculate residuals and RMSE
    indiv_residuals = G - y.reshape(-1, 1)
    indiv_rmse = np.sqrt(np.mean(indiv_residuals**2, axis=0))
    residuals = y - g_pred
    rmse = np.sqrt(np.mean(residuals**2))
    
    # Top subplot: RMSE comparison
    all_rmse = np.append(indiv_rmse, rmse)
    labels = [f'Reg{i+1}' for i in range(num_regressors)] + ['Aggregated']
    bar_colors = colors + ['purple']

    # Create bar chart with border
    bars = ax1.bar(range(len(all_rmse)), all_rmse, color=bar_colors, alpha=0.7, 
                   edgecolor='black', linewidth=1)

    # Highlight the aggregated bar
    bars[-1].set_edgecolor('red')
    bars[-1].set_linewidth(2)

    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                 f'{height:.3f}', ha='center', va='bottom', fontsize=10)

    ax1.axhline(y=rmse, color='r', linestyle='--', alpha=0.7, linewidth=1.5,
                label=f'Aggregated RMSE: {rmse:.3f}')
    ax1.set_xticks(range(len(all_rmse)))
    ax1.set_xticklabels(labels, rotation=45 if len(all_rmse) > 6 else 0, fontsize=10)
    ax1.set_xlabel('Model', fontsize=12)
    ax1.set_ylabel('RMSE', fontsize=12)
    ax1.set_title('RMSE Comparison', fontsize=14, pad=10)
    ax1.grid(True, linestyle='--', alpha=0.7, axis='y')
    ax1.legend(fontsize=10, frameon=True, facecolor='white', edgecolor='lightgray')
    
    # Bottom subplot: Coefficient comparison
    # Add a "pseudo-coefficient" of 1.0 for the aggregated prediction for visual comparison
    all_coef = np.append(coef, sum(coef))  
    
    # Create bar chart for coefficients
    bars = ax2.bar(range(len(all_coef)), all_coef, color=bar_colors, alpha=0.7,
                   edgecolor='black', linewidth=1)
    
    # Highlight the aggregated bar (same as RMSE plot)
    bars[-1].set_edgecolor('red')
    bars[-1].set_linewidth(2)
    
    # Add value labels on top of bars
    for bar in bars:
        height = bar.get_height()
        # Position labels above or below bars depending on coefficient value
        pos_y = height + 0.02 if height >= 0 else height - 0.07
        ax2.text(bar.get_x() + bar.get_width()/2., pos_y,
                 f'{height:.3f}', ha='center', va='bottom', fontsize=10)
    
    # Add a horizontal line at y=0 for reference
    ax2.axhline(y=0, color='k', linestyle='-', alpha=0.3, linewidth=1.5)
    
    # Add horizontal line at sum of coefficients
    sum_coef = np.sum(coef)
    ax2.axhline(y=sum_coef, color='g', linestyle='--', alpha=0.7, linewidth=1.5,
                label=f'Sum of Coefficients: {sum_coef:.3f}')
    
    ax2.set_xticks(range(len(all_coef)))
    ax2.set_xticklabels(labels, rotation=45 if len(all_coef) > 6 else 0, fontsize=10)
    ax2.set_xlabel('Model', fontsize=12)
    ax2.set_ylabel('Coefficient Value', fontsize=12)
    ax2.set_title('Model Coefficients', fontsize=14, pad=10)
    ax2.grid(True, linestyle='--', alpha=0.7, axis='y')
    ax2.legend(fontsize=10, frameon=True, facecolor='white', edgecolor='lightgray')
    
    # Set y-axis limits to provide some padding
    y_min = min(0, np.min(all_coef) - 0.1)
    y_max = np.max(all_coef) + 0.1
    ax2.set_ylim(y_min, y_max)
    
    # Add overall title
    fig.suptitle('Model Performance and Coefficients', fontsize=16, y=1.0)
    
    plt.tight_layout()
    fig.subplots_adjust(top=0.94)  # Adjust for the suptitle
    plt.savefig(f"{debug_dir}/model_comparison.png", dpi=300, 
                bbox_inches='tight', facecolor='white')
    plt.close()


def plot_all(G, y, coef, g_pred, debug_dir):
    """Generate all plots for aggregator visualization.
    
    Args:
        G (np.ndarray): Matrix of individual model predictions, shape (n_samples, n_models)
        y (np.ndarray): Ground truth values, shape (n_samples,)
        coef (np.ndarray): Model coefficients, shape (n_models,)
        g_pred (np.ndarray): Aggregated predictions, shape (n_samples,)
        debug_dir (str): Directory to save plots
        
    Returns:
        str: Path to debug directory where plots are saved
    """
    # Setup plotting environment
    debug_dir = setup_plot_environment(debug_dir)
    
    # Generate all plots
    plot_predictions(G, y, coef, g_pred, debug_dir)
    plot_residuals(G, y, g_pred, debug_dir)
    plot_contributions(G, y, coef, g_pred, debug_dir)
    plot_individual_predictions_grid(G, y, coef, g_pred, debug_dir)
    plot_residuals_grid(G, y, g_pred, debug_dir)
    plot_rmse_and_coef_comparison(G, y, coef, g_pred, debug_dir)
    
    print(f"Plots saved to {debug_dir}/")
    return debug_dir


# Example usage
if __name__ == "__main__":
    from residual_chronos.Aggregator import SPAAggregator

    predictions = [[1.5, 2.5, 1.5, 2.5],
                [2.5, 1.5, 2.5, 2.5],
                [3.5, 3.5, 3.5, 3.5]]

    G = np.stack(predictions, axis=1)
    y = np.array([2, 2, 2, 2])
    expected_coef_normalized = np.array([0.5, 0.5])

    spa = SPAAggregator(num_models=G.shape[1], sigma=0.1)
    spa = spa.fit(G, y)
    coef = spa.coef_
    g_pred = spa.predict(G)

    print(f"Coefficients: {coef}")
    print(f"residuals: {np.sum((y - g_pred )** 2)}, residuals of G: {np.sum((G - y.reshape(-1, 1)) ** 2, axis=0)}")
    
    # Generate all plots
    plot_all(G, y, coef, g_pred, debug_dir="examples/plots/debug")
