# filename: codebase/centroid_pca_analysis.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import time
import glob


def analyze_centroid_structure(per_ic_pca_results_path, database_path="data/"):
    """
    Analyzes the geometric structure of latent space centroids.

    Args:
        per_ic_pca_results_path (str): Path to the .npz file containing per-IC PCA results (including centroids).
        database_path (str): Path to the directory where results (plots) will be saved.
    """
    print("Starting Step 3: Analysis of centroids and their geometric structure")

    # Create database_path directory if it doesn't exist
    if not os.path.exists(database_path):
        os.makedirs(database_path)
        print("Created directory: " + database_path)

    # Timestamp for unique filenames
    timestamp = time.strftime("%Y%m%d-%H%M%S")

    # 1. Load Centroids
    try:
        data = np.load(per_ic_pca_results_path)
        all_centroids = data['all_centroids']
        print("Successfully loaded centroids from: " + per_ic_pca_results_path)
        print("Shape of all_centroids: " + str(all_centroids.shape))  # Expected (25, 10)
    except Exception as e:
        print("Error loading centroids data: " + str(e))
        return

    if all_centroids.shape[1] == 0:
        print("Error: Centroids data is empty or has no features.")
        return
        
    # 2. Perform PCA on Centroids
    # n_components can be at most min(n_samples, n_features)
    # Here, n_samples=25 (ICs), n_features=10 (latent_dims)
    # So, n_components can be up to 10.
    n_components_pca = min(all_centroids.shape[0], all_centroids.shape[1])
    
    pca_centroids = PCA(n_components=n_components_pca)
    try:
        projected_centroids = pca_centroids.fit_transform(all_centroids)
    except ValueError as ve:
        print("ValueError during PCA fitting: " + str(ve))
        # This can happen if variance is zero for some components, e.g. if all centroids are identical
        # Or if n_components is too large for the data's rank
        if "n_components=" + str(n_components_pca) + " must be between 0 and min(n_samples, n_features)=" + str(min(all_centroids.shape)) + " with svd_solver='full'" in str(ve):
             print("Attempting PCA with fewer components if possible or handling degenerate cases.")
             # Potentially, one could try with n_components_pca -1, or check data rank.
             # For now, we will just report the error and exit this part.
        return


    eigenvalues_centroids = pca_centroids.explained_variance_
    explained_variance_ratio_centroids = pca_centroids.explained_variance_ratio_
    cumulative_explained_variance_centroids = np.cumsum(explained_variance_ratio_centroids)

    # 3. Record and Print PCA Results
    print("\n--- PCA of IC Centroids (" + str(all_centroids.shape[0]) + " centroids, " + str(all_centroids.shape[1]) + " dimensions) ---")
    print("Number of principal components considered: " + str(n_components_pca))
    print("\nEigenvalues (Explained Variance per component):")
    for i, eigval in enumerate(eigenvalues_centroids):
        print("  CPC" + str(i+1) + ": " + str(eigval))

    print("\nExplained Variance Ratio per component:")
    # Header for the table
    header = "Centroid PC | Eigenvalue      | Var. Explained (%) | Cum. Var. Explained (%)"
    print(header)
    print("-" * len(header))
    for i in range(n_components_pca):
        pc_label = "CPC" + str(i+1)
        eig_val_str = "{:.6f}".format(eigenvalues_centroids[i])
        var_exp_str = "{:.2f}".format(explained_variance_ratio_centroids[i] * 100.0)
        cum_var_exp_str = "{:.2f}".format(cumulative_explained_variance_centroids[i] * 100.0)
        
        row = pc_label.ljust(11) + " | " + \
              eig_val_str.rjust(15) + " | " + \
              var_exp_str.rjust(18) + " | " + \
              cum_var_exp_str.rjust(23)
        print(row)


    # 4. Generate and Save Plots
    plt.rcParams['text.usetex'] = False  # Disable LaTeX rendering

    # Plot 1: Scree Plot for Centroid PCA
    fig1, ax1_1 = plt.subplots(figsize=(10, 6))
    
    pc_numbers = range(1, n_components_pca + 1)
    ax1_1.bar(pc_numbers, explained_variance_ratio_centroids * 100,
            alpha=0.7, align='center', label='Individual Explained Variance (%)', color='darkcyan')
    ax1_1.set_xlabel('Principal Component Number (for Centroids)')
    ax1_1.set_ylabel('Explained Variance (%)', color='darkcyan')
    ax1_1.tick_params(axis='y', labelcolor='darkcyan')
    ax1_1.set_xticks(pc_numbers)

    ax1_2 = ax1_1.twinx()
    ax1_2.plot(pc_numbers, cumulative_explained_variance_centroids * 100,
             color='orangered', marker='o', linestyle='-', linewidth=2, label='Cumulative Explained Variance (%)')
    ax1_2.set_ylabel('Cumulative Explained Variance (%)', color='orangered')
    ax1_2.tick_params(axis='y', labelcolor='orangered')
    ax1_2.set_ylim(0, 105)

    plt.title('PCA of IC Centroids: Explained Variance')
    fig1.tight_layout()
    
    lines, labels = ax1_1.get_legend_handles_labels()
    lines2, labels2 = ax1_2.get_legend_handles_labels()
    ax1_2.legend(lines + lines2, labels + labels2, loc='center right')
    
    plt.grid(True, linestyle='--', alpha=0.7)

    plot1_filename = os.path.join(database_path, "centroid_pca_scree_plot_3_" + timestamp + ".png")
    try:
        plt.savefig(plot1_filename, dpi=300)
        print("\nCentroid PCA scree plot saved to: " + plot1_filename)
        print("Plot Description: Scree plot for PCA performed on the 25 IC centroids. It shows individual and cumulative variance explained by principal components of the centroid distribution.")
    except Exception as e:
        print("Error saving plot 1 (scree plot): " + str(e))
    plt.close(fig1)

    # Plot 2: 2D Scatter Plot of Centroids (PC1 vs PC2)
    if n_components_pca >= 2:
        fig2, ax2 = plt.subplots(figsize=(10, 8))
        scatter = ax2.scatter(projected_centroids[:, 0], projected_centroids[:, 1], 
                              c=range(all_centroids.shape[0]), cmap='viridis', s=50, alpha=0.8)
        ax2.set_xlabel('Centroid Principal Component 1 (CPC1)')
        ax2.set_ylabel('Centroid Principal Component 2 (CPC2)')
        ax2.set_title('IC Centroids Projected onto First Two Principal Components')
        ax2.grid(True, linestyle='--', alpha=0.7)
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax2, label='Initial Condition (IC) Index')
        cbar.set_ticks(np.arange(0, all_centroids.shape[0], 
                                 max(1, all_centroids.shape[0] // 10)))

        # Annotate points with IC index
        for i in range(all_centroids.shape[0]):
            ax2.text(projected_centroids[i, 0] + 0.01 * (ax2.get_xlim()[1] - ax2.get_xlim()[0]), 
                     projected_centroids[i, 1] + 0.01 * (ax2.get_ylim()[1] - ax2.get_ylim()[0]), 
                     str(i), fontsize=9)
        
        fig2.tight_layout()
        plot2_filename = os.path.join(database_path, "centroid_pca_scatter_2D_3_" + timestamp + ".png")
        try:
            plt.savefig(plot2_filename, dpi=300)
            print("\n2D scatter plot of centroids saved to: " + plot2_filename)
            print("Plot Description: Scatter plot of the 25 IC centroids projected onto their first two principal components (CPC1 vs CPC2). Points are colored by IC index, and labeled with their index.")
        except Exception as e:
            print("Error saving plot 2 (2D scatter): " + str(e))
        plt.close(fig2)
    else:
        print("\nSkipping 2D scatter plot: Not enough principal components (" + str(n_components_pca) + ").")


    # Plot 3: 3D Scatter Plot of Centroids (PC1 vs PC2 vs PC3)
    if n_components_pca >= 3:
        fig3 = plt.figure(figsize=(12, 10))
        ax3 = fig3.add_subplot(111, projection='3d')
        scatter3d = ax3.scatter(projected_centroids[:, 0], projected_centroids[:, 1], projected_centroids[:, 2],
                                c=range(all_centroids.shape[0]), cmap='viridis', s=50, alpha=0.8)
        
        ax3.set_xlabel('Centroid Principal Component 1 (CPC1)')
        ax3.set_ylabel('Centroid Principal Component 2 (CPC2)')
        ax3.set_zlabel('Centroid Principal Component 3 (CPC3)')
        ax3.set_title('IC Centroids Projected onto First Three Principal Components')
        
        # Add colorbar
        cbar3d = plt.colorbar(scatter3d, ax=ax3, label='Initial Condition (IC) Index', shrink=0.75)
        cbar3d.set_ticks(np.arange(0, all_centroids.shape[0], 
                                   max(1, all_centroids.shape[0] // 10)))

        # Annotate points (can be crowded in 3D, use with caution or make optional)
        # for i in range(all_centroids.shape[0]):
        #     ax3.text(projected_centroids[i, 0], projected_centroids[i, 1], projected_centroids[i, 2],
        #              str(i), fontsize=8)

        fig3.tight_layout()
        plot3_filename = os.path.join(database_path, "centroid_pca_scatter_3D_3_" + timestamp + ".png")
        try:
            plt.savefig(plot3_filename, dpi=300)
            print("\n3D scatter plot of centroids saved to: " + plot3_filename)
            print("Plot Description: 3D scatter plot of the 25 IC centroids projected onto their first three principal components (CPC1, CPC2, CPC3). Points are colored by IC index.")
        except Exception as e:
            print("Error saving plot 3 (3D scatter): " + str(e))
        plt.close(fig3)
    else:
        print("\nSkipping 3D scatter plot: Not enough principal components (" + str(n_components_pca) + ").")

    print("\nFinished Step 3.")


if __name__ == "__main__":
    # Define the path for loading inputs and saving outputs
    output_database_path = "data/"

    # Find the latest per_ic_pca_results.npz file from Step 2
    list_of_files = glob.glob(os.path.join(output_database_path, 'per_ic_pca_results_*.npz'))
    if not list_of_files:
        print("Error: No 'per_ic_pca_results_*.npz' file found in " + output_database_path)
        print("Please ensure Step 2 (per_ic_pca_analysis) has been run successfully.")
    else:
        latest_per_ic_results_file = max(list_of_files, key=os.path.getctime)
        print("Using per-IC PCA results from: " + latest_per_ic_results_file)
        analyze_centroid_structure(latest_per_ic_results_file, database_path=output_database_path)
