import numpy as np
from matplotlib import pyplot as plt
import json
import os

# --- Configuration ---
eval_ons = {
    "imagenet1k": "ImageNet-1K",
    "food101": "Food101",
    "cifar100": "CIFAR100",
    "objectnet": "ObjectNet",
    "imagenetr": "ImageNet-R",
    "imagenets": "ImageNet-Sketch",
}

ds_names = {
    "cc": "Conceptual Captions",
    "datacomp": "DataComp",
    "laion": "LAION",
}

seeds = range(5)

method_names = {
    "repvlm": "REPVLM",
    "probvlm": "ProbVLM",
    "mcdo": "MC-Dropout",
}

nrows = len(ds_names)
ncols = len(eval_ons)

fig, axes = plt.subplots(nrows, ncols, figsize=(3 * ncols, 2.5 * nrows), sharex=True, sharey='col')

# Convert dicts to lists for indexing
proxy_items = list(ds_names.items())
eval_items = list(eval_ons.items())

# --- Main Loop ---
for row_idx, (proxy_key, proxy_name) in enumerate(proxy_items):
    for col_idx, (eval_key, eval_name) in enumerate(eval_items):
        
        ax = axes[row_idx][col_idx]
        
        # 1. Collect Data
        values = {}
        for method_key, method_label in method_names.items():
            i2t_accs = []
            for seed in seeds:
                path = f"results/cls/{eval_key}/{proxy_key}/{method_key}/{seed}.json"
                if not os.path.exists(path):
                    continue
                try:
                    accs = json.load(open(path))
                    i2t_accs.append(accs["i2t"])
                except:
                    pass

            if not i2t_accs: continue

            i2t_accs = np.array(i2t_accs)
            values[method_key] = {
                "mean": i2t_accs.mean(axis=0) * 100,  # Convert to percentage
                "std": i2t_accs.std(axis=0) * 100,    # Convert to percentage
            }

        # 2. Plot Lines
        for method_key, method_label in method_names.items():
            if method_key not in values: continue
            data = values[method_key]
            
            # X-axis: Rejected Fraction
            unc_levels = len(data["mean"])
            rej_fracs = np.linspace(0, 90, unc_levels)
            
            ax.plot(
                rej_fracs, data["mean"],
                label=method_label, marker="o", markersize=3
            )
            ax.fill_between(
                rej_fracs,
                data["mean"] - data["std"],
                data["mean"] + data["std"],
                alpha=0.2
            )

        # 3. Formatting
        ax.grid(True, linestyle='--', alpha=0.5)
        ax.set_xlim(0, 90)

        # A. Column Titles (Eval Dataset Name) - Only on Top Row
        if row_idx == 0:
            ax.set_title(eval_name, fontsize="x-large", pad=10)

        # B. Row Labels (Proxy Name) - Add to the right of the last column
        if col_idx == 0:
            ax.text(-0.3, 0.5, proxy_name, transform=ax.transAxes, 
                    rotation=90, va='center', ha='left', fontsize="x-large")

        # C. Y-Axis Label - Only on Left Column
        if col_idx == 0:
            ax.set_ylabel("Accuracy (%)", fontsize="large")

        # D. X-Axis Label - Only on Bottom Row
        if row_idx == nrows - 1:
            ax.set_xlabel("Rejected Fraction (%)", fontsize="large")


# --- Global Legend ---
# Get handles/labels from the very first plot (top-left)
handles, labels = axes[-1][1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.02), ncol=len(method_names), fontsize="large")

plt.tight_layout()
plt.subplots_adjust(top=0.92) # Make space for the legend

# --- Save ---
save_path = "results.pdf"
plt.savefig(save_path, bbox_inches="tight")
print(f"Saved stacked figure to {save_path}")
plt.close()