# filename: codebase/robust_inference.py
import pandas as pd
import numpy as np
import pickle
import os

def get_summary_medians(dataframes, key_parameters, model_names):
    """
    Calculates the median for key parameters from all models.

    Args:
        dataframes (dict): A dictionary of pandas DataFrames, keyed by model name.
        key_parameters (list): A list of column names for which to compute medians.
        model_names (list): An ordered list of model names.

    Returns:
        pd.DataFrame: A DataFrame containing the median values, indexed by
                      parameter and model name.
    """
    summary_data = []
    for param in key_parameters:
        for model in model_names:
            df = dataframes[model]
            median = df[param].median()
            summary_data.append({'Parameter': param, 'Model': model, 'Median': median})
    summary_df = pd.DataFrame(summary_data).set_index(['Parameter', 'Model'])
    return summary_df


def identify_robust_parameters(jsd_matrices, summary_medians_df, key_parameters, jsd_threshold, median_range_threshold):
    """
    Identifies robustly constrained parameters based on JSD and median range criteria.

    Args:
        jsd_matrices (dict): Dictionary of JSD matrices for each parameter.
        summary_medians_df (pd.DataFrame): DataFrame with median values for each parameter and model.
        key_parameters (list): List of parameters to evaluate.
        jsd_threshold (float): The maximum allowed JSD for a parameter to be considered robust.
        median_range_threshold (float): The maximum allowed relative range of medians.

    Returns:
        dict: A dictionary mapping each parameter to its robustness status ('Robust' or 'Model-Dependent').
    """
    print("--- Step 5.1: Identifying Robustly Constrained Parameters ---")
    print("Robustness Criteria:")
    print("1. Max JSD between any model pair < " + str(jsd_threshold))
    print("2. Relative Median Range ((max-min)/mean) < " + str(median_range_threshold * 100) + "%")
    print("-" * 60)

    robustness_status = {}
    for param in key_parameters:
        # Criterion 1: JSD
        max_jsd = jsd_matrices[param].max().max()
        jsd_passed = max_jsd < jsd_threshold

        # Criterion 2: Median Range
        medians = summary_medians_df.loc[param]['Median']
        mean_median = medians.mean()
        relative_range = (medians.max() - medians.min()) / abs(mean_median) if mean_median != 0 else 0
        median_range_passed = relative_range < median_range_threshold

        # Final Decision
        is_robust = jsd_passed and median_range_passed
        status = 'Robust' if is_robust else 'Model-Dependent'
        robustness_status[param] = status

        print("Parameter: " + param)
        print("  - Max JSD: " + str(round(max_jsd, 4)) + " (Pass: " + str(jsd_passed) + ")")
        print("  - Relative Median Range: " + str(round(relative_range * 100, 2)) + "% (Pass: " + str(median_range_passed) + ")")
        print("  => Status: " + status + "\n")

    return robustness_status


def get_physical_discrepancy_note(param, subspace_jsd_matrices):
    """
    Generates a note on the physical origin of model discrepancy for a parameter.

    Args:
        param (str): The parameter name.
        subspace_jsd_matrices (dict): Dictionary of JSD matrices for each physical subspace.

    Returns:
        str: A brief note on the likely physical origin of the discrepancy.
    """
    param_to_subspace = {
        'mass_1_source': 'Mass & Distance',
        'mass_2_source': 'Mass & Distance',
        'redshift': 'Mass & Distance',
        'chi_eff': 'Effective Spin',
        'chi_p': 'Effective Spin',
        'final_mass_source': 'Remnant Properties',
        'final_spin': 'Remnant Properties'
    }
    subspace = param_to_subspace.get(param)
    if not subspace:
        return "Source of discrepancy not mapped to a specific subspace."

    max_jsd = subspace_jsd_matrices[subspace].values.max()
    note = "Discrepancy linked to '" + subspace + "' subspace (max subspace JSD: " + str(round(max_jsd, 3)) + ")."
    return note


def compile_final_results(robustness_status, dataframes, summary_medians_df, subspace_jsd_matrices, key_parameters, model_names):
    """
    Compiles the final summary table of astrophysical inferences.

    Args:
        robustness_status (dict): Dictionary with the robustness status of each parameter.
        dataframes (dict): Dictionary of all model DataFrames.
        summary_medians_df (pd.DataFrame): DataFrame with median values.
        subspace_jsd_matrices (dict): Dictionary of subspace JSD matrices.
        key_parameters (list): List of parameters to include in the table.
        model_names (list): List of all model names.

    Returns:
        pd.DataFrame: The final summary table.
    """
    print("\n--- Step 5.2 & 5.3: Deriving Consensus Constraints and Compiling Final Table ---")
    final_results = []

    for param in key_parameters:
        status = robustness_status[param]
        result_row = {'Parameter': param, 'Status': status}

        if status == 'Robust':
            # Aggregate samples from all models
            all_samples = np.concatenate([dataframes[model][param].values for model in model_names])
            # Compute consensus median and 90% credible interval
            quantiles = np.quantile(all_samples, [0.05, 0.5, 0.95])
            result_row['Consensus Median'] = quantiles[1]
            result_row['Consensus 90% CI'] = str(round(quantiles[0], 3)) + " - " + str(round(quantiles[2], 3))
            result_row['Physical Discrepancy Source'] = 'N/A'
        else: # Model-Dependent
            medians = summary_medians_df.loc[param]['Median']
            result_row['Consensus Median'] = str(round(medians.min(), 3)) + " - " + str(round(medians.max(), 3)) + " (Range)"
            result_row['Consensus 90% CI'] = 'See individual models'
            result_row['Physical Discrepancy Source'] = get_physical_discrepancy_note(param, subspace_jsd_matrices)

        final_results.append(result_row)

    final_df = pd.DataFrame(final_results)
    return final_df


def main():
    """
    Main function to execute Step 5 of the analysis.
    """
    print("--- Starting Step 5: Robust Astrophysical Inference and Consensus Constraints ---")

    # --- Configuration ---
    data_dir = 'data'
    model_names = [
        'NRSur7dq4', 'IMRPhenomXO4a', 'SEOBNRv5PHM',
        'IMRPhenomXPHM', 'IMRPhenomTPHM'
    ]
    key_parameters = [
        'mass_1_source', 'mass_2_source', 'chi_eff', 'chi_p',
        'redshift', 'final_mass_source', 'final_spin'
    ]
    # Robustness criteria thresholds
    JSD_THRESHOLD = 0.05
    MEDIAN_RANGE_THRESHOLD = 0.10  # 10%

    # --- Load Data from Previous Steps ---
    print("\nLoading data from previous steps...")
    try:
        with open(os.path.join(data_dir, 'gw_data_all_models.pkl'), 'rb') as f:
            dataframes = pickle.load(f)
        with open(os.path.join(data_dir, 'jsd_matrices.pkl'), 'rb') as f:
            jsd_matrices = pickle.load(f)
        with open(os.path.join(data_dir, 'subspace_jsd_matrices.pkl'), 'rb') as f:
            subspace_jsd_matrices = pickle.load(f)
        print("Data loaded successfully.")
    except FileNotFoundError as e:
        print("Error: Could not find required data file: " + str(e))
        print("Please ensure Steps 1, 2, and 4 have been run successfully.")
        return
    except Exception as e:
        print("An error occurred while loading data: " + str(e))
        return

    # --- Run Analysis ---
    summary_medians_df = get_summary_medians(dataframes, key_parameters, model_names)

    robustness_status = identify_robust_parameters(jsd_matrices, summary_medians_df, key_parameters, JSD_THRESHOLD, MEDIAN_RANGE_THRESHOLD)

    final_table = compile_final_results(robustness_status, dataframes, summary_medians_df, subspace_jsd_matrices, key_parameters, model_names)

    # --- Display and Save Final Results ---
    print("\n--- Step 5.4: Final Results Compilation ---")
    pd.set_option('display.max_rows', None)
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', 200)
    pd.set_option('display.max_colwidth', 100)
    print("\nFinal Astrophysical Inference Summary Table:")
    print(final_table)

    # Save the table to a CSV file
    output_path = os.path.join(data_dir, 'final_astrophysical_inference.csv')
    try:
        final_table.to_csv(output_path, index=False)
        print("\nSuccessfully saved final inference table to: " + output_path)
    except Exception as e:
        print("\nError saving final table to CSV: " + str(e))

    print("\n--- Step 5 execution completed successfully. ---")


if __name__ == '__main__':
    main()