# filename: codebase/physics_informed_decomposition.py
import pandas as pd
import numpy as np
import pickle
import os
import time
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import gaussian_kde
from sklearn.preprocessing import StandardScaler

def calculate_subspace_jsd(dataframes, model_names, subspaces):
    """
    Calculates pairwise multi-dimensional JSD for defined parameter subspaces.

    For each subspace, this function iterates through all pairs of models,
    standardizes the data within the subspace for the pair, builds a
    multi-dimensional KDE for each model's posterior samples, and then
    computes the Jensen-Shannon Divergence (JSD) between the two models'
    distributions.

    Args:
        dataframes (dict): A dictionary of pandas DataFrames, keyed by model name.
        model_names (list): An ordered list of model names.
        subspaces (dict): A dictionary where keys are subspace names and values
                          are lists of parameter names in that subspace.

    Returns:
        dict: A dictionary where keys are subspace names and values are
              pandas DataFrames containing the pairwise JSD matrices.
    """
    print("--- Step 4.2: Quantifying Subspace-Specific Discrepancies ---")
    print("\nKDE Setup Documentation:")
    print("- Method: Multi-dimensional Kernel Density Estimation using scipy.stats.gaussian_kde.")
    print("- Bandwidth Selection: Silverman's rule ('silverman') is used for all KDEs.")
    print("- JSD Calculation: A Monte Carlo approximation of JSD is used. This approach is deterministic as it utilizes all available samples.")
    print("- Standardization: For each pairwise comparison, parameters are standardized (z-scored) to ensure equal weighting in the KDE.\n")

    jsd_subspace_matrices = {}
    epsilon = 1e-100  # Small constant to avoid log(0)

    for subspace_name, params in subspaces.items():
        print("----------------------------------------------------")
        print("Processing Subspace: " + subspace_name + " (" + str(len(params)) + " dimensions)")
        n_models = len(model_names)
        jsd_matrix = pd.DataFrame(np.zeros((n_models, n_models)), index=model_names, columns=model_names)

        for i in range(n_models):
            for j in range(i, n_models):
                model_a_name = model_names[i]
                model_b_name = model_names[j]

                if i == j:
                    continue

                print("  Calculating JSD for " + model_a_name + " vs " + model_b_name + "...")

                # Extract samples for the current subspace
                samples_a = dataframes[model_a_name][params].values
                samples_b = dataframes[model_b_name][params].values

                # Standardize data for this specific pair and subspace
                scaler = StandardScaler()
                scaler.fit(np.vstack((samples_a, samples_b)))
                samples_a_scaled = scaler.transform(samples_a)
                samples_b_scaled = scaler.transform(samples_b)

                # Build KDEs on the standardized data
                # The .T is required because gaussian_kde expects shape (n_dims, n_samples)
                kde_a = gaussian_kde(samples_a_scaled.T, bw_method='silverman')
                kde_b = gaussian_kde(samples_b_scaled.T, bw_method='silverman')

                # Evaluate PDFs on each set of samples
                pdf_a_on_a = kde_a(samples_a_scaled.T)
                pdf_b_on_a = kde_b(samples_a_scaled.T)
                pdf_a_on_b = kde_a(samples_b_scaled.T)
                pdf_b_on_b = kde_b(samples_b_scaled.T)

                # Mixture PDF evaluations
                m_on_a = 0.5 * (pdf_a_on_a + pdf_b_on_a)
                m_on_b = 0.5 * (pdf_a_on_b + pdf_b_on_b)

                # KL Divergence using Monte Carlo approximation
                kl_a_m = np.mean(np.log(pdf_a_on_a + epsilon) - np.log(m_on_a + epsilon))
                kl_b_m = np.mean(np.log(pdf_b_on_b + epsilon) - np.log(m_on_b + epsilon))

                # Jensen-Shannon Divergence
                jsd = 0.5 * (kl_a_m + kl_b_m)
                jsd = max(0, jsd)  # Ensure non-negativity due to numerical precision

                jsd_matrix.loc[model_a_name, model_b_name] = jsd
                jsd_matrix.loc[model_b_name, model_a_name] = jsd

        jsd_subspace_matrices[subspace_name] = jsd_matrix
        print("\nJSD Matrix for Subspace: " + subspace_name)
        print(jsd_matrix)

    return jsd_subspace_matrices


def plot_jsd_heatmaps(jsd_matrices, subspaces, output_dir):
    """
    Generates and saves heatmaps of the JSD matrices for each subspace.

    Args:
        jsd_matrices (dict): Dictionary of JSD matrices for each subspace.
        subspaces (dict): Dictionary defining the subspaces.
        output_dir (str): The directory where the plot will be saved.
    """
    print("\n--- Step 4.3: Generating and Saving JSD Heatmaps ---")
    
    n_subspaces = len(subspaces)
    fig, axes = plt.subplots(2, 2, figsize=(16, 14))
    axes = axes.flatten()

    subspace_names = list(subspaces.keys())

    for i, subspace_name in enumerate(subspace_names):
        ax = axes[i]
        matrix = jsd_matrices[subspace_name]
        sns.heatmap(matrix, annot=True, fmt=".3f", cmap="cividis", ax=ax, linewidths=.5)
        ax.set_title("JSD Matrix for " + subspace_name + " Subspace", fontsize=14)
        ax.tick_params(axis='x', rotation=45)
        ax.tick_params(axis='y', rotation=0)

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    fig.suptitle("Physics-Informed Discrepancy Decomposition via JSD", fontsize=20, y=1.02)
    plt.tight_layout()

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    filename = 'subspace_jsd_heatmaps_4_' + timestamp + '.png'
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches='tight')
    plt.close(fig)

    print("\nSuccessfully saved JSD heatmaps to: " + filepath)
    print("Plot Description: Heatmaps of pairwise Jensen-Shannon Divergence (JSD) for each of the four defined physical parameter subspaces. Higher values indicate greater disagreement between models within that specific physical domain.")


def main():
    """
    Main function to execute Step 4 of the analysis.
    """
    print("--- Starting Step 4: Physics-Informed Discrepancy Decomposition ---")

    # --- Configuration ---
    input_file = 'data/gw_data_all_models.pkl'
    output_dir = 'data'
    
    model_names = [
        'NRSur7dq4', 'IMRPhenomXO4a', 'SEOBNRv5PHM',
        'IMRPhenomXPHM', 'IMRPhenomTPHM'
    ]

    # 1. Define Physical Parameter Subspaces
    print("\n--- Step 4.1: Defining Physical Parameter Subspaces ---")
    subspaces = {
        'Mass & Distance': ['mass_1_source', 'mass_2_source', 'redshift'],
        'Effective Spin': ['chi_eff', 'chi_p'],
        'Individual Spin & Orientation': ['a_1', 'a_2', 'cos_tilt_1', 'cos_tilt_2', 'cos_theta_jn', 'phi_jl'],
        'Remnant Properties': ['final_mass_source', 'final_spin']
    }
    for name, params in subspaces.items():
        print("- " + name + ": " + str(params))

    # --- Load Data ---
    print("\nLoading preprocessed data from " + input_file)
    if not os.path.exists(input_file):
        print("Error: Input data file not found. Please run Step 1 first.")
        return
        
    try:
        with open(input_file, 'rb') as f:
            dataframes = pickle.load(f)
        print("Data loaded successfully.")
    except Exception as e:
        print("Error loading pickle file: " + str(e))
        return

    # --- Run Analysis ---
    jsd_subspace_matrices = calculate_subspace_jsd(dataframes, model_names, subspaces)

    # Save the results
    output_path = os.path.join(output_dir, 'subspace_jsd_matrices.pkl')
    try:
        with open(output_path, 'wb') as f:
            pickle.dump(jsd_subspace_matrices, f)
        print("\nSuccessfully saved subspace JSD matrices to: " + output_path)
    except Exception as e:
        print("Error saving JSD matrices to pickle file: " + str(e))

    # --- Plot Results ---
    plot_jsd_heatmaps(jsd_subspace_matrices, subspaces, output_dir)

    print("\n--- Step 4 execution completed successfully. ---")


if __name__ == '__main__':
    # Ensure the output directory exists
    if not os.path.exists('data'):
        os.makedirs('data')
    main()