# filename: codebase/global_projection_analysis.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import time
import glob


def analyze_projection_on_global_pca(data_path, per_ic_pca_results_path, database_path="data/"):
    """
    Relates per-IC latent space structures to the global latent space structure.

    Args:
        data_path (str): Path to the original .npy data file.
        per_ic_pca_results_path (str): Path to the .npz file from Step 2.
        database_path (str): Path to save plots and results.
    """
    print("Starting Step 5: Relating per-IC manifolds to global latent space structure")

    if not os.path.exists(database_path):
        os.makedirs(database_path)
        print("Created directory: " + database_path)

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    plt.rcParams['text.usetex'] = False  # Disable LaTeX rendering

    # 1. Load Data
    try:
        data_bundle = np.load(data_path)
        latent_space_data = data_bundle[:, :, :, 3:]  # Shape: (101, 103, 25, 10)
        print("Successfully loaded data_bundle from: " + data_path)
        print("Shape of latent_space_data: " + str(latent_space_data.shape))
    except Exception as e:
        print("Error loading data_bundle: " + str(e))
        return

    try:
        per_ic_results = np.load(per_ic_pca_results_path)
        all_centroids = per_ic_results['all_centroids']  # Shape: (25, 10)
        all_eigenvalues_per_ic = per_ic_results['all_eigenvalues']  # Shape: (25, 10)
        print("Successfully loaded per-IC PCA results from: " + per_ic_pca_results_path)
    except Exception as e:
        print("Error loading per-IC PCA results: " + str(e))
        return

    n_x, n_t, n_ic, n_latent_dim = latent_space_data.shape

    # 2. Global PCA
    print("\n--- Performing Global PCA ---")
    L_global = latent_space_data.reshape((n_x * n_t * n_ic, n_latent_dim))
    pca_global = PCA(n_components=n_latent_dim)
    pca_global.fit(L_global)

    global_explained_variance_ratio = pca_global.explained_variance_ratio_
    cumulative_global_variance = np.cumsum(global_explained_variance_ratio)

    # Determine d_glob for >= 99% variance
    variance_threshold_global = 0.99
    try:
        d_glob = np.where(cumulative_global_variance >= variance_threshold_global)[0][0] + 1
    except IndexError:
        d_glob = n_latent_dim  # Should not happen if 10 components explain >99%
    
    U_glob = pca_global.components_[:d_glob, :]  # Shape (d_glob, n_latent_dim)
    mean_L_global = pca_global.mean_             # Shape (n_latent_dim,)

    print("Global PCA: Determined d_glob = " + str(d_glob) + " components to explain >= " + str(variance_threshold_global * 100) + "% of global variance.")
    print("Actual global variance explained by " + str(d_glob) + " components: " + str(cumulative_global_variance[d_glob-1] * 100) + "%")


    # 3. Per-IC Analysis in Global Subspace
    print("\n--- Analyzing Per-IC Data Projection onto Global Subspace ---")
    variance_captured_ratios = np.zeros(n_ic)
    projected_centroids_coords = np.zeros((n_ic, d_glob))

    for k in range(n_ic):
        L_k_flat = latent_space_data[:, :, k, :].reshape((n_x * n_t, n_latent_dim))
        C_k = all_centroids[k, :]
        
        # Center L_k_flat on its own centroid C_k
        L_k_centered_on_C_k = L_k_flat - C_k
        
        # Total intrinsic variance of L_k (sum of its eigenvalues from per-IC PCA)
        total_variance_k = np.sum(all_eigenvalues_per_ic[k, :])
        if total_variance_k == 0: # Avoid division by zero if an IC has no variance
            variance_captured_ratios[k] = 0.0
            print("Warning: IC " + str(k) + " has zero total intrinsic variance.")
        else:
            # Project L_k_centered_on_C_k onto the global subspace (get coordinates)
            # U_glob.T has shape (n_latent_dim, d_glob)
            L_k_projected_coords = L_k_centered_on_C_k @ U_glob.T  # Shape (n_points, d_glob)
            
            # Variance of the projected data
            variance_in_global_subspace_k = np.sum(np.var(L_k_projected_coords, axis=0))
            variance_captured_ratios[k] = variance_in_global_subspace_k / total_variance_k
        
        # Project centroid C_k onto global subspace (get coordinates)
        C_k_centered_on_global_mean = C_k - mean_L_global
        projected_centroids_coords[k, :] = C_k_centered_on_global_mean @ U_glob.T

    # 4. Print Quantitative Results
    print("\n--- Percentage of Per-IC Variance Captured by Global " + str(d_glob) + "-dim Subspace ---")
    for k in range(n_ic):
        print("  IC " + str(k) + ": " + str(variance_captured_ratios[k] * 100.0) + "%")
    
    print("\nSummary for Variance Captured Ratios:")
    print("  Mean: " + str(np.mean(variance_captured_ratios) * 100.0) + "%")
    print("  Min:  " + str(np.min(variance_captured_ratios) * 100.0) + "%")
    print("  Max:  " + str(np.max(variance_captured_ratios) * 100.0) + "%")
    print("  Std:  " + str(np.std(variance_captured_ratios) * 100.0) + "%")

    # 5. Plotting
    # Plot 1: Variance Captured per IC
    fig1, ax1 = plt.subplots(figsize=(12, 7))
    ax1.bar(range(n_ic), variance_captured_ratios * 100, color='teal', alpha=0.8)
    ax1.set_xlabel("Initial Condition (IC) Index")
    ax1.set_ylabel("Per-IC Variance Captured by Global Subspace (%)")
    ax1.set_title("Percentage of Each IC's Variance Captured by " + str(d_glob) + "-dim Global Subspace")
    ax1.set_xticks(range(n_ic))
    ax1.set_xticklabels([str(i) for i in range(n_ic)], rotation=45, ha="right")
    ax1.set_ylim(0, 105)
    ax1.grid(axis='y', linestyle='--', alpha=0.7)
    fig1.tight_layout()
    plot1_filename = os.path.join(database_path, "variance_captured_by_global_subspace_5_" + timestamp + ".png")
    try:
        plt.savefig(plot1_filename, dpi=300)
        print("\nPlot 1 (Variance Captured) saved to: " + plot1_filename)
        print("Plot Description: Bar chart showing the percentage of each IC's intrinsic variance that is captured when projected onto the " + str(d_glob) + "-dimensional global principal subspace.")
    except Exception as e:
        print("Error saving plot 1: " + str(e))
    plt.close(fig1)

    # Plot 2: Projected Centroids in Global PC Space (2D)
    if d_glob >= 2:
        fig2, ax2 = plt.subplots(figsize=(10, 8))
        scatter = ax2.scatter(projected_centroids_coords[:, 0], projected_centroids_coords[:, 1],
                               c=range(n_ic), cmap='viridis', s=50, alpha=0.8)
        ax2.set_xlabel("Global PC1 Coordinate")
        ax2.set_ylabel("Global PC2 Coordinate")
        ax2.set_title("IC Centroids Projected onto First Two Global PCs (d_glob=" + str(d_glob) + ")")
        ax2.grid(True, linestyle='--', alpha=0.7)
        cbar = plt.colorbar(scatter, ax=ax2, label='Initial Condition (IC) Index')
        cbar.set_ticks(np.arange(0, n_ic, max(1, n_ic // 10)))
        for i in range(n_ic):
            ax2.text(projected_centroids_coords[i, 0] + 0.001 * (ax2.get_xlim()[1] - ax2.get_xlim()[0]), 
                     projected_centroids_coords[i, 1] + 0.001 * (ax2.get_ylim()[1] - ax2.get_ylim()[0]), 
                     str(i), fontsize=9)
        fig2.tight_layout()
        plot2_filename = os.path.join(database_path, "projected_centroids_2D_5_" + timestamp + ".png")
        try:
            plt.savefig(plot2_filename, dpi=300)
            print("\nPlot 2 (Projected Centroids 2D) saved to: " + plot2_filename)
            print("Plot Description: Scatter plot of the 25 IC centroids projected onto the first two global principal components. Points are colored and labeled by IC index.")
        except Exception as e:
            print("Error saving plot 2: " + str(e))
        plt.close(fig2)
    else:
        print("\nSkipping 2D scatter plot of projected centroids: d_glob < 2.")

    # Plot 3: Projected Centroids in Global PC Space (3D)
    if d_glob >= 3:
        fig3 = plt.figure(figsize=(12, 10))
        ax3 = fig3.add_subplot(111, projection='3d')
        scatter3d = ax3.scatter(projected_centroids_coords[:, 0], projected_centroids_coords[:, 1], projected_centroids_coords[:, 2],
                                c=range(n_ic), cmap='viridis', s=50, alpha=0.8)
        ax3.set_xlabel("Global PC1 Coordinate")
        ax3.set_ylabel("Global PC2 Coordinate")
        ax3.set_zlabel("Global PC3 Coordinate")
        ax3.set_title("IC Centroids Projected onto First Three Global PCs (d_glob=" + str(d_glob) + ")")
        cbar3d = plt.colorbar(scatter3d, ax=ax3, label='Initial Condition (IC) Index', shrink=0.75)
        cbar3d.set_ticks(np.arange(0, n_ic, max(1, n_ic // 10)))
        # Annotations in 3D can be tricky, often omitted for clarity or done selectively
        # for i in range(n_ic):
        #     ax3.text(projected_centroids_coords[i, 0], projected_centroids_coords[i, 1], projected_centroids_coords[i, 2],
        #              str(i), fontsize=8)
        fig3.tight_layout()
        plot3_filename = os.path.join(database_path, "projected_centroids_3D_5_" + timestamp + ".png")
        try:
            plt.savefig(plot3_filename, dpi=300)
            print("\nPlot 3 (Projected Centroids 3D) saved to: " + plot3_filename)
            print("Plot Description: 3D scatter plot of the 25 IC centroids projected onto the first three global principal components. Points are colored by IC index.")
        except Exception as e:
            print("Error saving plot 3: " + str(e))
        plt.close(fig3)
    else:
        print("\nSkipping 3D scatter plot of projected centroids: d_glob < 3.")

    print("\nFinished Step 5.")


if __name__ == "__main__":
    # Define paths
    # The user provided path for the main data bundle:
    data_file_path = '/Users/fanonymous/Documents/Software/AstroPilot/Project_turbulenceV1/data_for_Paco_turbulence_bundle.npy'
    output_database_path = "data/"  # General output directory

    # Check if the main data file exists
    if not os.path.exists(data_file_path):
        print("Error: Data file not found at " + data_file_path)
    else:
        # Find the latest per_ic_pca_results.npz file from Step 2
        list_of_files = glob.glob(os.path.join(output_database_path, 'per_ic_pca_results_*.npz'))
        if not list_of_files:
            print("Error: No 'per_ic_pca_results_*.npz' file found in " + output_database_path)
            print("Please ensure Step 2 (per_ic_pca_analysis) has been run successfully.")
        else:
            latest_per_ic_results_file = max(list_of_files, key=os.path.getctime)
            print("Using per-IC PCA results from: " + latest_per_ic_results_file)
            
            analyze_projection_on_global_pca(data_file_path, 
                                             latest_per_ic_results_file, 
                                             database_path=output_database_path)