import os
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional
class ClusterType(Enum):
    CW = "cw"
def _guess_cluster_type() -> ClusterType:
    return ClusterType.CW
def get_cluster_type(
    cluster_type: Optional[ClusterType] = None,
) -> Optional[ClusterType]:
    if cluster_type is None:
        return _guess_cluster_type()
    return cluster_type
def get_slurm_account(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
    cluster_type = get_cluster_type(cluster_type)
    if cluster_type is None:
        return None
    return {
        ClusterType.CW: "fair_amaia_cw_explore",
    }[cluster_type]
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
    cluster_type = get_cluster_type(cluster_type)
    if cluster_type is None:
        return None
    CHECKPOINT_DIRNAMES = {
        ClusterType.CW: "",
    }
    return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
def get_user_checkpoint_path(
    cluster_type: Optional[ClusterType] = None,
) -> Optional[Path]:
    checkpoint_path = get_checkpoint_path(cluster_type)
    if checkpoint_path is None:
        return None
    username = os.environ.get("USER")
    assert username is not None
    return checkpoint_path / username
def get_slurm_qos(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
    cluster_type = get_cluster_type(cluster_type)
    if cluster_type is None:
        return None
    return {
        ClusterType.CW: "explore",
    }.get(cluster_type)
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
    cluster_type = get_cluster_type(cluster_type)
    if cluster_type is None:
        return None
    SLURM_PARTITIONS = {
        ClusterType.CW: "learn",
    }
    return SLURM_PARTITIONS[cluster_type]
def get_slurm_executor_parameters(
    nodes: int,
    num_gpus_per_node: int,
    cluster_type: Optional[ClusterType] = None,
    **kwargs,
) -> Dict[str, Any]:
    params = {
        "mem_gb": 0,  # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
        "gpus_per_node": num_gpus_per_node,
        "tasks_per_node": num_gpus_per_node,  # one task per GPU
        "cpus_per_task": 10,
        "nodes": nodes,
        "slurm_partition": get_slurm_partition(cluster_type),
    }
    cluster_type = get_cluster_type(cluster_type)
    if cluster_type == ClusterType.CW:
        params["cpus_per_task"] = 16
    params.update(kwargs)
    return params
