                                                      
                                                                 

import os
from packaging.version import Version

import time
                                             
_TRAIN_START_TIME = time.time()
import torch

from megatron.core.rerun_state_machine import (
    get_rerun_state_machine,
)
from megatron.core.pipeline_parallel import get_forward_backward_func
from megatron.core.num_microbatches_calculator import (
    get_num_microbatches,
)
import megatron.core.parallel_state as mpu
from megatron.core import package_info
from megatron.training.utils import (
    logical_and_across_model_parallel_group,
    reduce_max_stat_across_model_parallel_group,
    unwrap_model,
)
from megatron.training.global_vars import (
    get_args,
    get_timers,
)
from megatron.training import print_rank_0
try:
    from megatron.training.training import (
        cuda_graph_set_manual_hooks, 
        cuda_graph_capture,
        has_nvidia_modelopt,
    )
except ImportError:
    cuda_graph_set_manual_hooks = None
    cuda_graph_capture = None
    has_nvidia_modelopt = False

from mpatch.paddingiterator import pad_to_longest


def train_step_lt_0_13(forward_step_func, data_iterator,
               model, optimizer, opt_param_scheduler, config):
    """Single training step."""
    args = get_args()
    timers = get_timers()

    if Version(package_info.__version__) >= Version("0.12.1"):
                                                                                          
        if args.curr_iteration == args.iteration and args.external_cuda_graph:
            cuda_graph_capture(model, config, args)

                               
            for model_chunk in model:
                model_chunk.zero_grad_buffer()
            optimizer.zero_grad()

                                                      
            gc.collect()
            torch.cuda.empty_cache()

    rerun_state_machine = get_rerun_state_machine()
    while rerun_state_machine.should_run_forward_backward(data_iterator):
                           
        for model_chunk in model:
            model_chunk.zero_grad_buffer()
        optimizer.zero_grad()

                      
        sub_iterorator = data_iterator
        max_len_in_gb = args.seq_length
        decoder_seq_length = args.decoder_seq_length
        if args.px_inputs_pad_to_longest:
            sub_iterorator, max_len_in_gb, decoder_seq_length = pad_to_longest(
                args, model, config, data_iterator)
            print_rank_0(f"Train with longest pad and real seq_len in global_batch is {max_len_in_gb}")

                       
        forward_backward_func = get_forward_backward_func()
        losses_reduced = forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=sub_iterorator,
            model=model,
            num_microbatches=get_num_microbatches(),
            seq_length=max_len_in_gb,
            micro_batch_size=args.micro_batch_size,
            decoder_seq_length=decoder_seq_length,
            forward_only=False)
                    

    should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
    if should_exit:
        return {}, True, should_checkpoint, should_exit, exit_code, None, None

                          
    if args.empty_unused_memory_level >= 1:
        torch.cuda.empty_cache()

                       
    if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

                        

    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
    timers('optimizer').stop()

                                                                                         
                                       
    update_successful = logical_and_across_model_parallel_group(update_successful)
                                                                                     
                                       
    grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
    if args.log_num_zeros_in_grad:
        num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)

                      
    if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.update_momentum(args.curr_iteration)

                           
    if update_successful:
        increment = get_num_microbatches() * \
                    args.micro_batch_size * \
                    args.data_parallel_size
        opt_param_scheduler.step(increment=increment)
        skipped_iter = 0
    else:
        skipped_iter = 1

                          
    if args.empty_unused_memory_level >= 2:
        torch.cuda.empty_cache()

    if Version(package_info.__version__) >= Version("0.12.1"):
                                                            
        if args.curr_iteration == args.iteration and args.external_cuda_graph:
            if args.use_distributed_optimizer and args.overlap_param_gather:
                cuda_graph_set_manual_hooks(model)

    if mpu.is_pipeline_last_stage(ignore_virtual=True):
                                           
        loss_reduced = {}
        for key in losses_reduced[0].keys():
            numerator = 0
            denominator = 0
            for x in losses_reduced:
                val = x[key]
                                                                                
                                                                          
                if isinstance(val, tuple) or isinstance(val, list):
                    numerator += val[0]
                    denominator += val[1]
                else:
                                                                                  
                                                  
                    numerator += val
                    denominator += 1
            loss_reduced[key] = numerator / denominator
        return loss_reduced, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
    return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad


def train_step(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config):
    """Single training step."""
    args = get_args()
    timers = get_timers()

                                                                                      
    if args.curr_iteration == args.iteration and args.external_cuda_graph:
        cuda_graph_capture(model, config, args)

                           
        for model_chunk in model:
            model_chunk.zero_grad_buffer()
        optimizer.zero_grad()

                                                  
        gc.collect()
        torch.cuda.empty_cache()

    rerun_state_machine = get_rerun_state_machine()
    while rerun_state_machine.should_run_forward_backward(data_iterator):
                           
        for model_chunk in model:
            model_chunk.zero_grad_buffer()
        optimizer.zero_grad()

        if has_nvidia_modelopt:
                                                                                           
            adjust_tensor_shapes_fn = get_tensor_shapes_adjust_fn_for_distillation(
                model, args.seq_length, args.micro_batch_size, args.decoder_seq_length
            )
        else:
            adjust_tensor_shapes_fn = None

                      
        sub_iterorator = data_iterator
        max_len_in_gb = args.seq_length
        decoder_seq_length = args.decoder_seq_length
        if args.px_inputs_pad_to_longest:
            sub_iterorator, max_len_in_gb, decoder_seq_length = pad_to_longest(
                args, model, config, data_iterator)
            print_rank_0(f"Train with longest pad and real seq_len in global_batch is {max_len_in_gb}")

                       
        forward_backward_func = get_forward_backward_func()
        losses_reduced = forward_backward_func(
            forward_step_func=forward_step_func,
            data_iterator=sub_iterorator,
            model=model,
            num_microbatches=get_num_microbatches(),
            seq_length=max_len_in_gb,
            micro_batch_size=args.micro_batch_size,
            decoder_seq_length=args.decoder_seq_length,
            forward_only=False,
            adjust_tensor_shapes_fn=adjust_tensor_shapes_fn,
        )
                    

    should_checkpoint, should_exit, exit_code = rerun_state_machine.should_checkpoint_and_exit()
    if should_exit:
        return {}, True, should_checkpoint, should_exit, exit_code, None, None

                          
    if args.empty_unused_memory_level >= 1:
        torch.cuda.empty_cache()

                       
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

                        

    timers('optimizer', log_level=1).start(barrier=args.barrier_with_L1_time)
    update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
    timers('optimizer').stop()

                                                                                         
                                       
    update_successful = logical_and_across_model_parallel_group(update_successful)
                                                                                     
                                       
    grad_norm = reduce_max_stat_across_model_parallel_group(grad_norm)
    if args.log_num_zeros_in_grad:
        num_zeros_in_grad = reduce_max_stat_across_model_parallel_group(num_zeros_in_grad)

                      
    if args.vision_pretraining and args.vision_pretraining_type == "dino":
        unwrapped_model = unwrap_model(model[0])
        unwrapped_model.update_momentum(args.curr_iteration)

                           
    if update_successful:
        increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size
        opt_param_scheduler.step(increment=increment)
        skipped_iter = 0
    else:
        skipped_iter = 1

                          
    if args.empty_unused_memory_level >= 2:
        torch.cuda.empty_cache()

                                                        
    if args.curr_iteration == args.iteration and args.external_cuda_graph:
        if args.use_distributed_optimizer and args.overlap_param_gather:
            cuda_graph_set_manual_hooks(model)

    if mpu.is_pipeline_last_stage(ignore_virtual=True):
                                           
        loss_reduced = {}

        for key in losses_reduced[0].keys():
            val = [x[key].view(-1) for x in losses_reduced]
            if val[0].numel() == 2:
                if args.sft:
                                                                                         
                    val = torch.vstack(val)
                    val = val[:, 0] / val[:, 1]
                    val = val.mean()
                    torch.distributed.all_reduce(
                        val,
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    val /= torch.distributed.get_world_size(
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    loss_reduced[key] = val
                else:
                                                                                    
                                                                              
                    val = torch.vstack(val).sum(dim=0)
                    torch.distributed.all_reduce(
                        val,
                        group=mpu.get_data_parallel_group(with_context_parallel=True)
                    )
                    loss_reduced[key] = val[0] / val[1]
            elif val[0].numel() == 1:
                                                                             
                val = torch.cat(val).mean()
                loss_reduced[key] = val
            else:
                raise ValueError(f"Invalid value shape: {val[0].shape} for key {key}")
        return (
            loss_reduced,
            skipped_iter,
            should_checkpoint,
            should_exit,
            exit_code,
            grad_norm,
            num_zeros_in_grad,
        )
    return {}, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad
