import numpy as np
import matplotlib.pyplot as plt
import os
import re
from pathlib import Path
from matplotlib.colors import LinearSegmentedColormap

# Reuse the histogram organization function from the previous implementation
def organize_histograms(root_folder, name_condition=None):
    """Organize histograms by pruner, n_layers, compression, and hidden_dim."""
    organized_data = {}
    
    # Get all folder paths
    folder_paths = [f for f in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, f))]
    
    # Filter for sorted folders only as specified
    if name_condition is not None:
        sorted_folders = [f for f in folder_paths if f.endswith(f"_sorted{name_condition}")]
    else:
        try:
            sorted_folders = [f for f in folder_paths if f.endswith("_sorted")]
        except:
            sorted_folders = [f for f in folder_paths if f.endswith("_sortedfirst")]
    
    for folder in sorted_folders:
        # Extract parameters from folder name using regex
        match = re.search(r'fc_(\w+)_L_(\d+)_N_(\d+)_compression_([0-9.]+)_sorted', folder)
        if match:
            pruner, n_layers, hidden_dim, compression = match.groups()
            n_layers = int(n_layers)
            hidden_dim = int(hidden_dim)
            compression = float(compression)
            
            if pruner == 'synflow' and hidden_dim == 1000 and n_layers == 4:
                continue
            if hidden_dim == 1024 and pruner == 'synflow':
                hidden_dim = 1000
            
            # Initialize nested dictionaries if they don't exist
            if pruner not in organized_data:
                organized_data[pruner] = {}
            if n_layers not in organized_data[pruner]:
                organized_data[pruner][n_layers] = {}
            if compression not in organized_data[pruner][n_layers]:
                organized_data[pruner][n_layers][compression] = {}
            
            # Load histogram and store it
            histogram_file = os.path.join(root_folder, folder, f"histogram_{pruner}.npy")
            if os.path.exists(histogram_file):
                histogram = np.load(histogram_file)
                organized_data[pruner][n_layers][compression][hidden_dim] = histogram
            else:
                print(f"Warning: Histogram file not found: {histogram_file}")
    
    return organized_data

def visualize_histogram_heatmaps(organized_data, pruner, n_layers, cmap_name='viridis', name_condition=None):
    """
    Visualize 2D histograms as heatmaps for a specific pruner and n_layers pair.
    
    Args:
        organized_data: Dictionary with organized histogram data
        pruner: Pruner name to visualize (e.g., 'rand', 'snip', 'synflow')
        n_layers: Number of layers to visualize
        cmap_name: Matplotlib colormap name to use
    """
    # Check if the specified pruner and n_layers exist in the data
    if pruner not in organized_data or n_layers not in organized_data[pruner]:
        print(f"No data available for pruner={pruner}, n_layers={n_layers}")
        return
    
    # Define compressions and hidden_dims
    compressions = sorted(organized_data[pruner][n_layers].keys())
    all_hidden_dims = set()
    for comp in compressions:
        all_hidden_dims.update(organized_data[pruner][n_layers][comp].keys())
    hidden_dims = sorted(all_hidden_dims)
    
    # Create a figure with subplots (rows=compressions, columns=hidden_dims)
    fig, axes = plt.subplots(len(compressions), len(hidden_dims), 
                             figsize=(4*len(hidden_dims), 3*len(compressions)), 
                             squeeze=False)
    
    # Set colormap
    cmap = plt.get_cmap(cmap_name)
    
    # Set title
    if name_condition is not None:
        fig.suptitle(f"2D Histograms for Pruner: {pruner}, {name_condition} hidden layer", fontsize=16)
    else:
        fig.suptitle(f"2D Histograms for Pruner: {pruner}", fontsize=16)
    
    # Add row and column headers
    compression_mapping = {
        '0.5': '70%',
        '0.75': '80%',
        '1.0': '90%',
    }
    for i, comp in enumerate(compressions):
        axes[i, 0].set_ylabel(f"Sparsity: {compression_mapping[str(comp)]}", fontsize=12)
    
    for j, hidden_dim in enumerate(hidden_dims):
        axes[0, j].set_title(f"Hidden Dim: {hidden_dim}", fontsize=12)
    
    # Create a reference vmin and vmax for consistent color scaling across all subplots
    global_min = float('inf')
    global_max = float('-inf')
    
    # Find global min and max for consistent color scaling
    for i, comp in enumerate(compressions):
        for j, hidden_dim in enumerate(hidden_dims):
            if hidden_dim in organized_data[pruner][n_layers][comp]:
                histogram = organized_data[pruner][n_layers][comp][hidden_dim]
                global_min = min(global_min, np.min(histogram))
                global_max = max(global_max, np.max(histogram))
    
    # Plot each histogram as a heatmap
    for i, comp in enumerate(compressions):
        for j, hidden_dim in enumerate(hidden_dims):
            ax = axes[i, j]
            
            # Check if histogram exists
            if hidden_dim in organized_data[pruner][n_layers][comp]:
                histogram = organized_data[pruner][n_layers][comp][hidden_dim]
                
                # Create heatmap
                im = ax.imshow(histogram, cmap=cmap, aspect='auto', 
                               origin='lower', vmin=0, vmax=1,
                               extent=[0, 1, 0, 1])  # Set extent to [0,1] range
                
                # Set tick marks in range [0, 1]
                ax.set_xticks(np.linspace(0, 1, 5))
                ax.set_yticks(np.linspace(0, 1, 5))
                
                # Set tick labels format
                ax.set_xticklabels([f"{x:.1f}" for x in np.linspace(0, 1, 5)])
                ax.set_yticklabels([f"{y:.1f}" for y in np.linspace(0, 1, 5)])
                
                # Add grid for reference
                ax.grid(alpha=0.3)
            else:
                ax.text(0.5, 0.5, "No data", ha='center', va='center')
                ax.set_xticks([])
                ax.set_yticks([])
    
    # Add colorbar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    cbar = fig.colorbar(im, cax=cbar_ax)
    cbar.set_label('Histogram Value')
    
    # Adjust layout
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])
    
    return fig

def visualize_all_heatmaps(organized_data, cmaps=None):
    """
    Visualize 2D histogram heatmaps for all pruner and n_layers combinations.
    
    Args:
        organized_data: Dictionary with organized histogram data
        cmaps: List of colormap names to use (one per pruner-n_layers pair)
    """
    if cmaps is None:
        # Define colormap options - one for each pruner for better visual distinction
        cmaps = {
            'rand': 'viridis',
            'snip': 'plasma', 
            'synflow': 'inferno'
        }
    
    # Create visualizations for each pruner and n_layers pair
    for pruner in organized_data:
        for n_layers in organized_data[pruner]:
            # Choose colormap
            cmap_name = cmaps.get(pruner, 'viridis')
            
            # Create and save figure
            fig = visualize_histogram_heatmaps(organized_data, pruner, n_layers, cmap_name)
            if fig:
                filename = f"histograms_{pruner}_L{n_layers}.png"
                fig.savefig(filename, dpi=300, bbox_inches='tight')
                print(f"Saved {filename}")
                plt.close(fig)

def main(root_folder):
    """Execute the workflow to analyze and visualize histogram heatmaps."""
    # Organize histograms
    print("Organizing histograms...")
    name_condition = 'first'
    organized_data = organize_histograms(root_folder)
    
    # Define custom colormap options
    cmaps = {
        'rand': 'viridis',    # Blue-green-yellow
        'snip': 'plasma',     # Purple-orange
        'synflow': 'inferno'  # Black-red-yellow
    }
    
    # # Visualize all heatmaps
    # print("Visualizing histogram heatmaps...")
    # visualize_all_heatmaps(organized_data, cmaps)
    
    # For a specific pruner-n_layers pair (as example)
    pruner = 'synflow'
    n_layers = 3
    print(f"Creating detailed visualization for {pruner} pruner with {n_layers} layers...")
    fig = visualize_histogram_heatmaps(organized_data, pruner, n_layers, 'viridis')
    if fig:
        plt.savefig(f"./visualize_histograms/histograms_{pruner}_L{n_layers}_{name_condition}.png", dpi=300, bbox_inches='tight')
        plt.show()
    
    print("Done!")

# Example usage
if __name__ == "__main__":
    root_folder = "./histogram_results"  # Replace with actual path
    main(root_folder)