# filename: codebase/global_pca_analysis.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import time


def perform_global_pca_analysis(data_path, database_path="data/"):
    """
    Loads latent space data, performs global PCA, and saves results.

    Args:
        data_path (str): Path to the .npy data file.
        database_path (str): Path to the directory where results (plots) will be saved.
    """
    print("Starting Step 1: Data loading, latent space extraction, and global latent space PCA")

    # Create database_path directory if it doesn't exist
    if not os.path.exists(database_path):
        os.makedirs(database_path)
        print("Created directory: " + database_path)

    # 1. Data Loading
    try:
        data_bundle = np.load(data_path)
        print("Successfully loaded data from: " + data_path)
    except Exception as e:
        print("Error loading data: " + str(e))
        return

    # 2. Verification
    print("Shape of loaded data_bundle: " + str(data_bundle.shape))
    if data_bundle.shape != (101, 103, 25, 13):
        print("Warning: Expected data dimensions (101, 103, 25, 13), but got " + str(data_bundle.shape))
    else:
        print("Data dimensions successfully verified: (101, 103, 25, 13)")

    # 3. Latent Space Extraction
    # The last 10 components of the 4th axis represent the latent space.
    latent_space_data = data_bundle[:, :, :, 3:]
    print("Shape of extracted latent_space_data: " + str(latent_space_data.shape))  # Expected: (101, 103, 25, 10)

    # 4. Reshape for Global Analysis
    n_x, n_t, n_ic, n_latent_dim = latent_space_data.shape
    L_global = latent_space_data.reshape((n_x * n_t * n_ic, n_latent_dim))
    print("Shape of L_global (reshaped for PCA): " + str(L_global.shape))  # Expected: (101*103*25, 10) = (260075, 10)

    # 5. Compute Mean, Covariance, and Perform PCA on Global Latent Space
    print("\n--- Global Latent Space PCA ---")

    # Mean vector
    mean_L_global = np.mean(L_global, axis=0)
    print("Mean vector of L_global (shape " + str(mean_L_global.shape) + "):\n" + str(mean_L_global))

    # Covariance matrix
    # For large datasets, computing full covariance explicitly can be memory intensive.
    # PCA implementations often center data internally.
    # However, for reporting, we can compute it.
    # To avoid potential issues with very large L_global, let's use np.cov with rowvar=False
    # and ensure L_global is not excessively large before this step if memory is a concern.
    # Given L_global.shape is (260075, 10), np.cov(L_global, rowvar=False) is feasible.
    
    # Centering the data for covariance calculation
    L_global_centered = L_global - mean_L_global
    cov_L_global = np.cov(L_global_centered, rowvar=False)
    print("Covariance matrix of L_global (shape " + str(cov_L_global.shape) + "):\n" + str(cov_L_global))

    # Perform PCA
    pca_global = PCA(n_components=n_latent_dim)
    pca_global.fit(L_global)  # PCA is performed on centered data by default

    eigenvalues_global = pca_global.explained_variance_
    explained_variance_ratio_global = pca_global.explained_variance_ratio_
    cumulative_explained_variance_global = np.cumsum(explained_variance_ratio_global)

    print("\nGlobal PCA Results:")
    print("Eigenvalues (Explained Variance per component):")
    for i, eigval in enumerate(eigenvalues_global):
        print("  PC" + str(i+1) + ": " + str(eigval))

    print("\nExplained Variance Ratio per component:")
    for i, ratio in enumerate(explained_variance_ratio_global):
        print("  PC" + str(i+1) + ": " + str(ratio * 100.0) + "%")

    print("\nCumulative Explained Variance:")
    for i, cum_ratio in enumerate(cumulative_explained_variance_global):
        print("  Up to PC" + str(i+1) + ": " + str(cum_ratio * 100.0) + "%")

    # 6. Save summary statistics and plots (e.g., scree plot of explained variance)
    plt.rcParams['text.usetex'] = False  # Disable LaTeX rendering

    fig, ax1 = plt.subplots(figsize=(10, 6))
    
    # Bar plot for individual explained variance
    ax1.bar(range(1, n_latent_dim + 1), explained_variance_ratio_global * 100,
            alpha=0.7, align='center', label='Individual Explained Variance (%)', color='steelblue')
    ax1.set_xlabel('Principal Component Number')
    ax1.set_ylabel('Explained Variance (%)', color='steelblue')
    ax1.tick_params(axis='y', labelcolor='steelblue')
    ax1.set_xticks(range(1, n_latent_dim + 1))

    # Line plot for cumulative explained variance
    ax2 = ax1.twinx()
    ax2.plot(range(1, n_latent_dim + 1), cumulative_explained_variance_global * 100,
             color='firebrick', marker='o', linestyle='-', linewidth=2, label='Cumulative Explained Variance (%)')
    ax2.set_ylabel('Cumulative Explained Variance (%)', color='firebrick')
    ax2.tick_params(axis='y', labelcolor='firebrick')
    ax2.set_ylim(0, 105)  # Ensure 100% is visible

    plt.title('Global Latent Space: Explained Variance by Principal Components')
    fig.tight_layout()  # Adjust layout to prevent overlap
    
    # Add legend by collecting handles and labels from both axes
    lines, labels = ax1.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax2.legend(lines + lines2, labels + labels2, loc='center right')
    
    plt.grid(True, linestyle='--', alpha=0.7)

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    plot_filename = os.path.join(database_path, "global_pca_scree_plot_1_" + timestamp + ".png")
    try:
        plt.savefig(plot_filename, dpi=300)
        print("\nScree plot saved to: " + plot_filename)
        print("Plot Description: Scree plot showing the percentage of variance explained by each of the 10 principal components for the global latent space, along with the cumulative explained variance. The x-axis lists the principal components, the left y-axis shows individual explained variance (%), and the right y-axis shows cumulative explained variance (%).")
    except Exception as e:
        print("Error saving plot: " + str(e))
    plt.close(fig)
    
    print("\nFinished Step 1.")


if __name__ == "__main__":
    # This script expects the data file to be in a specific path.
    # For execution in the AstroPilot environment, this path might need adjustment
    # or the file might need to be placed in a location accessible by the script.
    # The user provided path:
    data_file_path = '/Users/fanonymous/Documents/Software/AstroPilot/Project_turbulenceV1/data_for_Paco_turbulence_bundle.npy'
    
    # Check if the file exists before running
    if not os.path.exists(data_file_path):
        print("Error: Data file not found at " + data_file_path)
        print("Please ensure the file path is correct or the file is available.")
    else:
        # Define the path for saving outputs (e.g., plots)
        # The instructions specify to save under "data/"
        output_database_path = "data/"
        perform_global_pca_analysis(data_file_path, database_path=output_database_path)