# filename: codebase/manifold_orientation_analysis.py
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.linalg import svdvals
import os
import time
import glob


def analyze_manifold_orientations(per_ic_pca_results_path, database_path="data/"):
    """
    Performs comparative analysis of manifold orientations and relationships.

    Args:
        per_ic_pca_results_path (str): Path to the .npz file from Step 2.
        database_path (str): Path to save plots and results.
    """
    print("Starting Step 4: Comparative analysis of manifold orientations and relationships")

    if not os.path.exists(database_path):
        os.makedirs(database_path)
        print("Created directory: " + database_path)

    timestamp = time.strftime("%Y%m%d-%H%M%S")

    # 1. Load Data
    try:
        data = np.load(per_ic_pca_results_path, allow_pickle=True)
        all_principal_vectors = data['all_principal_vectors']  # (n_ic, n_components, n_features)
        # all_centroids = data['all_centroids']  # (n_ic, n_features)
        # intrinsic_dims_95_variance = data['intrinsic_dims_95_variance']
        print("Successfully loaded data from: " + per_ic_pca_results_path)
    except Exception as e:
        print("Error loading data: " + str(e))
        return

    n_ic, n_pca_components, n_features = all_principal_vectors.shape
    d_ic = 3  # Intrinsic dimension, based on Step 2 results (consistently 3 for 95% variance)
    print("Number of ICs: " + str(n_ic) + ", Latent features: " + str(n_features) + ", PCA components per IC: " + str(n_pca_components))
    print("Using d_ic = " + str(d_ic) + " for subspace analysis.")

    # 2. Alignment of Principal Vectors (Dot Products)
    print("\n--- 2. Alignment of Principal Vectors (Dot Products) ---")
    dot_product_matrices = []
    for j in range(d_ic):  # For PC1, PC2, PC3
        dot_matrix = np.zeros((n_ic, n_ic))
        for k in range(n_ic):
            for l in range(n_ic):
                v_kj = all_principal_vectors[k, j, :]
                v_lj = all_principal_vectors[l, j, :]
                dot_matrix[k, l] = np.abs(np.dot(v_kj, v_lj))
        dot_product_matrices.append(dot_matrix)

        print("\nSummary for PC" + str(j+1) + " vs PC" + str(j+1) + " dot products (excluding diagonal):")
        off_diagonal_dots = dot_matrix[~np.eye(n_ic, dtype=bool)]
        print("  Mean: " + str(np.mean(off_diagonal_dots)))
        print("  Min:  " + str(np.min(off_diagonal_dots)))
        print("  Max:  " + str(np.max(off_diagonal_dots)))
        print("  Std:  " + str(np.std(off_diagonal_dots)))

    # Plotting dot product heatmaps
    plt.rcParams['text.usetex'] = False
    fig_dot_heatmaps, axes_dot = plt.subplots(1, d_ic, figsize=(7 * d_ic, 6))
    if d_ic == 1:
        axes_dot = [axes_dot]  # Ensure axes_dot is iterable for d_ic=1

    for j in range(d_ic):
        ax = axes_dot[j]
        im = ax.imshow(dot_product_matrices[j], cmap='viridis', vmin=0, vmax=1)
        ax.set_title("Dot Products: PC" + str(j+1) + " (IC k) vs PC" + str(j+1) + " (IC l)")
        ax.set_xlabel("Initial Condition Index (l)")
        ax.set_ylabel("Initial Condition Index (k)")
        fig_dot_heatmaps.colorbar(im, ax=ax, label="Absolute Dot Product")
    fig_dot_heatmaps.tight_layout()
    plot_filename_dot = os.path.join(database_path, "dot_product_heatmaps_4_" + timestamp + ".png")
    try:
        plt.savefig(plot_filename_dot, dpi=300)
        print("\nDot product heatmaps saved to: " + plot_filename_dot)
        print("Plot Description: Heatmaps showing the absolute dot product between the " + str(d_ic) + " leading principal vectors (PC1 vs PC1, PC2 vs PC2, etc.) for all pairs of initial conditions.")
    except Exception as e:
        print("Error saving dot product heatmaps: " + str(e))
    plt.close(fig_dot_heatmaps)

    # 3. Subspace Similarity Analysis
    print("\n--- 3. Subspace Similarity Analysis (d_ic = " + str(d_ic) + ") ---")
    subspace_similarity_matrix = np.zeros((n_ic, n_ic))
    for k in range(n_ic):
        for l in range(n_ic):
            U_k = all_principal_vectors[k, :d_ic, :]  # (d_ic, n_features)
            U_l = all_principal_vectors[l, :d_ic, :]  # (d_ic, n_features)
            
            # Singular values of U_k @ U_l.T are cosines of principal angles
            s = svdvals(U_k @ U_l.T)
            similarity = (1.0 / d_ic) * np.sum(s**2)  # Avg squared cosine
            subspace_similarity_matrix[k, l] = similarity
            
    print("\nSummary for Subspace Similarities (d_ic=" + str(d_ic) + ", excluding diagonal):")
    off_diagonal_sim = subspace_similarity_matrix[~np.eye(n_ic, dtype=bool)]
    print("  Mean: " + str(np.mean(off_diagonal_sim)))
    print("  Min:  " + str(np.min(off_diagonal_sim)))
    print("  Max:  " + str(np.max(off_diagonal_sim)))
    print("  Std:  " + str(np.std(off_diagonal_sim)))

    # Plotting subspace similarity heatmap
    fig_subspace_heatmap, ax_subspace = plt.subplots(figsize=(8, 7))
    im_subspace = ax_subspace.imshow(subspace_similarity_matrix, cmap='viridis', vmin=0, vmax=1)
    ax_subspace.set_title("Subspace Similarity (d_ic=" + str(d_ic) + ") between IC k and IC l")
    ax_subspace.set_xlabel("Initial Condition Index (l)")
    ax_subspace.set_ylabel("Initial Condition Index (k)")
    fig_subspace_heatmap.colorbar(im_subspace, ax=ax_subspace, label="Avg. Squared Cosine of Principal Angles")
    fig_subspace_heatmap.tight_layout()
    plot_filename_subspace = os.path.join(database_path, "subspace_similarity_heatmap_4_" + timestamp + ".png")
    try:
        plt.savefig(plot_filename_subspace, dpi=300)
        print("\nSubspace similarity heatmap saved to: " + plot_filename_subspace)
        print("Plot Description: Heatmap showing the similarity between the " + str(d_ic) + "-dimensional principal subspaces for all pairs of initial conditions. Similarity is measured as the average squared cosine of principal angles.")
    except Exception as e:
        print("Error saving subspace similarity heatmap: " + str(e))
    plt.close(fig_subspace_heatmap)

    # 4. PCA of Principal Vectors
    print("\n--- 4. PCA of Sets of Principal Vectors ---")
    fig_pca_pv, axes_pca_pv = plt.subplots(2, d_ic, figsize=(7 * d_ic, 10))

    for j in range(d_ic):  # For v_k1, v_k2, v_k3
        P_j_vectors = all_principal_vectors[:, j, :]  # (n_ic, n_features), e.g. (25, 10)
        
        # PCA for this set of principal vectors
        # n_components_pv_pca can be at most min(n_ic, n_features)
        n_components_pv_pca = min(n_ic, n_features)
        pca_pv = PCA(n_components=n_components_pv_pca)
        projected_P_j_vectors = pca_pv.fit_transform(P_j_vectors)

        eigenvalues_pv = pca_pv.explained_variance_
        explained_variance_ratio_pv = pca_pv.explained_variance_ratio_
        cumulative_explained_variance_pv = np.cumsum(explained_variance_ratio_pv)

        print("\nPCA Results for the set of PC" + str(j+1) + " vectors from all ICs:")
        print("  Number of components considered: " + str(n_components_pv_pca))
        print("  Eigenvalues: " + str(eigenvalues_pv))
        print("  Explained Variance Ratio (%): " + str(explained_variance_ratio_pv * 100))
        print("  Cumulative Explained Variance (%): " + str(cumulative_explained_variance_pv * 100))

        # Scree Plot (Row 1)
        ax_scree = axes_pca_pv[0, j]
        pc_numbers_pv = range(1, n_components_pv_pca + 1)
        ax_scree.bar(pc_numbers_pv, explained_variance_ratio_pv * 100, alpha=0.7, label='Individual EVR (%)', color='steelblue')
        ax_scree.set_xlabel('PC of PC' + str(j+1) + ' vectors')
        ax_scree.set_ylabel('Explained Var. Ratio (%)')
        ax_scree.set_title('PCA of IC-PC' + str(j+1) + ' Vectors')
        ax_scree.set_xticks(list(pc_numbers_pv))
        
        ax_scree_twin = ax_scree.twinx()
        ax_scree_twin.plot(list(pc_numbers_pv), cumulative_explained_variance_pv * 100, color='firebrick', marker='o', label='Cumulative EVR (%)')
        ax_scree_twin.set_ylabel('Cumulative EVR (%)', color='firebrick')
        ax_scree_twin.tick_params(axis='y', labelcolor='firebrick')
        ax_scree_twin.set_ylim(0, 105)
        
        lines, labels = ax_scree.get_legend_handles_labels()
        lines2, labels2 = ax_scree_twin.get_legend_handles_labels()
        ax_scree_twin.legend(lines + lines2, labels + labels2, loc='best', fontsize='small')
        ax_scree.grid(True, linestyle='--', alpha=0.5)

        # Scatter Plot (Row 2)
        ax_scatter = axes_pca_pv[1, j]
        if n_components_pv_pca >= 2:
            scatter = ax_scatter.scatter(projected_P_j_vectors[:, 0], projected_P_j_vectors[:, 1],
                                         c=range(n_ic), cmap='viridis', s=30, alpha=0.8)
            ax_scatter.set_xlabel('Component 1 of PC' + str(j+1) + ' vectors')
            ax_scatter.set_ylabel('Component 2 of PC' + str(j+1) + ' vectors')
            ax_scatter.set_title('IC-PC' + str(j+1) + ' Vectors (in their PC space)')
            ax_scatter.grid(True, linestyle='--', alpha=0.5)
        else:
            ax_scatter.text(0.5, 0.5, "Not enough PCs for 2D scatter", ha='center', va='center')

    fig_pca_pv.tight_layout(rect=[0, 0, 0.95, 1])

    plot_filename_pca_pv = os.path.join(database_path, "pca_of_principal_vectors_4_" + timestamp + ".png")
    try:
        plt.savefig(plot_filename_pca_pv, dpi=300)
        print("\nPCA of principal vectors plots saved to: " + plot_filename_pca_pv)
        print("Plot Description: Figure with two rows. Top row: scree plots for PCA performed on the set of first, second, and third principal vectors (v_k1, v_k2, v_k3) from all ICs. Bottom row: scatter plots of these v_k1, v_k2, v_k3 vectors projected onto their own first two principal components, colored by IC index.")
    except Exception as e:
        print("Error saving PCA of principal vectors plots: " + str(e))
    plt.close(fig_pca_pv)

    print("\nFinished Step 4.")


if __name__ == "__main__":
    output_database_path = "data/"

    list_of_files = glob.glob(os.path.join(output_database_path, 'per_ic_pca_results_*.npz'))
    if not list_of_files:
        print("Error: No 'per_ic_pca_results_*.npz' file found in " + output_database_path)
        print("Please ensure Step 2 (per_ic_pca_analysis) has been run successfully.")
    else:
        latest_per_ic_results_file = max(list_of_files, key=os.path.getctime)
        print("Using per-IC PCA results from: " + latest_per_ic_results_file)
        analyze_manifold_orientations(latest_per_ic_results_file, database_path=output_database_path)
