# %%
import wandb
import pandas as pd
import matplotlib.pyplot as plt

wandb_project = ""
api = wandb.Api()
runs = api.runs(wandb_project)

filter_string = ""

records = []
for run in runs:
    if filter_string not in run.name:
        continue
    config = run.config
    summary = run.summary
    print(
          "ortho_weight:", config.get("ortho_weight"), 
          "seed:", config.get("seed"), 
          "combined accuracy:", summary.get("Top-1 Combined Accuracy"),
          "location accuracy:", summary.get("Top-1 Location Accuracy")
    ) 
    records.append({
        "ortho_weight": config.get("ortho_weight"),
        "seed": config.get("seed"),
        "combined accuracy": summary.get("Top-1 Combined Accuracy"),
        "location accuracy": summary.get("Top-1 Location Accuracy"),
    })

df = pd.DataFrame(records)

# Drop rows where ortho_weight is missing
df = df[df["ortho_weight"].notnull()]

# %%
agg_combined = df.groupby("ortho_weight")["combined accuracy"].agg(["mean", "std"]).reset_index()
agg_location = df.groupby("ortho_weight")["location accuracy"].agg(["mean", "std", "count"]).reset_index()
agg_location["se"] = agg_location["std"] / agg_location["count"]**0.5

# %%
agg_combined

# %%
agg_location

# %%
# Plotting
# plot combined accuracy as a scatter plot
plt.scatter(
    agg_combined["ortho_weight"], agg_combined["mean"], label='Coordinate & Visual Embeddings \n (10 Random Seeds)')
plt.errorbar(agg_combined["ortho_weight"], agg_combined["mean"], yerr=agg_combined["std"], fmt='none', capsize=5, color='steelblue')

# Add horizontal line for visual embedding performance at 0.7
plt.axhline(0.7, xmin=0, xmax=1, color='orange', linestyle='--', 
    label='Visual Embedding')
plt.plot([0], [0.7], 'r.')  # Optional: mark the start at x=0

# Set xscale to symlog to handle zero and positive values
plt.xscale("symlog", linthresh=1e-5)

plt.xlabel("Regularization Parameter")
plt.ylabel("Top-1 Accuracy (mean ± std)")
plt.title("Regularization Parameter vs Predictive Performance")
plt.grid(True, which="both", ls="--", lw=0.5)

# Set xticks to include 0 and all unique ortho_weight values, but avoid overlap
xticks = sorted(set(agg_combined["ortho_weight"].tolist() + [0]))
min_dist = 1e-6
filtered_xticks = [xticks[0]]
for x in xticks[1:]:
    if abs(x - filtered_xticks[-1]) > min_dist:
        filtered_xticks.append(x)
plt.xticks(filtered_xticks, labels=[str(x) for x in filtered_xticks], rotation=30, ha='right')

# Show 0.76 as a tick on y axis and leave some vertical space above it
yticks = plt.yticks()[0].tolist()
if 0.76 not in yticks:
    yticks.append(0.76)
# Exclude 0.69 from y-ticks
yticks = [y for y in yticks if abs(y - 0.69) > 1e-6]
plt.yticks(sorted(yticks))

# Leave vertical space above 0.76
plt.ylim(top=max(0.695, 0.762))

plt.tight_layout()
plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.9))  # Increased horizontal space
plt.show()



# %%
# plot location accuracy as a scatter plot with muted blue
plt.scatter(
    agg_location["ortho_weight"], 
    agg_location["mean"], 
    label='Coordinate & Visual Embeddings \n (10 Random Seeds)',
    color='steelblue'
)
plt.errorbar(
    agg_location["ortho_weight"], 
    agg_location["mean"], 
    yerr=agg_location["se"],  # Use standard error instead of std
    fmt='none', 
    capsize=5,
    color='steelblue'
)

# Set xscale to symlog to handle zero and positive values
plt.xscale("symlog", linthresh=1e-5)

plt.xlabel("Regularization Parameter")
plt.ylabel("Top-1 Location Accuracy (mean ± SE)")
plt.title("Regularization Parameter vs Location Predictive Performance")
plt.grid(True, which="both", ls="--", lw=0.5)

# Set xticks to include 0 and all unique ortho_weight values, but avoid overlap
xticks = sorted(set(agg_location["ortho_weight"].tolist() + [0]))
min_dist = 1e-6
filtered_xticks = [xticks[0]]
for x in xticks[1:]:
    if abs(x - filtered_xticks[-1]) > min_dist:
        filtered_xticks.append(x)
plt.xticks(filtered_xticks, labels=[str(x) for x in filtered_xticks], rotation=30, ha='right')

plt.tight_layout()
plt.legend(loc='center left', bbox_to_anchor=(1.05, 0.9))  # Increased horizontal space
plt.show()

# %%
