from typing import Optional
import torch.backends.cudnn as cudnn
from accelerate import Accelerator, GradScalerKwargs, DistributedDataParallelKwargs
from accelerate.utils import set_seed

__all__ = ["PadoAccelerator"]


class PadoAccelerator(Accelerator):
    """
    Wrapper of Accelerator.
    * Because almost always this class will be initialized first, this also sets seed.
    """

    def __init__(self,
                 seed: int = 1234,
                 cpu: bool = False,
                 fp16: bool = False,
                 split_batches: bool = False,
                 cudnn_benchmark: bool = True,
                 cudnn_deterministic: bool = False,
                 fp16_init_scale: float = 2048.0,
                 fp16_growth_interval: Optional[int] = 1000,
                 ddp_broadcast_buffers: bool = True,
                 ddp_find_unused_parameters: bool = False,
                 ) -> None:
        # -------------------------------------------------------------------------------------- #
        ddp_kwargs = DistributedDataParallelKwargs(
            dim=0, broadcast_buffers=ddp_broadcast_buffers,
            find_unused_parameters=ddp_find_unused_parameters,
            gradient_as_bucket_view=True  # is this good enough?
        )
        scaler_kwargs = GradScalerKwargs(
            init_scale=fp16_init_scale, growth_interval=fp16_growth_interval, enabled=fp16
        )
        super().__init__(device_placement=True,
                         split_batches=split_batches, fp16=fp16, cpu=cpu,
                         kwargs_handlers=[ddp_kwargs, scaler_kwargs])
        # -------------------------------------------------------------------------------------- #
        self.set_seed(seed, cudnn_benchmark, cudnn_deterministic)

        # alias
        self.is_master: bool = self.is_local_main_process
        self.world_size: int = self.num_processes
        self.local_rank: int = self.process_index

    @staticmethod
    def set_seed(seed: int,
                 cudnn_benchmark: bool = True,
                 cudnn_deterministic: bool = False) -> None:
        set_seed(seed)
        if cudnn_benchmark:
            cudnn.benchmark = True
        if cudnn_deterministic:
            cudnn.deterministic = True
