# filename: codebase/data_preprocessing.py
import pandas as pd
import os
import pickle


def load_and_preprocess_data(file_paths, model_names):
    """
    Loads, preprocesses, and verifies gravitational wave data from multiple CSV files.

    This function performs the initial data loading and sanity checks. It iterates
    through a list of model names and their corresponding file paths, loading each
    CSV into a pandas DataFrame. It adds a 'model' column for identification,
    verifies column integrity, checks for missing values, and validates the
    'log_likelihood' column. It also prints descriptive statistics for each model's
    dataset to provide a first look at the data's physical plausibility. The
    processed DataFrames are stored in a dictionary and saved to a pickle file.

    Args:
        file_paths (dict): A dictionary mapping model names to their CSV file paths.
        model_names (list): A list of model names to be processed.

    Returns:
        dict: A dictionary of pandas DataFrames, where keys are model names
              and values are the corresponding preprocessed DataFrames. Returns
              an empty dictionary if critical errors occur.
    """
    print("--- Step 1: Data Aggregation, Pre-processing, and Sanity Checks ---")

    # Expected columns based on the problem description
    expected_columns = [
        'mass_1_source', 'mass_2_source', 'a_1', 'a_2', 'final_mass_source',
        'final_spin', 'redshift', 'cos_tilt_1', 'cos_tilt_2', 'chi_eff',
        'chi_p', 'cos_theta_jn', 'phi_jl', 'log_likelihood'
    ]

    dataframes = {}
    all_checks_passed = True

    for model in model_names:
        path = file_paths.get(model)
        print("\n" + "----------------------------------------------------")
        print("Processing model: " + model)
        print("File path: " + path)

        if not os.path.exists(path):
            print("Error: File not found at " + path)
            all_checks_passed = False
            continue

        # 1. Load Data
        try:
            df = pd.read_csv(path)
            print("Successfully loaded " + str(len(df)) + " samples.")
        except Exception as e:
            print("Error loading CSV file: " + str(e))
            all_checks_passed = False
            continue

        # 2. Add 'model' column
        df['model'] = model

        # 3. Data Cleaning and Verification
        print("\nData Verification for " + model + ":")

        # Verify columns
        actual_columns = df.columns.drop('model').tolist()
        if sorted(actual_columns) == sorted(expected_columns):
            print("- Column check: Passed. Found " + str(len(actual_columns)) + " expected parameter columns.")
        else:
            print("- Column check: Failed.")
            print("  Expected: " + str(sorted(expected_columns)))
            print("  Found:    " + str(sorted(actual_columns)))
            all_checks_passed = False

        # Check for NaN values
        nan_count = df.isnull().sum().sum()
        if nan_count == 0:
            print("- NaN check: Passed. No missing values found.")
        else:
            print("- NaN check: Failed. Found " + str(nan_count) + " missing values.")
            all_checks_passed = False

        # Check log_likelihood values
        if 'log_likelihood' in df.columns:
            if df['log_likelihood'].nunique() > 1:
                print("- Log-likelihood check: Passed. Values are not all identical.")
            else:
                print("- Log-likelihood check: Failed. All log_likelihood values are identical.")
                all_checks_passed = False
        else:
            print("- Log-likelihood check: Failed. Column not found.")
            all_checks_passed = False

        # 4. Physical Plausibility Check (Summary Statistics)
        print("\nPhysical Plausibility Check (Summary Statistics) for " + model + ":")
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', 1000)
        print(df.describe())

        dataframes[model] = df

    print("\n" + "----------------------------------------------------")
    if all_checks_passed:
        print("\n--- Summary: All data integrity and sanity checks passed successfully. ---")
    else:
        print("\n--- Summary: One or more data integrity or sanity checks failed. Please review the output above. ---")

    # Save the dictionary of dataframes for the next step
    output_dir = 'data'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print("Created directory: " + output_dir)

    output_path = os.path.join(output_dir, 'gw_data_all_models.pkl')
    try:
        with open(output_path, 'wb') as f:
            pickle.dump(dataframes, f)
        print("Successfully saved the consolidated data dictionary to: " + output_path)
    except Exception as e:
        print("Error saving data to pickle file: " + str(e))

    print("\nStep 1 execution completed.")
    return dataframes


if __name__ == '__main__':
    base_path = '/mnt/ceph/users/fanonymous/AstroPilot/GW/Iteration1/data/'
    model_names = [
        'NRSur7dq4',
        'IMRPhenomXO4a',
        'SEOBNRv5PHM',
        'IMRPhenomXPHM',
        'IMRPhenomTPHM'
    ]
    file_paths = {name: base_path + 'GW231123_' + name + '.csv' for name in model_names}

    # Run the data loading and preprocessing
    loaded_data = load_and_preprocess_data(file_paths, model_names)
