import wandb
import os
import yaml
from omegaconf import OmegaConf
import omegaconf
from src.utils import setup

def get_run_path(run_id, project, entity):
    return f"{entity}/{project}/{run_id}"

# 1. Access the run using the run path
def get_run(run_path):
    api = wandb.Api()
    run = api.run(run_path)
    return run

def resume_run(cfg):
    assert cfg.general.wandb_id != "" and cfg.general.wandb_id != None, "wandb_id must be set if wandb_resume is True"
    return wandb.init(id=cfg.general.wandb_id, project=cfg.general.project, entity=cfg.general.wandb_team, resume="allow")

def download_config_to(savedir, run):
    # Downloads the run config (shared between processes) and creates the directory structure
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    config_path = os.path.join(savedir, 'config.yaml')
    with open(config_path, 'w') as f:
        yaml.dump(run.config, f)
    return savedir

def load_config_from(savedir):
    config_path = os.path.join(savedir, 'config.yaml')
    config = OmegaConf.load(config_path)
    return config

def download_checkpoint(savedir, run, epoch_num):
    # Download the checkpoint
    artifact_name_prefix = f"eval_epoch{epoch_num}"
    all_artifacts = run.logged_artifacts()
    artifact_name = None
    for a in all_artifacts:
        if a.name.startswith(artifact_name_prefix + ":"):
            artifact_name = a.name
            a.download(root=savedir)
    assert artifact_name is not None, f"Artifact with prefix {artifact_name_prefix} not found for the specified run."

    # Get the name of the downloaded file
    downloaded_file = os.path.join(savedir, "artifacts", artifact_name, artifact_name.split(":")[0] + ".pt")

    return downloaded_file

def save_file(run_id, project, entity, file_path):
    # e.g., for saving the modified config file
    with wandb.init(id=run_id, project=project, entity=entity, resume="allow") as run:
        # Move back to the original directory so that wandb.save works properly
        wandb.save(file_path)

def save_results(run_id, project, entity, results, directory, file_prefix, file_postfix):
    # This saves results in the form of a dictionary, where the keys are the epoch numbers and the values are the results
    with wandb.init(id=run_id, project=project, entity=entity, resume="allow") as run:
        # Move back to the original directory so that wandb.save works properly
        for i, epoch in enumerate(results.keys()):
            wandb.log({epoch:results[epoch]})
            if os.path.exists(os.path.join(directory, file_prefix + f"{epoch}" + file_postfix)): # also save accompanying file if it exists
                wandb.save(os.path.join(directory, file_prefix + f"{epoch}" + file_postfix))