import wandb
import pandas as pd
from tqdm import tqdm
import contextlib
import os
import json

# Load movement pattern mapping
movement_df = pd.read_csv("birdsnap-species-movement-patterns.csv", sep="\t")
migratory_species = set(
    movement_df[movement_df["movement_patterns"] == "Full Migrant"]["scientific"]
)

def is_migratory(scientific_name):
    return scientific_name in migratory_species

api = wandb.Api()
runs = api.runs("")

filter_string = ""  # Change this to your filter string

os.makedirs(f"predictions_{filter_string}", exist_ok=True)

for i, run in enumerate(tqdm(runs, desc="Processing runs")):
    if filter_string not in run.name:
        continue

    pred_df = None
    for artifact in run.logged_artifacts():
        # Use only the predictions_sample artifact for each run
        if "predictions_sample" in artifact.name:
            artifact_dir = os.path.join("artifacts", artifact.name.replace(":", "_"))
            json_path = os.path.join(artifact_dir, "predictions_sample.table.json")
            # Try to load from disk first
            if os.path.exists(json_path):
                with open(json_path, "r") as f:
                    data = json.load(f)
            else:
                # Download artifact if not present
                try:
                    with open(os.devnull, "w") as fnull, contextlib.redirect_stdout(fnull), contextlib.redirect_stderr(fnull):
                        pred_dir = artifact.download(root=artifact_dir)
                except Exception:
                    pred_dir = artifact.download(root=artifact_dir)
                json_path = os.path.join(pred_dir, "predictions_sample.table.json")
                if not os.path.exists(json_path):
                    continue
                with open(json_path, "r") as f:
                    data = json.load(f)
            # Convert to DataFrame
            columns = data["columns"]
            df_data = data["data"]
            pred_df = pd.DataFrame(df_data, columns=columns)
            break

    if pred_df is None:
        continue

    # Save pred_df to a CSV file for loading later
    out_dir = os.path.join("artifacts", f"predictions_{filter_string}")
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, f"predictions_{i}.csv")
    pred_df.to_csv(out_path, index=False)
    print(f"Saved predictions to {out_path}")

    # Save run config as JSON
    config_path = os.path.join(out_dir, f"config_{i}.json")
    with open(config_path, "w") as f:
        json.dump(dict(run.config), f, indent=2)
    print(f"Saved config to {config_path}")
