from collections import defaultdict
from pprint import pprint

import wandb

GROUPBY_CONFIG_KEY = "memory_type"

WANDB_PROJECT = "ENTER WANDB PROJECT HERE"
WANDB_ENTITY = "ENTER WANDB USERNAME HERE"

WANDB_GROUP = "ENTER WANDB GROUP HERE"


def main():
    api = wandb.Api()
    runs = api.runs(f'{WANDB_ENTITY}/{WANDB_PROJECT}', filters={
        "group": WANDB_GROUP,
        # "display_name": {"$in": RUN_NAMES}
        "config.memory_hidden_size": {"$ne": 3},
        "config.learning_rate": 0.001
    })

    print(f"{len(runs)} runs found.")

    groups: defaultdict[str, list[wandb.Run]] = defaultdict(list)
    for run in runs:
        groups[run.config[GROUPBY_CONFIG_KEY]].append(run)

    min_ep = float("inf")
    for metric, run_list in groups.items():
        print(f"    # {metric} ({len(run_list)})")
        for run in run_list:
            print(f"    \"{run.name}\",        # {run.summary_metrics['episode']}   {run.summary_metrics['validation/len300/reward_mean']['max']}")
            min_ep = min(min_ep, run.summary_metrics["episode"])

    print(f"{min_ep}")


if __name__ == "__main__":
    main()
