# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

from cosmos_predict1.utils import distributed
from cosmos_predict1.utils.callback import Callback


@torch.jit.script
def _fused_nan_to_num(params: List[torch.Tensor]):
    for param in params:
        torch.nan_to_num(param, nan=0.0, posinf=0.0, neginf=0.0, out=param)


class GradClip(Callback):
    def __init__(
        self, clip_norm=1.0, force_finite: bool = True, model_key: Optional[str] = None, fsdp_enabled: bool = False
    ):
        self.clip_norm = clip_norm
        self.force_finite = force_finite
        self.model_key = model_key
        self.fsdp_enabled = fsdp_enabled

    def on_before_optimizer_step(
        self,
        model_ddp: distributed.DistributedDataParallel,
        optimizer: torch.optim.Optimizer,
        scheduler: torch.optim.lr_scheduler.LRScheduler,
        grad_scaler: torch.amp.GradScaler,
        iteration: int = 0,
    ) -> None:
        del optimizer, scheduler
        if isinstance(model_ddp, distributed.DistributedDataParallel):
            model = model_ddp.module
        else:
            model = model_ddp

        # select sub-network if specified
        if self.model_key is not None:
            items = self.model_key.split(".")
            for item in items:
                model = getattr(model, item)

        if self.force_finite:
            params = []
            for param in model.parameters():
                if param.grad is not None:
                    params.append(param.grad)
                    # torch.nan_to_num(param.grad, nan=0, posinf=0, neginf=0, out=param.grad)
            _fused_nan_to_num(params)

        # check if FSDP is used
        # total_norm
        if isinstance(model, FSDP) and self.fsdp_enabled:
            model.clip_grad_norm_(self.clip_norm)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), self.clip_norm, foreach=True)
