# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from dataclasses import dataclass
from typing import Callable, ContextManager, Optional

import torch


@dataclass
class ModelParallelConfig:
    """Base configuration for Megatron Core

    The initialization function has an argument for each parameter.
    """

    ###################
    # Model parallelism
    ###################
    tensor_model_parallel_size: int = 1
    """Intra-layer model parallelism. Splits tensors across GPU ranks."""

    pipeline_model_parallel_size: int = 1
    """Inter-layer model parallelism. Splits transformer layers across GPU ranks."""

    virtual_pipeline_model_parallel_size: Optional[int] = None
    """Interleaved pipeline parallelism is used to improve performance by reducing the pipeline
       bubble.  Considers a transformer block as a list of smaller transformer (virtual) blocks.
       The number of virtual blocks per pipeline model parallel rank is the virtual model parallel
       size.  See Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM:
       arxiv.org/pdf/2104.04473.pdf for more details.
    """

    sequence_parallel: bool = False
    """Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms
       and dropout sequentially.  See Reducing Activation Recomputation in Large Transformer Models
       (https://arxiv.org/abs/2205.05198) for more details.
    """

    context_parallel_size: int = 1
    """Splits network input along sequence dimension across GPU ranks."""

    expert_model_parallel_size: int = 1
    """Distributes Moe Experts across sub data parallel dimension."""

    ###################
    # Initialization
    ###################
    perform_initialization: bool = True
    """If true, weights are initialized. This option can be useful when you know you are going to
       load values from a checkpoint.
    """

    use_cpu_initialization: bool = False
    """When set to False, we initialize the weights directly on the GPU. CPU initialization is the
       same regardless of tensor model parallelism, but GPU initialization is not. Transferring
       weights from CPU to GPU can take a significant amount of time for large models.
    """

    ###################
    # Training
    ###################
    fp16: bool = False
    """If true, train with fp16 mixed precision training."""

    bf16: bool = False
    """If true, train with bf16 mixed precision training."""

    params_dtype: torch.dtype = torch.float32
    """dtype used when intializing the weights."""

    timers: Callable = None
    """Timers object to call for various timing functions. See megatron.core.timers.Timers"""

    finalize_model_grads_func: Callable = None
    """Function that finalizes gradients on all workers. Could include ensuring that grads are
       all-reduced across data parallelism, pipeline parallelism, and sequence parallelism
       dimensions.
    """

    grad_scale_func: Callable = None
    """If using loss scaling, this function should take the loss and return the scaled loss. If
       None, no function is called on the loss.
    """

    no_sync_func: Callable = None
    """Function that creates a context that suppresses asynchronous data-parallel communication. If
       the model is an instance of core.distributed.DistributedDataParallel, the default is to use
       core.distributed.DistributedDataParallel.no_sync.
    """

    grad_sync_func: Callable = None
    """Function that launches asynchronous gradient reductions (e.g. distributed optimizer gradient
       reduce-scatters). The function should take one argument: an iterable of parameters whose
       gradients are to be synchronized.
    """

    param_sync_func: Callable = None
    """Function that launches asynchronous parameter synchronizations (e.g. distributed optimizer
       parameter all-gathers). The function should take one argument: an iterable of parameters to
       be synchronized.
    """

    enable_autocast: bool = False
    """If true runs the forward step function inside torch.autocast context."""

    autocast_dtype: torch.dtype = None
    """dtype to pass to torch.amp.autocast when enabled. If None, is set to pipeline_dtype."""

    num_microbatches_with_partial_activation_checkpoints: Optional[int] = None
    """If int, set the number of microbatches where not all of the layers will be checkpointed and
       recomputed. The rest of the microbatches within the window of maximum outstanding
       microbatches will recompute all layers (either full recompute or selective recompute). If
       None, the checkpoint and recompute will be left up to the forward_step function.

    """

    ###################
    # Optimizations
    ###################
    gradient_accumulation_fusion: bool = False
    """If true, fuses weight gradient accumulation to GEMMs. Requires the custom CUDA extension
       fused_weight_gradient_mlp_cuda module. To use gradient_accumulation_fusion you must install
       APEX with --cpp_ext and --cuda_ext. For example: "pip install --global-option=\"--cpp_ext\"
       --global-option=\"--cuda_ext\" ". Note that the extension requires CUDA>=11. Otherwise, you
       must turn off gradient accumulation fusion.
    """

    async_tensor_model_parallel_allreduce: bool = False
    """If true, enables asynchronous execution of tensor-model-parallel all-reduce with weight
       gradient compuation of a column-linear layer.
    """
    tp_comm_overlap: bool = False
    """If true, allows overlapping of Linear layer execution with tensor parallel communication
       collectives like AllGather/ReduceScatter. Overlapping is done for the linear layers wherever
       possible during the forward and the backward pass.
    """

    tp_comm_bulk_wgrad: bool = True
    """If true, allows All-Gather overlap with Bprop activation gradient GEMM. Don't care if
       tp_comm_overlap is False.
    """

    tp_comm_bulk_dgrad: bool = True
    """If true, allows Reduce-Scatter overlap with Bprop weight gradient GEMM. Don't care if
       tp_comm_overlap is False.
    """

    tp_comm_overlap_ag: bool = True
    """If true, allows All-Gather overlap with GEMM by pipelining the GEMM and All-Gather.
       Don't care if tp_comm_overlap is False.
    """

    tp_comm_overlap_rs: bool = True
    """If true, allows Reduce-Scatter overlap with GEMM by pipelining the GEMM and Reduce-Scatter.
       Don't care if tp_comm_overlap is False.
    """

    tp_comm_split_ag: bool = True
    """Deprecated from TransformerEngine v1.6.0.
       If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
       splits. Don't care if tp_comm_overlap is False.
    """

    tp_comm_atomic_ag: bool = False
    """Deprecated from TransformerEngine v1.6.0.
        If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather both
       done atomically. Don't care if tp_comm_overlap is False.
    """

    tp_comm_split_rs: bool = True
    """Deprecated from TransformerEngine v1.6.0.
       If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
       Reduce-Scatter splits. Don't care if tp_comm_overlap is False.
    """

    tp_comm_atomic_rs: bool = False
    """Deprecated from TransformerEngine v1.6.0.
       If true, allows Reduce-Scatter overlap with Fprop GEMM by pipelining the GEMM and
       Reduce-Scatter both done atomically. Don't care if tp_comm_overlap is False.
    """

    ###################
    # Pipeline Parallel
    ###################
    pipeline_dtype: torch.dtype = None
    """dtype used in p2p communication, usually params_dtype"""

    variable_seq_lengths: bool = False
    """Support for variable sequence lengths across microbatches. Setting this communicates the size
        of tensors during pipeline parallelism communication, because of this extra overhead it
        should only be set if the sequence length varies by microbatch within a global batch.
    """

    overlap_p2p_comm: bool = False
    """When True some of the peer to peer communication for pipeline parallelism will overlap with
       computation. Must be False if batch_p2p_comm is true.
    """

    batch_p2p_comm: bool = True
    """Use batch_isend_irecv instead of individual isend/irecv calls. Must be False if
       overlap_p2p_comm is True.
    """

    batch_p2p_sync: bool = True
    """When using batch_isend_irecv, do a cuda.device.synchronize afterward to work around a bug in
       older version of PyTorch.
    """

    use_ring_exchange_p2p: bool = False
    """Use custom ring_exchange kernel instead of torch.distributed.batch_isend_irecv(). Requires
       custom built torch with torch.distributed.ring_exchange.
    """

    deallocate_pipeline_outputs: bool = False
    """If True, output data is deallocated after the tensor is sent to the next pipeline stage.
       Helps with saving memory, does nothing when pipeline parallel is not used.
    """

    defer_embedding_wgrad_compute: bool = False
    """If true, defers the embedding WGRAD GEMMs while pipeline flush is 
       taking place enabling us to hide pipeline flush latency. Defaults to False.
    """

    pipeline_model_parallel_split_rank: Optional[int] = None
    """If int, rank where encoder and decoder should be split in cases where the model has both an
       encoder and decoder (e.g., T5). Ignored if None.
    """

    ###################
    # CPU Offloading
    ###################
    cpu_offloading: bool = False
    """When set to True, all the activations are offloaded to the CPU asynchronously."""

    cpu_offloading_num_layers: int = 0
    """Tells the number of transformer layers for which activations has to be offloaded."""

    _cpu_offloading_context: ContextManager = None  # Used for internal use only, not to be set by the user. TODO: Need to move to the 'right' place when possible.
    """For internal use only, do not set."""

    cpu_offloading_activations: bool = True
    """If True, offloads the activations to CPU."""

    cpu_offloading_weights: bool = True
    """If True, offloads the weights to CPU."""

    ###################
    # Timing
    ###################
    barrier_with_L1_time: bool = True
    """If true, use barrier with level 1 time measurements. It is up to the user to make sure
       calling barrier with their timers will not result in hangs. This can happen if for example
       the user adds a level 1 timer that is not called by all ranks.
    """

    def __post_init__(self):
        """ Python dataclass method that is used to modify attributes after initialization.
            See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
        """
        if self.sequence_parallel:
            if self.tensor_model_parallel_size <= 1:
                raise ValueError("Can not use sequence paralllelism without tensor parallelism")
            if self.async_tensor_model_parallel_allreduce:
                # sequence_parallelism already does this async
                self.async_tensor_model_parallel_allreduce = False

        if self.pipeline_model_parallel_size > 1:
            if self.pipeline_dtype is None:
                raise ValueError(
                    "When using pipeline parallelism, pipeline_dtype must be specified"
                )

        if self.autocast_dtype is None:
            self.autocast_dtype = self.params_dtype

        if self.defer_embedding_wgrad_compute and self.pipeline_model_parallel_size == 1:
            raise ValueError(
                "Cannot defer embedding wgrad compute when pipeline model parallel is not used"
            )

        if self.defer_embedding_wgrad_compute and not self.gradient_accumulation_fusion:
            raise ValueError(
                "Cannot defer embedding wgrad compute when gradient accumulation fusion is not used"
            )

        if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
            if self.sequence_parallel is False:
                raise ValueError(
                    "When using expert parallelism and tensor parallelism, sequence parallelism must be used"
                )
