# filename: codebase/regression_modeling.py
import torch
import numpy as np
import sklearn
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, r2_score, make_scorer
from sklearn.multioutput import MultiOutputRegressor
import xgboost as xgb
import matplotlib.pyplot as plt
import scipy.stats
import os
import time
import datetime
import torch.serialization  # Added for safe_globals if needed, or to be aware of its context

# Required for unpickling data from previous steps if they contain these types
from collections import defaultdict
from torch_geometric.data import Data


# --- Configuration ---
QITT_PROCESSED_DATA_PATH = 'data/qitt_processed_data.pt'
PROCESSED_MERGER_TREES_PATH = 'data/processed_merger_trees.pt'  # For baselines B1, B2
FINAL_PROCESSED_DATA_PATH = 'data/final_processed_data.pt'   # For baseline B4
OUTPUT_DIR = 'data'
PLOTS_DIR = OUTPUT_DIR  # Save plots in data/

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

# Matplotlib settings
plt.rcParams['text.usetex'] = False
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300

# Constants from previous steps (verify these match your actual outputs)
MAX_N_SUB = 60
NUM_PHYSICAL_FEATURES = 10  # Number of features in sub_data.physical_features
D_FEAT_COMBINED = 74        # NUM_PHYSICAL_FEATURES + GNN_EMBEDDING_DIM


# --- Helper Functions ---

def get_timestamp_str():
    """Generates a timestamp string for filenames."""
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")


def custom_scorer(y_true, y_pred):
    """Custom scorer for GridSearchCV: negative sum of RMSEs."""
    rmse_omega_m = np.sqrt(mean_squared_error(y_true[:, 0], y_pred[:, 0]))
    rmse_sigma_8 = np.sqrt(mean_squared_error(y_true[:, 1], y_pred[:, 1]))
    return -(rmse_omega_m + rmse_sigma_8)


# --- Data Loading and Baseline Feature Extraction ---

def load_qitt_data(path):
    """Loads QITT features and labels."""
    print("Loading QITT data from: " + str(path))
    # Explicitly set weights_only=False as the file contains NumPy arrays and was saved
    # without specific weight-only considerations. This matches behavior of earlier PyTorch versions.
    data = torch.load(path, map_location='cpu', weights_only=False)
    X_train = data['train_qitt_features']
    X_val = data['val_qitt_features']
    X_test = data['test_qitt_features']
    y_train = data['train_labels']
    y_val = data['val_labels']
    y_test = data['test_labels']
    
    # Combine train and val for hyperparameter tuning on the larger set
    X_train_full = np.concatenate((X_train, X_val), axis=0)
    y_train_full = np.concatenate((y_train, y_val), axis=0)
    
    print("QITT data loaded. Shapes:")
    print("  X_train_full (train+val): " + str(X_train_full.shape))
    print("  y_train_full (train+val): " + str(y_train_full.shape))
    print("  X_test: " + str(X_test.shape))
    print("  y_test: " + str(y_test.shape))
    return X_train_full, y_train_full, X_test, y_test


def extract_aggregate_features_for_set(data_list):
    """Extracts aggregate features for a list of tree Data objects."""
    aggregate_features_list = []
    for tree_data in data_list:
        if not hasattr(tree_data, 'x') or tree_data.x is None or \
           not hasattr(tree_data, 'x_norm') or tree_data.x_norm is None:
            print("Warning: Tree found with missing 'x' or 'x_norm'. Skipping for aggregate features.")
            num_agg_feats = 11 
            agg_feats = np.zeros(num_agg_feats)
            aggregate_features_list.append(agg_feats)
            continue

        num_nodes = tree_data.num_nodes
        num_edges = tree_data.edge_index.shape[1] if hasattr(tree_data, 'edge_index') and tree_data.edge_index is not None else 0
        
        total_mass = torch.sum(10**tree_data.x[:, 0]).item() if tree_data.x.shape[0] > 0 else 0.0
        
        features_tree = [num_nodes, num_edges, total_mass]
        
        if tree_data.x_norm.shape[0] > 0:
            for i in range(tree_data.x_norm.shape[1]):
                features_tree.append(torch.mean(tree_data.x_norm[:, i]).item())
                features_tree.append(torch.std(tree_data.x_norm[:, i]).item())
        else: 
            # Assuming x_norm would have 4 features if nodes existed
            for _ in range(4 * 2): 
                features_tree.append(0.0)

        aggregate_features_list.append(np.array(features_tree))
    return np.array(aggregate_features_list)


def load_baseline1_aggregate_features(path):
    """Loads data and extracts aggregate graph-level features (Baseline B1)."""
    print("Loading data for Baseline B1 (Aggregate Features) from: " + str(path))
    data_step1 = torch.load(path, map_location='cpu', weights_only=False)
    
    train_trees = data_step1['train_data']
    val_trees = data_step1['val_data']
    test_trees = data_step1['test_data']

    X_train_agg = extract_aggregate_features_for_set(train_trees)
    X_val_agg = extract_aggregate_features_for_set(val_trees)
    X_test_agg = extract_aggregate_features_for_set(test_trees)
    
    X_train_full_agg = np.concatenate((X_train_agg, X_val_agg), axis=0)
    
    y_train = torch.stack([tree.y for tree in train_trees]).squeeze(1).numpy()
    y_val = torch.stack([tree.y for tree in val_trees]).squeeze(1).numpy()
    y_test = torch.stack([tree.y for tree in test_trees]).squeeze(1).numpy()
    y_train_full = np.concatenate((y_train, y_val), axis=0)

    print("Baseline B1 features extracted. Shapes:")
    print("  X_train_full_agg: " + str(X_train_full_agg.shape))
    print("  X_test_agg: " + str(X_test_agg.shape))
    return X_train_full_agg, y_train_full, X_test_agg, y_test


def extract_raw_substructure_features_for_set(data_list, max_n_sub, num_physical_features):
    """Extracts raw physical substructure features for a list of tree Data objects."""
    all_tree_feature_vectors = []
    padding_physical_vector = np.zeros(num_physical_features)

    for tree_data in data_list:
        sub_physical_features_list = []
        if hasattr(tree_data, 'substructures'):
            for sub_data in tree_data.substructures:
                sub_physical_features_list.append(sub_data.physical_features.numpy())
        
        if len(sub_physical_features_list) > max_n_sub:
            current_tree_features = np.array(sub_physical_features_list[:max_n_sub])
        else:
            padded_list = sub_physical_features_list + [padding_physical_vector] * (max_n_sub - len(sub_physical_features_list))
            current_tree_features = np.array(padded_list)
        
        all_tree_feature_vectors.append(current_tree_features.flatten())
    return np.array(all_tree_feature_vectors)


def load_baseline2_raw_substructure_features(path, max_n_sub, num_physical_features):
    """Loads data and extracts raw substructure features (Baseline B2)."""
    print("Loading data for Baseline B2 (Raw Substructure Physical Feats) from: " + str(path))
    data_step1 = torch.load(path, map_location='cpu', weights_only=False)

    train_trees = data_step1['train_data']
    val_trees = data_step1['val_data']
    test_trees = data_step1['test_data']

    X_train_raw = extract_raw_substructure_features_for_set(train_trees, max_n_sub, num_physical_features)
    X_val_raw = extract_raw_substructure_features_for_set(val_trees, max_n_sub, num_physical_features)
    X_test_raw = extract_raw_substructure_features_for_set(test_trees, max_n_sub, num_physical_features)
    
    X_train_full_raw = np.concatenate((X_train_raw, X_val_raw), axis=0)

    y_train = torch.stack([tree.y for tree in train_trees]).squeeze(1).numpy()
    y_val = torch.stack([tree.y for tree in val_trees]).squeeze(1).numpy()
    y_test = torch.stack([tree.y for tree in test_trees]).squeeze(1).numpy()
    y_train_full = np.concatenate((y_train, y_val), axis=0)
    
    print("Baseline B2 features extracted. Shapes:")
    print("  X_train_full_raw: " + str(X_train_full_raw.shape))
    print("  X_test_raw: " + str(X_test_raw.shape))
    return X_train_full_raw, y_train_full, X_test_raw, y_test


def load_baseline4_flattened_combined_features(path, max_n_sub, d_feat_combined):
    """Loads data and extracts flattened combined features (Baseline B4)."""
    print("Loading data for Baseline B4 (Flattened Combined Feats) from: " + str(path))
    data_step2 = torch.load(path, map_location='cpu', weights_only=False)

    train_tensors_list = data_step2['train_tensors']
    val_tensors_list = data_step2['val_tensors']
    test_tensors_list = data_step2['test_tensors']

    X_train_flat = np.array([t.numpy().flatten() for t in train_tensors_list])
    X_val_flat = np.array([t.numpy().flatten() for t in val_tensors_list])
    X_test_flat = np.array([t.numpy().flatten() for t in test_tensors_list])

    X_train_full_flat = np.concatenate((X_train_flat, X_val_flat), axis=0)

    y_train = torch.stack(data_step2['train_labels']).squeeze(1).numpy()
    y_val = torch.stack(data_step2['val_labels']).squeeze(1).numpy()
    y_test = torch.stack(data_step2['test_labels']).squeeze(1).numpy()
    y_train_full = np.concatenate((y_train, y_val), axis=0)

    print("Baseline B4 features extracted. Shapes:")
    print("  X_train_full_flat: " + str(X_train_full_flat.shape))
    print("  X_test_flat: " + str(X_test_flat.shape))
    return X_train_full_flat, y_train_full, X_test_flat, y_test


# --- Model Training and Evaluation ---

def train_evaluate_model(model_name, model_base, param_grid, X_train, y_train, X_test, y_test):
    """Trains, tunes, and evaluates a given regression model."""
    print("\nTraining and evaluating: " + str(model_name))
    
    multi_output_model_base = MultiOutputRegressor(model_base, n_jobs=-1)

    if param_grid:
        adjusted_param_grid = {'estimator__' + k: v for k, v in param_grid.items()}
        grid_search = GridSearchCV(multi_output_model_base, adjusted_param_grid,
                                   cv=5, scoring=make_scorer(custom_scorer, greater_is_better=True), 
                                   verbose=0, n_jobs=-1)
        print("  Starting GridSearchCV for " + str(model_name) + "...")
        grid_search.fit(X_train, y_train)
        best_model = grid_search.best_estimator_
        print("  Best hyperparameters: " + str(grid_search.best_params_))
    else:
        best_model = multi_output_model_base
        print("  Skipping GridSearchCV (no param_grid). Fitting with default parameters.")
        best_model.fit(X_train, y_train)

    y_pred_test = best_model.predict(X_test)

    rmse_omega_m = np.sqrt(mean_squared_error(y_test[:, 0], y_pred_test[:, 0]))
    rmse_sigma_8 = np.sqrt(mean_squared_error(y_test[:, 1], y_pred_test[:, 1]))
    r2_omega_m = r2_score(y_test[:, 0], y_pred_test[:, 0])
    r2_sigma_8 = r2_score(y_test[:, 1], y_pred_test[:, 1])

    results = {
        'RMSE_Omega_m': rmse_omega_m, 'RMSE_sigma_8': rmse_sigma_8,
        'R2_Omega_m': r2_omega_m, 'R2_sigma_8': r2_sigma_8,
        'model_object': best_model,
        'y_pred_test': y_pred_test
    }
    
    print("  Test Set Performance:")
    print("    RMSE Omega_m: " + str(round(rmse_omega_m, 4)))
    print("    RMSE sigma_8: " + str(round(rmse_sigma_8, 4)))
    print("    R2 Omega_m:   " + str(round(r2_omega_m, 4)))
    print("    R2 sigma_8:   " + str(round(r2_sigma_8, 4)))
    
    return results


# --- Plotting Functions ---
plot_counter = 1

def plot_model_performance_comparison(results_dict, timestamp):
    """Plots bar charts comparing RMSE and R2 for all models."""
    global plot_counter
    model_names = list(results_dict.keys())
    
    metrics_to_plot = {
        'RMSE': ['RMSE_Omega_m', 'RMSE_sigma_8'],
        'R2 Score': ['R2_Omega_m', 'R2_sigma_8']
    }
    param_labels = ['Omega_m', 'sigma_8']

    for metric_type, metric_keys in metrics_to_plot.items():
        fig, ax = plt.subplots(figsize=(12, 8))  # Increased height for rotated labels
        n_models = len(model_names)
        bar_width = 0.35
        index = np.arange(n_models)

        values_param1 = [results_dict[m][metric_keys[0]] for m in model_names]
        values_param2 = [results_dict[m][metric_keys[1]] for m in model_names]

        ax.bar(index - bar_width/2, values_param1, bar_width, label=param_labels[0])
        ax.bar(index + bar_width/2, values_param2, bar_width, label=param_labels[1])

        ax.set_xlabel('Model Configuration')
        ax.set_ylabel(metric_type)
        ax.set_title(metric_type + ' Comparison Across Models')
        ax.set_xticks(index)
        ax.set_xticklabels(model_names, rotation=45, ha="right")
        ax.legend()
        ax.grid(True, linestyle='--', alpha=0.7)
        
        fig.tight_layout()
        plot_filename = os.path.join(PLOTS_DIR, "model_comparison_" + metric_type.replace(' ', '_').lower() + "_" + str(plot_counter) + "_" + str(timestamp) + ".png")
        plt.savefig(plot_filename)
        print("Saved plot: " + str(plot_filename))
        print("  Description: Bar chart comparing " + metric_type + " for Omega_m and sigma_8 across all trained models.")
        plt.close(fig)
        plot_counter += 1


def plot_feature_importances(model, feature_names, model_name_str, timestamp, top_n=20):
    """Plots top N feature importances for tree-based models."""
    global plot_counter
    
    importances = None
    # Check for RandomForestRegressor directly or wrapped in MultiOutputRegressor
    if isinstance(model, MultiOutputRegressor) and hasattr(model.estimators_[0], 'feature_importances_'):
        importances_list = [est.feature_importances_ for est in model.estimators_]
        importances = np.mean(importances_list, axis=0)
    # Check for XGBRegressor directly or wrapped in MultiOutputRegressor
    elif isinstance(model, MultiOutputRegressor) and hasattr(model.estimator_, 'feature_importances_'):
        importances = model.estimator_.feature_importances_
    elif hasattr(model, 'feature_importances_'):
         importances = model.feature_importances_
    else:
        print("Warning: Feature importances not available or not extracted for " + str(model_name_str))
        return

    indices = np.argsort(importances)[::-1][:top_n]
    
    plt.figure(figsize=(10, max(6, top_n * 0.3)))
    plt.title("Top " + str(top_n) + " Feature Importances for " + str(model_name_str))
    
    if feature_names is None:
        feat_labels = ["Feature " + str(i) for i in indices]
    else:
        feat_labels = [feature_names[i] for i in indices]

    plt.barh(range(top_n), importances[indices][::-1], align="center")
    plt.yticks(range(top_n), np.array(feat_labels)[::-1])
    plt.xlabel("Importance")
    plt.tight_layout()
    
    plot_filename = os.path.join(PLOTS_DIR, "feature_importances_" + model_name_str.replace(' ', '_').replace('(', '').replace(')', '') + "_" + str(plot_counter) + "_" + str(timestamp) + ".png")
    plt.savefig(plot_filename)
    print("Saved plot: " + str(plot_filename))
    print("  Description: Top " + str(top_n) + " feature importances for the " + str(model_name_str) + " model.")
    plt.close()
    plot_counter += 1


def plot_predicted_vs_true(y_true, y_pred, model_name_str, param_name, timestamp):
    """Plots predicted vs. true values for a specific parameter."""
    global plot_counter
    plt.figure(figsize=(8, 8))
    plt.scatter(y_true, y_pred, alpha=0.5, label='Predictions')
    
    min_val = min(np.min(y_true), np.min(y_pred))
    max_val = max(np.max(y_true), np.max(y_pred))
    plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Ideal (y=x)')
    
    plt.xlabel("True " + str(param_name))
    plt.ylabel("Predicted " + str(param_name))
    plt.title("Predicted vs. True for " + str(param_name) + " (" + str(model_name_str) + ")")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.axis('equal')
    plt.tight_layout()

    plot_filename = os.path.join(PLOTS_DIR, "pred_vs_true_" + model_name_str.replace(' ', '_').replace('(', '').replace(')', '') + "_" + param_name.replace('_', '') + "_" + str(plot_counter) + "_" + str(timestamp) + ".png")
    plt.savefig(plot_filename)
    print("Saved plot: " + str(plot_filename))
    print("  Description: Scatter plot of predicted vs. true values for " + str(param_name) + " using the " + str(model_name_str) + " model.")
    plt.close()
    plot_counter += 1


# --- Main Execution ---

def main():
    if not os.path.exists(PLOTS_DIR):
        os.makedirs(PLOTS_DIR)

    current_timestamp = get_timestamp_str()
    all_model_results = {}

    X_train_qitt, y_train_qitt, X_test_qitt, y_test_qitt = load_qitt_data(QITT_PROCESSED_DATA_PATH)
    y_train_master = y_train_qitt 
    y_test_master = y_test_qitt

    X_train_b1, _, X_test_b1, _ = load_baseline1_aggregate_features(PROCESSED_MERGER_TREES_PATH)
    X_train_b2, _, X_test_b2, _ = load_baseline2_raw_substructure_features(PROCESSED_MERGER_TREES_PATH, MAX_N_SUB, NUM_PHYSICAL_FEATURES)
    print("\nSkipping Baseline B3 (Graphlet Counts) due to implementation complexity in this step.")
    X_train_b4, _, X_test_b4, _ = load_baseline4_flattened_combined_features(FINAL_PROCESSED_DATA_PATH, MAX_N_SUB, D_FEAT_COMBINED)

    feature_sets_train = {
        "QITT": X_train_qitt,
        "B1_Aggregate": X_train_b1,
        "B2_RawSubPhys": X_train_b2,
        "B4_FlatCombined": X_train_b4,
    }
    feature_sets_test = {
        "QITT": X_test_qitt,
        "B1_Aggregate": X_test_b1,
        "B2_RawSubPhys": X_test_b2,
        "B4_FlatCombined": X_test_b4,
    }

    models_to_run = {
        "LinearRegression": (LinearRegression(), {}),
        "RandomForest": (RandomForestRegressor(random_state=RANDOM_SEED, n_jobs=-1), {
            'n_estimators': [50, 100],
            'max_depth': [None, 10],
            'min_samples_split': [2, 5],
            'min_samples_leaf': [1, 2],
        }),
        "XGBoost": (xgb.XGBRegressor(random_state=RANDOM_SEED, objective='reg:squarederror', n_jobs=-1), {
            'n_estimators': [50, 100],
            'learning_rate': [0.05, 0.1],
            'max_depth': [3, 5],
            'subsample': [0.8, 1.0],
        })
    }

    for fs_name, X_train_fs in feature_sets_train.items():
        X_test_fs = feature_sets_test[fs_name]
        
        print("\nProcessing Feature Set: " + str(fs_name))
        print("  Train shape: " + str(X_train_fs.shape) + ", Test shape: " + str(X_test_fs.shape))

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train_fs)
        X_test_scaled = scaler.transform(X_test_fs)
        
        X_train_scaled = np.nan_to_num(X_train_scaled)
        X_test_scaled = np.nan_to_num(X_test_scaled)

        for model_key, (model_instance, params) in models_to_run.items():
            full_model_name = fs_name + "_" + model_key
            results = train_evaluate_model(full_model_name, model_instance, params,
                                           X_train_scaled, y_train_master, X_test_scaled, y_test_master)
            all_model_results[full_model_name] = results

    print("\n--- Overall Model Performance Summary ---")
    header = "Model Configuration".ljust(30) + " | " + "RMSE Omega_m".ljust(15) + " | " + "RMSE sigma_8".ljust(15) + " | " + "R2 Omega_m".ljust(15) + " | " + "R2 sigma_8".ljust(15)
    print(header)
    print("-" * 95)
    for name, res in all_model_results.items():
        line = name.ljust(30) + " | " + str(round(res['RMSE_Omega_m'], 4)).ljust(15) + " | " + str(round(res['RMSE_sigma_8'], 4)).ljust(15) + " | " + str(round(res['R2_Omega_m'], 4)).ljust(15) + " | " + str(round(res['R2_sigma_8'], 4)).ljust(15)
        print(line)

    print("\n--- Statistical Significance (Paired t-tests vs. QITT_XGBoost) ---")
    qitt_best_model_key = "QITT_XGBoost"
    if qitt_best_model_key not in all_model_results:
        qitt_keys = [k for k in all_model_results if k.startswith("QITT_")]
        if not qitt_keys:
            print("QITT model results not found for t-tests.")
        else:
            qitt_best_model_key = qitt_keys[0]
            print("Warning: QITT_XGBoost not found, using " + str(qitt_best_model_key) + " as QITT reference for t-tests.")

    if qitt_best_model_key in all_model_results:
        qitt_pred_errors_om = (all_model_results[qitt_best_model_key]['y_pred_test'][:, 0] - y_test_master[:, 0])**2
        qitt_pred_errors_s8 = (all_model_results[qitt_best_model_key]['y_pred_test'][:, 1] - y_test_master[:, 1])**2

        for model_name_full, results_data in all_model_results.items():
            if model_name_full.startswith("B") and model_name_full.endswith("XGBoost"):
                baseline_pred_errors_om = (results_data['y_pred_test'][:, 0] - y_test_master[:, 0])**2
                baseline_pred_errors_s8 = (results_data['y_pred_test'][:, 1] - y_test_master[:, 1])**2
                t_stat_om, p_val_om = scipy.stats.ttest_rel(qitt_pred_errors_om, baseline_pred_errors_om)
                t_stat_s8, p_val_s8 = scipy.stats.ttest_rel(qitt_pred_errors_s8, baseline_pred_errors_s8)
                print("  " + qitt_best_model_key + " vs " + model_name_full + ":")
                print("    Omega_m: p-value = " + str(format(p_val_om, '.4e')) + " (t-stat = " + str(round(t_stat_om, 3)) + ")")
                print("    sigma_8: p-value = " + str(format(p_val_s8, '.4e')) + " (t-stat = " + str(round(t_stat_s8, 3)) + ")")
    else:
        print("QITT_XGBoost model results not found, skipping t-tests.")

    plot_model_performance_comparison(all_model_results, current_timestamp)

    models_for_fi_plot = ["QITT_RandomForest", "QITT_XGBoost", 
                          "B4_FlatCombined_RandomForest", "B4_FlatCombined_XGBoost"]
    for model_name_fi in models_for_fi_plot:
        if model_name_fi in all_model_results:
            model_obj = all_model_results[model_name_fi]['model_object']
            if model_name_fi.startswith("QITT"):
                num_input_features = X_train_qitt.shape[1]
            elif model_name_fi.startswith("B4_FlatCombined"):
                num_input_features = X_train_b4.shape[1]
            else:
                num_input_features = 100 

            generic_feature_names = ["feature_" + str(i) for i in range(num_input_features)]
            plot_feature_importances(model_obj, generic_feature_names, model_name_fi, current_timestamp)

    if qitt_best_model_key in all_model_results:
        best_qitt_results = all_model_results[qitt_best_model_key]
        plot_predicted_vs_true(y_test_master[:, 0], best_qitt_results['y_pred_test'][:, 0], 
                               qitt_best_model_key, "Omega_m", current_timestamp)
        plot_predicted_vs_true(y_test_master[:, 1], best_qitt_results['y_pred_test'][:, 1], 
                               qitt_best_model_key, "sigma_8", current_timestamp)
                                
    print("\nStep 4 (Regression Modeling and Baseline Comparisons) complete.")


if __name__ == '__main__':
    main()
