from omegaconf import DictConfig
import wandb
from wandb.sdk.wandb_run import Run
from dotenv import load_dotenv
import os
import flatten_dict
from typing import Optional
from utils.paths import ROOT, WANDB_PATH


def init(config: DictConfig, dir: Optional[str] = None, **kwargs) -> Run:
    """Initialize a Weights & Biases run.

    Args:
        config: Experiment configuration
        dir: Optional directory to save the wandb run, defaults to `WANDB_PAGH`.
    """
    # Load environment variables from `ROOT/.env`
    load_dotenv(dotenv_path=os.path.join(ROOT, ".env"))

    # Check if required environment variables are set
    if "WANDB_API_KEY" not in os.environ:
        raise ValueError("Please add `WANDB_API_KEY` to the file `.env`")

    if "project" not in kwargs:
        raise ValueError("Please set the project name to the file `.env`")

    # Make directory to save the wandb run
    dir = dir or WANDB_PATH
    os.makedirs(dir, exist_ok=True)

    # Start a new wandb run
    config = flatten_dict.flatten(config, reducer="path")

    run = wandb.init(config=config, dir=dir, **kwargs)

    return run


def save_artifact(run: Run, local_path: str, name: str, type: str, log: callable = print):
    try:
        artifact = wandb.Artifact(name=name, type=type)
        artifact.add_file(local_path)
        run.log_artifact(artifact)
        log(f"Checkpoint logged to wandb: {local_path}")
    except Exception as e:
        log(f"Failed to log checkpoint to wandb: {e}")
