import warnings
from dataclasses import dataclass, field
from typing import Any, Literal, Optional

from verl.base_config import BaseConfig
from verl.trainer.config import CheckpointConfig

from ...utils.profiler import ProfilerConfig
from .model import HFModelConfig
from .optimizer import OptimizerConfig

__all__ = ["FSDPEngineConfig", "McoreEngineConfig", "TrainingWorkerConfig", "VeOmniEngineConfig"]


@dataclass
class EngineConfig(BaseConfig):
    _mutable_fields = BaseConfig._mutable_fields | {
        "use_dynamic_bsz",
        "max_token_len_per_gpu",
        "micro_batch_size_per_gpu",
        "infer_max_token_len_per_gpu",
        "infer_micro_batch_size_per_gpu",
        "use_fused_kernels",
        "use_remove_padding",
    }

    # whether to offload param
    param_offload: bool = False
    # whether to offload optimizer
    optimizer_offload: bool = False
    # whether to offload grad
    grad_offload: bool = False
    # whether the engine is forward only (e.g., ref policy)
    forward_only: bool = False
    # the strategy (backend)
    strategy: str = None
    # model dtype
    dtype: str = "bfloat16"  # ["bfloat16", "float16"]
    # whether to use dynamic bsz
    use_dynamic_bsz: bool = True
    # for training
    max_token_len_per_gpu: int = None
    micro_batch_size_per_gpu: int = None
    # for inference
    infer_max_token_len_per_gpu: int = None
    infer_micro_batch_size_per_gpu: int = None
    # whether use fuse lm head kernel
    use_fused_kernels: bool = False
    # TODO (this may conflict with the one in model config)
    use_remove_padding: bool = True

    seed: int = 42

    full_determinism: bool = False

    def __post_init__(self):
        pass
        # TODO: turn on this check after we reorg config
        # if self.use_dynamic_bsz:
        #     assert self.max_token_len_per_gpu is not None
        # else:
        #     assert self.micro_batch_size_per_gpu is not None


@dataclass
class McoreEngineConfig(EngineConfig):
    """Configuration for Megatron parallelism.

    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.

    Args:
        param_offload (bool): Whether to offload parameters to CPU.
        grad_offload (bool): Whether to offload gradients to CPU.
        optimizer_offload (bool): Whether to offload optimizer states to CPU.
        tensor_model_parallel_size (int): Tensor model parallel size.
        expert_model_parallel_size (int): Expert model parallel size for MoE models.
        expert_tensor_parallel_size (Optional[int]): Expert tensor parallel size for MoE models.
        pipeline_model_parallel_size (int): Pipeline model parallel size.
        virtual_pipeline_model_parallel_size (Optional[int]): Virtual pipeline model parallel size
            for interleaved scheduling.
        context_parallel_size (int): Context parallel size for long sequences.
        sequence_parallel (bool): Whether to enable sequence parallelism.
        use_distributed_optimizer (bool): Whether to use distributed optimizer.
        use_dist_checkpointing (bool): Whether to use distributed checkpointing.
        dist_checkpointing_path (Optional[str]): Path for distributed checkpointing.
        seed (int): Random seed for reproducibility.
        override_ddp_config (dict[str, Any]): Override configuration for DDP.
        override_transformer_config (dict[str, Any]): Override configuration for transformer.
        use_mbridge (bool): Whether to use MBridge for communication.
        dtype (str): Mixed precision training param dtype, default "bfloat16"
    """

    # sequence_parallel is not listed as a frozen field for auto-correction purpose
    _mutable_fields = EngineConfig._mutable_fields | {"sequence_parallel"}
    # mcore parallelism
    tensor_model_parallel_size: int = 1
    expert_model_parallel_size: int = 1
    expert_tensor_parallel_size: Optional[int] = None
    pipeline_model_parallel_size: int = 1
    virtual_pipeline_model_parallel_size: Optional[int] = None
    context_parallel_size: int = 1
    sequence_parallel: bool = True
    use_distributed_optimizer: bool = True
    use_dist_checkpointing: bool = False
    dist_checkpointing_path: Optional[str] = None
    dist_checkpointing_prefix: str = ""
    override_ddp_config: dict[str, Any] = field(default_factory=dict)
    override_transformer_config: dict[str, Any] = field(default_factory=dict)
    override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
    use_mbridge: bool = True
    vanilla_mbridge: bool = True
    strategy: str = "megatron"

    def __post_init__(self) -> None:
        super().__post_init__()
        """config validation logics go here"""
        assert self.strategy == "megatron"
        assert self.dtype in ["bfloat16", "float16"], f"dtype {self.dtype} not supported"
        if self.tensor_model_parallel_size == 1:
            warnings.warn("set sequence parallel to false as TP size is 1", stacklevel=2)
            self.sequence_parallel = False


@dataclass
class FSDPEngineConfig(EngineConfig):
    """Configuration for FSDP (Fully Sharded Data Parallel).

    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.

    Args:
        wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.
        param_offload (bool): Whether to offload parameters to CPU, default False
        optimizer_offload (bool): Whether to offload optimizer states to CPU, default False
        offload_policy (bool): Whether to offload policy model parameters, default False
        reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True
        fsdp_size (int): FSDP group size. -1 means use all available GPUs.
        forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False
        model_dtype (str): Model data type used to initialize the transformers model. default "fp32"
        use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False
        seed (int): Random seed for reproducibility.
        full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results
            in distributed training. Important: this will negatively impact performance, so only use it for
            debugging.
        mixed_precision (Optional[dict[str, Any]]): Mixed precision configuration for FSDP, default None
        dtype (str): Mixed precision training param dtype, default "bfloat16"
    """

    # ulysses_sequence_parallel_size is mutable for backward compatibility
    _mutable_fields = EngineConfig._mutable_fields | {"ulysses_sequence_parallel_size"}

    # fsdp specific flags
    wrap_policy: dict[str, Any] = field(default_factory=dict)
    offload_policy: bool = False
    reshard_after_forward: bool = True
    fsdp_size: int = -1
    forward_prefetch: bool = False
    model_dtype: str = "fp32"
    use_orig_params: bool = False
    mixed_precision: Optional[dict[str, Any]] = None
    ulysses_sequence_parallel_size: int = 1
    entropy_from_logits_with_chunking: bool = False
    use_torch_compile: bool = True
    entropy_checkpointing: bool = False
    strategy: str = "fsdp"

    def __post_init__(self):
        super().__post_init__()
        assert self.strategy in ["fsdp", "fsdp2"], f"strategy {self.strategy} not supported"


@dataclass
class VeOmniEngineConfig(EngineConfig):
    """Configuration for VeOmni.

    The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.

    Args:
        wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy.
        param_offload (bool): Whether to offload parameters to CPU, default False
        optimizer_offload (bool): Whether to offload optimizer states to CPU, default False
        offload_policy (bool): Whether to offload policy model parameters, default False
        reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True
        data_parallel_size (int): FSDP group size, default 1
        data_parallel_replicate_size (int): Data parallel replicate size, default 1
        data_parallel_shard_size (int): Data parallel shard degree, default 1
        tensor_parallel_size (int): Tensor parallel size, default 1
        expert_parallel_size (int): Expert parallel size, default 1
        pipeline_parallel_size (int): Pipeline parallel size, default 1
        context_parallel_size (int): Ring-attn context parallel size, default 1
        ulysses_parallel_size (int): Ulysses sequence parallel size, default 1
        data_parallel_mode (str): Data parallel mode, default "fsdp"
        init_device (str): Device to initialize model weights.
            1. `cpu`: Init parameters on CPU in rank0 only.
            2. `cuda`: Init parameters on GPU.
            3. `meta`: Init parameters on meta.
            4. `npu`: Init parameters on Ascend NPU.
            default "meta"
        enable_full_shard (bool): Enable fully shard for FSDP training (ZeRO-3), default False
        enable_fsdp_offload (bool): Enable CPU offload for FSDP1, default False
        enable_reentrant (bool): Use reentrant gradient checkpointing, default False
        attn_implementation (str): Attention implementation to use, default "flash-attn"
        moe_implementation (str): MoE implementation to use, default "eager"
        force_use_huggingface (bool): Force loading model from huggingface, default False
        activation_gpu_limit (float): When enabling activation offload, `activation_gpu_limit` GB
            activations are allowed to reserve on GPU, default 0.0
        basic_modules (list[str]): List of basic modules to use, default None
        forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False
        model_dtype (str): Model data type used to initialize the transformers model. default "fp32"
        use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False
        seed (int): Random seed for reproducibility.
        full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results
            in distributed training. Important: this will negatively impact performance, so only use it for
            debugging.
        mixed_precision (Optional[dict[str, Any]]): Mixed precision configuration for FSDP, default None

    """

    wrap_policy: dict[str, Any] = field(default_factory=dict)
    offload_policy: bool = False
    reshard_after_forward: bool = True
    forward_prefetch: bool = False
    use_orig_params: bool = False
    entropy_from_logits_with_chunking: bool = False
    use_torch_compile: bool = True
    entropy_checkpointing: bool = False
    strategy: str = "veomni"
    data_parallel_size: int = 1
    data_parallel_replicate_size: int = 1
    data_parallel_shard_size: int = 1
    tensor_parallel_size: int = 1
    expert_parallel_size: int = 1
    pipeline_parallel_size: int = 1
    context_parallel_size: int = 1
    ulysses_parallel_size: int = 1
    data_parallel_mode: Literal["ddp", "fsdp1", "fsdp2"] = "fsdp"
    seed: int = 42
    full_determinism: bool = False
    mixed_precision: bool = False
    init_device: str = "meta"
    enable_full_shard: bool = False
    ckpt_manager: Literal["dcp"] = "dcp"
    load_checkpoint_path: Optional[str] = None
    enable_fsdp_offload: bool = False
    enable_reentrant: bool = False
    attn_implementation: Optional[Literal["eager", "sdpa", "flash_attention_2", "native-sparse"]] = "flash_attention_2"
    moe_implementation: Optional[Literal[None, "eager", "fused"]] = "eager"
    force_use_huggingface: bool = False
    activation_gpu_limit: float = 0.0
    basic_modules: Optional[list[str]] = field(default_factory=list)

    def __post_init__(self):
        super().__post_init__()
        assert self.strategy in ["veomni"], f"strategy {self.strategy} not supported"


@dataclass
class TrainingWorkerConfig(BaseConfig):
    model_type: str = None  # model type (language_model/value_model)
    model_config: HFModelConfig = None
    engine_config: EngineConfig = None
    optimizer_config: OptimizerConfig = None
    checkpoint_config: CheckpointConfig = None
    profiler_config: ProfilerConfig = None
