import os
from typing import Optional, Callable

import torch
import torch.distributed as dist


class DistributedManager:
    initialized: bool = False
    local_rank: int = 0
    rank: int = 0
    local_world_size: int = 1
    world_size: int = 1

    @classmethod
    def init(
            cls,
            backend: Optional[str] = None,
            device_func: Optional[Callable[[int], int]] = None
    ) -> None:
        if cls.initialized:
            raise ValueError('DistributedManager already initialized')
        cls.backend: Optional[str] = backend
        cls.local_rank: int = int(os.environ.get('LOCAL_RANK'))
        cls.rank: int = int(os.environ.get('RANK'))
        cls.local_world_size: int = int(os.environ.get('LOCAL_WORLD_SIZE'))
        cls.world_size: int = int(os.environ.get('WORLD_SIZE'))
        if torch.cuda.is_available():
            torch.cuda.set_device(device_func(cls.local_rank) if device_func is not None else cls.local_rank)
        dist.init_process_group(cls.backend)
        cls.initialized: bool = True

    @classmethod
    def main_rank(cls) -> int:
        return 0

    @classmethod
    def is_main(cls) -> bool:
        return cls.rank == cls.main_rank()

    @classmethod
    def destroy(cls) -> None:
        if cls.initialized:
            dist.destroy_process_group()
