import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import logging
import random
import os
import joblib
import optuna
import torch
import hydra
import pickle
from omegaconf import DictConfig as Config
from scipy.stats import ks_2samp, gaussian_kde, wasserstein_distance, entropy, chi2_contingency
from scipy.spatial.distance import cdist, jensenshannon
from sklearn.metrics import classification_report, balanced_accuracy_score, roc_auc_score, confusion_matrix, pairwise_distances, f1_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neural_network import MLPClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
from catboost import CatBoostClassifier
from sklearn.neighbors import NearestNeighbors
from .my_mlflow import log_eval_experiment
from sklearn.preprocessing import LabelEncoder
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.compose import ColumnTransformer
from sklearn.feature_selection import mutual_info_classif


_log = logging.getLogger(__name__)


# ------------------ UTILITY EVALUATION FUNCTIONS ------------------ #


classifiers_dict = {
    'RandomForest': "RandomForestClassifier",
    'XGBoost': "XGBClassifier",
    'LightGBM': "LGBMClassifier",
    'CatBoost': "CatBoostClassifier",
    'LogisticRegression': "LogisticRegression",
    'SVM': "SVC",
    'MLP': "MLPClassifier",
}


def is_gpu_available():
    return torch.cuda.is_available()


def optuna_objective(trial, model_name, X_train, y_train, X_test, y_test):
    use_gpu = is_gpu_available()

    if model_name == "RandomForest":
        params = {
            "n_estimators": trial.suggest_int("n_estimators", 50, 300),
            "max_depth": trial.suggest_int("max_depth", 3, 20),
            "min_samples_split": trial.suggest_int("min_samples_split", 2, 10),
            "min_samples_leaf": trial.suggest_int("min_samples_leaf", 1, 10),
            "n_jobs": -1
        }
        model = RandomForestClassifier(**params)

    elif model_name == "XGBoost":
        params = {
            "n_estimators": trial.suggest_int("n_estimators", 50, 300),
            "max_depth": trial.suggest_int("max_depth", 3, 20),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
            "tree_method": "gpu_hist" if use_gpu else "hist",
            "device": "cuda" if use_gpu else "cpu"
        }
        model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', **params)

    elif model_name == "LightGBM":
        params = {
            "n_estimators": trial.suggest_int("n_estimators", 50, 300),
            "num_leaves": trial.suggest_int("num_leaves", 20, 100),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
            "device": "gpu" if use_gpu else "cpu"
        }
        model = LGBMClassifier(**params)

    elif model_name == "CatBoost":
        params = {
            "iterations": trial.suggest_int("iterations", 50, 300),
            "depth": trial.suggest_int("depth", 3, 10),
            "learning_rate": trial.suggest_float("learning_rate", 0.01, 0.3),
            "task_type": "GPU" if use_gpu else "CPU",
            "devices": "0"
        }
        model = CatBoostClassifier(verbose=0, **params)

    elif model_name == "LogisticRegression":
        params = {
            "C": trial.suggest_float("C", 0.01, 10.0, log=True),
            "penalty": trial.suggest_categorical("penalty", ["l1", "l2"]),
            "solver": trial.suggest_categorical("solver", ["liblinear", "saga"]),
            "n_jobs": -1
        }
        model = LogisticRegression(**params)

    elif model_name == "SVM":
        params = {
            "C": trial.suggest_float("C", 0.01, 10.0, log=True),
            "kernel": trial.suggest_categorical("kernel", ["linear", "rbf"]),
            "probability": True
        }
        model = SVC(**params)

    elif model_name == "MLP":
        params = {
            "hidden_layer_sizes": trial.suggest_categorical("hidden_layer_sizes", [(100,), (50, 50), (100, 50)]),
            "activation": trial.suggest_categorical("activation", ["relu", "tanh"]),
            "alpha": trial.suggest_float("alpha", 1e-5, 1e-1, log=True),
            "max_iter": 500
        }
        model = MLPClassifier(**params)

    else:
        raise ValueError(f"Model {model_name} is not supported.")

    model.fit(X_train, y_train)
    preds = model.predict(X_test)
    return f1_score(y_test, preds, average="macro")
    # return balanced_accuracy_score(y_test, preds)


def optimize_model(model_name, X_train, y_train, X_test, y_test, n_trials=30):
    study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=42))
    study.optimize(lambda trial: optuna_objective(trial, model_name, X_train, y_train, X_test, y_test), n_trials=n_trials)
    return study.best_trial.params


def evaluate_mle(real_data, synthetic_data, test_data, cfg: Config, n_runs=10):
    _log.info("Starting MLE evaluation...")
    _log.info(f"Real data shape: {real_data.shape}, Synthetic data shape: {synthetic_data.shape}, Test data shape: {test_data.shape}")

    classifiers = ["RandomForest", "XGBoost", "LightGBM", "CatBoost", "LogisticRegression", "SVM", "MLP"]
    results = []
    
    X_real = real_data.drop(columns=[cfg.dataset.target_column])
    y_real = real_data[cfg.dataset.target_column]

    X_syn = synthetic_data.drop(columns=[cfg.dataset.target_column])
    y_syn = synthetic_data[cfg.dataset.target_column]

    X_test = test_data.drop(columns=[cfg.dataset.target_column])
    y_test = test_data[cfg.dataset.target_column]
    
    for name, clf in classifiers_dict.items(): # name in classifiers: # name, clf in classifiers_dict.items():
        _log.info(f"Evaluating classifier: {name}") 
        _log.info(f"Optimizing hyperparameters for {name} using Optuna")
        best_params = optimize_model(name, X_real, y_real, X_test, y_test)

        balanced_acc_real_arr = np.zeros(n_runs)
        balanced_acc_syn_arr = np.zeros(n_runs)
        
        f1_score_real_class0_arr = np.zeros(n_runs)
        f1_score_real_class1_arr = np.zeros(n_runs)
        f1_score_real_arr = np.zeros(n_runs)
        f1_score_syn_class0_arr = np.zeros(n_runs)
        f1_score_syn_class1_arr = np.zeros(n_runs)
        f1_score_syn_arr = np.zeros(n_runs)
        
        auc_real_arr = np.zeros(n_runs)
        auc_syn_arr = np.zeros(n_runs)

        best_balanced_acc_real = 0
        best_balanced_acc_syn = 0

        best_f1_score_real = 0
        best_f1_score_syn = 0

        best_real_cm = None
        best_syn_cm = None
        
        best_real_model = None
        best_syn_model = None

        # average results over 'n_runs' runs for each model
        for i in range(n_runs):
            _log.info(f"Run {i + 1}/{n_runs} for {name}")

            clf_real = globals()[clf](**best_params) # globals()[classifiers_dict[name]](**best_params)
            clf_real.fit(X_real, y_real)
            real_preds = clf_real.predict(X_test)
            real_report = classification_report(y_test, real_preds, output_dict=True)

            clf_syn = globals()[clf](**best_params) # globals()[classifiers_dict[name]](**best_params)
            clf_syn.fit(X_syn, y_syn)
            syn_preds = clf_syn.predict(X_test)
            syn_report = classification_report(y_test, syn_preds, output_dict=True)

            balanced_acc_real = balanced_accuracy_score(y_test, real_preds)
            balanced_acc_syn = balanced_accuracy_score(y_test, syn_preds)

            balanced_acc_real_arr[i] = balanced_acc_real
            balanced_acc_syn_arr[i] = balanced_acc_syn

            if balanced_acc_real > best_balanced_acc_real:
                best_balanced_acc_real = balanced_acc_real
                best_real_cm = confusion_matrix(y_test, real_preds)
                best_real_model = clf_real

            if balanced_acc_syn > best_balanced_acc_syn:
                best_balanced_acc_syn = balanced_acc_syn
                best_syn_cm = confusion_matrix(y_test, syn_preds)
                best_syn_model = clf_syn

            f1_score_real_class0_arr[i] = real_report['0']['f1-score']
            f1_score_real_class1_arr[i] = real_report['1']['f1-score']
            f1_score_real_arr[i] = real_report['macro avg']['f1-score']
            f1_score_syn_class0_arr[i] = syn_report['0']['f1-score']
            f1_score_syn_class1_arr[i] = syn_report['1']['f1-score']
            f1_score_syn_arr[i] = syn_report['macro avg']['f1-score']

            if f1_score_real_arr[i] > best_f1_score_real:
                best_f1_score_real = f1_score_real_arr[i]
                best_real_cm = confusion_matrix(y_test, real_preds)
                best_real_model = clf_real

            if f1_score_syn_arr[i] > best_f1_score_syn:
                best_f1_score_syn = f1_score_syn_arr[i]
                best_syn_cm = confusion_matrix(y_test, syn_preds)
                best_syn_model = clf_syn

            # if len(y_test.unique()) == 2:
            auc_real_arr[i] = roc_auc_score(y_test, real_preds)
            auc_syn_arr[i] = roc_auc_score(y_test, syn_preds)

        # save best models
        model_dir = os.path.join(
            cfg.paths.MLE_models_dir,
            cfg.datagen_method.name,
            cfg.dataset.name
        )
        os.makedirs(model_dir, exist_ok=True)

        real_model_path = os.path.join(model_dir, f"{name}_best_real_model.pkl")
        syn_model_path = os.path.join(model_dir, f"{name}_best_synth_model.pkl")

        _log.info(f"Saving best REAL model for {name} to: {real_model_path}")
        joblib.dump(best_real_model, real_model_path)

        _log.info(f"Saving best SYNTH model for {name} to: {syn_model_path}")
        joblib.dump(best_syn_model, syn_model_path)

        results.append({
            'Classifier': name,
            'Avg Balanced Accuracy (Real)': np.mean(balanced_acc_real_arr),
            'Std Balanced Accuracy (Real)': np.std(balanced_acc_real_arr),
            'Avg Balanced Accuracy (Synthetic)': np.mean(balanced_acc_syn_arr),
            'Std Balanced Accuracy (Synthetic)': np.std(balanced_acc_syn_arr),
            'Avg Balanced Accuracy Diff': np.mean(balanced_acc_syn_arr - balanced_acc_real_arr),

            'Avg F1 Score (Class 0, Real)': np.mean(f1_score_real_class0_arr),
            'Std F1 Score (Class 0, Real)': np.std(f1_score_real_class0_arr),
            'Avg F1 Score (Class 0, Synthetic)': np.mean(f1_score_syn_class0_arr),
            'Std F1 Score (Class 0, Synthetic)': np.std(f1_score_syn_class0_arr),
            'Avg F1 Score Diff (Class 0)': np.mean(f1_score_syn_class0_arr - f1_score_real_class0_arr),

            'Avg F1 Score (Class 1, Real)': np.mean(f1_score_real_class1_arr),
            'Std F1 Score (Class 1, Real)': np.std(f1_score_real_class1_arr),
            'Avg F1 Score (Class 1, Synthetic)': np.mean(f1_score_syn_class1_arr),
            'Std F1 Score (Class 1, Synthetic)': np.std(f1_score_syn_class1_arr),
            'Avg F1 Score Diff (Class 1)': np.mean(f1_score_syn_class1_arr - f1_score_real_class1_arr),

            'Avg F1 Score (Real)': np.mean(f1_score_real_arr),
            'Std F1 Score (Real)': np.std(f1_score_real_arr),
            'Avg F1 Score (Synthetic)': np.mean(f1_score_syn_arr),
            'Std F1 Score (Synthetic)': np.std(f1_score_syn_arr),
            'Avg F1 Score Diff': np.mean(f1_score_syn_arr - f1_score_real_arr),

            'Avg AUC-ROC (Real)': np.mean(auc_real_arr),
            'Std AUC-ROC (Real)': np.std(auc_real_arr),
            'Avg AUC-ROC (Synthetic)': np.mean(auc_syn_arr),
            'Std AUC-ROC (Synthetic)': np.std(auc_syn_arr),
            'Avg AUC-ROC Diff': np.mean(auc_syn_arr - auc_real_arr),

            'Confusion Matrix of Best Model (Real)': best_real_cm.tolist(),
            'Confusion Matrix of Best Model (Synthetic)': best_syn_cm.tolist(),

            'Best Hyperparameters': best_params
        })

        _log.info(f"Finished evaluating {name}.")

    mle_results = pd.DataFrame(results)#.T
    _log.info("MLE evaluation completed successfully.")

    return mle_results


# ------------------ FIDELITY EVALUATION FUNCTIONS ------------------ #


def compute_density_error_WD_JSD(real_data, synthetic_data, num_columns):
    """
    Compute column-wise density error:
    - Standardized Wasserstein Distance for numerical columns
    - Jensen-Shannon Divergence (base 2) for categorical columns

    Returns:
        avg_wd, avg_jsd
    """
    _log.info("Computing density error with Wasserstein Distance (WD) and Jensen-Shannon Divergence (JSD)...")

    errors_wd = []
    jsds = []
    for col in synthetic_data.columns:
        if col in num_columns:
            _log.info(f"Computing WD for numerical column: {col}")
            real_col = real_data[col].dropna()
            synth_col = synthetic_data[col].dropna()

            # Standardize both columns (z-score)
            real_std = (real_col - real_col.mean()) / (real_col.std() + 1e-10)
            synth_std = (synth_col - synth_col.mean()) / (synth_col.std() + 1e-10)

            # Compute WD on standardized columns
            wd = wasserstein_distance(real_std, synth_std)
            _log.info(f"WD for numerical column {col}: {wd:.4f}")
            errors_wd.append(wd)
        else:
            _log.info(f"Computing JSD for categorical column: {col}")
            real_counts = real_data[col].value_counts(normalize=True)
            synth_counts = synthetic_data[col].value_counts(normalize=True)

            all_categories = sorted(set(real_counts.index).union(set(synth_counts.index)))
            real_probs = np.array([real_counts.get(cat, 0) for cat in all_categories])
            synth_probs = np.array([synth_counts.get(cat, 0) for cat in all_categories])

            jsd = jensenshannon(real_probs, synth_probs, base=2.0) #/ np.log(2)
            jsds.append(jsd)

    avg_wd = np.mean(errors_wd) if errors_wd else 0.0
    avg_jsd = np.mean(jsds) if jsds else 0.0

    _log.info(f"Average WD: {avg_wd:.4f}")
    _log.info(f"Average JSD: {avg_jsd:.4f}")

    return avg_wd, avg_jsd


def compute_pairwise_correlation_error(real_corr, synthetic_corr):
    _log.info("Computing pairwise correlation error...")
    # error rate for pairwise column correlations
    error_matrix = np.abs(real_corr - synthetic_corr)
    mask = np.triu(np.ones(error_matrix.shape), k=1).astype(bool)
    return error_matrix.where(mask).mean().mean() * 100
    # return error_matrix.mean().mean() * 100


def plot_correlation_error(real_corr, synth_corr, cfg: Config):
    """
    Absolute differences between correlation matrices of real and synthetic datasets
    """
    _log.info("Plotting correlation error heatmap...")
    error_matrix = np.abs(real_corr - synth_corr)
    plt.figure(figsize=(10, 8))
    sns.heatmap(error_matrix, annot=False, cmap='Reds', cbar=True, vmin=0, vmax=1)
    plt.title(f'Absolute Correlation Difference - {cfg.datagen_method.name} ({cfg.dataset.name})')
    save_fig_path = f"{cfg.paths.results_dir}{cfg.datagen_method.name}/{cfg.dataset.name}/{cfg.datagen_method.params.sampling_strategy}"
    os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
    save_path = f"{save_fig_path}/correlation_difference.png"
    plt.savefig(save_path)
    plt.close()

    return save_path


def plot_feature_distributions(real_data, synthetic_data, cfg: Config):
    _log.info("Plotting feature distributions for numerical and categorical columns...")
    num_cols = real_data.select_dtypes(include=[np.number]).columns
    cat_cols = real_data.select_dtypes(exclude=[np.number]).columns
    
    total_features = len(num_cols) + len(cat_cols)
    num_cols_per_row = 3
    num_rows = int(np.ceil(total_features / num_cols_per_row))
    
    fig, axes = plt.subplots(nrows=num_rows, ncols=num_cols_per_row, figsize=(15, num_rows * 4))
    fig.subplots_adjust(hspace=0.6, wspace=0.3)
    axes = axes.flatten()
    
    for i, col in enumerate(num_cols):
        sns.histplot(real_data[col], label='Real', color='blue', alpha=0.5, kde=True, ax=axes[i])
        sns.histplot(synthetic_data[col], label=cfg.datagen_method.name, color='red', alpha=0.5, kde=True, ax=axes[i])
        axes[i].set_title(f"{col}")
        axes[i].legend()
    
    for i, col in enumerate(cat_cols, start=len(num_cols)):
        real_counts = real_data[col].value_counts(normalize=True)
        synth_counts = synthetic_data[col].value_counts(normalize=True)
        categories = set(real_counts.index).union(set(synth_counts.index))
        real_values = [real_counts.get(cat, 0) for cat in categories]
        synth_values = [synth_counts.get(cat, 0) for cat in categories]
        
        x = np.arange(len(categories))
        width = 0.4
        axes[i].bar(x - width/2, real_values, width=width, label='Real', color='blue', alpha=0.7)
        axes[i].bar(x + width/2, synth_values, width=width, label=cfg.datagen_method.name, color='red', alpha=0.7)
        axes[i].set_xticks(x)
        axes[i].set_xticklabels(categories, rotation=45, ha='right')
        axes[i].set_title(f"{col}")
        axes[i].legend()
    
    # Hide unused subplots
    for j in range(total_features, len(axes)):
        fig.delaxes(axes[j])
    
    save_fig_path = f'{cfg.paths.results_dir}/{cfg.datagen_method.name}/{cfg.dataset.name}/{cfg.datagen_method.params.sampling_strategy}'
    os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
    save_path = f"{save_fig_path}/feature_distributions.png"
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

    return save_path


def plot_feature_distributions_per_class(real_data, synthetic_data, cfg: Config):
    """
    Plot per-class feature distributions between real and synthetic data.
    """
    _log.info("Plotting per-class feature distributions...")
    target_col = cfg.dataset.target_column
    output_dir = f'{cfg.paths.results_dir}/{cfg.datagen_method.name}/{cfg.dataset.name}/{cfg.datagen_method.params.sampling_strategy}/per_class_feature_distributions'
    os.makedirs(output_dir, exist_ok=True)

    classes = sorted(real_data[target_col].unique())
    feature_cols = [col for col in real_data.columns if col != target_col]

    fig_paths = []

    for cls in classes:
        real_subset = real_data[real_data[target_col] == cls]
        synth_subset = synthetic_data[synthetic_data[target_col] == cls]

        num_cols = real_subset.select_dtypes(include=[np.number]).columns
        cat_cols = real_subset.select_dtypes(exclude=[np.number]).columns

        total_features = len(num_cols) + len(cat_cols)
        cols_per_row = 3
        nrows = int(np.ceil(total_features / cols_per_row))

        fig, axes = plt.subplots(nrows=nrows, ncols=cols_per_row, figsize=(15, nrows * 4))
        fig.subplots_adjust(hspace=0.6, wspace=0.3)
        axes = axes.flatten()

        for i, col in enumerate(num_cols):
            sns.histplot(real_subset[col], label='Real', color='blue', kde=True, alpha=0.5, ax=axes[i])
            sns.histplot(synth_subset[col], label='Synthetic', color='red', kde=True, alpha=0.5, ax=axes[i])
            axes[i].set_title(f"{col}")
            axes[i].legend()

        for i, col in enumerate(cat_cols, start=len(num_cols)):
            real_counts = real_subset[col].value_counts(normalize=True)
            synth_counts = synth_subset[col].value_counts(normalize=True)
            categories = set(real_counts.index).union(set(synth_counts.index))
            real_vals = [real_counts.get(cat, 0) for cat in categories]
            synth_vals = [synth_counts.get(cat, 0) for cat in categories]
            x = np.arange(len(categories))
            width = 0.4
            axes[i].bar(x - width/2, real_vals, width=width, label='Real', color='blue')
            axes[i].bar(x + width/2, synth_vals, width=width, label='Synthetic', color='red')
            axes[i].set_xticks(x)
            axes[i].set_xticklabels(categories, rotation=45)
            axes[i].set_title(f"{col}")
            axes[i].legend()

        for j in range(total_features, len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        fig_path = f'{output_dir}/feature_distributions_per_class_{cls}.png'
        # os.makedirs(os.path.dirname(fig_path), exist_ok=True)
        plt.savefig(fig_path)
        plt.close()
        _log.info(f"Saved per-class distribution plot for class {cls} to {fig_path}")
        fig_paths.append(fig_path)
    
    return fig_paths


def cramers_v(x, y):
    """Cramér's V statistic for categorical-categorical association."""
    confusion_matrix = pd.crosstab(x, y)
    chi2 = chi2_contingency(confusion_matrix)[0]
    n = confusion_matrix.sum().sum()
    phi2 = chi2 / n
    r, k = confusion_matrix.shape
    phi2_corr = max(0, phi2 - ((k - 1)*(r - 1)) / (n - 1))
    r_corr = r - ((r - 1)**2) / (n - 1)
    k_corr = k - ((k - 1)**2) / (n - 1)
    return np.sqrt(phi2_corr / min((k_corr - 1), (r_corr - 1)))


def correlation_ratio(categories, measurements):
    """Correlation Ratio (eta squared) for numerical-categorical association."""
    fcat, _ = pd.factorize(categories)
    cat_num = np.max(fcat) + 1
    y_avg_array = np.zeros(cat_num)
    n_array = np.zeros(cat_num)
    for i in range(cat_num):
        cat_measures = measurements[np.argwhere(fcat == i).flatten()]
        n_array[i] = len(cat_measures)
        y_avg_array[i] = np.mean(cat_measures) if len(cat_measures) > 0 else 0
    y_total_avg = np.sum(y_avg_array * n_array) / np.sum(n_array)
    numerator = np.sum(n_array * (y_avg_array - y_total_avg)**2)
    denominator = np.sum((measurements - y_total_avg)**2)
    return np.sqrt(numerator / denominator) if denominator != 0 else 0


def compute_mixed_correlation(df, numerical_cols, categorical_cols):
    all_cols = numerical_cols + categorical_cols
    corr_matrix = pd.DataFrame(index=all_cols, columns=all_cols, dtype=float)

    for col1 in all_cols:
        for col2 in all_cols:
            if col1 == col2:
                corr_matrix.loc[col1, col2] = 1.0
            elif (col1 in numerical_cols) and (col2 in numerical_cols):
                corr_matrix.loc[col1, col2] = df[[col1, col2]].corr(method="pearson").iloc[0, 1]
            elif (col1 in categorical_cols) and (col2 in categorical_cols):
                corr_matrix.loc[col1, col2] = cramers_v(df[col1], df[col2])
            else:
                # Mixed: num-cat or cat-num
                num_col, cat_col = (col1, col2) if col1 in numerical_cols else (col2, col1)
                corr_matrix.loc[col1, col2] = correlation_ratio(df[cat_col].values, df[num_col].values)
    return corr_matrix


def evaluate_correlation(real_data, synthetic_data, cfg: Config):
    _log.info("Evaluating correlation differences between real and synthetic data...")
    num_cols = cfg.dataset.numerical_columns
    cat_cols = cfg.dataset.categorical_columns
    reconstructed_real = real_data.copy()
    reconstructed_synth = synthetic_data.copy()
    all_cols = num_cols + cat_cols

    real_corr_matrix = compute_mixed_correlation(reconstructed_real[all_cols], num_cols, cat_cols)
    synth_corr_matrix = compute_mixed_correlation(reconstructed_synth[all_cols], num_cols, cat_cols)

    # Save correlation matrices separately
    output_dir = f'{cfg.paths.results_dir}/{cfg.datagen_method.name}/{cfg.dataset.name}/{cfg.datagen_method.params.sampling_strategy}'
    os.makedirs(output_dir, exist_ok=True)
    real_corr_matrix.to_csv(f'{output_dir}/real_correlation_matrix.csv', index=False)
    synth_corr_matrix.to_csv(f'{output_dir}/synth_correlation_matrix.csv', index=False)
    _log.info(f"Saved correlation matrices to {output_dir}")

    error_rate = compute_pairwise_correlation_error(real_corr_matrix, synth_corr_matrix)
    print(f"Pairwise Column Correlation Error: {error_rate:.2f}%")

    corr_error_plot_path = None
    corr_error_plot_path = plot_correlation_error(real_corr_matrix, synth_corr_matrix, cfg)

    return error_rate, corr_error_plot_path


# ------------------ PRIVACY EVALUATION FUNCTIONS ------------------ #


def compute_dcr_and_nndr(X_real, X_syn, sample_frac=0.15):
    real_sampled = X_real.sample(frac=sample_frac, random_state=42).to_numpy()
    synth_sampled = X_syn.sample(frac=sample_frac, random_state=42).to_numpy()

    scaler_real = StandardScaler().fit(real_sampled)
    scaler_synth = StandardScaler().fit(synth_sampled)

    real_scaled = scaler_real.transform(real_sampled)
    synth_scaled = scaler_synth.transform(synth_sampled)

    # Computing pair-wise distances between real and synthetic 
    dist_rf = pairwise_distances(real_scaled, synth_scaled)
    # Computing pair-wise distances within real 
    dist_rr = pairwise_distances(real_scaled)
    # Computing pair-wise distances within synthetic
    dist_ff = pairwise_distances(synth_scaled)

    dist_rr = dist_rr[~np.eye(dist_rr.shape[0], dtype=bool)].reshape(dist_rr.shape[0], -1)
    dist_ff = dist_ff[~np.eye(dist_ff.shape[0], dtype=bool)].reshape(dist_ff.shape[0], -1)

    min_rf = np.sort(dist_rf, axis=1)[:, :2]
    min_rr = np.sort(dist_rr, axis=1)[:, :2]
    min_ff = np.sort(dist_ff, axis=1)[:, :2]

    dcr_rf = np.percentile(min_rf[:, 0], 5)
    dcr_rr = np.percentile(min_rr[:, 0], 5)
    dcr_ff = np.percentile(min_ff[:, 0], 5)

    nndr_rf = np.percentile(min_rf[:, 0] / (min_rf[:, 1] + 1e-10), 5)
    nndr_rr = np.percentile(min_rr[:, 0] / (min_rr[:, 1] + 1e-10), 5)
    nndr_ff = np.percentile(min_ff[:, 0] / (min_ff[:, 1] + 1e-10), 5)

    return {
        "DCR_RF": dcr_rf,
        "DCR_RR": dcr_rr,
        "DCR_FF": dcr_ff,
        "NNDR_RF": nndr_rf,
        "NNDR_RR": nndr_rr,
        "NNDR_FF": nndr_ff
    }


# ------------------ RUN ALL EVALUATION FUNCTIONS ------------------ #


def fit_and_save_preprocessor(cfg: Config, df: pd.DataFrame) -> ColumnTransformer:
    df = df.dropna()
    X = df.drop(columns=[cfg.dataset.target_column])

    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), list(cfg.dataset.numerical_columns)),
            ('cat', OneHotEncoder(handle_unknown='ignore'), list(cfg.dataset.categorical_columns))
        ],
        remainder='passthrough',
    )

    preprocessor.fit(X)

    os.makedirs(os.path.dirname(cfg.paths.processed_data_dir), exist_ok=True)
    joblib.dump(preprocessor, f"{cfg.paths.processed_data_dir}/preprocessor.pkl")
    _log.info(f"Preprocessor saved to {cfg.paths.processed_data_dir}/preprocessor.pkl")

    return preprocessor


def transform_with_preprocessor(cfg: Config, df: pd.DataFrame) -> pd.DataFrame:
    preprocessor = joblib.load(f"{cfg.paths.processed_data_dir}/preprocessor.pkl")
    df = df.dropna()

    X = df.drop(columns=[cfg.dataset.target_column])
    y = df[cfg.dataset.target_column]

    X_processed = preprocessor.transform(X)
    feature_names = preprocessor.get_feature_names_out()

    if hasattr(X_processed, 'toarray'):
        # If the output is sparse, convert it to a dense array
        X_processed = X_processed.toarray()
    X_processed_df = pd.DataFrame(X_processed, columns=feature_names, index=df.index)
    X_processed_df[cfg.dataset.target_column] = y

    return X_processed_df


# uncomment line below if running this file directly 
# @hydra.main(version_base=None, config_path='../conf', config_name="datagen")
def run_evaluation(cfg: Config, training_run_id = None, gen_time = None):
    _log.info("Starting run_evaluation process...")
    
    real_data = pd.read_csv(f"{cfg.paths.processed_data_dir}{cfg.dataset.splits.train.path}")
    test_data = pd.read_csv(f"{cfg.paths.processed_data_dir}{cfg.dataset.splits.test.path}")
    synthetic_data = pd.read_csv(f"{cfg.paths.synth_data_dir}{cfg.datagen_method.name}/{cfg.datagen_method.params.sampling_strategy}/{cfg.dataset.synthetic_path}")

    fit_and_save_preprocessor(cfg, real_data)

    real_data_preprocessed = transform_with_preprocessor(cfg, real_data)
    synthetic_data_preprocessed = transform_with_preprocessor(cfg, synthetic_data)
    test_data_preprocessed = transform_with_preprocessor(cfg, test_data)

    target_col = cfg.dataset.target_column
    numerical_columns=cfg.dataset.numerical_columns
    categorical_columns=cfg.dataset.categorical_columns

    # MLE scores
    mle_results = evaluate_mle(real_data_preprocessed, synthetic_data_preprocessed, test_data_preprocessed, cfg)
    mle_dir = os.path.join(
        cfg.paths.results_dir,
        cfg.datagen_method.name,
        cfg.dataset.name,
        cfg.datagen_method.params.sampling_strategy
    )
    os.makedirs(mle_dir, exist_ok=True)
    mle_results.to_csv(f"{mle_dir}/mle_score.csv", index=True)
    _log.info(f"MLE scores calculated")

    # Global privacy metrics
    dcr_metrics = compute_dcr_and_nndr(
        X_real=real_data_preprocessed.drop(columns=[target_col]),
        X_syn=synthetic_data_preprocessed.drop(columns=[target_col])
    )
    _log.info(f"DCR and NNDR metrics: {dcr_metrics}")

    # Per-class DCR/NNDR
    dcr_nndr_per_class_records = []
    classes = sorted(real_data[target_col].unique())
    for cls in classes:
        real_cls = real_data_preprocessed[real_data_preprocessed[target_col] == cls]
        synth_cls = synthetic_data_preprocessed[synthetic_data_preprocessed[target_col] == cls]

        if len(real_cls) < 2 or len(synth_cls) < 2:
            _log.warning(f"Too few samples for class {cls} in '{cfg.dataset.name}' for method '{cfg.datagen_method.name}'. Skipping.")
            continue

        try:
            privacy_per_class = compute_dcr_and_nndr(
                real_cls.drop(columns=[target_col]),
                synth_cls.drop(columns=[target_col])
            )
            dcr_nndr_per_class_records.append({
                        "class": cls,
                        **privacy_per_class
                    })
        except Exception as e:
            _log.warning(f"Failed to compute per-class DCR/NNDR for '{cfg.dataset.name}' for method '{cfg.datagen_method.name}'-class{cls}: {e}")

    # Average density errors
    _log.info("Computing average density errors...")
    avg_wd, avg_jsd = compute_density_error_WD_JSD(real_data, synthetic_data, numerical_columns)
    _log.info(f"'WD average (numerical columns)': {avg_wd}, 'JSD average (categorical columns)': {avg_jsd}")

    # Average density errors per class
    _log.info("Computing per-class density errors...")
    wd_jsd_per_class = []

    if target_col in real_data.columns:
        for cls in sorted(real_data[target_col].unique()):
            real_cls = real_data[real_data[target_col] == cls]
            synth_cls = synthetic_data[synthetic_data[target_col] == cls]

            wd, jsd = compute_density_error_WD_JSD(real_cls, synth_cls, numerical_columns)
            _log.info(f"Class {cls} - WD: {wd:.4f}, JSD: {jsd:.4f}")

            wd_jsd_per_class.append({
                "class": cls,
                "WD": wd,
                "JSD": jsd
            })
    else:
        _log.warning(f"Target column '{target_col}' not found in {cfg.dataset.name}")

    ft_distr_path = plot_feature_distributions(real_data, synthetic_data, cfg)
    ft_distr_paths_per_class = plot_feature_distributions_per_class(real_data, synthetic_data, cfg)

    corr_results, matrix_path = evaluate_correlation(real_data, synthetic_data, cfg)
    _log.info(f"Correlation Evaluation: {corr_results}")

    if (cfg.mlflow):
        # Log evaluation experiment to MLflow
        log_eval_experiment(
            cfg=cfg, 
            mle_results=mle_results,
            dcr_metrics=dcr_metrics,
            dcr_nndr_per_class_records=dcr_nndr_per_class_records,
            avg_wd=avg_wd,
            avg_jsd=avg_jsd,
            wd_jsd_per_class=wd_jsd_per_class,
            corr_results=corr_results,
            matrix_path=matrix_path,
            ft_distr_path=ft_distr_path,
            ft_distr_paths_per_class=ft_distr_paths_per_class,
            training_run_id=training_run_id,
            gen_time=gen_time
            )
        _log.info(f"Eval metrics saved to MLflow server.")


if __name__ == "__main__":
    run_evaluation()
