# filename: codebase/per_ic_pca_analysis.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import time

def perform_per_ic_pca_analysis(data_path, database_path="data/"):
    """
    Performs per-initial condition (IC) PCA on the latent space data.

    Args:
        data_path (str): Path to the .npy data file.
        database_path (str): Path to the directory where results (data and plots) will be saved.
    """
    print("Starting Step 2: Per-initial condition (IC) latent space analysis and PCA")

    # Create database_path directory if it doesn't exist
    if not os.path.exists(database_path):
        os.makedirs(database_path)
        print("Created directory: " + database_path)

    # Timestamp for unique filenames
    timestamp = time.strftime("%Y%m%d-%H%M%S")

    # 1. Data Loading
    try:
        data_bundle = np.load(data_path)
        print("Successfully loaded data from: " + data_path)
    except Exception as e:
        print("Error loading data: " + str(e))
        return

    # Extract latent space data
    latent_space_data = data_bundle[:, :, :, 3:]
    # Expected shape: (101, 103, 25, 10)
    n_x, n_t, n_ic, n_latent_dim = latent_space_data.shape
    print("Shape of extracted latent_space_data: " + str(latent_space_data.shape))

    # Initialize storage for per-IC results
    all_centroids = np.zeros((n_ic, n_latent_dim))
    all_principal_vectors = np.zeros((n_ic, n_latent_dim, n_latent_dim))  # (n_ic, n_components, n_features)
    all_eigenvalues = np.zeros((n_ic, n_latent_dim))
    all_explained_variance_ratios = np.zeros((n_ic, n_latent_dim))
    all_cumulative_explained_variance = np.zeros((n_ic, n_latent_dim))
    intrinsic_dims_95_variance = np.zeros(n_ic, dtype=int)

    variance_threshold = 0.95

    # 2. Per-IC PCA Loop
    print("\n--- Per-Initial Condition PCA Results ---")
    for k in range(n_ic):
        print("\n--- PCA for Initial Condition " + str(k) + " ---")
        L_k = latent_space_data[:, :, k, :]  # Shape (101, 103, 10)
        L_k_flat = L_k.reshape((n_x * n_t, n_latent_dim))  # Shape (10403, 10)

        # Centroid
        C_k = np.mean(L_k_flat, axis=0)
        all_centroids[k, :] = C_k
        print("Centroid (C_" + str(k) + "): " + str(C_k))

        # PCA
        pca_k = PCA(n_components=n_latent_dim)
        pca_k.fit(L_k_flat)

        eigenvalues_k = pca_k.explained_variance_
        explained_variance_ratio_k = pca_k.explained_variance_ratio_
        principal_vectors_k = pca_k.components_
        
        all_eigenvalues[k, :] = eigenvalues_k
        all_explained_variance_ratios[k, :] = explained_variance_ratio_k
        all_principal_vectors[k, :, :] = principal_vectors_k

        print("Eigenvalues: " + str(eigenvalues_k))
        print("Explained Variance Ratio (%): " + str(explained_variance_ratio_k * 100.0))

        cumulative_explained_k = np.cumsum(explained_variance_ratio_k)
        all_cumulative_explained_variance[k, :] = cumulative_explained_k
        print("Cumulative Explained Variance (%): " + str(cumulative_explained_k * 100.0))

        # Intrinsic dimensionality for 95% variance
        try:
            dim_95 = np.where(cumulative_explained_k >= variance_threshold)[0][0] + 1
        except IndexError:  # Should not happen if 10 components explain >95%
            dim_95 = n_latent_dim 
        intrinsic_dims_95_variance[k] = dim_95
        print("Intrinsic Dimensionality (" + str(variance_threshold * 100) + "% variance): " + str(dim_95) + " components")

    # 3. Summary Statistics Calculation
    print("\n\n--- Summary of Per-IC PCA Results (averaged over " + str(n_ic) + " ICs) ---")
    avg_eigenvalues = np.mean(all_eigenvalues, axis=0)
    avg_explained_variance_ratios = np.mean(all_explained_variance_ratios, axis=0)
    avg_cumulative_explained_variance = np.mean(all_cumulative_explained_variance, axis=0)
    std_cumulative_explained_variance = np.std(all_cumulative_explained_variance, axis=0)

    print("Principal Component | Avg. Eigenvalue | Avg. Var. Explained (%) | Avg. Cum. Var. Explained (%) | StdDev of Cum. Var. Expl. (%)")
    print("--------------------|-----------------|-------------------------|------------------------------|-------------------------------")
    for i in range(n_latent_dim):
        pc_str = "PC" + str(i+1)
        avg_eig_str = "{:.4f}" % avg_eigenvalues[i]
        avg_var_exp_str = "{:.2f}" % (avg_explained_variance_ratios[i] * 100.0)
        avg_cum_var_exp_str = "{:.2f}" % (avg_cumulative_explained_variance[i] * 100.0)
        std_cum_var_exp_str = "{:.2f}" % (std_cumulative_explained_variance[i] * 100.0)
        print(pc_str.ljust(20) + "|" + avg_eig_str.rjust(17) + "|" + avg_var_exp_str.rjust(25) + "|" + avg_cum_var_exp_str.rjust(30) + "|" + std_cum_var_exp_str.rjust(31))


    print("\n--- Distribution of Intrinsic Dimensionalities (" + str(variance_threshold * 100) + "% variance threshold) ---")
    print("Values: " + str(intrinsic_dims_95_variance))
    print("Min: " + str(np.min(intrinsic_dims_95_variance)))
    print("Max: " + str(np.max(intrinsic_dims_95_variance)))
    print("Mean: " + str(np.mean(intrinsic_dims_95_variance)))
    print("Median: " + str(np.median(intrinsic_dims_95_variance)))
    print("Std Dev: " + str(np.std(intrinsic_dims_95_variance)))

    # 4. Data Saving
    npz_filename = os.path.join(database_path, "per_ic_pca_results_" + timestamp + ".npz")
    try:
        np.savez(npz_filename,
                 all_centroids=all_centroids,
                 all_principal_vectors=all_principal_vectors,
                 all_eigenvalues=all_eigenvalues,
                 all_explained_variance_ratios=all_explained_variance_ratios,
                 intrinsic_dims_95_variance=intrinsic_dims_95_variance,
                 data_path=data_path,
                 variance_threshold=variance_threshold)
        print("\nPer-IC PCA results saved to: " + npz_filename)
    except Exception as e:
        print("Error saving .npz file: " + str(e))

    # 5. Plotting
    plt.rcParams['text.usetex'] = False  # Disable LaTeX rendering

    # Plot 1: Average Per-IC Scree Plot
    fig1, ax1_1 = plt.subplots(figsize=(12, 7))
    
    # Bar plot for average individual explained variance
    ax1_1.bar(range(1, n_latent_dim + 1), avg_explained_variance_ratios * 100,
              alpha=0.7, align='center', label='Avg. Individual Explained Variance (%)', color='deepskyblue')
    ax1_1.set_xlabel('Principal Component Number')
    ax1_1.set_ylabel('Avg. Explained Variance (%)', color='deepskyblue')
    ax1_1.tick_params(axis='y', labelcolor='deepskyblue')
    ax1_1.set_xticks(range(1, n_latent_dim + 1))
    ax1_1.set_ylim(0, max(avg_explained_variance_ratios * 100) * 1.1)


    # Line plot for average cumulative explained variance
    ax1_2 = ax1_1.twinx()
    ax1_2.plot(range(1, n_latent_dim + 1), avg_cumulative_explained_variance * 100,
               color='crimson', marker='o', linestyle='-', linewidth=2, label='Avg. Cumulative Explained Variance (%)')
    # Shaded region for std dev of cumulative variance
    ax1_2.fill_between(range(1, n_latent_dim + 1),
                       (avg_cumulative_explained_variance - std_cumulative_explained_variance) * 100,
                       (avg_cumulative_explained_variance + std_cumulative_explained_variance) * 100,
                       color='crimson', alpha=0.2, label='Std. Dev. of Cum. Var. Expl.')
    ax1_2.set_ylabel('Avg. Cumulative Explained Variance (%)', color='crimson')
    ax1_2.tick_params(axis='y', labelcolor='crimson')
    ax1_2.set_ylim(0, 105)

    plt.title('Average Per-IC Latent Space: Explained Variance by Principal Components')
    fig1.tight_layout()
    
    lines, labels = ax1_1.get_legend_handles_labels()
    lines2, labels2 = ax1_2.get_legend_handles_labels()
    ax1_2.legend(lines + lines2, labels + labels2, loc='center right')
    
    plt.grid(True, linestyle='--', alpha=0.7)

    plot1_filename = os.path.join(database_path, "per_ic_avg_scree_plot_2_" + timestamp + ".png")
    try:
        plt.savefig(plot1_filename, dpi=300)
        print("\nAverage per-IC scree plot saved to: " + plot1_filename)
        print("Plot Description: Scree plot showing the average percentage of variance explained by each principal component across 25 ICs. Blue bars represent average individual explained variance. Red line shows average cumulative explained variance, with shaded area indicating +/- 1 standard deviation.")
    except Exception as e:
        print("Error saving plot 1: " + str(e))
    plt.close(fig1)

    # Plot 2: Distribution of Intrinsic Dimensionalities
    fig2, ax2 = plt.subplots(figsize=(10, 6))
    min_dim = np.min(intrinsic_dims_95_variance)
    max_dim = np.max(intrinsic_dims_95_variance)
    if min_dim == max_dim:  # Handle case where all ICs have the same intrinsic dimensionality
        bins = [min_dim - 0.5, min_dim + 0.5]  # A single bin centered on the value
    else:
        bins = np.arange(min_dim, max_dim + 2) - 0.5  # Bins for integer values

    ax2.hist(intrinsic_dims_95_variance, bins=bins, rwidth=0.8, color='mediumseagreen', alpha=0.9)
    ax2.set_xlabel('Intrinsic Dimensionality (Number of PCs for ' + str(variance_threshold * 100) + '% Variance)')
    ax2.set_ylabel('Number of Initial Conditions (Frequency)')
    ax2.set_title('Distribution of Intrinsic Dimensionalities Across ICs')
    ax2.set_xticks(np.arange(min_dim, max_dim + 1))  # Ensure integer ticks
    ax2.grid(axis='y', linestyle='--', alpha=0.7)
    
    fig2.tight_layout()
    plot2_filename = os.path.join(database_path, "intrinsic_dim_dist_plot_2_" + timestamp + ".png")
    try:
        plt.savefig(plot2_filename, dpi=300)
        print("\nIntrinsic dimensionality distribution plot saved to: " + plot2_filename)
        print("Plot Description: Histogram showing the distribution of intrinsic dimensionalities (number of principal components needed to explain " + str(variance_threshold * 100) + "% of variance) for the 25 initial conditions.")
    except Exception as e:
        print("Error saving plot 2: " + str(e))
    plt.close(fig2)

    print("\nFinished Step 2.")


if __name__ == "__main__":
    data_file_path = '/Users/fanonymous/Documents/Software/AstroPilot/Project_turbulenceV1/data_for_Paco_turbulence_bundle.npy'
    
    if not os.path.exists(data_file_path):
        print("Error: Data file not found at " + data_file_path)
        print("Please ensure the file path is correct or the file is available.")
    else:
        output_database_path = "data/"
        perform_per_ic_pca_analysis(data_file_path, database_path=output_database_path)