import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os


def generate_community_labels(n, K):
    """
    Generates an array of community labels for a graph with n nodes and K equal-sized blocks.
    """
    block_size = n // K
    labels = np.zeros(n, dtype=int)
    for i in range(K):
        labels[i * block_size : (i + 1) * block_size] = i
    return labels

def plot_embedding(ax, s_n_val, K_val, n_val, data_dir):
    """
    Loads an embedding and creates a scatter plot of its first dimension,
    calculating and setting optimal axis limits for the subplot.
    """
    embedding_filename = os.path.join(data_dir, f"X_s{s_n_val}_n{n_val}_K{K_val}.csv")
    
    if not os.path.exists(embedding_filename):
        print(f"Warning: Embedding file not found, skipping plot for K={K_val}: {embedding_filename}")
        ax.set_title(f"K={K_val}\n(Data file not found)", fontsize=14)
        ax.text(0.5, 0.5, "Data not found", ha='center', va='center', transform=ax.transAxes)
        return

    try:
        # Load the embedding data
        X = pd.read_csv(embedding_filename, header=None).values
        
        if X.shape[1] < 1:
            # Handle case where embedding is empty or has no dimensions
            ax.set_title(f"K={K_val}\n(Dim < 1)", fontsize=14)
            ax.text(0.5, 0.5, "Dimension < 1", ha='center', va='center', transform=ax.transAxes)
            return

        # Generate the true community labels for coloring
        labels = generate_community_labels(n_val, K_val)
        
        # Create the scatter plot, forcing y-coordinates to 0
        scatter = ax.scatter(X[:, 0], np.zeros_like(X[:, 0]), c=labels, cmap='viridis', alpha=0.7)
        
        ax.set_title(f"K = {K_val}", fontsize=14)
        ax.set_xlabel("Embedding Value")
        ax.set_ylabel("")
        ax.grid(True, linestyle='--', alpha=0.5)
        
        # --- Calculate and apply limits for THIS subplot only ---
        x_min, x_max = np.min(X[:, 0]), np.max(X[:, 0])
        x_buffer = (x_max - x_min) * 0.1
        ax.set_xlim(x_min - x_buffer, x_max + x_buffer)
        ax.set_ylim(-0.1, 0.1) # Fixed small range around 0 for y-axis
        
        # Create a manual legend to ensure colors match
        cmap = plt.get_cmap('viridis')
        colors = cmap(np.linspace(0, 1, K_val))
        
        legend_handles = []
        for i in range(K_val):
            # Create a dummy Line2D object for each community's color
            handle = plt.Line2D([0], [0], marker='o', color='w', 
                                markerfacecolor=colors[i], markersize=10, 
                                label=f'Community {i+1}')
            legend_handles.append(handle)
        
        ax.legend(handles=legend_handles, title="Communities")

    except Exception as e:
        print(f"An error occurred while processing K={K_val}: {e}")
        ax.set_title(f"K={K_val}\n(Error)", fontsize=14)
        ax.text(0.5, 0.5, "Error during processing", ha='center', va='center', transform=ax.transAxes)

def main():
    """
    Main function to create and save the embedding visualization plot.
    """
    # --- Configuration ---
    n = 1200
    s_n = 1
    K_values = [2, 3, 4]
    
    data_dir = "../dat"
    img_dir = "../img"
    os.makedirs(img_dir, exist_ok=True)

    # --- Plotting Setup ---
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(1, len(K_values), figsize=(len(K_values) * 7, 6))

    # --- Generate each subplot ---
    for i, K in enumerate(K_values):
        # The limit calculation is now done inside plot_embedding
        plot_embedding(axes[i], s_n, K, n, data_dir)
        
    plt.tight_layout()
    
    # --- Save the plot ---
    output_filename = os.path.join(img_dir, f"embedding_direct_plot_n{n}_sn{s_n}.eps")
    plt.savefig(output_filename, format='eps', bbox_inches='tight')
    print(f"\nPlot saved to {output_filename}")

if __name__ == "__main__":
    main()
