# Copyright 2025 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import types
from typing import TYPE_CHECKING, Any, Optional

from ..extras import logging


if TYPE_CHECKING:
    from ..hparams import TrainingArguments


logger = logging.get_logger(__name__)


def create_fp8_kwargs(training_args: "TrainingArguments") -> list[Any]:
    """Create AORecipeKwargs for FP8 training with HuggingFace Accelerate.

    Args:
        training_args: Training arguments containing FP8 configuration

    Returns:
        List containing AORecipeKwargs if FP8 is enabled and supported, empty list otherwise
    """
    if not training_args.fp8:
        return []

    backend = getattr(training_args, "fp8_backend", "auto")
    logger.info_rank0(f"Creating FP8 configuration with backend: {backend}")

    try:
        # Use Transformer Engine backend (optimal for Hopper GPUs)
        if backend == "te":
            from accelerate.utils import FP8RecipeKwargs

            logger.info_rank0("Using Transformer Engine FP8 backend")
            return [FP8RecipeKwargs(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")]

        # Use TorchAO backend (default)
        from accelerate.utils import AORecipeKwargs

        # Create Float8LinearConfig if torchao backend is used
        config = None
        if backend == "torchao" or backend == "auto":
            from torchao.float8 import Float8LinearConfig

            # Use rowwise scaling for better performance (as recommended by torchao)
            # Configure alignment requirements for FP8 kernels
            config = Float8LinearConfig.from_recipe_name("rowwise")

            # Enable alignment for better kernel performance
            if hasattr(config, "enable_amax_init"):
                config.enable_amax_init = True
            if hasattr(config, "enable_pre_and_post_forward"):
                config.enable_pre_and_post_forward = True

        # Create module filter function to skip problematic layers
        # TorchAO FP8 requires dimensions divisible by 16 for optimal kernels
        def module_filter_func(module, layer_name):
            # Skip embedding and output layers for numerical stability
            skip_layers = ["embed", "lm_head", "output", "classifier"]
            if any(skip_name in layer_name.lower() for skip_name in skip_layers):
                return False

            # Only convert Linear layers
            if not (hasattr(module, "weight") and len(module.weight.shape) == 2):
                return False

            # Check dimension alignment for FP8 kernels
            weight = module.weight
            in_features, out_features = weight.shape[1], weight.shape[0]

            # Skip layers with dimensions not divisible by 16 to avoid kernel errors
            if in_features % 16 != 0 or out_features % 16 != 0:
                logger.debug(
                    f"Skipping layer {layer_name} with dimensions {out_features}x{in_features} (not divisible by 16)"
                )
                return False

            return True

        # Map FSDP all-gather setting if available (this affects the underlying implementation)
        if (
            hasattr(training_args, "fp8_enable_fsdp_float8_all_gather")
            and training_args.fp8_enable_fsdp_float8_all_gather
        ):
            logger.info_rank0("FSDP float8 all-gather optimization requested")

        return [AORecipeKwargs(config=config, module_filter_func=module_filter_func)]
    except Exception as e:
        logger.info_rank0(f"Failed to create FP8 configuration: {e}")
        return []


def get_fp8_mixed_precision(training_args: "TrainingArguments") -> Optional[str]:
    """Get the mixed precision setting for Accelerate when using FP8.

    Args:
        training_args: Training arguments containing FP8 configuration

    Returns:
        "fp8" if FP8 is enabled, None otherwise
    """
    return "fp8" if training_args.fp8 else None


def configure_fp8_environment(training_args: "TrainingArguments") -> None:
    """Configure FP8 environment for HuggingFace Accelerate.

    FP8 training is handled entirely through HuggingFace Accelerate, regardless of whether
    DeepSpeed or FSDP is used for distributed training. This function sets up the environment
    variables and validates the FP8 configuration.

    Args:
        training_args: Training arguments containing FP8 configuration
    """
    if not training_args.fp8:
        return

    # Set mixed precision to fp8 for HuggingFace Accelerate
    os.environ["ACCELERATE_MIXED_PRECISION"] = "fp8"
    logger.info_rank0("Set ACCELERATE_MIXED_PRECISION=fp8")

    # Configure FP8 backend and options
    backend = getattr(training_args, "fp8_backend", "auto")
    if backend != "auto":
        os.environ["FP8_BACKEND"] = backend
        logger.info_rank0(f"Set FP8_BACKEND={backend}")

    # Create and validate FP8 recipe kwargs (for logging/debugging)
    fp8_kwargs = create_fp8_kwargs(training_args)
    logger.info_rank0(f"FP8 AORecipeKwargs created: {len(fp8_kwargs)} items")

    # Enable FSDP float8 all-gather optimization if requested
    if hasattr(training_args, "fp8_enable_fsdp_float8_all_gather") and training_args.fp8_enable_fsdp_float8_all_gather:
        os.environ["FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER"] = "true"
        logger.info_rank0("Set FP8_ENABLE_FSDP_FLOAT8_ALL_GATHER=true")

    logger.info_rank0("FP8 environment configured - all FP8 training handled by HuggingFace Accelerate")


def verify_fp8_status(accelerator, training_args: "TrainingArguments") -> None:
    """Verify that FP8 training is actually working after model preparation.

    Args:
        accelerator: The HuggingFace Accelerator instance
        training_args: Training arguments containing FP8 configuration
    """
    if not training_args.fp8:
        return

    # Check Accelerate's FP8 status
    fp8_enabled = getattr(accelerator, "fp8_enabled", False)
    fp8_backend_type = getattr(accelerator, "fp8_backend", "UNKNOWN")

    backend = getattr(training_args, "fp8_backend", "auto")
    if backend == "torchao" or backend == "auto":
        logger.info_rank0(
            "FP8 training enabled with TorchAO backend. For optimal performance, "
            "ensure model layer dimensions are mostly divisible by 16. "
            "If you encounter issues, try fp8_backend='te' with Transformer Engine."
        )
    else:
        logger.info_rank0(f"FP8 training enabled with {backend} backend.")

    logger.info_rank0(f"Accelerate FP8 status - enabled: {fp8_enabled}, backend: {fp8_backend_type}")

    if not fp8_enabled:
        logger.info_rank0("WARNING: FP8 was requested but Accelerate shows fp8_enabled=False. FP8 may not be working.")


def patch_accelerator_for_fp8() -> None:
    """Patch Accelerator to inject FP8 recipe kwargs.

    This is needed because HuggingFace Trainer doesn't pass kwargs_handlers to Accelerator.
    We monkey-patch Accelerator.__init__ to inject the FP8 recipe and force mixed_precision='fp8'.
    """
    import transformer_engine.pytorch as te
    from accelerate import Accelerator

    # Guard against multiple patches
    if getattr(Accelerator, "_te_fp8_patched", False):
        return

    # Stub for Accelerate 1.12+ compatibility (te.fp8.check_mxfp8_support doesn't exist yet)
    if not hasattr(te, "fp8"):
        te.fp8 = types.ModuleType("fp8")
        te.fp8.check_mxfp8_support = lambda: (False, "MXFP8 not supported")

    try:
        from accelerate.utils import TERecipeKwargs as FP8Recipe

        use_te_recipe = True
    except ImportError:
        from accelerate.utils import FP8RecipeKwargs as FP8Recipe

        use_te_recipe = False

    original_init = Accelerator.__init__

    def patched_init(self, *args, **kwargs):
        if "kwargs_handlers" not in kwargs or not kwargs["kwargs_handlers"]:
            if use_te_recipe:
                kwargs["kwargs_handlers"] = [
                    FP8Recipe(fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
                ]
            else:
                kwargs["kwargs_handlers"] = [
                    FP8Recipe(backend="TE", fp8_format="HYBRID", amax_history_len=16, amax_compute_algo="max")
                ]
            # Only force mixed_precision when we inject handlers
            kwargs["mixed_precision"] = "fp8"
        return original_init(self, *args, **kwargs)

    Accelerator.__init__ = patched_init
    Accelerator._te_fp8_patched = True
