from dataclasses import dataclass

__all__ = ["SlurmJob"]


@dataclass(unsafe_hash=True)  # for batching jobs
class SlurmJob:
    partition: str
    gpu_type: str | None = None
    time_min: int = 60
    num_nodes: int = 1
    mem_per_node: int = 64
    cpus_per_node: int = 8
    gpus_per_node: int = 0
    account: str | None = None

    def to_parameters(self) -> dict:
        return dict(
            stderr_to_stdout=True,
            use_srun=False,
            time=self.time_min,
            nodes=self.num_nodes,
            ntasks_per_node=1,
            mem=f"{self.mem_per_node}G",
            cpus_per_task=self.cpus_per_node,
            gpus_per_node=self.gpus_per_node,
            constraint=self.gpu_type,
            partition=self.partition,
            account=self.account,
        )
