import os
import glob
import pickle
import argparse
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import Rectangle, ConnectionPatch
from collections import defaultdict

sns.set_style("ticks")
plt.rcParams["figure.dpi"] = 100

cfg = {
    "out_dir": "figure",
    "n_per_provider": 100,
    "bases_all": [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048],
    "lam_bases_value": [2, 8, 32],
    "inset_anchor": (0.42, 0.41, 0.55, 0.55),
    "datasets": {
        "mnist": {
            "save_dir": "save/mnist",
            "lam_exp_value": "01022",
            "value_pdf": "value_mnist.pdf",
            "sensitivity_pdf": "sensitivity_mnist.pdf",
            "rank_pdf": "rank.pdf",
            "marginal_pdf": "marginal_mnist.pdf",
            "value_box": (-0.025, 0.11),
            "sensitivity_ylim": (-0.07, 0.14),
        },
        "cifar10": {
            "save_dir": "save/cifar10",
            "lam_exp_value": "01022",
            "value_pdf": "value_cifar10.pdf",
            "sensitivity_pdf": "sensitivity_cifar10.pdf",
            "rank_pdf": "rank.pdf",
            "marginal_pdf": "marginal_cifar10.pdf",
            "value_box": (-0.013, 0.052),
            "sensitivity_ylim": (-0.03, 0.07),
        },
    },
}

os.makedirs(cfg["out_dir"], exist_ok=True)

style = {
    "role_order": ["owner", "anchor", "gan", "ddpm", "ddim", "fm", "copier", "poisoner"],
    "model_order": ["gan", "ddpm", "ddim", "fm"],
    "booster_name_map": {
        "booster1": "GAN",
        "booster2": "DDPM",
        "booster3": "DDIM",
        "booster4": "FM",
    },
    "palette_pastel": {
        "owner": "#a1c9f4",
        "anchor": "#ffb482",
        "gan": "#8de5a1",
        "ddpm": "#cfcfcf",
        "ddim": "#ff9f9b",
        "fm": "#debb9b",
        "copier": "#fab0e4",
        "poisoner": "#b3de69",
    },
    "palette_solid": {
        "owner": "#1f77b4",
        "anchor": "#ff7f0e",
        "gan": "#2ca02c",
        "ddpm": "#7f7f7f",
        "ddim": "#d62728",
        "fm": "#8c564b",
        "copier": "#e377c2",
        "poisoner": "#bcbd22",
    },
}

def out_path(ds_cfg, key):
    return os.path.join(cfg["out_dir"], ds_cfg[key])

def is_booster(name):
    return "booster" in str(name).lower()

def booster_tag(name, all_names=None):
    s = str(name).lower()
    if "booster1" in s: return "gan"
    if "booster2" in s: return "ddpm"
    if "booster3" in s: return "ddim"
    if "booster4" in s: return "fm"
    for t in style["model_order"]:
        if t in s:
            return t
    if all_names is None:
        return None
    has_b3 = any(("booster3" in str(x).lower() or "booster_3" in str(x).lower()) for x in all_names)
    has_b4 = any(("booster4" in str(x).lower() or "booster_4" in str(x).lower()) for x in all_names)
    if "booster2" in s:
        return "ddpm" if (has_b3 or has_b4) else "fm"
    return None

def pretty_provider(name):
    s = str(name).lower()
    if s in style["booster_name_map"]:
        return style["booster_name_map"][s]
    if is_booster(s):
        t = booster_tag(s)
        return t.upper() if t else str(name)
    return s.capitalize()

def role_key(name, all_names=None):
    s = str(name).lower()
    if s in {"owner", "anchor", "copier", "poisoner"}:
        return s
    if is_booster(s):
        t = booster_tag(s, all_names=all_names)
        return t if t in style["role_order"] else None
    return None

def color_for(name, palette="pastel", all_names=None):
    pal = style["palette_pastel"] if palette == "pastel" else style["palette_solid"]
    rk = role_key(name, all_names=all_names)
    return pal.get(rk, "#cccccc")

def load_results(save_dir, n_per_provider, methods=None, lam_bases=None, lam_exponents="12345"):
    if methods is not None:
        methods = [m.upper() for m in methods]

    files = []

    def add_pasv_candidates(base):
        files.append(os.path.join(save_dir, f"result_PASV_{n_per_provider}_b{base}_exp{lam_exponents}.pkl"))
        files.append(os.path.join(save_dir, f"result_PASV_{n_per_provider}_{base}.pkl"))

    if methods is None:
        pattern_no_lam = os.path.join(save_dir, f"result_*_{n_per_provider}.pkl")
        pattern_lam_any = os.path.join(save_dir, f"result_*_{n_per_provider}_*.pkl")
        files = sorted(set(glob.glob(pattern_no_lam) + glob.glob(pattern_lam_any)))
    else:
        for m in methods:
            if m == "PASV":
                if lam_bases:
                    for base in lam_bases:
                        add_pasv_candidates(base)
                else:
                    files.extend(glob.glob(os.path.join(save_dir, f"result_{m}_{n_per_provider}_*.pkl")))
            else:
                files.append(os.path.join(save_dir, f"result_{m}_{n_per_provider}.pkl"))

    files = sorted(set(fp for fp in files if os.path.exists(fp)))

    out = []
    for fp in files:
        with open(fp, "rb") as f:
            payload = pickle.load(f)

        fname = os.path.basename(fp)
        stem = os.path.splitext(fname)[0]
        tokens = stem.split("_")

        method = payload.get("method")
        if method is None and len(tokens) >= 2 and tokens[0] == "result":
            method = tokens[1]
        method = method or "UNKNOWN"

        npp = None
        if len(tokens) >= 3 and tokens[0] == "result":
            try:
                npp = int(tokens[2])
            except Exception:
                npp = None

        lam_base = payload.get("lam_base")
        lam_exps = payload.get("lam_exponents")

        if method.upper() == "PASV":
            if lam_base is None and len(tokens) >= 4:
                t = tokens[3]
                try:
                    lam_base = float(t[1:]) if t.startswith("b") else float(t)
                except Exception:
                    lam_base = None
            if lam_exps is None:
                if len(tokens) >= 5 and str(tokens[4]).startswith("exp"):
                    s = str(tokens[4])[3:]
                    lam_exps = [float(x) for x in list(s)] if len(s) else None
                if lam_exps is None:
                    lam_exps = [1.0, 2.0, 3.0, 4.0, 5.0]

        result = {
            "path": fp,
            "method": method,
            "provider_names": payload["provider_names"],
            "group_sums_reps": np.asarray(payload["group_sums_reps"], dtype=float),
            "nreps": int(payload.get("nreps", np.asarray(payload["group_sums_reps"]).shape[0])),
            "n_per_provider": npp if npp is not None else n_per_provider,
            "lam_base": lam_base,
            "lam_exponents": lam_exps,
            "payload": payload,
        }
        # Copy mean_rank_reps to top level if present (for plot_rank compatibility)
        if "mean_rank_reps" in payload:
            result["mean_rank_reps"] = payload["mean_rank_reps"]
        out.append(result)
    return out

def order_value_results(results, lam_bases):
    top = []
    for m in ["SV", "WSV", "PSV"]:
        r = next((x for x in results if str(x.get("method", "")).upper() == m), None)
        if r is not None:
            top.append(r)

    pasv = [x for x in results if str(x.get("method", "")).upper() == "PASV"]
    order_map = {float(b): i for i, b in enumerate(lam_bases)}
    pasv = sorted(
        pasv,
        key=lambda r: order_map.get(float(r.get("lam_base")) if r.get("lam_base") is not None else float("inf"), 999),
    )
    return top + pasv


def plot_value(
    results,
    providers=None,
    figsize=(18, 4),
    ref_providers=None,
    inset_anchor=None,
    box_y_bottom=-0.03,
    box_y_top=0.16,
    inset_ytick_fontsize=7,
    hide_inset_xtick=True,
    lam_bases=None,
    save_path=None,
):
    if ref_providers is None:
        ref_providers = results[0]["provider_names"]

    nrows, ncols = 1, len(results)
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize, sharey=True)

    if isinstance(axes, np.ndarray):
        axes = axes.ravel().tolist()
    else:
        axes = [axes]

    for i, (ax, res) in enumerate(zip(axes, results)):
        provs = res["provider_names"] if providers is None else [p for p in providers if p in res["provider_names"]]

        owners = [p for p in provs if str(p).lower() == "owner"]
        anchors = [p for p in provs if str(p).lower() == "anchor"]
        boosters = [p for p in provs if is_booster(p)]
        copiers = [p for p in provs if "copier" in str(p).lower()]
        poisoners = [p for p in provs if "poisoner" in str(p).lower()]
        others = [p for p in provs if p not in owners + anchors + boosters + copiers + poisoners]

        boosters_sorted = sorted(
            boosters,
            key=lambda p: style["model_order"].index(booster_tag(p, all_names=provs))
            if booster_tag(p, all_names=provs) in style["model_order"]
            else 999,
        )
        provs_sorted = owners + anchors + boosters_sorted + copiers + poisoners + others

        arr = res["group_sums_reps"]
        idx = [res["provider_names"].index(p) for p in provs_sorted]
        vals = arr[:, idx]
        mean = vals.mean(axis=0)
        ci95 = 1.96 * vals.std(axis=0, ddof=1) / np.sqrt(vals.shape[0]) if vals.shape[0] > 1 else np.zeros_like(mean)

        x = np.arange(len(provs_sorted))
        colors = [color_for(p, palette="pastel", all_names=provs_sorted) for p in provs_sorted]
        labels = [pretty_provider(p) for p in provs_sorted]

        ax.bar(x, mean, yerr=ci95, color=colors, edgecolor="black", linewidth=0.2, capsize=3)
        ax.axhline(0, color="black", linewidth=0.6, linestyle="--")
        ax.set_xticks(x)
        ax.set_xticklabels(labels, rotation=45)

        method = str(res.get("method", ""))
        method_u = method.upper()

        if method_u != "SV":
            no_owner_idx = [j for j, p in enumerate(provs_sorted) if str(p).lower() != "owner"]
            if no_owner_idx:
                bar_w = 0.8
                x_left = x[no_owner_idx[0]] - bar_w / 2.0
                x_right = x[no_owner_idx[-1]] + bar_w / 2.0

                rect = Rectangle(
                    (x_left - 0.1, box_y_bottom),
                    (x_right - x_left) + 0.2,
                    (box_y_top - box_y_bottom),
                    fill=False, edgecolor="black", linewidth=0.7
                )
                ax.add_patch(rect)

                if inset_anchor is not None:
                    ax_in = inset_axes(
                        ax, width="100%", height="100%", loc="lower left",
                        bbox_to_anchor=inset_anchor, bbox_transform=ax.transAxes, borderpad=0
                    )
                else:
                    ax_in = inset_axes(ax, width="50%", height="50%", loc="upper right")

                mean_in = mean[no_owner_idx]
                ci_in = ci95[no_owner_idx]
                labels_in = [pretty_provider(provs_sorted[j]) for j in no_owner_idx]
                colors_in = [color_for(provs_sorted[j], palette="pastel", all_names=provs_sorted) for j in no_owner_idx]

                xi = np.arange(len(no_owner_idx))
                ax_in.bar(xi, mean_in, yerr=ci_in, color=colors_in, edgecolor="black", linewidth=0.2, capsize=2)
                ax_in.axhline(0, color="black", linewidth=0.4, linestyle="--")
                ax_in.set_ylim(box_y_bottom, box_y_top)
                ax_in.tick_params(axis="y", labelsize=inset_ytick_fontsize)
                ax_in.grid(False)

                if hide_inset_xtick:
                    ax_in.set_xticks([])
                    ax_in.tick_params(axis="x", which="both", bottom=False, top=False, labelbottom=False)
                else:
                    ax_in.set_xticks(xi)
                    ax_in.set_xticklabels(labels_in, rotation=30, fontsize=8)

                b_left, b_right, b_top = (x_left - 0.1), (x_right + 0.1), box_y_top
                cp1 = ConnectionPatch(
                    xyA=(0.0, 0.0), xyB=(b_left, b_top),
                    coordsA="axes fraction", coordsB="data",
                    axesA=ax_in, axesB=ax,
                    color="black", linewidth=0.8
                )
                cp2 = ConnectionPatch(
                    xyA=(1.0, 0.0), xyB=(b_right, b_top),
                    coordsA="axes fraction", coordsB="data",
                    axesA=ax_in, axesB=ax,
                    color="black", linewidth=0.8
                )
                ax.add_artist(cp1)
                ax.add_artist(cp2)

        title = method
        if method_u == "PASV":
            b = res.get("lam_base")
            if b is not None:
                title = f"PASV ($b={int(float(b))}$)"
        ax.set_title(title)

        if i == 0:
            ax.set_ylabel("Provider-Wise Value")
        else:
            ax.set_ylabel("")
            ax.tick_params(axis="y", which="both", left=False, labelleft=False)

        ax.grid(False)
        ax.set_axisbelow(True)

    for ax in axes[len(results):]:
        ax.axis("off")

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
        plt.savefig(save_path, format="pdf", bbox_inches="tight", pad_inches=0.1)
    plt.show()
    plt.close(fig)

def table_value_grid(
    results,
    providers=None,
    ref_providers=None,
    lam_bases=None,
):
    if ref_providers is None:
        ref_providers = results[0]["provider_names"]
    
    all_providers = None
    all_means = []
    all_cis = []
    column_names = []
    
    for res in results:
        provs = res["provider_names"] if providers is None else [p for p in providers if p in res["provider_names"]]
        
        owners = [p for p in provs if str(p).lower() == "owner"]
        anchors = [p for p in provs if str(p).lower() == "anchor"]
        boosters = [p for p in provs if is_booster(p)]
        copiers = [p for p in provs if "copier" in str(p).lower()]
        poisoners = [p for p in provs if "poisoner" in str(p).lower()]
        others = [p for p in provs if p not in owners + anchors + boosters + copiers + poisoners]
        
        boosters_sorted = sorted(
            boosters,
            key=lambda p: style["model_order"].index(booster_tag(p, all_names=provs))
            if booster_tag(p, all_names=provs) in style["model_order"]
            else 999,
        )
        provs_sorted = owners + anchors + boosters_sorted + copiers + poisoners + others
        
        if all_providers is None:
            all_providers = provs_sorted
        
        arr = res["group_sums_reps"]
        idx = [res["provider_names"].index(p) for p in provs_sorted]
        vals = arr[:, idx]
        mean = vals.mean(axis=0)
        ci95 = 1.96 * vals.std(axis=0, ddof=1) / np.sqrt(vals.shape[0]) if vals.shape[0] > 1 else np.zeros_like(mean)
        
        all_means.append(mean)
        all_cis.append(ci95)
        
        method = str(res.get("method", ""))
        method_u = method.upper()
        if method_u == "PASV":
            b = res.get("lam_base")
            if b is not None:
                column_names.append(f"PASV (b={int(float(b))})")
            else:
                column_names.append(method)
        else:
            column_names.append(method)
    
    data = {}
    for i, col_name in enumerate(column_names):
        col_data = []
        for j, prov in enumerate(all_providers):
            mean_val = all_means[i][j]
            ci_val = all_cis[i][j]
            col_data.append(f"{mean_val:.4f} ± {ci_val:.4f}")
        data[col_name] = col_data
    
    df = pd.DataFrame(data, index=[pretty_provider(p) for p in all_providers])
    return df

def _fetch_finite_result(save_dir, n_per_provider, base_val, exp_str, require_mean_rank=False):
    if base_val == 1:
        r = load_results(save_dir, n_per_provider=n_per_provider, methods=["PSV"])
        if r:
            r[0]["lam_base"] = 1.0
            if require_mean_rank and "mean_rank_reps" not in r[0]:
                return None
            return r[0]
        return None
    r = load_results(
        save_dir,
        n_per_provider=n_per_provider,
        methods=["PASV"],
        lam_bases=[base_val],
        lam_exponents=exp_str,
    )
    if r:
        if require_mean_rank and "mean_rank_reps" not in r[0]:
            return None
        return r[0]
    return None

def _fetch_limit_result(save_dir, n_per_provider, limit_provider, require_mean_rank=False, anchor_fallback_path=None):
    if limit_provider == "anchor":
        if anchor_fallback_path:
            fp = anchor_fallback_path
        else:
            return None
    else:
        fp = os.path.join(save_dir, f"result_PASV_limit_{limit_provider}_{n_per_provider}.pkl")
    
    if not os.path.exists(fp):
        return None
    with open(fp, "rb") as f:
        payload = pickle.load(f)
    
    if require_mean_rank and "mean_rank_reps" not in payload:
        return None
    return payload

def _base_to_exp(b, log_base=2.0):
    b = float(b)
    if b <= 0:
        return np.nan
    return np.log(b) / np.log(float(log_base))

def plot_sensitivity(
    save_dir,
    n_per_provider,
    bases_all,
    ylim=None,
    marker_size=2,
    x_log_base=2,
    figsize=(12, 3.2),
    save_path=None,
):
    scenarios = [
        ("anchor", "01000", "anchor"),
        ("booster", "00100", "booster"),
        ("copier", "00010", "copier"),
        ("poisoner", "00001", "poisoner"),
    ]

    def fetch_finite(base_val, exp_str):
        return _fetch_finite_result(save_dir, n_per_provider, base_val, exp_str, require_mean_rank=False)

    def fetch_limit(limit_provider):
        return _fetch_limit_result(save_dir, n_per_provider, limit_provider, require_mean_rank=False)

    def base_to_exp(b):
        return _base_to_exp(b, log_base=x_log_base)

    fig, axes = plt.subplots(1, 4, figsize=figsize, sharey=True)
    axes = np.array(axes).ravel()

    all_y = []
    legend_map = {}

    for ax, (title_key, exp_str, limit_key) in zip(axes, scenarios):
        provider_order = None
        mean_series = {}
        sd_series = {}
        x_found = []   # exponent positions

        for b in bases_all:
            r = fetch_finite(b, exp_str)
            if r is None:
                continue

            if provider_order is None:
                provider_order = [p for p in r["provider_names"] if str(p).lower() != "owner"]
                for p in provider_order:
                    mean_series[p] = []
                    sd_series[p] = []

            arr = r["group_sums_reps"]
            m = arr.mean(axis=0)
            sd = arr.std(axis=0, ddof=1) if arr.shape[0] > 1 else np.zeros_like(m)

            base_used = float(r.get("lam_base", b))
            x_found.append(base_to_exp(base_used))

            for p in provider_order:
                j = r["provider_names"].index(p)
                mean_series[p].append(float(m[j]))
                sd_series[p].append(float(sd[j]))

        x_found = np.asarray(x_found, dtype=float)
        sort_idx = np.argsort(x_found)
        x_found = x_found[sort_idx]

        limit_payload = fetch_limit(limit_key)
        limit_vals = {}
        limit_sds = {}
        if limit_payload is not None:
            names = limit_payload["provider_names"]
            arr = np.asarray(limit_payload["group_sums_reps"], dtype=float)
            m = arr.mean(axis=0)
            sd = arr.std(axis=0, ddof=1) if arr.shape[0] > 1 else np.zeros_like(m)
            for p in (provider_order or []):
                if p in names:
                    j = names.index(p)
                    limit_vals[p] = float(m[j])
                    limit_sds[p] = float(sd[j])

        for p in (provider_order or []):
            y = np.asarray(mean_series[p], dtype=float)[sort_idx]
            ysd = np.asarray(sd_series[p], dtype=float)[sort_idx]

            c = color_for(p, palette="solid", all_names=provider_order)
            label = pretty_provider(p)
            legend_map[label] = c

            ax.plot(x_found, y, marker="o", markersize=marker_size, linewidth=1.2, color=c)
            ax.fill_between(x_found, y - ysd, y + ysd, color=c, alpha=0.18, linewidth=0)

            all_y.extend(y.tolist())
            all_y.extend((y - ysd).tolist())
            all_y.extend((y + ysd).tolist())

            if p in limit_vals:
                lv = limit_vals[p]
                ls = limit_sds.get(p, 0.0)
                if title_key != "anchor":
                    ax.axhline(lv, color=c, linestyle="--", linewidth=1.2, alpha=0.7)

                    ax.fill_between([-0.1, 11.1], lv - ls, lv + ls, color=c, alpha=0.12, linewidth=0)
                all_y.extend([lv - ls, lv + ls])
                
        title_disp = "Boosters" if title_key == "booster" else title_key.capitalize()
        ax.set_title(f"Limiting {title_disp}", fontsize=12)

        ax.set_xlim(-0.1, 11.1)
        ax.set_xticks(list(range(12)))
        ax.set_xlabel(r"$\log_2(b)$" if float(x_log_base) == 2.0 else rf"$\log_{{{int(x_log_base)}}}(b)$", fontsize=11)

        ax.grid(True, which="both", axis="both", alpha=0.25)

    axes[0].set_ylabel("Provider-Wise Value")
    for ax in axes[1:]:
        ax.set_ylabel("")
        ax.tick_params(axis="y", which="both", left=False, labelleft=False)

    if ylim is None and all_y:
        y0, y1 = min(all_y), max(all_y)
        pad = 0.1 * (y1 - y0) if y1 > y0 else 0.01
        ylim = (y0 - pad, y1 + pad)
    if ylim is not None:
        for ax in axes:
            ax.set_ylim(ylim)

    labels = list(legend_map.keys())
    handles = [
        plt.Line2D([0], [0], color=legend_map[l], linewidth=1.2, marker="o", markersize=marker_size)
        for l in labels
    ]
    fig.legend(
        handles, labels,
        ncol=len(labels) if labels else 1,
        loc="lower center",
        bbox_to_anchor=(0.5, 0.06),
        frameon=False,
        fontsize=11
    )

    plt.tight_layout(rect=[0, 0.12, 1, 1])
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
        plt.savefig(save_path, format="pdf", bbox_inches="tight", pad_inches=0.1)
    plt.show()
    plt.close(fig)



def load_main_experiment_data(save_root, dataset, result_filename="result_PSV_100.pkl"):
    fp = os.path.join(save_root, dataset, result_filename)
    with open(fp, "rb") as f:
        payload = pickle.load(f)
    mc = payload["marginal_contribution"]
    nreps = payload.get("nreps", len(mc["coalition_sizes_reps"]))
    return {
        "coalition_sizes_reps": mc["coalition_sizes_reps"],
        "marginal_contributions_reps": mc["marginal_contributions_reps"],
        "providers_reps": mc["providers_reps"],
        "nreps": nreps,
        "provider_names": payload.get("provider_names", []),
        "dataset": dataset,
    }

def compute_marginal_contrib(data, min_size=20):
    cs_reps = data["coalition_sizes_reps"]
    mc_reps = data["marginal_contributions_reps"]
    pr_reps = data["providers_reps"]
    nreps = data["nreps"]
    names = data["provider_names"]

    provider_order = []
    for p in ["owner", "anchor", "copier", "poisoner"]:
        if p in names:
            provider_order.append(p)
    for p in names:
        pl = str(p).lower()
        if pl not in {"owner", "anchor", "copier", "poisoner"} and (
            is_booster(pl) or any(t in pl for t in ["gan", "ddpm", "ddim", "fm"])
        ):
            provider_order.append(p)

    bucket = defaultdict(lambda: defaultdict(list))

    for r in range(nreps):
        sizes = np.asarray(cs_reps[r])
        contribs = np.asarray(mc_reps[r])
        provs = pr_reps[r]

        tmp = defaultdict(lambda: defaultdict(list))
        for s, v, p in zip(sizes, contribs, provs):
            if s > min_size:
                tmp[p][s].append(v)

        for p in tmp:
            for s in tmp[p]:
                bucket[p][s].append(float(np.mean(tmp[p][s])))

    out = {}
    for p in provider_order:
        if p not in bucket:
            continue
        sizes = sorted(bucket[p].keys())
        xs, ms, ss = [], [], []
        for s in sizes:
            vals = bucket[p][s]
            if not vals:
                continue
            m = float(np.mean(vals))
            sd = float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0
            if np.isfinite(m) and np.isfinite(sd):
                xs.append(s)
                ms.append(m)
                ss.append(sd)
        if xs:
            out[p] = {
                "sizes": np.asarray(xs),
                "means": np.asarray(ms),
                "stds": np.asarray(ss),
            }

    return {"provider_order": provider_order, "provider_mean_data": out, "dataset": data["dataset"]}

def compute_marginal(
    save_root,
    dataset,
    result_filename="result_PSV_100.pkl",
    min_size=20,
    out_pkl_path="marginal_plot_data.pkl",
):
    data = load_main_experiment_data(save_root, dataset, result_filename=result_filename)
    plot_data = compute_marginal_contrib(data, min_size=min_size)

    os.makedirs(os.path.dirname(out_pkl_path) if os.path.dirname(out_pkl_path) else ".", exist_ok=True)
    with open(out_pkl_path, "wb") as f:
        pickle.dump(plot_data, f)


def load_marginal_plot_data(pkl_path):
    with open(pkl_path, "rb") as f:
        return pickle.load(f)


def plot_marginal_contrib(
    plot_data,
    save_path=None,
    figsize=(14, 2.6),
    linewidth=1,
    markersize=1,
    alpha_band=0.3,
    xlabel=r"$s$",
):
    d = plot_data["provider_mean_data"]
    base_order = [p for p in plot_data.get("provider_order", list(d.keys())) if p in d]

    def low(p): return str(p).lower()
    def is_owner(p): return low(p) == "owner"
    def is_anchor(p): return low(p) == "anchor"
    def is_copier(p): return "copier" in low(p)
    def is_poisoner(p): return "poisoner" in low(p)
    def booster_idx(p):
        s = low(p)
        if "booster1" in s: return 1
        if "booster2" in s: return 2
        if "booster3" in s: return 3
        if "booster4" in s: return 4
        return None

    owners = [p for p in base_order if is_owner(p)]
    anchors = [p for p in base_order if is_anchor(p)]
    boosters = [p for p in base_order if booster_idx(p) is not None]
    boosters = sorted(boosters, key=lambda p: booster_idx(p))
    copiers = [p for p in base_order if is_copier(p)]
    poisoners = [p for p in base_order if is_poisoner(p)]

    picked = set(owners + anchors + boosters + copiers + poisoners)
    others = [p for p in base_order if p not in picked]

    provider_order = (owners + anchors + boosters + copiers + poisoners + others)[:8]

    non_owner_vals, owner_vals = [], []
    for p in provider_order:
        m = d[p]["means"]; s = d[p]["stds"]
        lo = m - s; hi = m + s
        if is_owner(p):
            owner_vals.extend(lo.tolist() + hi.tolist())
        else:
            non_owner_vals.extend(lo.tolist() + hi.tolist())

    def ylims(vals):
        if not vals:
            return (-0.1, 0.1)
        v0, v1 = float(np.min(vals)), float(np.max(vals))
        pad = 0.05 * (v1 - v0) if v1 > v0 else 0.01
        return (v0 - pad, v1 + pad)

    y_non = ylims(non_owner_vals)
    y_owner = ylims(owner_vals)

    x_min, x_max = -20, 820
    x_ticks = [0, 200, 400, 600, 800]

    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(1, 9, width_ratios=[1, 0.2, 1, 1, 1, 1, 1, 1, 1], wspace=0.15)

    axes = []
    axes.append(fig.add_subplot(gs[0, 0]))
    axes.append(fig.add_subplot(gs[0, 2]))
    for c in range(3, 9):
        axes.append(fig.add_subplot(gs[0, c]))

    for i, (ax, p) in enumerate(zip(axes, provider_order)):
        ax.tick_params(axis="y", labelsize=8)
        x = d[p]["sizes"]; m = d[p]["means"]; s = d[p]["stds"]
        c = color_for(p, palette="solid", all_names=provider_order)

        ax.plot(x, m, color=c, linewidth=linewidth, marker="o", markersize=markersize)
        ax.fill_between(x, m - s, m + s, color=c, alpha=alpha_band, linewidth=0)
        ax.axhline(0, color="black", linewidth=0.4, linestyle="--")

        ax.set_xlim(x_min, x_max)
        ax.set_xticks(x_ticks)
        ax.set_xticklabels([str(t) for t in x_ticks], fontsize=8, rotation=30) 
        ax.set_ylim(y_owner if is_owner(p) else y_non)

        ax.set_title(pretty_provider(p), fontsize=10)
        ax.grid(True, alpha=0.3)

        if i == 0:
            ax.set_ylabel("Marginal Contribution", fontsize=8)
        elif i >= 2:
            ax.set_yticks([])

        ax.set_xlabel(xlabel, fontsize=10)

    plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
        plt.savefig(save_path, format="pdf", bbox_inches="tight", pad_inches=0.1)
    plt.show()
    plt.close(fig)

def plot_marginal(
    in_pkl_path,
    save_path=None,
    figsize=(10, 4),
    linewidth=1,
    markersize=1,
    alpha_band=0.3,
    xlabel=r"$s$",
):
    plot_data = load_marginal_plot_data(in_pkl_path)
    plot_marginal_contrib(
        plot_data,
        save_path=save_path,
        figsize=figsize,
        linewidth=linewidth,
        markersize=markersize,
        alpha_band=alpha_band,
        xlabel=xlabel,
    )



def plot_rank(
    save_dir,
    n_per_provider,
    bases_all,
    marker_size=2,
    figsize=(12, 3.2),
    wspace=0.18,
    legend_anchor=(0.5, 0.06),
    save_path=None,
):
    scenarios = [
        ("anchor", "01000", "anchor"),
        ("booster", "00100", "booster"),
        ("copier", "00010", "copier"),
        ("poisoner", "00001", "poisoner"),
    ]

    providers_to_plot = [
        ("anchor", ["anchor"], "Anchor"),
        ("booster", ["booster1", "booster2", "booster3", "booster4"], "Boosters"),
        ("copier", ["copier"], "Copier"),
        ("poisoner", ["poisoner"], "Poisoner"),
    ]

    scenario_color = {
        "anchor": style["palette_solid"]["anchor"],
        "booster": style["palette_solid"]["gan"],
        "copier": style["palette_solid"]["copier"],
        "poisoner": style["palette_solid"]["poisoner"],
    }

    def fetch_finite(base_val, exp_str):
        result = _fetch_finite_result(save_dir, n_per_provider, base_val, exp_str, require_mean_rank=True)
        if result is None:
            return None
        # Ensure lam_base is set correctly
        if base_val == 1:
            result["lam_base"] = 1.0
        else:
            result["lam_base"] = float(result.get("lam_base", base_val))
        return result

    def fetch_limit(limit_provider):
        anchor_fallback = os.path.join(save_dir, "result_PASV_100_b1.04858e+06_exp01000.pkl") if limit_provider == "anchor" else None
        return _fetch_limit_result(save_dir, n_per_provider, limit_provider, require_mean_rank=True, anchor_fallback_path=anchor_fallback)

    def base_to_exp(b):
        return _base_to_exp(b, log_base=2.0)

    fig, axes = plt.subplots(1, 4, figsize=figsize, gridspec_kw={"wspace": wspace})

    all_y_anchor = []
    all_y_others = []

    legend_map = {
        "Limiting Anchor": scenario_color["anchor"],
        "Limiting Boosters": scenario_color["booster"],
        "Limiting Copier": scenario_color["copier"],
        "Limiting Poisoner": scenario_color["poisoner"],
    }

    bases_iter = [1] + [b for b in bases_all if b != 1]

    for ax, (prov_key, prov_names, title) in zip(axes, providers_to_plot):
        for scen_name, exp_str, limit_key in scenarios:
            c = scenario_color.get(scen_name, "#cccccc")
            x_found, ranks, sds = [], [], []

            for b in bases_iter:
                r = fetch_finite(b, exp_str)
                if r is None:
                    continue

                lam_base = 1.0 if b == 1 else float(r.get("lam_base", b))
                x = base_to_exp(lam_base)
                if not np.isfinite(x):
                    continue

                names = r["provider_names"]
                prov_lowers = [q.lower() for q in prov_names]
                idxs = [i for i, p in enumerate(names) if str(p).lower() in prov_lowers]
                if not idxs:
                    continue

                arr = np.asarray(r["mean_rank_reps"], dtype=float)
                rr = arr[:, idxs].mean(axis=1)

                x_found.append(float(x))
                ranks.append(float(rr.mean()))
                sds.append(float(rr.std(ddof=1)) if rr.shape[0] > 1 else 0.0)

            if x_found:
                x_found = np.asarray(x_found, dtype=float)
                ranks = np.asarray(ranks, dtype=float)
                sds = np.asarray(sds, dtype=float)

                si = np.argsort(x_found)
                x_found, ranks, sds = x_found[si], ranks[si], sds[si]

                ax.plot(x_found, ranks, marker="o", markersize=marker_size, linewidth=1.2, color=c)
                ax.fill_between(x_found, ranks - sds, ranks + sds, color=c, alpha=0.18, linewidth=0)

                if prov_key == "anchor":
                    all_y_anchor.extend((ranks - sds).tolist() + (ranks + sds).tolist())
                else:
                    all_y_others.extend((ranks - sds).tolist() + (ranks + sds).tolist())

            lim = fetch_limit(limit_key)
            if lim is not None:
                names = lim["provider_names"]
                prov_lowers = [q.lower() for q in prov_names]
                idxs = [i for i, p in enumerate(names) if str(p).lower() in prov_lowers]
                if idxs:
                    arr = np.asarray(lim["mean_rank_reps"], dtype=float)
                    rr = arr[:, idxs].mean(axis=1)
                    lv = float(rr.mean())
                    ls = float(rr.std(ddof=1)) if rr.shape[0] > 1 else 0.0

                    if prov_key != "anchor":
                        ax.axhline(lv, color=c, linestyle="--", linewidth=1.2, alpha=0.7)
                        ax.fill_between([-0.1, 11.1], lv - ls, lv + ls, color=c, alpha=0.12, linewidth=0)

                    if prov_key == "anchor":
                        all_y_anchor.extend([lv - ls, lv + ls])
                    else:
                        all_y_others.extend([lv - ls, lv + ls])

        ax.set_title(title)
        ax.set_xlim(-0.1, 11.1)
        ax.set_xticks(list(range(12)))
        ax.set_xlabel(r"$\log_2(b)$", fontsize=11)
        ax.grid(True, which="both", axis="both", alpha=0.25)

    if all_y_anchor:
        y0, y1 = min(all_y_anchor), max(all_y_anchor)
        pad = 0.1 * (y1 - y0) if y1 > y0 else 0.2
        axes[0].set_ylim(y0 - pad, y1 + pad)

    if all_y_others:
        y0, y1 = min(all_y_others), max(all_y_others)
        pad = 0.1 * (y1 - y0) if y1 > y0 else 0.2
        for ax in axes[1:]:
            ax.set_ylim(y0 - pad, y1 + pad)

    axes[0].set_ylabel("Mean Rank")
    axes[2].tick_params(axis="y", which="both", left=False, labelleft=False)
    axes[3].tick_params(axis="y", which="both", left=False, labelleft=False)

    handles = [
        plt.Line2D([0], [0], color=legend_map["Limiting Anchor"], linewidth=1.2, marker="o", markersize=marker_size),
        plt.Line2D([0], [0], color=legend_map["Limiting Boosters"], linewidth=1.2, marker="o", markersize=marker_size),
        plt.Line2D([0], [0], color=legend_map["Limiting Copier"], linewidth=1.2, marker="o", markersize=marker_size),
        plt.Line2D([0], [0], color=legend_map["Limiting Poisoner"], linewidth=1.2, marker="o", markersize=marker_size),
    ]
    labels = list(legend_map.keys())
    fig.legend(
        handles, labels,
        ncol=4,
        loc="lower center",
        bbox_to_anchor=legend_anchor,
        frameon=False,
        fontsize=11,
    )

    plt.tight_layout(rect=[0, 0.10, 1, 1])

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else ".", exist_ok=True)
        plt.savefig(save_path, format="pdf", bbox_inches="tight", pad_inches=0.1)

    plt.show()
    plt.close(fig)


def process_dataset(dataset):
    if dataset not in cfg["datasets"]:
        raise ValueError(f"Unknown dataset: {dataset}. Available: {list(cfg['datasets'].keys())}")
    
    ds_cfg = cfg["datasets"][dataset]
    save_dir = ds_cfg["save_dir"]
    npp = cfg["n_per_provider"]
    
    print(f"Processing {dataset.upper()}...")
    
    # Plot value
    results = load_results(
        save_dir,
        n_per_provider=npp,
        methods=["SV", "WSV", "PSV", "PASV"],
        lam_bases=cfg["lam_bases_value"],
        lam_exponents=ds_cfg["lam_exp_value"],
    )
    results_plot = order_value_results(results, cfg["lam_bases_value"])

    yb, yt = ds_cfg["value_box"]
    plot_value(
        results_plot,
        figsize=(15, 2.25),
        inset_anchor=cfg["inset_anchor"],
        box_y_bottom=yb,
        box_y_top=yt,
        inset_ytick_fontsize=7,
        hide_inset_xtick=True,
        lam_bases=cfg["lam_bases_value"],
        save_path=out_path(ds_cfg, "value_pdf"),
    )
    plt.close("all")
    print(f"  - Value plot saved: {out_path(ds_cfg, 'value_pdf')}")

    # Plot sensitivity
    plot_sensitivity(
        save_dir=save_dir,
        figsize=(15, 2.25),
        n_per_provider=npp,
        bases_all=cfg["bases_all"],
        ylim=ds_cfg["sensitivity_ylim"],
        marker_size=2,
        save_path=out_path(ds_cfg, "sensitivity_pdf"),
    )
    print(f"  - Sensitivity plot saved: {out_path(ds_cfg, 'sensitivity_pdf')}")

    # Compute and plot marginal contributions
    compute_marginal(
        save_root="save",
        dataset=dataset,
        result_filename="result_PSV_100.pkl",
        min_size=20,
        out_pkl_path=f"save/{dataset}/marginal.pkl",
    )

    plot_marginal(
        in_pkl_path=f"save/{dataset}/marginal.pkl",
        save_path=f"figure/marginal_{dataset}.pdf",
        figsize=(15, 1.2),
        linewidth=1,
        markersize=0.0,
        alpha_band=0.3
    )
    print(f"  - Marginal plot saved: figure/marginal_{dataset}.pdf")

    # Plot rank
    plot_rank(
        save_dir=save_dir,
        n_per_provider=cfg["n_per_provider"],
        bases_all=cfg["bases_all"],
        marker_size=2,
        figsize=(15, 2.25),
        legend_anchor=(0.5, -0.23),
        save_path=out_path(ds_cfg, "rank_pdf"),
    )
    print(f"  - Rank plot saved: {out_path(ds_cfg, 'rank_pdf')}")

    # Print table (reuse results from value plot)
    results_plot = order_value_results(results, cfg["lam_bases_value"])

    table_df = table_value_grid(results_plot, lam_bases=cfg["lam_bases_value"])
    print(f"\n{dataset.upper()} Table:")
    print(table_df)
    print("\n")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["mnist", "cifar10", "all"],
        default="all",
    )
    
    args = parser.parse_args()
    
    if args.dataset == "all":
        datasets = ["mnist", "cifar10"]
    else:
        datasets = [args.dataset]
    
    for dataset in datasets:
        process_dataset(dataset)


if __name__ == "__main__":
    main()