# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import logging

from dataclasses import dataclass
from typing import Dict, Iterable, Iterator, List, Tuple

import torch

from torchtitan.config.job_config import GradientClipping as GradientClippingConfig
from torchtitan.distributed import utils as dist_utils
from torchtitan.distributed.parallel_dims import ParallelDims

logger = logging.getLogger(__name__)

# Log param groups only once per process
_LOGGED_PARAM_GROUPS = False


def _all_parameters_from_model_parts(
    model_parts: List[torch.nn.Module],
) -> List[torch.nn.Parameter]:
    return [p for m in model_parts for p in m.parameters()]


class BaseGradClipper:
    def step(
        self,
        model_parts: List[torch.nn.Module],
        parallel_dims: ParallelDims,
        max_norm: float,
        max_norm_last_layer: float | None = None,
    ) -> torch.Tensor:
        raise NotImplementedError


def _unwrap_model(m: torch.nn.Module) -> torch.nn.Module:
    target = m
    for attr in ("module", "_ddp_wrapped_module", "_orig_mod", "model"):
        if hasattr(target, attr):
            target = getattr(target, attr)
    return target


def _iter_layer_param_groups(
    model_parts: List[torch.nn.Module],
) -> Iterator[Tuple[str, List[torch.nn.Parameter]]]:
    """Yield (layer_key, params) for each logical layer on this rank.

    If a model part exposes a "layers" mapping (e.g., ModuleDict) we use each entry as a layer.
    For model parts containing final layers (norm, output/lm_head), they get special naming.
    Otherwise we treat the whole model part as a single group.
    Only returns groups with at least one parameter requiring grad on this rank.
    """

    global _LOGGED_PARAM_GROUPS
    should_log = not _LOGGED_PARAM_GROUPS

    # Collect all model parts and their info first
    all_groups = []
    has_transformer_layers = False
    max_layer_id = -1

    for part_idx, m in enumerate(model_parts):
        base = _unwrap_model(m)
        if hasattr(base, "layers") and isinstance(
            getattr(base, "layers"), (dict, torch.nn.ModuleDict)  # noqa: B009
        ):
            has_transformer_layers = True
            layers_map = base.layers
            layer_count = 0
            total_params = 0

            # Collect final layer params (norm/output/lm_head) that live on the same module
            final_param_tensors: List[torch.nn.Parameter] = []
            final_param_names: List[str] = []
            for component_name in ["norm", "output", "lm_head"]:
                if hasattr(base, component_name):
                    component = getattr(base, component_name)
                    if component is not None:
                        for name, param in component.named_parameters(
                            prefix=component_name
                        ):
                            if param.requires_grad:
                                final_param_tensors.append(param)
                                final_param_names.append(f"{name}({param.shape})")

            # Collect embedding params (e.g., tok_embeddings) to be merged into the first layer
            embed_param_tensors: List[torch.nn.Parameter] = []
            embed_param_names: List[str] = []
            for embed_name in [
                "tok_embeddings",
                "embeddings",
                "embedding",
                "word_embeddings",
            ]:
                if hasattr(base, embed_name):
                    embed_module = getattr(base, embed_name)
                    if embed_module is not None:
                        for name, param in embed_module.named_parameters(
                            prefix=embed_name
                        ):
                            if param.requires_grad:
                                embed_param_tensors.append(param)
                                embed_param_names.append(f"{name}({param.shape})")

            # Process each transformer layer, but store locally so we can merge final layers into the last one
            local_groups: List[
                Tuple[str, List[torch.nn.Parameter], List[str], int, str, str]
            ] = []
            local_max_layer_id = -1
            local_max_group_idx = -1
            local_min_layer_id = 1 << 30
            local_min_group_idx = -1
            for layer_key, layer_mod in (
                layers_map.items()
                if isinstance(layers_map, dict)
                else layers_map.items()
            ):
                params = [p for p in layer_mod.parameters() if p.requires_grad]
                if len(params) > 0:
                    param_count = sum(p.numel() for p in params)
                    layer_group_name = f"part{part_idx}/layer{layer_key}"

                    # Get parameter names for this layer
                    param_names = []
                    for name, param in layer_mod.named_parameters():
                        if param.requires_grad:
                            param_names.append(f"{name}({param.shape})")

                    # Track the maximum layer ID (both globally and locally)
                    try:
                        layer_id = int(layer_key)
                        if layer_id > local_max_layer_id:
                            local_max_layer_id = layer_id
                            local_max_group_idx = len(local_groups)
                        if layer_id < local_min_layer_id:
                            local_min_layer_id = layer_id
                            local_min_group_idx = len(local_groups)
                        max_layer_id = max(max_layer_id, layer_id)
                    except ValueError:
                        pass

                    local_groups.append(
                        (
                            layer_group_name,
                            params,
                            param_names,
                            param_count,
                            "transformer_layer",
                            layer_key,
                        )
                    )
                    layer_count += 1
                    total_params += param_count

            # If embedding params exist in this module, merge them into the first (lowest) transformer layer in this part
            if embed_param_tensors and local_groups and local_min_group_idx >= 0:
                g_name, g_params, g_names, g_count, g_type, g_layer_key = local_groups[
                    local_min_group_idx
                ]
                merged_params = embed_param_tensors + g_params
                merged_names = embed_param_names + g_names
                merged_count = sum(p.numel() for p in merged_params)
                try:
                    prefix = int(g_layer_key)
                    new_group_name = f"part{part_idx}/layer{prefix}_with_embed"
                except Exception:
                    new_group_name = f"{g_name}_with_embed"
                if should_log:
                    logger.info(
                        f"Merging embedding params into first layer group: {new_group_name}"
                    )
                    logger.info(f"  Embed parameters: {', '.join(embed_param_names)}")
                local_groups[local_min_group_idx] = (
                    new_group_name,
                    merged_params,
                    merged_names,
                    merged_count,
                    g_type,
                    g_layer_key,
                )

            # If final layers exist in this module, merge them into the last (highest) transformer layer in this part
            if final_param_tensors and local_groups and local_max_group_idx >= 0:
                g_name, g_params, g_names, g_count, g_type, g_layer_key = local_groups[
                    local_max_group_idx
                ]
                merged_params = g_params + final_param_tensors
                merged_names = g_names + final_param_names
                merged_count = sum(p.numel() for p in merged_params)
                try:
                    suffix = int(g_layer_key)
                    new_group_name = f"part{part_idx}/layer{suffix}_with_final"
                except Exception:
                    new_group_name = f"{g_name}_with_final"
                if should_log:
                    logger.info(
                        f"Merging final layer params into last layer group: {new_group_name}"
                    )
                    logger.info(
                        f"  Final layer parameters: {', '.join(final_param_names)}"
                    )
                local_groups[local_max_group_idx] = (
                    new_group_name,
                    merged_params,
                    merged_names,
                    merged_count,
                    g_type,
                    g_layer_key,
                )

            # Commit local groups
            all_groups.extend(local_groups)

            if should_log and layer_count > 0:
                logger.info(
                    f"Model part {part_idx}: Found {layer_count} transformer layer groups with {total_params:,} total parameters"
                )
        else:
            # Handle non-transformer model parts
            params = [p for p in m.parameters() if p.requires_grad]
            if len(params) > 0:
                param_count = sum(p.numel() for p in params)

                # Get parameter names for this model part
                param_names = []
                for name, param in m.named_parameters():
                    if param.requires_grad:
                        param_names.append(f"{name}({param.shape})")

                # Check if this part contains final layers
                has_final_layers = False
                final_layer_components = ["norm", "output", "lm_head"]

                for component_name in final_layer_components:
                    if hasattr(base, component_name):
                        component = getattr(base, component_name)
                        if component is not None and any(
                            p.requires_grad for p in component.parameters()
                        ):
                            has_final_layers = True
                            break

                if has_final_layers and has_transformer_layers:
                    # This is a final layers part - give it special naming to indicate it should be grouped with last layer
                    group_name = f"part{part_idx}/final_layers_last{max_layer_id}"
                    group_type = "final_layers"
                else:
                    # Regular single group
                    group_name = f"part{part_idx}"
                    group_type = "single_group"

                all_groups.append(
                    (group_name, params, param_names, param_count, group_type, None)
                )

    # Now yield the groups and (optionally) log them once
    for (
        group_name,
        params,
        param_names,
        param_count,
        group_type,
        layer_key,
    ) in all_groups:
        if should_log:
            if group_type == "transformer_layer":
                logger.info(
                    f"Gradient clipping layer group: {group_name} with {len(params)} parameters ({param_count:,} total params)"
                )
            elif group_type == "final_layers":
                logger.info(
                    f"Gradient clipping layer group: {group_name} (final layers - norm/output) "
                    f"with {len(params)} parameters ({param_count:,} total params)"
                )
            else:
                logger.info(
                    f"Gradient clipping layer group: {group_name} (single group) "
                    f"with {len(params)} parameters ({param_count:,} total params)"
                )
            logger.info(f"  Parameters: {', '.join(param_names)}")
        yield (group_name, params)

    if should_log:
        _LOGGED_PARAM_GROUPS = True


class VanillaGradClipper(BaseGradClipper):
    def __init__(self, cfg: GradientClippingConfig):
        self.scope = cfg.scope
        self.per_layer_norms: Dict[str, float] = {}

    def step(
        self,
        model_parts: List[torch.nn.Module],
        parallel_dims: ParallelDims,
        max_norm: float,
        max_norm_last_layer: float | None = None,
    ) -> torch.Tensor:

        if self.scope == "per_layer":
            # Always compute global norm for logging consistency
            global_norm = dist_utils.clip_grad_norm_(
                _all_parameters_from_model_parts(model_parts),
                float("inf"),
                foreach=True,
                pp_mesh=(
                    parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
                ),
                ep_dense_params_mesh_ndim=(
                    parallel_dims.dense_params_mesh_ndim
                    if parallel_dims.ep_enabled
                    else None
                ),
            )

            if max_norm is None or max_norm == float("inf"):
                return global_norm

            # Clear previous per-layer norms
            self.per_layer_norms.clear()

            for layer_key, params in _iter_layer_param_groups(model_parts):
                if len(params) == 0:
                    continue

                # Determine if this is the last layer (contains "final" or "last" in name)
                is_last_layer = (
                    "final" in layer_key.lower()
                    or "last" in layer_key.lower()
                    or "_with_final" in layer_key.lower()
                )

                # Choose the appropriate max norm
                current_max_norm = (
                    max_norm_last_layer
                    if is_last_layer and max_norm_last_layer is not None
                    else max_norm
                )

                # First compute the layer norm without clipping for logging
                layer_norm = dist_utils.clip_grad_norm_(
                    params,
                    float("inf"),  # No clipping, just compute norm
                    foreach=True,
                    pp_mesh=(
                        parallel_dims.world_mesh["pp"]
                        if parallel_dims.pp_enabled
                        else None
                    ),
                    ep_dense_params_mesh_ndim=(
                        parallel_dims.dense_params_mesh_ndim
                        if parallel_dims.ep_enabled
                        else None
                    ),
                )

                # Store the layer norm for logging
                self.per_layer_norms[layer_key] = float(layer_norm.item())

                # Apply clipping with the chosen max norm
                if current_max_norm is not None and current_max_norm != float("inf"):
                    dist_utils.clip_grad_norm_(
                        params,
                        current_max_norm,
                        foreach=True,
                        pp_mesh=(
                            parallel_dims.world_mesh["pp"]
                            if parallel_dims.pp_enabled
                            else None
                        ),
                        ep_dense_params_mesh_ndim=(
                            parallel_dims.dense_params_mesh_ndim
                            if parallel_dims.ep_enabled
                            else None
                        ),
                    )
            return global_norm

        # Global scope: clip all params together
        # Compute per-layer norms before clipping for debugging/inspection if needed
        global_norm = dist_utils.clip_grad_norm_(
            _all_parameters_from_model_parts(model_parts),
            max_norm,
            foreach=True,
            pp_mesh=(
                parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
            ),
            ep_dense_params_mesh_ndim=(
                parallel_dims.dense_params_mesh_ndim
                if parallel_dims.ep_enabled
                else None
            ),
        )
        return global_norm

    def get_per_layer_norms(self) -> Dict[str, float]:
        """Return the per-layer gradient norms from the last step."""
        return self.per_layer_norms.copy()

    def compute_per_layer_norms(
        self,
        model_parts: List[torch.nn.Module],
        parallel_dims: ParallelDims,
    ) -> Dict[str, float]:
        """
        Compute and return per-layer gradient norms without applying clipping.
        Intended for use in global vanilla clipping, before and/or after clipping.
        """
        norms: Dict[str, float] = {}
        for layer_key, params in _iter_layer_param_groups(model_parts):
            if len(params) == 0:
                continue
            layer_norm = dist_utils.clip_grad_norm_(
                params,
                float("inf"),
                foreach=True,
                pp_mesh=(
                    parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
                ),
                ep_dense_params_mesh_ndim=(
                    parallel_dims.dense_params_mesh_ndim
                    if parallel_dims.ep_enabled
                    else None
                ),
            )
            norms[layer_key] = float(layer_norm.item())
        return norms


@dataclass
class _EMAStats:
    initialized: bool = False
    mean: float | None = None
    var: float | None = None
    buffer: list[float] | None = None


class ZClipGradClipper(BaseGradClipper):
    def __init__(self, cfg: GradientClippingConfig):
        self.alpha = cfg.alpha
        self.z_thresh = cfg.z_thresh
        self.eps = cfg.eps
        self.warmup_steps = cfg.warmup_steps
        self.mode = cfg.mode
        self.clip_option = cfg.clip_option if cfg.mode == "zscore" else None
        self.clip_factor = cfg.clip_factor
        self.skip_update_on_spike = cfg.skip_update_on_spike
        self.scope = cfg.scope

        self.stats = _EMAStats(initialized=False, mean=None, var=None, buffer=[])
        self.layer_stats: Dict[str, _EMAStats] = {}
        self.per_layer_norms: Dict[str, float] = {}

    def _compute_total_grad_norm(
        self, params: Iterable[torch.nn.Parameter], parallel_dims: ParallelDims
    ) -> torch.Tensor:
        # Use the distributed-aware norm computation; max_norm=inf ensures no clipping happens.
        return dist_utils.clip_grad_norm_(
            list(params),
            float("inf"),
            foreach=True,
            pp_mesh=(
                parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None
            ),
            ep_dense_params_mesh_ndim=(
                parallel_dims.dense_params_mesh_ndim
                if parallel_dims.ep_enabled
                else None
            ),
        )

    def _initialize_ema(self):
        assert self.stats.buffer is not None and len(self.stats.buffer) > 0
        mean = sum(self.stats.buffer) / len(self.stats.buffer)
        var = sum((x - mean) ** 2 for x in self.stats.buffer) / len(self.stats.buffer)
        self.stats.mean = float(mean)
        self.stats.var = float(var)
        self.stats.initialized = True
        self.stats.buffer = []

    def _update_ema(self, effective_norm: float):
        assert self.stats.mean is not None and self.stats.var is not None
        mean = self.alpha * self.stats.mean + (1.0 - self.alpha) * effective_norm
        var = (
            self.alpha * self.stats.var
            + (1.0 - self.alpha) * (effective_norm - mean) ** 2
        )
        self.stats.mean = float(mean)
        self.stats.var = float(var)

    def _compute_clip_val(self, total_norm: float) -> float | None:
        assert self.stats.mean is not None and self.stats.var is not None
        std = self.stats.var**0.5
        if self.mode == "percentile":
            threshold = self.stats.mean + self.z_thresh * std
            if total_norm > threshold:
                return float(threshold)
            return None
        # zscore mode
        z = (total_norm - self.stats.mean) / (std + self.eps)
        if z > self.z_thresh:
            if self.clip_option == "adaptive_scaling":
                eta = z / self.z_thresh
                threshold = self.stats.mean + (self.z_thresh * std) / max(eta, 1e-12)
                threshold *= self.clip_factor
                return float(threshold)
            elif self.clip_option == "mean":
                return float(self.stats.mean)
        return None

    def _get_layer_stats(self, key: str) -> _EMAStats:
        if key not in self.layer_stats:
            self.layer_stats[key] = _EMAStats(
                initialized=False, mean=None, var=None, buffer=[]
            )
        return self.layer_stats[key]

    def _initialize_layer_ema(self, st: _EMAStats):
        assert st.buffer is not None and len(st.buffer) > 0
        mean = sum(st.buffer) / len(st.buffer)
        var = sum((x - mean) ** 2 for x in st.buffer) / len(st.buffer)
        st.mean = float(mean)
        st.var = float(var)
        st.initialized = True
        st.buffer = []

    def _update_layer_ema(self, st: _EMAStats, effective_norm: float):
        assert st.mean is not None and st.var is not None
        mean = self.alpha * st.mean + (1.0 - self.alpha) * effective_norm
        var = self.alpha * st.var + (1.0 - self.alpha) * (effective_norm - mean) ** 2
        st.mean = float(mean)
        st.var = float(var)

    def _compute_layer_clip_val(self, st: _EMAStats, total_norm: float) -> float | None:
        assert st.mean is not None and st.var is not None
        std = st.var**0.5
        if self.mode == "percentile":
            threshold = st.mean + self.z_thresh * std
            if total_norm > threshold:
                return float(threshold)
            return None
        z = (total_norm - st.mean) / (std + self.eps)
        if z > self.z_thresh:
            if self.clip_option == "adaptive_scaling":
                eta = z / self.z_thresh
                threshold = st.mean + (self.z_thresh * std) / max(eta, 1e-12)
                threshold *= self.clip_factor
                return float(threshold)
            elif self.clip_option == "mean":
                return float(st.mean)
        return None

    @staticmethod
    def _apply_in_place_clipping(
        params: Iterable[torch.nn.Parameter], total_norm: float, max_allowed_norm: float
    ) -> None:
        if max_allowed_norm is None or total_norm <= max_allowed_norm:
            return
        coef = max_allowed_norm / (total_norm + 1e-6)
        for p in params:
            if p.grad is not None:
                p.grad.mul_(coef)

    def step(
        self,
        model_parts: List[torch.nn.Module],
        parallel_dims: ParallelDims,
        max_norm: float,
        max_norm_last_layer: float | None = None,
    ) -> torch.Tensor:
        params = _all_parameters_from_model_parts(model_parts)
        total_norm_tensor = self._compute_total_grad_norm(params, parallel_dims)
        total_norm = float(total_norm_tensor.item())

        if self.scope == "per_layer":
            # Clear previous per-layer norms
            self.per_layer_norms.clear()

            # Per-layer ZClip: compute and apply independently for each layer present on this rank.
            for key, layer_params in _iter_layer_param_groups(model_parts):
                if len(layer_params) == 0:
                    continue
                # Compute this layer's norm (distributed-aware) without clipping by using inf
                layer_norm_tensor = dist_utils.clip_grad_norm_(
                    layer_params,
                    float("inf"),
                    foreach=True,
                    pp_mesh=(
                        parallel_dims.world_mesh["pp"]
                        if parallel_dims.pp_enabled
                        else None
                    ),
                    ep_dense_params_mesh_ndim=(
                        parallel_dims.dense_params_mesh_ndim
                        if parallel_dims.ep_enabled
                        else None
                    ),
                )
                layer_norm = float(layer_norm_tensor.item())

                # Store the layer norm for logging
                self.per_layer_norms[key] = layer_norm

                st = self._get_layer_stats(key)
                if not st.initialized:
                    if st.buffer is None:
                        st.buffer = []
                    st.buffer.append(layer_norm)
                    if len(st.buffer) >= self.warmup_steps:
                        self._initialize_layer_ema(st)
                    # During warmup, optionally vanilla max_norm
                    if max_norm is not None and max_norm != float("inf"):
                        self._apply_in_place_clipping(
                            layer_params, layer_norm, max_norm
                        )
                    continue

                clip_val = self._compute_layer_clip_val(st, layer_norm)
                if clip_val is None:
                    effective_max = max_norm
                elif max_norm is None or max_norm == float("inf"):
                    effective_max = clip_val
                else:
                    effective_max = min(clip_val, max_norm)

                if effective_max is not None and effective_max != float("inf"):
                    self._apply_in_place_clipping(
                        layer_params, layer_norm, effective_max
                    )

                if not (clip_val is not None and self.skip_update_on_spike):
                    self._update_layer_ema(
                        st, clip_val if clip_val is not None else layer_norm
                    )

            # Global norm is returned for logging consistency
            return total_norm_tensor

        # Global ZClip
        # Warmup (global scope only): collect samples, optionally still apply vanilla max_norm
        # For per-layer scope, skip global warmup to avoid mixing with per-layer warmup.
        elif self.scope == "global":
            # Clear per-layer norms since we're in global scope
            self.per_layer_norms.clear()
            if not self.stats.initialized:
                if self.stats.buffer is None:
                    self.stats.buffer = []
                self.stats.buffer.append(total_norm)
                if len(self.stats.buffer) >= self.warmup_steps:
                    self._initialize_ema()
                # During warmup we only apply vanilla max_norm, if provided
                if max_norm is not None and max_norm != float("inf"):
                    self._apply_in_place_clipping(params, total_norm, max_norm)
                return total_norm_tensor

            clip_val = self._compute_clip_val(total_norm)
            if clip_val is None:
                effective_max = max_norm
            elif max_norm is None or max_norm == float("inf"):
                effective_max = clip_val
            else:
                effective_max = min(clip_val, max_norm)

            if effective_max is not None and effective_max != float("inf"):
                self._apply_in_place_clipping(params, total_norm, effective_max)

            if not (clip_val is not None and self.skip_update_on_spike):
                self._update_ema(clip_val if clip_val is not None else total_norm)

            return total_norm_tensor

        else:
            raise ValueError(f"Unknown gradient clipping scope: {self.scope}")

    def get_per_layer_norms(self) -> Dict[str, float]:
        """Return the per-layer gradient norms from the last step."""
        return self.per_layer_norms.copy()


def build_grad_clipper(cfg: GradientClippingConfig) -> BaseGradClipper:
    if cfg.method == "vanilla":
        return VanillaGradClipper(cfg)
    if cfg.method == "zclip":
        return ZClipGradClipper(cfg)
    raise ValueError(f"Unknown gradient clipping method: {cfg.method}")
