
import wandb
from tqdm import tqdm
from datetime import datetime, timedelta

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--beta", type=float)
    parser.add_argument("--policy_name", type=str)
    parser.add_argument("--seed", type=int)
    parser.add_argument("--num_generation_idxs", type=int, default=200)
    args = parser.parse_args()

    api = wandb.Api()
    beta = args.beta
    policy_name = args.policy_name

    runs = api.runs("anonymous/llm-exploration", filters={
        "config.coreset.elliptical.beta": beta,
        "config.policy.name": policy_name,
        "config.seed": args.seed,
        "tags": {"$in": ["gpt-4o-mini-hard"]},
    })

    remaining_generation_idxs = set(list(range(args.num_generation_idxs)))
    for i, run in tqdm(enumerate(runs), desc="Crawling wandb for existing runs ..."):
        generation_idx = run.config['task']['generation']['generation_idx']
        if generation_idx in remaining_generation_idxs:
            remaining_generation_idxs.remove(generation_idx)

    # print to allow shell script to save
    print(",".join(map(str, remaining_generation_idxs)))

    
