import csv
from collections import defaultdict
from pathlib import Path

import wandb
from wandb.apis.public.runs import Run


def extract_data(run: Run):
    history = defaultdict(list)

    for sample in run.history(
        samples=1000, keys=["_step", "tsp-20-val.opt-gap", "tsp-val.opt-gap"], pandas=False
    ):
        history["step"].append(sample["_step"])
        history["tsp-20"].append(sample["tsp-20-val.opt-gap"])
        history["tsp-100"].append(sample["tsp-val.opt-gap"])

    return history


def save_data(history: dict[str, list[float]], output: Path):
    columns = list(history.keys())
    with open(output / "history.csv", "w", newline="") as csvfile:
        csv.writer(csvfile).writerow(columns)

        rows = [
            [history[c][sample_id] for c in columns]
            for sample_id in range(len(history[columns[0]]))
        ]
        csv.writer(csvfile).writerows(rows)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--run-id", type=str, help="WandB run id")
    parser.add_argument("--output", type=Path, help="Output directory")
    args = parser.parse_args()

    api = wandb.Api()
    run = api.run(f"neuralcombopt/tsp-equivariant/{args.run_id}")
    history = extract_data(run)
    save_data(history, args.output)
