# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

import os
import socket
from datetime import timedelta

try:
    import nvidia_resiliency_ext.inprocess as inprocess
except ImportError:
    inprocess = None

import warnings

import torch

from megatron.core import rerun_state_machine
from megatron.training import get_args
from megatron.training.async_utils import (
    reset_persistent_async_worker,
)

from . import arguments


def destroy_state():
    from . import training
    training.destroy_global_state()
    rerun_state_machine.destroy_rerun_state_machine()

def inprocess_restart(train, args):
    if inprocess is None:
        warnings.warn('In-process restart is not available')
        return train

    if 'TORCH_CPP_LOG_LEVEL' not in os.environ or os.environ['TORCH_CPP_LOG_LEVEL'] not in (
        'error',
        'fatal',
    ):
        warnings.warn(
            'Set TORCH_CPP_LOG_LEVEL=error to suppress c10d waitForInput timeout warning messages'
        )

    # Layers represents a configuration for a layer of branches at a certain
    # depth in a topology tree constructed by inprocess.rank_assignment.Tree.
    # First layer contains all ranks and it's the root of the topology tree,
    # the second optional layer groups ranks by nodes.
    layers = [
        inprocess.rank_assignment.Layer(
            min_ranks=args.inprocess_active_world_size,
            max_ranks=args.inprocess_active_world_size,
            flag=inprocess.rank_assignment.LayerFlag.RESERVE,
        )
    ]
    if args.inprocess_granularity == 'node':
        device_count = torch.cuda.device_count()

        layers.append(
            inprocess.rank_assignment.Layer(
                min_ranks=device_count,
                max_ranks=device_count,
                key_or_fn=lambda _: socket.gethostname(),
                flag=inprocess.rank_assignment.LayerFlag.RESERVE,
            )
        )

    finalize = [
        inprocess.finalize.ThreadedFinalize(timeout=timedelta(seconds=10), fn=destroy_state)
    ]

    if args.inprocess_empty_cuda_cache:
        finalize.append(
            inprocess.finalize.ThreadedFinalize(
                timeout=timedelta(seconds=10), fn=torch.cuda.empty_cache
            )
        )

    initialize = inprocess.Compose(
        inprocess.initialize.RetryController(min_world_size=args.inprocess_active_world_size),
        inprocess.nested_restarter.NestedRestarterHandlingCompleted(),
    )

    class AbortCheckpoint(inprocess.abort.Abort):
        def __call__(
            self, state: inprocess.state.FrozenState
        ) -> inprocess.state.FrozenState:
            reset_persistent_async_worker()
            return state

    abort = inprocess.Compose(
        inprocess.abort.AbortTransformerEngine(),
        inprocess.abort.AbortTorchDistributed(),
        AbortCheckpoint(),
        inprocess.nested_restarter.NestedRestarterHandlingStarting(),
    )
    completion = inprocess.nested_restarter.NestedRestarterFinalized()
    terminate = inprocess.nested_restarter.NestedRestarterAborted()

    train = inprocess.Wrapper(
        store_kwargs={
            'timeout': timedelta(seconds=300),
            'port': int(os.environ['MASTER_PORT']) + 2,
        },
        initialize=initialize,
        abort=abort,
        completion=completion,
        terminate=terminate,
        health_check=inprocess.health_check.CudaHealthCheck(timeout=timedelta(seconds=10)),
        rank_assignment=inprocess.rank_assignment.Tree(layers=layers),
        finalize=inprocess.Compose(*finalize),
        heartbeat_interval=timedelta(seconds=args.inprocess_heartbeat_interval),
        heartbeat_timeout=timedelta(seconds=args.inprocess_heartbeat_timeout),
        barrier_timeout=timedelta(seconds=args.inprocess_barrier_timeout),
        completion_timeout=timedelta(seconds=args.inprocess_completion_timeout),
        monitor_process_interval=timedelta(seconds=args.inprocess_monitor_process_interval),
        monitor_thread_interval=timedelta(seconds=args.inprocess_monitor_thread_interval),
        last_call_wait=timedelta(seconds=args.inprocess_last_call_wait),
        soft_timeout=timedelta(seconds=args.inprocess_soft_timeout),
        hard_timeout=timedelta(seconds=args.inprocess_hard_timeout),
        termination_grace_time=timedelta(seconds=args.inprocess_termination_grace_time),
        enabled=True,
    )(train)

    return train


def maybe_wrap_for_inprocess_restart(pretrain):

    args = arguments.parse_args(ignore_unknown_args=True)

    if args.inprocess_restart:
        pretrain = inprocess_restart(pretrain, args)

        store = torch.distributed.TCPStore(
            host_name=os.environ['MASTER_ADDR'],
            port=int(os.environ['MASTER_PORT'])+1,
            world_size=int(os.getenv('WORLD_SIZE', '1')),
            is_master=(int(os.getenv('RANK', '0')) == 0),
            timeout=timedelta(seconds=300),
            wait_for_workers=True,
            use_libuv=True,
        )
    else:
        store = None

    return pretrain, store


def maybe_force_nccl_backend_init(device_id):

    args = get_args()

    # Inprocess uses destroy_process_group to terminate NCCL backend, which
    # does not terminate NCCL kernels if NCCL backend wasn't fully initialized
    # before additional distributed subgroups are created. This forces initialization
    # of the NCCL backend.
    if args.inprocess_restart:
        tensor = torch.ones(128, device=device_id)
        torch.distributed.all_reduce(tensor)
        torch.cuda.synchronize()
