import os
from pathlib import Path
from typing import Optional, Union


ARTIFACTS_DIR_KEY = "ARTIFACTS_DIR"
_artifacts_dir: Optional[Path] = None

def set_artifacts_dir(dir: Union[str, Path]):
    global _artifacts_dir
    _artifacts_dir = Path(dir)

def get_artifacts_dir() -> Path:
    if _artifacts_dir is not None:
        return _artifacts_dir
    else:
        if ARTIFACTS_DIR_KEY in os.environ:
            set_artifacts_dir(os.environ[ARTIFACTS_DIR_KEY])
        if _artifacts_dir is not None:
            return _artifacts_dir
        raise ValueError(
            "Artifacts directory not set."
            f"Either set an environment variable named {ARTIFACTS_DIR_KEY} "
            "to point to the directory where artifacts "
            "(training logs and checkpoints) should be located "
            "or pass it as a Hydra configuration option under "
            "'dirs.artifacts'."
        )



def _append_directories(
    base_dir: Path,
    directory_names: Union[str, list[str]],
) -> Path:
    if isinstance(directory_names, str):
        directory_names = [directory_names]
    if len(directory_names) == 0:
        raise ValueError("The name sequence must contain at least one item.")
    for name in directory_names:
        base_dir /= name
    return base_dir


def get_logging_dir(
    task_name_parts: Union[str, list[str]],
) -> Path:
    # Trim the sub-second part from the timestamp
    # timestamp = datetime.datetime.now().isoformat()[:-7]
    # / timestamp
    return _append_directories(get_artifacts_dir(), task_name_parts) / "logs"


def get_checkpoints_dir(
    task_name_parts: Union[str, list[str]],
) -> Path:
    return _append_directories(get_artifacts_dir(), task_name_parts) / "checkpoints"


def recreate_dir(directory: Path) -> None:
    # if directory.is_dir():
    #     # Delete previous checkpoints if they exist
    #     shutil.rmtree(directory)
    directory.mkdir(parents=True, exist_ok=True)


def get_results_dir(
    task_name_parts: Union[str, list[str]],
) -> Path:
    return _append_directories(get_artifacts_dir(), task_name_parts) / "results"
