# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Utility functions for SLURM configuration and cluster settings.
"""

import os
import socket
import typing as tp
from enum import Enum

import omegaconf


class ClusterType(Enum):
    AWS = "aws"
    FAIR = "fair"
    RSC = "rsc"
    LOCAL_DARWIN = "darwin"
    DEFAULT = "default"  # used for any other cluster.


def _guess_cluster_type() -> ClusterType:
    uname = os.uname()
    fqdn = socket.getfqdn()
    if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn):
        return ClusterType.AWS

    if fqdn.endswith(".fair"):
        return ClusterType.FAIR

    if fqdn.endswith(".facebook.com"):
        return ClusterType.RSC

    if uname.sysname == "Darwin":
        return ClusterType.LOCAL_DARWIN

    return ClusterType.DEFAULT


def get_cluster_type(
    cluster_type: tp.Optional[ClusterType] = None,
) -> tp.Optional[ClusterType]:
    if cluster_type is None:
        return _guess_cluster_type()

    return cluster_type


def get_slurm_parameters(
    cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None
) -> omegaconf.DictConfig:
    """Update SLURM parameters in configuration based on cluster type.
    If the cluster type is not specify, it infers it automatically.
    """
    from ..environment import AudioCraftEnvironment

    cluster_type = get_cluster_type(cluster_type)
    # apply cluster-specific adjustments
    if cluster_type == ClusterType.AWS:
        cfg["mem_per_gpu"] = None
        cfg["constraint"] = None
        cfg["setup"] = []
    elif cluster_type == ClusterType.RSC:
        cfg["mem_per_gpu"] = None
        cfg["setup"] = []
        cfg["constraint"] = None
        cfg["partition"] = "learn"
    slurm_exclude = AudioCraftEnvironment.get_slurm_exclude()
    if slurm_exclude is not None:
        cfg["exclude"] = slurm_exclude
    return cfg
