import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
import torch
from sklearn.linear_model import RANSACRegressor, LinearRegression
import statsmodels.api as sm
from matplotlib.colors import LinearSegmentedColormap

import utils.const as C
import utils.helpers as UH

SCIENTIFIC_PALETTE = [
    "#0A2F51",
    "#0E4C6D",
    "#137177",
    "#489B8C",
    "#7EB09B",
    "#B4C7B8",
    "#D4BBA8",
    "#E8A87C",
    "#F38D68",
    "#E76F51",
]
# SCIENTIFIC_PALETTE.reverse()
PALETTE_COOL = [
    "#1A1B41",
    "#2E4057",
    "#3D5A80",
    "#5C7A99",
    "#829BB5",
    "#A8BCCE",
    "#C7D3DD",
    "#E0E7EC",
]
PALETTE_COOL.reverse()
cmap = LinearSegmentedColormap.from_list("path", SCIENTIFIC_PALETTE)

def plot_median_heatmap(Ks, name):
    data = UH.get_median_Ks(Ks)
    # get sorted unique coordinates
    xs = sorted({int(k[0]) for k in data})
    ys = sorted({int(k[1]) for k in data})

    # initialize grid
    heatmap = np.zeros((len(xs), len(ys)))

    # fill grid
    for (x, y), v in data.items():
        heatmap[xs.index(x), ys.index(y)] = v

    plt.imshow(heatmap, origin="lower", cmap="viridis",)
    plt.colorbar(label="Value")

    plt.xticks(range(len(ys)), ys)
    plt.yticks(range(len(xs)), xs)

    plt.xlabel("Y")
    plt.ylabel("X")
    plt.title("2D Heatmap")
    encoder_name = C.ENCODER_NAME.split("/")[1]
    plt.savefig(f"plots/{encoder_name}/{name}_heatmaps.png")
    plt.close()
    return

def plot_median_logplot(Ks, name):
    data = UH.get_median_Ks(Ks)
    x = []
    y = []
    c = []

    for (i, j), v in data.items():
        x.append(i)   # first index → x-axis
        y.append(v)   # value → y-axis
        c.append(j)   # second index → color

    plt.figure(figsize=(6, 4))
    sc = plt.scatter(
        x, y,
        c=c,
        cmap="viridis",
        norm=LogNorm(vmin=min(c), vmax=max(c)),
        s=120,
        edgecolors="black"
    )

    plt.xscale("log")
    plt.yscale("log")

    plt.xlabel("First index (x)")
    plt.ylabel("Value")
    plt.title("Log–Log Scatter Colored by Second Index")

    plt.colorbar(sc, label="Second index")

    plt.grid(True, which="both", ls="--", alpha=0.4)
    encoder_name = C.ENCODER_NAME.split("/")[1]
    plt.savefig(f"plots/{encoder_name}/{name}_logplot.png")
    plt.close()

def visualize_score_field(points, doc_embeddings, score_field, name):
    plt.figure(figsize=(7, 7))
    plt.style.use("seaborn-v0_8-darkgrid")
    doc_embeddings = doc_embeddings.cpu().float()
    field = score_field.cpu().float()
    points = points.cpu().float()

    norm = torch.norm(field, dim=1, keepdim=True).clamp(min=1e-8)
    field_norm = field / norm

    plt.scatter(
        doc_embeddings[:, 0], doc_embeddings[:, 1],
        color="#e74c3c",
        s=160,
        label="Document centroids",
        zorder=3,
        marker="*"
    )

    step = max(1, len(points) // 5000)
    qv = plt.quiver(
        points[::step, 0],
        points[::step, 1],
        field_norm[::step, 0],
        field_norm[::step, 1],
        norm[::step, 0].numpy(),
        cmap="viridis",
        scale=30,
        alpha=0.7
    )

    cbar = plt.colorbar(qv, shrink=0.9, pad=0.02)
    cbar.set_label("Force magnitude", fontsize=11, fontweight="bold")
    cbar.ax.tick_params(labelsize=10)

    plt.legend()
    plt.axis("equal")
    plt.tight_layout()
    encoder_name = C.ENCODER_NAME.split("/")[1]
    plt.savefig(f"plots/{name}.png")
    plt.close()


def plot_individ_median_logplot_for_errors(Ks, name):
    data = UH.get_median_Ks(Ks)
 
    # --- 1. PREPARE DATA ---
    x, y, c = [], [], []
    for (i, j), v in data.items():
        x.append(i)
        y.append(v)
        c.append(j)
 
    x_arr = np.array(x, dtype=float)
    y_arr = np.array(y, dtype=float)
    c_arr = np.array(c, dtype=float)
 
    # --- 2. DEFINE MASKS ---
    positive_mask = (x_arr > 0) & (y_arr > 0)
    zero_mask = (y_arr == 0)
 
    if positive_mask.sum() < 3:
        raise ValueError("Not enough positive points for fitting")
 
    x_fit = x_arr[positive_mask]
    y_fit = y_arr[positive_mask]
 
    # --- 3. SATURATION LEVEL ---
    y_inf = np.max(y_fit)
 
    eps = 1e-6
    y_norm = np.clip(y_fit / y_inf, eps, 1 - eps)
 
    # --- 4. LINEARIZED SPACE ---
    log_x = np.log(x_fit).reshape(-1, 1)
    z = np.log(-np.log(1 - y_norm))
 
    # --- 5. RANSAC (OUTLIER DETECTION) ---
    ransac = RANSACRegressor(
        estimator=LinearRegression(),
        min_samples=0.5,
        random_state=42
    )
    ransac.fit(log_x, z)
 
    inlier_mask = ransac.inlier_mask_
    outlier_mask = ~inlier_mask
 
    # --- 6. OLS ON INLIERS ---
    X_in = sm.add_constant(log_x[inlier_mask])
    z_in = z[inlier_mask]
 
    model = sm.OLS(z_in, X_in)
    results = model.fit()
 
    intercept, slope = results.params
    r_squared = results.rsquared
 
    c_param = np.exp(intercept)
    B = slope
 
    ci = np.asarray(results.conf_int(alpha=0.05))
    c_ci = np.exp(ci[0])
    B_ci = ci[1]
 
    print(
        "Saturating stretched-exponential fit\n"
        f"y = y_inf (1 - exp(-c x^B))\n"
        f"y_inf = {y_inf:.4f}\n"
        f"c = {c_param:.4e}  (95% CI [{c_ci[0]:.4e}, {c_ci[1]:.4e}])\n"
        f"B = {B:.4f}  (95% CI [{B_ci[0]:.4f}, {B_ci[1]:.4f}])\n"
        f"R² (linearized, inliers) = {r_squared:.4f}"
    )
 
    # --- 7. CONFIDENCE BAND ---
    x_plot = np.logspace(0, np.log10(x_arr.max()), 400)
    log_x_plot = np.log(x_plot)
    X_plot = sm.add_constant(log_x_plot)
 
    pred = results.get_prediction(X_plot).summary_frame(alpha=0.05)
 
    z_mean = pred["mean"].values
    z_lo = pred["mean_ci_lower"].values
    z_hi = pred["mean_ci_upper"].values
 
    y_mean = y_inf * (1 - np.exp(-np.exp(z_mean)))
    y_lo = y_inf * (1 - np.exp(-np.exp(z_hi)))
    y_hi = y_inf * (1 - np.exp(-np.exp(z_lo)))
 
    # --- 8. PLOT ---
    plt.figure(figsize=(7, 5))
 
    # Inliers
    sc = plt.scatter(
        x_arr[positive_mask][inlier_mask],
        y_arr[positive_mask][inlier_mask],
        c=c_arr[positive_mask][inlier_mask],
        cmap=cmap,
        norm=LogNorm(vmin=c_arr.min(), vmax=c_arr.max()),
        s=120,
        edgecolors="black",
        alpha=0.9,
        label="Inliers"
    )
 
    # Outliers
    if outlier_mask.any():
        plt.scatter(
            x_arr[positive_mask][outlier_mask],
            y_arr[positive_mask][outlier_mask],
            c=c_arr[positive_mask][outlier_mask],
            cmap=cmap,
            norm=LogNorm(vmin=c_arr.min(), vmax=c_arr.max()),
            s=120,
            facecolors="none",
            edgecolors="red",
            linewidth=2,
            label="Outliers"
        )
 
    # Zero point(s)
    if zero_mask.any():
        plt.scatter(
            x_arr[zero_mask],
            y_arr[zero_mask],
            c=c_arr[zero_mask],
            cmap=cmap,
            norm=LogNorm(vmin=c_arr.min(), vmax=c_arr.max()),
            s=140,
            facecolors="none",
            edgecolors="black",
            linewidth=2.5,
            label="Zero regime"
        )
 
    # Fit line
    plt.plot(
        x_plot,
        y_mean,
        "k--",
        color="black",
        lw=2.5,
        label=rf"fit: $y={y_inf:.2f}(1-e^{{-{c_param:.2e}x^{{{B:.2f}}}}})$"
              f"\n$R^2={r_squared:.3f}$"
    )
 
    # CI band
    plt.fill_between(
        x_plot,
        y_lo,
        y_hi,
        color=SCIENTIFIC_PALETTE[0],
        alpha=0.15,
        label="95% CI"
    )
 
    plt.xscale("log")
    plt.xlabel("Documents per Query")
    plt.ylabel("Error")
 
    plt.colorbar(sc, label="Documents per Query")
    plt.legend()
    plt.grid(True, which="both", ls="--", alpha=0.4)
    plt.tight_layout()
 
    plt.savefig(f"plots/individuals/{name}_saturating_ci_logplot.png")
    plt.close()


def plot_individ_median_logplot_with_fit(Ks, name, use_fixed_scaling=True):
    # --- 1. PREPARE DATA ---
    data = UH.get_median_Ks(Ks)
    x = []
    y = []
    c = []
 
    for (i, j), v in data.items():
        x.append(i)
        y.append(v)
        c.append(j)
 
    x_arr = np.array(x, dtype=float)
    y_arr = np.array(y, dtype=float)
 
    # Log-Log Space
    log_x = np.log(x_arr).reshape(-1, 1)
    log_y = np.log(y_arr)
 
 
    # --- 2. DETECT OUTLIERS (RANSAC) ---
    # We use RANSAC solely to get the mask (True/False) for valid data
    ransac = RANSACRegressor(
        estimator=LinearRegression(),
        min_samples=0.5,
        random_state=42
    )
    ransac.fit(log_x, log_y)
    inlier_mask = ransac.inlier_mask_
    outlier_mask = np.logical_not(inlier_mask)
 
    # --- 3. STATISTICAL MODELING (OLS on Inliers) ---
    # Select only the valid data
    X_inliers = log_x[inlier_mask]
    y_inliers = log_y[inlier_mask]
 
    # Add a constant column for the intercept (statsmodels requirement)
    X_inliers_const = sm.add_constant(X_inliers)
 
    # Fit OLS (Ordinary Least Squares) on the clean data
    model = sm.OLS(y_inliers, X_inliers_const)
    results = model.fit()
 
    # Get Parameters
    intercept = results.params[0]
    slope = results.params[1]
    # Calculate a and b
    b = slope
    a = np.exp(intercept)
    r_squared = results.rsquared
 
    print(f"Robust Fit Parameters: a = {a:.4f}, b = {b:.4f}")
    print(f"R-squared (on inliers): {r_squared:.4f}")
 
    # --- 4. CALCULATE CONFIDENCE INTERVALS ---
    # Create a range of X values for the plot
    x_plot = np.geomspace(min(x), max(x), 100)
    log_x_plot = np.log(x_plot)
    # Format for statsmodels prediction (must add constant column)
    X_plot_const = sm.add_constant(log_x_plot)
 
    # Get prediction summary (mean, mean_ci_lower, mean_ci_upper)
    predictions = results.get_prediction(X_plot_const)
    pred_summary = predictions.summary_frame(alpha=0.05) # alpha=0.05 -> 95% CI
 
    # Transform predictions back from Log space to Linear space
    y_fitted = np.exp(pred_summary['mean'])
    ci_lower = np.exp(pred_summary['mean_ci_lower'])
    ci_upper = np.exp(pred_summary['mean_ci_upper'])
    print(ci_lower.mean(), ci_upper.mean())
 
    # --- 5. PLOTTING ---
    plt.figure(figsize=(7, 5))
 
    # Plot Inliers
    sc = plt.scatter(
        x_arr[inlier_mask],
        y_arr[inlier_mask],
        c=np.array(c)[inlier_mask],
        cmap=cmap,
        norm=LogNorm(vmin=min(c), vmax=max(c)),
        s=120,
        edgecolors="black",
        alpha=0.9,
        label="Inliers"
    )
 
    # Plot Outliers
    if np.sum(outlier_mask) > 0:
        plt.scatter(
            x_arr[outlier_mask],
            y_arr[outlier_mask],
            c=np.array(c)[outlier_mask],
            cmap=cmap,
            s=120,
            linewidth=2,
            label="Outliers (Excluded)",
            alpha=0.5,
            edgecolors="red"
        )
 
    # Plot Regression Line
    plt.plot(x_plot, y_fitted, 'k--', linewidth=2, 
             label=f'Fit: $y={a:.2f}x^{{{b:.2f}}}$\n$R^2={r_squared:.3f}$')
 
    # Plot Confidence Interval (Shaded Region)
    plt.fill_between(
        x_plot, 
        ci_lower, 
        ci_upper, 
        color=SCIENTIFIC_PALETTE[0], 
        alpha=0.15, 
        label='95% Confidence Interval'
    )
 
    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Documents per Query")
    plt.ylabel("K-value")
    # Colorbar
    plt.colorbar(sc, label="Documents per Query")
    if use_fixed_scaling:
        ax = plt.gca()
        ax.set_ylim(0.5, 2)
    plt.legend(loc='best')
    plt.grid(True, which="both", ls="--", alpha=0.4)
    plt.tight_layout()
 
    plt.savefig(f"plots/{name}_robust_ci_logplot.png")
    plt.close()

def plot_individ_median_logplot_with_fit_linear(Ks, name, use_fixed_scaling=True):
    
    # --- 1. PREPARE DATA ---
    data = UH.get_median_Ks(Ks)
    x = []
    y = []
    c = []

    for (i, j), v in data.items():
        x.append(i)
        y.append(v)
        c.append(j)

    x_arr = np.array(x, dtype=float)
    y_arr = np.array(y, dtype=float)

    # Log-Log Space
    log_x = np.log(x_arr)
    log_y = np.log(y_arr)

    # --- 2. STANDARD LINEAR REGRESSION (OLS) ---
    X_const = sm.add_constant(log_x)  # Add intercept
    model = sm.OLS(log_y, X_const)
    results = model.fit()

    intercept = results.params[0]
    slope = results.params[1]
    b = slope
    a = np.exp(intercept)
    r_squared = results.rsquared
    print(name)
    print(f"Standard Fit Parameters: a = {a:.4f}, b = {b:.4f}")
    print(f"R-squared (all data): {r_squared:.4f}")

    # --- 3. CALCULATE CONFIDENCE INTERVALS ---
    x_plot = np.geomspace(min(x), max(x), 100)
    log_x_plot = np.log(x_plot)
    X_plot_const = sm.add_constant(log_x_plot)
    predictions = results.get_prediction(X_plot_const)
    pred_summary = predictions.summary_frame(alpha=0.05)

    # Transform predictions back to linear space
    y_fitted = np.exp(pred_summary['mean'])
    ci_lower = np.exp(pred_summary['mean_ci_lower'])
    ci_upper = np.exp(pred_summary['mean_ci_upper'])

    # --- 4. PLOTTING ---
    plt.figure(figsize=(7, 5))

    sc = plt.scatter(
        x_arr,
        y_arr,
        c=np.array(c),
        cmap=cmap,
        norm=LogNorm(vmin=min(c), vmax=max(c)),
        s=120,
        edgecolors="black",
        alpha=0.9,
        label="Data"
    )

    # Plot Regression Line
    plt.plot(x_plot, y_fitted, 'k--', linewidth=2, 
             label=f'Fit: $y={a:.2f}x^{{{b:.2f}}}$\n$R^2={r_squared:.3f}$')

    # Confidence Interval
    plt.fill_between(
        x_plot,
        ci_lower,
        ci_upper,
        color=SCIENTIFIC_PALETTE[0],
        alpha=0.15,
        label='95% Confidence Interval'
    )

    # plt.xscale("log")
    # plt.yscale("log")
    plt.xlabel("Documents per Query")
    plt.ylabel("K-delta")

    plt.colorbar(sc, label="Documents per Query")

    if use_fixed_scaling:
        ax = plt.gca()
        ax.set_ylim(0.6, 1.5)

    plt.legend(loc='best')
    plt.grid(True, which="both", ls="--", alpha=0.4)
    plt.tight_layout()

    plt.savefig(f"plots/{name}_linear_ci_logplot.pdf", format="pdf")
    plt.close()

def plot_individ_median_logplot_with_fit_exponential(Ks, name, use_fixed_scaling=True):
    
    # --- 1. PREPARE DATA ---
    data = UH.get_median_Ks(Ks)
    x = []
    y = []
    c = []

    for (i, j), v in data.items():
        x.append(i)
        y.append(v)
        c.append(j)

    x_arr = np.array(x, dtype=float)
    y_arr = np.array(y, dtype=float)

    # --- 2. EXPONENTIAL REGRESSION (OLS on Semi-Log) ---
    # Model: y = a * e^(bx)
    # Linearized: ln(y) = ln(a) + b * x
    
    # We use linear x_arr here, NOT log_x
    X_const = sm.add_constant(x_arr) 
    
    # We fit against log(y)
    log_y = np.log(y_arr)
    
    model = sm.OLS(log_y, X_const)
    results = model.fit()

    intercept = results.params[0]
    slope = results.params[1]
    
    # Convert parameters back to exponential form
    b = slope
    a = np.exp(intercept)
    r_squared = results.rsquared
    
    print(name)
    print(f"Exponential Fit Parameters: a = {a:.4f}, b = {b:.4f}")
    print(f"R-squared: {r_squared:.4f}")

    # --- 3. CALCULATE CONFIDENCE INTERVALS ---
    # We still use geomspace so the points are spaced evenly on your log-axis plot
    x_plot = np.geomspace(min(x), max(x), 100)
    
    # Predict using linear x_plot (because our model was trained on linear x)
    X_plot_const = sm.add_constant(x_plot)
    predictions = results.get_prediction(X_plot_const)
    pred_summary = predictions.summary_frame(alpha=0.05)

    # Transform predictions from log(y) space back to linear y space
    y_fitted = np.exp(pred_summary['mean'])
    ci_lower = np.exp(pred_summary['mean_ci_lower'])
    ci_upper = np.exp(pred_summary['mean_ci_upper'])

    # --- 4. PLOTTING ---
    plt.figure(figsize=(7, 5))

    # Scatter
    sc = plt.scatter(
        x_arr,
        y_arr,
        c=np.array(c),
        cmap=cmap, # Ensure this matches your original cmap variable
        norm=LogNorm(vmin=min(c), vmax=max(c)),
        s=120,
        edgecolors="black",
        alpha=0.9,
        label="Data"
    )

    # Plot Regression Line (Red)
    plt.plot(x_plot, y_fitted, 'k--', linewidth=2, color=SCIENTIFIC_PALETTE[0],
             label=f'Fit: $y={a:.2f}e^{{{b:.4f}x}}$\n$R^2={r_squared:.3f}$')

    # Confidence Interval
    plt.fill_between(
        x_plot,
        ci_lower,
        ci_upper,
        color=SCIENTIFIC_PALETTE[0],
        alpha=0.2,
        label='95% Confidence Interval'
    )

    # Keep the axis logarithmic for display
    plt.xscale("log")
    plt.yscale("log") 
    
    plt.xlabel("Documents per Query")
    plt.ylabel("K-delta")
    plt.colorbar(sc, label="Documents per Query")

    if use_fixed_scaling:
        ax = plt.gca()
        # Adjust limits to fit the data, slightly padded
        ax.set_ylim(0.6, 1.5)

    plt.legend(loc='best')
    plt.grid(True, which="both", ls="--", alpha=0.4)
    plt.tight_layout()

    plt.savefig(f"plots/{name}_exponential_ci_logplot.pdf", format="pdf")
    plt.close()
def bootstrap_median_ci(values, n_boot=1000, alpha=0.05):
    if len(values) == 0:
        return np.nan, np.nan, np.nan
    boot_medians = [np.median(np.random.choice(values, len(values), replace=True)) for _ in range(n_boot)]
    median = np.median(values)
    ci_low = np.percentile(boot_medians, 100 * alpha / 2)
    ci_high = np.percentile(boot_medians, 100 * (1 - alpha / 2))
    return median, ci_low, ci_high

def plot_raw_logplot_with_fit(raw_bin_data, name, use_fixed_scaling=True):
    # raw_bin_data: dict of (bin_key, bin_key): list of raw K_h values across all data

    # Get sorted bin keys (diagonal only)
    bin_keys = sorted(set(k[0] for k in raw_bin_data if k[0] == k[1] and len(raw_bin_data[k]) > 0))

    if not bin_keys:
        print("No data available for plotting.")
        return

    # Identify first bin and its median (for normalization to deltas)
    first_bin = bin_keys[0]
    first_values = raw_bin_data[(first_bin, first_bin)]
    median_0, _, _ = bootstrap_median_ci(first_values)
    if median_0 == 0:
        print("Warning: Median of first bin is zero; using additive deltas instead.")
        median_0 = 1.0  # Fallback to additive if division by zero

    # Compute delta distributions, medians, and CIs per bin
    delta_bin_data = []
    medians = []
    ci_lows = []
    ci_highs = []
    bin_sizes = []
    for bk in bin_keys:
        values = raw_bin_data[(bk, bk)]
        delta_values = [v / median_0 for v in values]  # Relative delta: ∆K = K / K_0
        delta_bin_data.append(delta_values)
        med, low, high = bootstrap_median_ci(delta_values)
        medians.append(med)
        ci_lows.append(low)
        ci_highs.append(high)
        bin_sizes.append(len(delta_values))

    x_arr = np.array(bin_keys, dtype=float)
    y_arr = np.array(medians, dtype=float)
    ci_low_arr = np.array(ci_lows)
    ci_high_arr = np.array(ci_highs)
    bin_sizes_arr = np.array(bin_sizes)

    # Filter non-nan
    valid_mask = ~np.isnan(y_arr)
    x_arr = x_arr[valid_mask]
    y_arr = y_arr[valid_mask]
    ci_low_arr = ci_low_arr[valid_mask]
    ci_high_arr = ci_high_arr[valid_mask]
    bin_sizes_arr = bin_sizes_arr[valid_mask]
    delta_bin_data = [d for i, d in enumerate(delta_bin_data) if valid_mask[i]]

    if len(x_arr) < 2:
        print("Not enough valid data for fitting.")
        return

    # Log-Log Space for fitting
    log_x = np.log(x_arr).reshape(-1, 1)
    log_y = np.log(y_arr)

    # RANSAC for outlier detection on medians
    ransac = RANSACRegressor(
        estimator=LinearRegression(),
        min_samples=0.5,
        random_state=42
    )
    ransac.fit(log_x, log_y)
    inlier_mask = ransac.inlier_mask_
    outlier_mask = ~inlier_mask

    # Weighted OLS (WLS) on inliers, weighted by bin sizes
    X_inliers = log_x[inlier_mask]
    y_inliers = log_y[inlier_mask]
    weights = bin_sizes_arr[inlier_mask]
    X_inliers_const = sm.add_constant(X_inliers)
    model = sm.WLS(y_inliers, X_inliers_const, weights=weights)
    results = model.fit()

    # Get parameters
    intercept = results.params[0]
    slope = results.params[1]
    a = np.exp(intercept)
    b = slope
    r_squared = results.rsquared

    print(f"Robust Fit Parameters: a = {a:.4f}, b = {b:.4f}")
    print(f"R-squared (on inliers): {r_squared:.4f}")

    # Predictions and CI for fit line
    x_plot = np.geomspace(min(x_arr), max(x_arr), 100)
    log_x_plot = np.log(x_plot)
    X_plot_const = sm.add_constant(log_x_plot)
    predictions = results.get_prediction(X_plot_const)
    pred_summary = predictions.summary_frame(alpha=0.05)
    y_fitted = np.exp(pred_summary['mean'])
    ci_lower = np.exp(pred_summary['mean_ci_lower'])
    ci_upper = np.exp(pred_summary['mean_ci_upper'])

    # Plotting
    plt.figure(figsize=(10, 7))

    # Boxplots for delta distributions per bin
    box_positions = x_arr
    plt.boxplot(delta_bin_data, positions=box_positions, widths=0.15 * box_positions,
                manage_ticks=False, showfliers=True, flierprops={'marker': 'o', 'markersize': 5, 'alpha': 0.5},
                boxprops={'linewidth': 1.5}, medianprops={'color': 'red', 'linewidth': 2})

    # Errorbars for medians with bootstrap CI
    plt.errorbar(x_arr, y_arr, yerr=[y_arr - ci_low_arr, ci_high_arr - y_arr],
                 fmt='none', ecolor='black', capsize=5, elinewidth=1.5, alpha=0.8)

    # Scatter for medians (inliers/outliers)
    plt.scatter(x_arr[inlier_mask], y_arr[inlier_mask],
                s=120, edgecolors="black", alpha=0.9, label="Inlier Medians", c='blue', zorder=3)
    if np.any(outlier_mask):
        plt.scatter(x_arr[outlier_mask], y_arr[outlier_mask],
                    s=120, linewidth=2, alpha=0.5, edgecolors="red", label="Outlier Medians", c='blue', zorder=3)

    # Fit line and shaded CI
    plt.plot(x_plot, y_fitted, 'k--', linewidth=2,
             label=f'Fit: $y={a:.2f}x^{{{b:.2f}}}$\n$R^2={r_squared:.3f}$')
    plt.fill_between(x_plot, ci_lower, ci_upper, color=SCIENTIFIC_PALETTE[0], alpha=0.15,
                     label='95% Confidence Interval on Fit')

    plt.xscale("log")
    plt.yscale("log")
    plt.xlabel("Documents per Query (log_2 binned)")
    plt.ylabel("ΔK (Relative Sensitivity)")
    if use_fixed_scaling:
        plt.ylim(0.93, 1.07)
    plt.legend(loc='best')
    plt.grid(True, which="both", ls="--", alpha=0.4)
    plt.tight_layout()

    plt.savefig(f"plots/{name}_robust_ci_logplot_with_distributions.png")
    plt.close()