from omegaconf import DictConfig
import wandb
from tqdm import tqdm

import utils

import logging
log = logging.getLogger(__name__)

MAX_STEPS = 300_000

def run(config: DictConfig):
    utils.preprocess_config(config)
    assert config.exp.run_id is not None, "Please provide a run_id to calculate the metric"

    log.info("Connecting to wandb")
    project_name = f"{config.wandb.entity}/{config.wandb.project}"

    api = wandb.Api()
    run_path = f"{project_name}/{config.exp.run_id}"
    run = api.run(run_path)

    series = run.history(
        samples = MAX_STEPS,
        keys = ["predictions/flip_rate"]
    ).dropna()["predictions/flip_rate"]

    log.info(f"Flip rate: {series.mean()}", )

    return 0