# filename: codebase/baseline_comparison.py
import pandas as pd
import numpy as np
import pickle
import os
import time
from scipy.stats import gaussian_kde, wasserstein_distance, entropy
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt


def calculate_summary_statistics(dataframes, key_parameters, model_names):
    """
    Calculates and prints summary statistics for key parameters from all models.

    For each model's DataFrame, this function computes the median, 5th percentile,
    and 95th percentile for a list of specified key parameters. The results are
    aggregated into a single pandas DataFrame for clear presentation.

    Args:
        dataframes (dict): A dictionary of pandas DataFrames, keyed by model name.
        key_parameters (list): A list of column names for which to compute statistics.
        model_names (list): An ordered list of model names.

    Returns:
        pd.DataFrame: A DataFrame containing the summary statistics, indexed by
                      parameter and model name.
    """
    print("--- Step 2.1: Calculating Summary Statistics ---")
    summary_data = []
    for param in key_parameters:
        for model in model_names:
            df = dataframes[model]
            # Calculate 5th, 50th (median), and 95th percentiles
            quantiles = df[param].quantile([0.05, 0.5, 0.95])
            summary_data.append({
                'Parameter': param,
                'Model': model,
                'Median': quantiles[0.5],
                '5th Percentile': quantiles[0.05],
                '95th Percentile': quantiles[0.95]
            })

    summary_df = pd.DataFrame(summary_data).set_index(['Parameter', 'Model'])
    
    # Set pandas display options to show all rows and columns
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 200)
    
    print("Summary Statistics (Median and 90% Credible Interval):")
    print(summary_df)
    print("\nSummary statistics calculation complete.")
    return summary_df


def calculate_divergence_matrices(dataframes, key_parameters, model_names):
    """
    Computes pairwise JSD and Wasserstein distances for key parameters.

    This function iterates through each key parameter and every pair of models,
    calculating the Jensen-Shannon Divergence (JSD) and the 1-Wasserstein
    distance between their 1D marginal posterior distributions.

    Args:
        dataframes (dict): A dictionary of pandas DataFrames, keyed by model name.
        key_parameters (list): A list of column names to analyze.
        model_names (list): An ordered list of model names.

    Returns:
        tuple: A tuple containing two dictionaries:
               - jsd_matrices (dict): JSD matrices for each parameter.
               - wass_matrices (dict): Wasserstein distance matrices for each parameter.
    """
    print("\n--- Step 2.2: Calculating Pairwise Statistical Divergence ---")
    
    # Report on KDE bandwidth selection
    print("\nKDE Bandwidth Selection Report:")
    print("Comparing Scott's and Silverman's rules for bandwidth estimation.")
    sample_data = dataframes['NRSur7dq4']['mass_1_source']
    bw_scott = gaussian_kde(sample_data, bw_method='scott').factor
    bw_silverman = gaussian_kde(sample_data, bw_method='silverman').factor
    print("For a sample parameter (mass_1_source):")
    print("  - Scott's rule bandwidth factor: " + str(bw_scott))
    print("  - Silverman's rule bandwidth factor: " + str(bw_silverman))
    print("Silverman's rule typically provides a slightly larger bandwidth, leading to smoother KDEs.")
    print("We will proceed using Silverman's rule for all KDE-based calculations.\n")

    jsd_matrices = {}
    wass_matrices = {}

    for param in key_parameters:
        print("----------------------------------------------------")
        print("Processing parameter: " + param)
        n_models = len(model_names)
        jsd_matrix = pd.DataFrame(np.zeros((n_models, n_models)), index=model_names, columns=model_names)
        wass_matrix = pd.DataFrame(np.zeros((n_models, n_models)), index=model_names, columns=model_names)

        for i in range(n_models):
            for j in range(i, n_models):
                model_a = model_names[i]
                model_b = model_names[j]

                samples_a = dataframes[model_a][param].values
                samples_b = dataframes[model_b][param].values

                if i == j:
                    jsd_matrix.loc[model_a, model_b] = 0.0
                    wass_matrix.loc[model_a, model_b] = 0.0
                    continue

                # 1. Wasserstein Distance (computationally simpler)
                wass_dist = wasserstein_distance(samples_a, samples_b)
                wass_matrix.loc[model_a, model_b] = wass_dist
                wass_matrix.loc[model_b, model_a] = wass_dist

                # 2. Jensen-Shannon Divergence
                # Define a common grid for KDE evaluation
                min_val = min(samples_a.min(), samples_b.min())
                max_val = max(samples_a.max(), samples_b.max())
                grid = np.linspace(min_val, max_val, 1000)

                # Estimate PDFs using KDE
                kde_a = gaussian_kde(samples_a, bw_method='silverman')
                pdf_a = kde_a(grid)
                kde_b = gaussian_kde(samples_b, bw_method='silverman')
                pdf_b = kde_b(grid)

                # Normalize PDFs to sum to 1
                pdf_a /= pdf_a.sum()
                pdf_b /= pdf_b.sum()
                
                # Add a small constant to avoid division by zero in entropy
                pdf_a = np.where(pdf_a == 0, 1e-100, pdf_a)
                pdf_b = np.where(pdf_b == 0, 1e-100, pdf_b)

                # Calculate JSD from KL divergence
                m = 0.5 * (pdf_a + pdf_b)
                jsd = 0.5 * (entropy(pdf_a, m, base=np.e) + entropy(pdf_b, m, base=np.e))
                jsd_matrix.loc[model_a, model_b] = jsd
                jsd_matrix.loc[model_b, model_a] = jsd

        jsd_matrices[param] = jsd_matrix
        wass_matrices[param] = wass_matrix

        print("\nJensen-Shannon Divergence Matrix for " + param + ":")
        print(jsd_matrix)
        print("\n1-Wasserstein Distance Matrix for " + param + ":")
        print(wass_matrix)

    print("\nDivergence matrix calculations complete.")
    return jsd_matrices, wass_matrices


def plot_marginal_posteriors(dataframes, key_parameters, model_names, output_dir):
    """
    Plots and saves the 1D marginal posterior distributions for key parameters.

    This function creates a single figure with subplots for each key parameter.
    Each subplot shows the Kernel Density Estimate (KDE) of the posterior
    distribution for all five models, allowing for visual comparison.

    Args:
        dataframes (dict): A dictionary of pandas DataFrames, keyed by model name.
        key_parameters (list): A list of column names to plot.
        model_names (list): An ordered list of model names.
        output_dir (str): The directory where the plot will be saved.
    """
    print("\n--- Generating Joint Plot of Marginal Posteriors ---")
    
    # Define units for labels
    param_units = {
        'mass_1_source': r'M$_\odot$',
        'mass_2_source': r'M$_\odot$',
        'final_mass_source': r'M$_\odot$',
        'redshift': '',
        'chi_eff': '',
        'chi_p': '',
        'final_spin': ''
    }
    
    n_params = len(key_parameters)
    # Adjust layout based on number of parameters
    if n_params <= 4:
        n_rows, n_cols = 2, 2
    elif n_params <= 6:
        n_rows, n_cols = 3, 2
    else:
        n_rows, n_cols = 4, 2
        
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 4 * n_rows))
    axes = axes.flatten()

    colors = plt.cm.viridis(np.linspace(0, 1, len(model_names)))
    model_colors = {model: color for model, color in zip(model_names, colors)}

    for i, param in enumerate(key_parameters):
        ax = axes[i]
        for model in model_names:
            samples = dataframes[model][param].values
            kde = gaussian_kde(samples, bw_method='silverman')
            x_grid = np.linspace(samples.min(), samples.max(), 500)
            pdf = kde(x_grid)
            ax.plot(x_grid, pdf, color=model_colors[model], label=model)
        
        unit = param_units.get(param, '')
        xlabel = param.replace('_', ' ').title()
        if unit:
            xlabel += ' (' + unit + ')'
            
        ax.set_xlabel(xlabel)
        ax.set_ylabel('Probability Density')
        ax.set_title('Marginal Posterior for ' + param.replace('_', ' ').title())
        ax.grid(True, linestyle='--', alpha=0.6)
        ax.legend()

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)

    plt.tight_layout()

    # Save the plot
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    filename = 'marginal_posteriors_comparison_1_' + timestamp + '.png'
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300)
    plt.close(fig)

    print("Successfully saved plot to: " + filepath)
    print("Plot Description: Comparison of 1D marginal posterior distributions for key astrophysical parameters across all five waveform models.")


def main():
    """
    Main function to execute Step 2 of the analysis.
    """
    print("--- Starting Step 2: Baseline Posterior Comparison and Pairwise Divergence Metrics ---")
    
    # --- Configuration ---
    input_file = 'data/gw_data_all_models.pkl'
    output_dir = 'data'
    
    model_names = [
        'NRSur7dq4', 'IMRPhenomXO4a', 'SEOBNRv5PHM',
        'IMRPhenomXPHM', 'IMRPhenomTPHM'
    ]
    
    key_parameters = [
        'mass_1_source', 'mass_2_source', 'chi_eff', 'chi_p',
        'redshift', 'final_mass_source', 'final_spin'
    ]

    # --- Load Data ---
    print("Loading preprocessed data from " + input_file)
    if not os.path.exists(input_file):
        print("Error: Input data file not found. Please run Step 1 first.")
        return
        
    try:
        with open(input_file, 'rb') as f:
            dataframes = pickle.load(f)
        print("Data loaded successfully.")
    except Exception as e:
        print("Error loading pickle file: " + str(e))
        return

    # --- Run Analysis Steps ---
    calculate_summary_statistics(dataframes, key_parameters, model_names)
    
    jsd_matrices, wass_matrices = calculate_divergence_matrices(dataframes, key_parameters, model_names)
    
    # Save divergence matrices for later use
    jsd_path = os.path.join(output_dir, 'jsd_matrices.pkl')
    wass_path = os.path.join(output_dir, 'wass_matrices.pkl')
    with open(jsd_path, 'wb') as f:
        pickle.dump(jsd_matrices, f)
    print("\nSaved JSD matrices to: " + jsd_path)
    with open(wass_path, 'wb') as f:
        pickle.dump(wass_matrices, f)
    print("Saved Wasserstein matrices to: " + wass_path)

    plot_marginal_posteriors(dataframes, key_parameters, model_names, output_dir)

    print("\n--- Step 2 execution completed successfully. ---")


if __name__ == '__main__':
    main()