# %%
import math
import matplotlib.pyplot as plt
import wandb
import numpy as np
from tqdm import tqdm
import seaborn as sns

# download results from wandb and plot

# %%
api = wandb.Api()
# Get the runs tagged with "final" from anonymous/inr-classification
project_name = "anonymous/inr-classification"
runs = api.runs(project_name, {"tags": "final"})

# Extract the values of test/best_acc at global_step 63999
results = {}

for run in tqdm(runs):
    df = run.history(x_axis="global_step", keys=["test/best_acc"])
    value = df.loc[df['global_step']==63999]["test/best_acc"].item()
    if run.name not in results:
        results[run.name] = []
    results[run.name].append(value)

for name, result in results.items():
    assert len(result) == len(results['fmnist_rt_drop20_pf8'])
    print(name, np.array(result).mean(), np.array(result).std(ddof=1))


# %%
remap_names = {
    'fmnist_rt_drop20_pf8': ('Relation Transformer', 8),
    'fmnist_rt_drop20_pf32': ('Relation Transformer', 32),
    'fmnist_pna_dhid64_drop20_pf8': ('PNA', 8),
    'fmnist_pna_dhid64_drop20_pf32': ('PNA', 32),
    'fmnist_rt_drop20_pf64': ('Relation Transformer', 64),
    'fmnist_pna_dhid64_drop20_pf64': ('PNA', 64),
    'fmnist_rt_drop20_pf16': ('Relation Transformer', 16),
    'fmnist_pna_dhid64_drop20_pf16': ('PNA', 16),
    'fmnist_pna_dhid64_drop20_pf4': ('PNA', 4),
    'fmnist_pna_dhid64_drop20': ('PNA', 0),
    'fmnist_rt_drop20_4pf': ('Relation Transformer', 4),
    'fmnist_rt_drop20': ('Relation Transformer', 0),
}
results = {remap_names[name]: result for name, result in results.items()}
# %%
def plot_bars(data_dict):
    """
    Plots a bar graph where bars with the same name are grouped together with zero margin.
    Adds a legend for probes.

    Args:
    - data_dict: Dictionary of format {(name, probe): [values]}
    """

    # Prepare data: Extract names, probes, means, and stds
    items = list(data_dict.items())
    items.sort(key=lambda x: (x[0][0], x[0][1]))  # Sort by name and then by probe

    names = [item[0][0] for item in items]
    unique_names = sorted(list(set(names)))
    probes = [item[0][1] for item in items]
    means = [np.mean(item[1]) for item in items]
    stds = [np.std(item[1]) for item in items]

    # Plot setup
    fig, ax = plt.subplots(figsize=(10,6))

    # Calculate bar positions and group centers
    bar_positions = []
    group_centers = []
    count = 0
    width = 1
    for name in unique_names:
        num_items = names.count(name)
        bar_positions.extend([i*width for i in range(count, count + num_items)])
        group_centers.append((2*count*width + num_items*width - width) / 2)  # Center of current group
        count += num_items + 1  # Add space between groups

    # Retrieve a colorblind-friendly palette
    unique_probes = sorted(list(set(probes)))
    palette = sns.color_palette("Blues", len(unique_probes)+2)
    color_mapping = {probe: palette[i+2] for i, probe in enumerate(unique_probes)}

    # Plotting bars
    handled_labels = []  # Keep track of probes already plotted to avoid duplicating legend entries
    for i, ((name, probe), values) in enumerate(items):
        # color_intensity = math.log(probe+4) / math.log(max(probes)+4)  # Use probe value to adjust color intensity
        color = color_mapping[probe]
        label = f"{probe}" if probe not in handled_labels else ""
        ax.bar(bar_positions[i], means[i], width=width, alpha=1, yerr=stds[i], color=color, label=label)
        if probe not in handled_labels:
            handled_labels.append(probe)

    fontsize = 36
    # Axis and labeling adjustments
    # ax.set_xlabel('Names')
    ax.set_ylabel('Accuracy', fontsize=fontsize)
    # ax.set_title('Comparison of Mean Values with Std Deviation')
    ax.set_xticks(group_centers)
    if 'Relation Transformer' in unique_names:
        unique_names[unique_names.index('Relation Transformer')] = 'RT'
    ax.set_xticklabels(unique_names, rotation=0, ha="center", fontsize=fontsize)

    # Aesthetics
    ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=.25)
    ax.tick_params(axis='y', labelsize=fontsize-4)
    ax.set_ylim(0.5)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.legend(title="# Probes", loc="upper left", frameon=False,
              bbox_to_anchor=(1,1), fontsize=fontsize-4, title_fontsize=fontsize,
              fancybox=False, edgecolor=None, facecolor=None, shadow=False)

    plt.tight_layout()
    plt.show()



# %%
{('Relation Transformer', 8): [0.7494999766349792,
  0.7045999765396118,
  0.7454999685287476,
  0.7461999654769897],
 ('Relation Transformer', 32): [0.7547000050544739,
  0.7197999954223633,
  0.7540000081062317,
  0.7450000047683716],
 ('PNA', 8): [0.724399983882904,
  0.7342000007629395,
  0.7319999933242798,
  0.7197999954223633],
 ('PNA', 32): [0.7425000071525574,
  0.7426999807357788,
  0.724399983882904,
  0.7379999756813049],
 ('Relation Transformer', 64): [0.7567999958992004,
  0.7572999596595764,
  0.7448999881744385,
  0.7366999983787537],
 ('PNA', 64): [0.7411999702453613,
  0.7407000064849854,
  0.7383999824523926,
  0.7482999563217163],
 ('Relation Transformer', 16): [0.7422999739646912,
  0.7545999884605408,
  0.751800000667572,
  0.7170000076293945],
 ('PNA', 16): [0.7267000079154968,
  0.7335000038146973,
  0.7319999933242798,
  0.7330999970436096],
 ('PNA', 4): [0.7184000015258789,
  0.717199981212616,
  0.7252999544143677,
  0.7166999578475952],
 ('PNA', 0): [0.6796999573707581,
  0.6832999587059021,
  0.6769999861717224,
  0.6803999543190002],
 ('Relation Transformer', 4): [0.740399956703186,
  0.7329999804496765,
  0.7407999634742737,
  0.7371000051498413],
 ('Relation Transformer', 0): [0.7196999788284302,
  0.7364999651908875,
  0.7251999974250793,
  0.7250999808311462]}
