import mlflow
import mlflow.artifacts
from pathlib import Path
from mlflow.entities import Run

from rddleval.artifacts import get_artifact_paths, get_runs_by_query


def download_artifacts(out_dir_name: str, runs: list[Run]):
    out_dir = Path(out_dir_name)
    out_dir.mkdir(exist_ok=True)

    for i, r in runs.iterrows():
        uri = r["artifact_uri"]
        artifacts = mlflow.artifacts.list_artifacts(artifact_uri=uri)
        for a in artifacts:
            if "pth" in a.path:
                out_path = out_dir / Path(a.path).name
                run_id = r["run_id"]
                mlflow.artifacts.download_artifacts(
                    run_id=run_id, artifact_path=out_path, dst_path="gnn500k"
                )


def get_all_paths():
    paths = get_artifact_paths(get_runs_by_query(""))
    import json

    with open("model_paths.json", "w") as f:
        json.dump(paths, f)


if __name__ == "__main__":
    mlflow.set_tracking_uri("http://localhost:5000")
    # search_string = "params.anneal_lr = 'False' and params.remove_false = 'True' and params.total_timesteps =  '500000' and params.ent_coef='0.0'"

    get_all_paths()

    import sys

    batch_id = sys.argv[1]

    runs = get_runs_by_query(f"params.batch_id = '{batch_id}'")
    paths = get_artifact_paths(runs)

    # for p in paths.values():
    # 	print(p)

    # runs = get_runs_by_ids([
    # 					"0f5bd2102aea410887e750312f567812",
    # 					"f42147d3b94a4205be7d631f851053dd",
    # 					"e8aa2172fd1544acac44018985ebb6f3",
    # 					"9fecda1e543643dbb1e1cdf2823ac094",
    # 					"f862a63c7a284a7c917f39e3b0299a78",
    # 					"88fcd8c2eede49bf8ba2a054927bcb7a",
    # 					"ca8bfa8bf96b4a02b547f56023b3f967",
    # 					"da5da7537cb24f77844660ee9709c5d6",
    # 					"108f9c7bc3964a8e90022341f41dae4d",
    # 					])

    # paths = get_artifact_paths(runs)
    # import json
    # with open("checkpoint_model_paths.json", "w") as f:
    # 	json.dump(paths, f)

    for k, p in paths.items():
        print(k, p)
