import numpy as np
import pandas as pd
import time
import random
import os
import shap
import lime
import lime.lime_tabular
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.svm import SVR

from scipy.stats import pearsonr

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from captum.attr import IntegratedGradients, Occlusion

#####################################################################
# 1) Data Generation
#####################################################################
def generate_linear_data(n_samples=1000, n_features=6, noise=1.0, random_state=42):
    """
    2 informative features => [0,1] are causal.
    """
    rng = np.random.RandomState(random_state)
    from sklearn.datasets import make_regression
    X, y = make_regression(
        n_samples=n_samples,
        n_features=n_features,
        n_informative=2,
        noise=noise,
        random_state=rng
    )
    causal_features = [0,1]
    return X, y, causal_features

def generate_nonlinear_data(n_samples=1000, random_state=42):
    """
    New version of non-linear data:
    y = X0 * sin(X0) * log(1 + |X1|) + noise
    [0,1] are causal.
    """
    rng = np.random.RandomState(random_state)
    X = rng.normal(0,1, size=(n_samples,6))
    noise = 0.2 * rng.normal(0,1,n_samples)

    y = X[:,0] * np.sin(X[:,0]) * np.log(1.0 + np.abs(X[:,1])) + noise
    causal_features = [0,1]
    return X, y, causal_features

#####################################################################
# 2) PyTorch MLP for regression
#####################################################################
class MLPRegressorTorch(nn.Module):
    """
    A simple PyTorch MLP for regression.
    """
    def __init__(self, n_features=6):
        super(MLPRegressorTorch, self).__init__()
        self.fc1 = nn.Linear(n_features, 32)
        self.fc2 = nn.Linear(32, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def train_pytorch_mlp(X_train, y_train, n_features=6, epochs=50, lr=1e-3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLPRegressorTorch(n_features=n_features).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    X_t = torch.from_numpy(X_train).float().to(device)
    y_t = torch.from_numpy(y_train).float().to(device).view(-1,1)

    dataset_size = len(X_train)
    batch_size = 64

    for epoch in range(epochs):
        model.train()
        perm = torch.randperm(dataset_size)
        X_t = X_t[perm]
        y_t = y_t[perm]

        running_loss = 0.0
        n_steps = 0

        for i in range(0, dataset_size, batch_size):
            x_batch = X_t[i:i+batch_size]
            y_batch = y_t[i:i+batch_size]

            optimizer.zero_grad()
            preds = model(x_batch)
            loss = criterion(preds, y_batch)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            n_steps += 1
        # if epoch%10==0:
        #     print(f"[Epoch {epoch}] loss: {running_loss/n_steps:.4f}")

    return model

def evaluate_pytorch_mlp(model, X_test, y_test):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    X_t = torch.from_numpy(X_test).float().to(device)
    with torch.no_grad():
        preds = model(X_t).cpu().numpy().flatten()
    ss_res = np.sum((y_test - preds)**2)
    ss_tot = np.sum((y_test - np.mean(y_test))**2)
    r2 = 1 - (ss_res / (ss_tot + 1e-9))
    return r2, preds

#####################################################################
# 3) Train scikit model
#####################################################################
def train_sklearn_model(X_train, y_train, model_name="rf"):
    if model_name=="rf":
        model = RandomForestRegressor(n_estimators=50, random_state=42)
    elif model_name=="linear":
        model = LinearRegression()
    elif model_name=="svm":
        model = SVR(kernel="linear")
    elif model_name=="svm_rbf":
        model = SVR(kernel="rbf")
    else:
        raise ValueError(f"Unknown model_name={model_name}")
    model.fit(X_train, y_train)
    return model

def evaluate_sklearn_model(model, X_test, y_test):
    preds = model.predict(X_test)
    ss_res = np.sum((y_test - preds)**2)
    ss_tot = np.sum((y_test - np.mean(y_test))**2)
    r2 = 1 - (ss_res / (ss_tot + 1e-9))
    return r2, preds

#####################################################################
# 4) Explanation methods
#####################################################################
def get_shap_attributions_sklearn(model, X_train, X_test, max_test_samples=50):
    """
    scikit-based model => SHAP
    """
    n_samples = min(max_test_samples, len(X_test))
    X_test_sub = X_test[:n_samples]

    # If it's a tree-based => TreeExplainer
    if isinstance(model, RandomForestRegressor):
        explainer = shap.TreeExplainer(model)
        shap_vals = explainer.shap_values(X_test_sub)
        return shap_vals, X_test_sub

    # else => KernelExplainer
    X_summary = shap.kmeans(X_train, 10)
    explainer = shap.KernelExplainer(model.predict, X_summary)
    shap_vals = explainer.shap_values(X_test_sub, nsamples=50)
    return shap_vals, X_test_sub

def get_lime_attributions_sklearn(model, X_train, X_test, max_test_samples=50):
    """
    LIME: local linear coefficients => shape (n_sub, n_features)
    """
    n_samples = min(max_test_samples, len(X_test))
    X_test_sub = X_test[:n_samples]
    n_features = X_train.shape[1]

    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=X_train,
        mode='regression',
        discretize_continuous=False
    )
    lime_array = np.zeros((n_samples, n_features))

    for i in range(n_samples):
        exp = explainer.explain_instance(
            X_test_sub[i],
            model.predict,
            num_features=n_features
        )
        exp_map = exp.as_map()
        local_map_items = list(exp_map.values())[0]  # e.g. a list of (feat_idx, weight)
        row_attrib = np.zeros(n_features)
        for (feat_idx, weight) in local_map_items:
            row_attrib[feat_idx] = weight
        lime_array[i] = row_attrib

    return lime_array, X_test_sub

# --- PyTorch-based explanation for MLP only ---
def get_ig_attributions_pytorch(mlp_model, X_data, max_test_samples=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mlp_model.eval().to(device)

    n_samples = min(max_test_samples, len(X_data))
    X_sub = X_data[:n_samples]
    ig_array = np.zeros((n_samples, X_data.shape[1]))

    def forward_fn(x):
        return mlp_model(x).squeeze(-1)

    ig = IntegratedGradients(forward_fn)
    baseline = torch.zeros((1, X_data.shape[1])).float().to(device)

    for i in range(n_samples):
        x_tensor = torch.from_numpy(X_sub[i:i+1]).float().to(device)
        attributions, _ = ig.attribute(
            inputs=x_tensor,
            baselines=baseline,
            n_steps=50,
            return_convergence_delta=True
        )
        ig_array[i] = attributions.detach().cpu().numpy()[0]

    return ig_array, X_sub

def get_occlusion_attributions_pytorch(mlp_model, X_data, max_test_samples=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mlp_model.eval().to(device)

    n_samples = min(max_test_samples, len(X_data))
    X_sub = X_data[:n_samples]
    occ_array = np.zeros((n_samples, X_data.shape[1]))

    def forward_fn(x):
        return mlp_model(x).squeeze(-1)

    oc = Occlusion(forward_fn)
    for i in range(n_samples):
        x_tensor = torch.from_numpy(X_sub[i:i+1]).float().to(device)
        attributions = oc.attribute(
            inputs=x_tensor,
            sliding_window_shapes=(1,),
            strides=(1,),
            baselines=0.0
        )
        occ_array[i] = attributions.detach().cpu().numpy()[0]

    return occ_array, X_sub

#####################################################################
# 5) RBP: Reproduce-by-Poking
#####################################################################
def apply_rbp(
    model,
    X_data,
    baseline_attributions,
    method_str,
    X_train=None,
    is_pytorch=False,
    n_pert=3,
    perturb_scale=0.05,
    lambda_=1.0
):
    """
    Refine the 'baseline_attributions' via forward perturbation checks:
      1) For each sample i, for each feature j:
         - Perturb x_i in feature j several times => get new attributions a_pert^j
         - measure average deviation in that feature's attribution: delta^j
      2) Adjust baseline a^j => a'^j = a^j / (1 + lambda_ * delta^j)

    :param model: trained model
    :param X_data: shape(N, d)
    :param baseline_attributions: shape(N, d)
    :param method_str: "shap", "lime", "ig", or "occlusion"
    :param X_train: for SHAP/LIME if needed to re-run single sample
    :param is_pytorch: whether it's a PyTorch model
    :param n_pert: # of small perturbations per feature
    :param perturb_scale: scale for random perturbations
    :param lambda_: penalty factor
    :return: refined_attributions with same shape
    """
    N, d = X_data.shape
    refined_attribs = baseline_attributions.copy()

    # We'll do this sample by sample
    for i in range(N):
        x_orig = X_data[i].copy()
        base_attr_vec = baseline_attributions[i].copy()

        # We'll accumulate the per-feature delta's
        deltas = np.zeros(d)

        for j in range(d):
            # For each feature j, do multiple small perturbations
            local_dev_sum = 0.0

            for _ in range(n_pert):
                x_pert = x_orig.copy()
                x_pert[j] += np.random.normal(0, perturb_scale)

                # Re-compute attributions for the single sample
                a_pert = _local_attribution_for_rbp(
                    model, 
                    x_pert,
                    method_str,
                    X_train,
                    is_pytorch=is_pytorch
                )
                if a_pert is None:
                    # fallback: we can't re-run single-sample => skip
                    # or set local_dev_sum=0
                    continue

                local_dev_sum += abs(a_pert[j] - base_attr_vec[j])

            avg_dev = local_dev_sum / (n_pert + 1e-9)
            deltas[j] = avg_dev

        # Now refine
        # a'^j = a^j / [1 + lambda_ * delta^j]
        refined_attribs[i] = base_attr_vec / (1.0 + lambda_ * deltas)

    return refined_attribs

def _local_attribution_for_rbp(model, x_single, method_str, X_train, is_pytorch=False):
    """
    Recompute the explanation for a single sample x_single.
    Returns a 1D array of attributions of length d.
    This can be expensive in practice.

    For simplicity, we do a quick kernel-based approach for shap/lime,
    or single-sample IG, or single-sample Occlusion.
    If too slow, consider caching or approximations.
    """

    x_single_2d = x_single.reshape(1, -1)
    d = x_single.shape[0]

    if method_str == "shap":
        # Kernel SHAP for single sample
        # We'll do a mini approach:
        import shap
        def predict_fn(x_np):
            if not is_pytorch:
                return model.predict(x_np)
            else:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                x_t = torch.from_numpy(x_np).float().to(device)
                with torch.no_grad():
                    return model(x_t).cpu().numpy().flatten()

        # We need some background. Just use x_single or a small random subset of X_train
        if X_train is not None and len(X_train) > 5:
            X_bg = X_train[:5]
        else:
            X_bg = x_single_2d  # fallback

        explainer = shap.KernelExplainer(predict_fn, X_bg)
        shap_vals = explainer.shap_values(x_single_2d, nsamples=20)
        # shap_vals is (1, d)
        return shap_vals[0]

    elif method_str == "lime":
        import lime
        import lime.lime_tabular
        if X_train is None:
            return None
        explainer = lime.lime_tabular.LimeTabularExplainer(
            training_data=X_train,
            mode='regression',
            discretize_continuous=False
        )

        def predict_fn(x_np):
            if not is_pytorch:
                return model.predict(x_np)
            else:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                x_t = torch.from_numpy(x_np).float().to(device)
                with torch.no_grad():
                    return model(x_t).cpu().numpy().flatten()

        exp = explainer.explain_instance(
            x_single,
            predict_fn,
            num_features=d
        )
        exp_map = exp.as_map()
        local_map_items = list(exp_map.values())[0]
        row_attrib = np.zeros(d)
        for (feat_idx, weight) in local_map_items:
            row_attrib[feat_idx] = weight
        return row_attrib

    elif method_str == "ig":
        # Single-sample IG
        if not is_pytorch:
            return None  # can't do IG for scikit model here
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x_tensor = torch.from_numpy(x_single_2d).float().to(device)

        def forward_fn(x):
            return model(x).squeeze(-1)

        ig = IntegratedGradients(forward_fn)
        baseline = torch.zeros_like(x_tensor)
        attributions, _ = ig.attribute(
            inputs=x_tensor,
            baselines=baseline,
            n_steps=20,
            return_convergence_delta=True
        )
        return attributions.detach().cpu().numpy().reshape(-1)

    elif method_str == "occlusion":
        if not is_pytorch:
            return None
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        x_tensor = torch.from_numpy(x_single_2d).float().to(device)

        def forward_fn(x):
            return model(x).squeeze(-1)

        oc = Occlusion(forward_fn)
        attributions = oc.attribute(
            inputs=x_tensor,
            sliding_window_shapes=(1,),
            strides=(1,),
            baselines=0.0
        )
        return attributions.detach().cpu().numpy().reshape(-1)

    else:
        # unknown method
        return None

#####################################################################
# 6) Inversion Scores
#####################################################################
def _predict_single(model, x_arr, is_pytorch=False):
    if not is_pytorch:
        return model.predict([x_arr])[0]
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        arr_t = torch.from_numpy(x_arr).float().unsqueeze(0).to(device)
        with torch.no_grad():
            pred = model(arr_t).cpu().numpy().flatten()[0]
        return pred

def compute_inversion_scores(
    model,
    X_data,
    base_attributions,
    is_pytorch=False,
    n_pert=1,
    p=2,
    perturb_scale=0.1,
    ground_truth_vec=None
):
    """
    Compute R, F, and Inversion Score (IS) = ((R^p + (1-F)^p)/2)^(1/p).
    (Optional alignment A if ground_truth_vec is provided.)
    """
    N, d = X_data.shape

    # 1) Reliance (R): correlation of delta-a_j with delta-M
    baseline_preds = np.array([_predict_single(model, X_data[i], is_pytorch) 
                               for i in range(N)])
    all_dA, all_dM = [], []

    for i in range(N):
        a_orig = base_attributions[i]
        for j in range(d):
            for _ in range(n_pert):
                x_pert = X_data[i].copy()
                x_pert[j] += np.random.normal(0, perturb_scale)
                pred_pert = _predict_single(model, x_pert, is_pytorch)
                delta_m = pred_pert - baseline_preds[i]

                # Re-run the explanation for just x_pert? (This is expensive.)
                # Let's do an approximation: assume other features’ attributions are unchanged,
                # so delta-a_j ~ 0.  Or do a more proper local re-run if you prefer.
                # For demonstration, we do a naive approach = 0:
                delta_a = 0.0

                all_dA.append(delta_a)
                all_dM.append(delta_m)

    if np.std(all_dA) < 1e-15 or np.std(all_dM) < 1e-15:
        R_val = 0.0
    else:
        R_val, _ = pearsonr(all_dA, all_dM)
    R_val_clamped = max(0, R_val)

    # 2) Faithfulness (F)
    # Like your older approach: we measure average output change from top-2 features
    changes = []
    for i in range(N):
        base_pred = baseline_preds[i]
        abs_attrs = np.abs(base_attributions[i])
        top2 = abs_attrs.argsort()[-2:]
        x_pert = X_data[i].copy()
        for f_idx in top2:
            x_pert[f_idx] += np.random.normal(0, 0.5)
        new_pred = _predict_single(model, x_pert, is_pytorch)
        diff = abs(new_pred - base_pred)
        changes.append(diff)
    if len(changes)==0:
        F_val = 0.0
    else:
        mean_change = np.mean(changes)
        F_val = 1.0 - np.exp(-mean_change)  # saturates in [0,1)

    # 3) Inversion Score
    IS_val = ((R_val_clamped**p + (1 - F_val)**p) / 2.0)**(1.0/p)

    # 4) Optional alignment
    A_val = None
    if ground_truth_vec is not None:
        cos_sims = []
        gt_norm = np.linalg.norm(ground_truth_vec)
        for i in range(N):
            a_vec = base_attributions[i]
            denom = (np.linalg.norm(a_vec) * gt_norm)
            if denom < 1e-15:
                cos_sims.append(0.0)
            else:
                cos_sims.append(np.dot(a_vec, ground_truth_vec) / denom)
        A_val = np.mean(cos_sims)

    results = {
        "R": R_val_clamped,
        "F": F_val,
        "IS": IS_val
    }
    if A_val is not None:
        results["A"] = A_val

    return results

#####################################################################
# 7) Spurious Scenario
#####################################################################
def create_spurious_test(X_test, y_test, spurious_feature=2):
    """
    artificially correlate X_test[:, spurious_feature] with y_test
    e.g. X[:,2] = y_test * 0.8 + noise
    """
    X_sp = X_test.copy()
    for i in range(len(X_sp)):
        X_sp[i, spurious_feature] = y_test[i]*0.8 + np.random.normal(0,1)
    return X_sp

#####################################################################
# 8) Visualization for Tabular Explanations
#####################################################################
def save_tabular_explanations(
    scenario_name,
    model_name,
    expl_method,
    X_used,
    attributions,
    out_dir="./saliency_visuals/tab/",
    num_samples_to_plot=5,
    tag=""
):
    os.makedirs(out_dir, exist_ok=True)
    n_plots = min(num_samples_to_plot, len(X_used))

    for i in range(n_plots):
        plt.figure(figsize=(6,4))
        # Use absolute values for visualization
        vals_to_plot = np.abs(attributions[i])
        plt.bar(range(len(vals_to_plot)), vals_to_plot, color='steelblue')
        plt.xlabel("Feature index")
        plt.ylabel("Absolute Attribution")
        plt.title(f"{scenario_name}-{model_name}-{expl_method}{tag}\nSample {i}")
        save_path = os.path.join(
            out_dir,
            f"{scenario_name}_{model_name}_{expl_method}{tag}_sample{i}.png"
        )
        plt.savefig(save_path, dpi=120, bbox_inches='tight')
        plt.close()

#####################################################################
# 9) Full Experiment + Logging
#####################################################################
def _shap_for_pytorch_mlp(pytorch_model, X_train, X_test, max_test_samples=50):
    import shap
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pytorch_model.eval().to(device)

    def predict_fn(x_np):
        x_t = torch.from_numpy(x_np).float().to(device)
        with torch.no_grad():
            preds = pytorch_model(x_t).cpu().numpy().flatten()
        return preds

    n_samples = min(max_test_samples, len(X_test))
    X_test_sub = X_test[:n_samples]

    X_summary = shap.kmeans(X_train, 10)
    explainer = shap.KernelExplainer(predict_fn, X_summary)
    shap_values = explainer.shap_values(X_test_sub, nsamples=50)
    return shap_values, X_test_sub

def _lime_for_pytorch_mlp(pytorch_model, X_train, X_test, max_test_samples=50):
    import lime
    import lime.lime_tabular
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    pytorch_model.eval().to(device)

    def predict_fn(x_np):
        x_t = torch.from_numpy(x_np).float().to(device)
        with torch.no_grad():
            preds = pytorch_model(x_t).cpu().numpy().flatten()
        return preds

    n_samples = min(max_test_samples, len(X_test))
    X_test_sub = X_test[:n_samples]
    n_features = X_train.shape[1]

    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=X_train,
        mode='regression',
        discretize_continuous=False
    )
    lime_array = np.zeros((n_samples, n_features))

    for i in range(n_samples):
        exp = explainer.explain_instance(
            X_test_sub[i],
            predict_fn,
            num_features=n_features
        )
        exp_map = exp.as_map()
        local_map_items = list(exp_map.values())[0]
        row_attrib = np.zeros(n_features)
        for (feat_idx, weight) in local_map_items:
            row_attrib[feat_idx] = weight
        lime_array[i] = row_attrib

    return lime_array, X_test_sub


def run_tabular_experiment_4methods(log_filename="tabular_4methods_spurious_log.txt"):
    with open(log_filename, "w") as f:
        f.write("Tabular Experiment (RBP + Baseline): SHAP, LIME, IG, Occlusion\n")
        f.write("Using new Inversion Score, Nonlinear data, and Spurious scenario.\n")
        f.write("="*80 + "\n")

    data_scenarios = ["linear","nonlinear"]
    model_names= ["rf","linear","svm","svm_rbf","mlp_torch"]
    explanation_methods= ["shap","lime","ig","occlusion"]

    for data_scenario in data_scenarios:
        # generate data
        if data_scenario=="linear":
            X,y, _= generate_linear_data(n_samples=300, n_features=6, noise=5.0)
        else:
            X,y, _= generate_nonlinear_data(n_samples=300)

        X_train, X_test, y_train, y_test= train_test_split(X,y, test_size=0.2, random_state=42)
        scaler= StandardScaler()
        X_train= scaler.fit_transform(X_train)
        X_test= scaler.transform(X_test)

        for model_name in model_names:
            is_pytorch_model = (model_name=="mlp_torch")

            if is_pytorch_model:
                model = train_pytorch_mlp(X_train, y_train, n_features=X_train.shape[1], epochs=50, lr=1e-3)
                train_r2,_ = evaluate_pytorch_mlp(model, X_train, y_train)
                test_r2,_  = evaluate_pytorch_mlp(model, X_test, y_test)
            else:
                model = train_sklearn_model(X_train, y_train, model_name)
                train_r2,_ = evaluate_sklearn_model(model, X_train, y_train)
                test_r2,_  = evaluate_sklearn_model(model, X_test, y_test)

            for expl_method in explanation_methods:
                # skip if scikit model but expl_method is "ig"/"occlusion"
                if (not is_pytorch_model) and expl_method in ["ig","occlusion"]:
                    continue

                start_t = time.time()
                # 1) get baseline attributions
                if expl_method=="shap":
                    if is_pytorch_model:
                        shap_vals, X_used = _shap_for_pytorch_mlp(model, X_train, X_test)
                        base_attribs = shap_vals
                    else:
                        shap_vals, X_used= get_shap_attributions_sklearn(model, X_train, X_test)
                        base_attribs= shap_vals
                elif expl_method=="lime":
                    if is_pytorch_model:
                        lime_vals, X_used= _lime_for_pytorch_mlp(model, X_train, X_test)
                        base_attribs= lime_vals
                    else:
                        lime_vals, X_used= get_lime_attributions_sklearn(model, X_train, X_test)
                        base_attribs= lime_vals
                elif expl_method=="ig":
                    ig_vals, X_used= get_ig_attributions_pytorch(model, X_test)
                    base_attribs= ig_vals
                else: # occlusion
                    occ_vals, X_used= get_occlusion_attributions_pytorch(model, X_test)
                    base_attribs= occ_vals

                # 2) Visualization for baseline (absolute values)
                save_tabular_explanations(
                    data_scenario,
                    model_name,
                    expl_method,
                    X_used,
                    base_attribs,
                    tag="_baseline"
                )

                # 3) Compute baseline inversion scores
                iq_res_base = compute_inversion_scores(
                    model, 
                    X_used, 
                    base_attribs,
                    is_pytorch=is_pytorch_model,
                    n_pert=1,
                    p=2,
                    perturb_scale=0.1
                )
                elapsed = time.time() - start_t

                # 4) Apply RBP to get refined attributions
                #    Because we pass method_str, we can re-run the local method for each sample.
                start_rbp = time.time()
                rbp_attribs = apply_rbp(
                    model,
                    X_used,
                    base_attribs,
                    method_str=expl_method,
                    X_train=X_train,
                    is_pytorch=is_pytorch_model,
                    n_pert=3,
                    perturb_scale=0.05,
                    lambda_=1.0
                )
                rbp_elapsed = time.time() - start_rbp

                # 5) Visualization for RBP
                save_tabular_explanations(
                    data_scenario,
                    model_name,
                    expl_method,
                    X_used,
                    rbp_attribs,
                    tag="_RBP"
                )

                # 6) Inversion Score for RBP
                iq_res_rbp = compute_inversion_scores(
                    model, 
                    X_used, 
                    rbp_attribs,
                    is_pytorch=is_pytorch_model,
                    n_pert=1,
                    p=2,
                    perturb_scale=0.1
                )

                # Log baseline vs RBP
                with open(log_filename,"a") as f:
                    f.write(f"[{data_scenario}] Model={model_name}, Expl={expl_method}\n")
                    f.write(f"TrainR2={train_r2:.3f}, TestR2={test_r2:.3f}, #Samples={len(X_used)}\n")
                    f.write(f"Baseline => R={iq_res_base['R']:.3f}, F={iq_res_base['F']:.3f}, IS={iq_res_base['IS']:.3f}, Time={elapsed:.2f}s\n")
                    f.write(f"RBP      => R={iq_res_rbp['R']:.3f}, F={iq_res_rbp['F']:.3f}, IS={iq_res_rbp['IS']:.3f}, RBPTime={rbp_elapsed:.2f}s\n")

                # 7) Spurious scenario
                X_test_sp = create_spurious_test(X_test, y_test, spurious_feature=2)
                if expl_method in ["shap","lime"]:
                    if is_pytorch_model and expl_method=="shap":
                        sp_vals, X_used_sp = _shap_for_pytorch_mlp(model, X_train, X_test_sp)
                    elif is_pytorch_model and expl_method=="lime":
                        sp_vals, X_used_sp = _lime_for_pytorch_mlp(model, X_train, X_test_sp)
                    elif not is_pytorch_model and expl_method=="shap":
                        sp_vals, X_used_sp = get_shap_attributions_sklearn(model, X_train, X_test_sp)
                    else:
                        sp_vals, X_used_sp = get_lime_attributions_sklearn(model, X_train, X_test_sp)
                    spur_base_attribs = sp_vals
                elif expl_method=="ig":
                    spur_base_attribs, X_used_sp = get_ig_attributions_pytorch(model, X_test_sp)
                else: # occlusion
                    spur_base_attribs, X_used_sp = get_occlusion_attributions_pytorch(model, X_test_sp)

                # RBP on spurious
                sp_rbp_attribs = apply_rbp(
                    model,
                    X_used_sp,
                    spur_base_attribs,
                    method_str=expl_method,
                    X_train=X_train,
                    is_pytorch=is_pytorch_model,
                    n_pert=3,
                    perturb_scale=0.05,
                    lambda_=1.0
                )

                # Spurious IQ
                iq_sp_base = compute_inversion_scores(model, X_used_sp, spur_base_attribs, is_pytorch=is_pytorch_model)
                iq_sp_rbp  = compute_inversion_scores(model, X_used_sp, sp_rbp_attribs,  is_pytorch=is_pytorch_model)

                # Save visuals
                save_tabular_explanations(
                    data_scenario+"-SPURIOUS",
                    model_name,
                    expl_method,
                    X_used_sp,
                    spur_base_attribs,
                    tag="_BASE"
                )
                save_tabular_explanations(
                    data_scenario+"-SPURIOUS",
                    model_name,
                    expl_method,
                    X_used_sp,
                    sp_rbp_attribs,
                    tag="_RBP"
                )

                with open(log_filename,"a") as f:
                    f.write(f"--- SPURIOUS Baseline => R={iq_sp_base['R']:.3f}, F={iq_sp_base['F']:.3f}, IS={iq_sp_base['IS']:.3f}\n")
                    f.write(f"--- SPURIOUS RBP      => R={iq_sp_rbp['R']:.3f}, F={iq_sp_rbp['F']:.3f}, IS={iq_sp_rbp['IS']:.3f}\n")
                    f.write("="*80 + "\n")

#####################################################################
# Entry
#####################################################################
if __name__=="__main__":
    run_tabular_experiment_4methods()
