import wandb
import matplotlib.pyplot as plt

def get_loss_curve(project_name, run_id, field_name='training_loss'):
    """
    Collects the loss curve values for a W&B run given the project name, run ID, and field name.

    Args:
        project_name (str): The name of the W&B project.
        run_id (str): The ID of the W&B run.
        field_name (str): The name of the field to extract values for. Default is 'training_loss'.

    Returns:
        list: A list of values for the specified field.
    """
    # Initialize the W&B API
    api = wandb.Api()

    # Get the run object from the W&B API
    run = api.run(f"{project_name}/{run_id}")

    # Check if the run has the specified field in its history
    if field_name not in run.history():
        raise ValueError(f"Field '{field_name}' not found in run history.")

    # Extract values for the specified field
    values = run.history()[field_name].dropna().tolist()

    return values

wandb_run_ids = {
    "baseline": "zs3583o5",
    "topo-scale-5": "3jb9d7pe",
    "topo-scale-10": "xuln36m1",
    "topo-scale-50": "ru6zaviq"
}

losses = {}
for label, run_id in wandb_run_ids.items():
    print(f"Running wandb api request for run: {label} id:{run_id}")
    losses[label] = get_loss_curve(
        project_name="neurips-nesim-gpt-neo-125m-wikipedia",
        run_id=run_id,
        field_name="training_loss"
    )

# Extract final losses
final_losses = [losses[label][-1] for label in wandb_run_ids]

# Labels and colors for the bar chart
labels = list(wandb_run_ids.keys())
colors = ['orange', '#003f5c', '#005f73', '#0082c9']

# Plotting
fig = plt.figure()
bars = plt.bar(labels, final_losses, color=colors)
# plt.xlabel('')
# plt.ylabel('Loss')
# plt.title('Final Loss for Different Runs')
plt.grid(True, axis='y', linestyle='--', alpha=0.7)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width() / 2, height, f'{height:.4f}', 
             ha='center', va='bottom')

# Save the plot
fig.savefig("assets/losses.png", dpi=300)