# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import time, os

from .global_vars import get_one_logger, get_args

_one_logger_utils_version = "1.2.0-mlm"


def get_timestamp_in_ms():
    """Helper function to get timestamp in ms

    Returns:
        [int]: [timestamp in ms]
    """
    return round(time.time() * 1000.0)


def on_train_start(iteration, consumed_train_samples, train_samples, seq_length,
                   train_iters, save, async_save, log_throughput,
                   num_floating_point_operations_so_far):
    """Function will be called at the start of train function to prepare and track E2E metrics.

    Args:
        iteration (int): current iteration number
        consumed_train_samples (int): consumed sample numbers so far
        train_samples (int): total train sample number
        seq_length (int): sequence length
        train_iters (int): target iteration
        save (str): output directory to save checkpoints to
        async_save (bool): apply async checkpointing save
        log_throughput (bool): log throughput or not
        num_floating_point_operations_so_far (int): flops so far
    """
    args = get_args()
    one_logger = get_one_logger()

    if one_logger:
        with one_logger.get_context_manager():
            # Get app train loop start time
            app_train_loop_start_time = get_timestamp_in_ms()
            one_logger.store_set('app_train_loop_start_time', app_train_loop_start_time)

            # Set up initial values in store
            one_logger.store_set('iteration_start', iteration)
            one_logger.store_set('train_samples_start', consumed_train_samples)

            # Init accumulative metric values in one-logger store
            one_logger.store_set('train_iterations_time_msecs_total', 0)
            one_logger.store_set('tracked_train_iterations', iteration)
            one_logger.store_set('validation_iterations_time_msecs_total', 0)
            one_logger.store_set('tracked_validation_iterations', 0)
            one_logger.store_set('save_checkpoint_count', 0)
            one_logger.store_set('save_checkpoint_sync_time_total', 0.0)

            # Derive train_samples from iters for iteration-based training
            train_samples_target = train_samples or train_iters * args.global_batch_size

            train_tokens_target = seq_length * train_samples_target
            e2e_metrics = {
                'train_samples_start': consumed_train_samples,
                'train_iterations_start': iteration,
                'train_samples_target': train_samples_target,
                'train_iterations_target': train_iters,
                'train_tokens_target': train_tokens_target,
                'app_train_loop_start_time': app_train_loop_start_time,
                'is_save_checkpoint_enabled': save is not None,
                'save_checkpoint_strategy': 'async' if async_save else 'sync',
            }
            if log_throughput:
                e2e_metrics.update({
                    'train_tflop_start': float(num_floating_point_operations_so_far) / (10**12),
                })
            one_logger.log_metrics(e2e_metrics)


def _produce_e2e_metrics(log_throughput=False, throughput=None):
    """ Generate APP metrics for E2E tracking
    NOTE: always call this function after barrier call

    Args:
        log_throughput (bool, optional): if log throughput or not. Defaults to False.
        throughput (int, optional): throughput value to log. Defaults to None.

    Returns:
        dict: all E2E metrics
    """
    one_logger = get_one_logger()
    
    if one_logger:
        with one_logger.get_context_manager():
            # Unpack and assign local vars
            base_metrics = one_logger.store_get('get_e2e_base_metrics')()
            (iteration, train_duration, eval_duration, eval_iterations,
             total_flops_since_current_train_start, num_floating_point_operations_so_far,
             consumed_train_samples, world_size, seq_length) = base_metrics.values()

            iteration_start = one_logger.store_get('iteration_start')
            train_samples_start = one_logger.store_get('train_samples_start')

            train_samples = consumed_train_samples - train_samples_start
            train_iterations = iteration - iteration_start
            train_iterations_time_msecs_avg = (train_duration * 1000.0) / train_iterations
            if eval_iterations:
                validation_iterations_time_msecs_avg = (eval_duration * 1000.0) / eval_iterations
            else:
                validation_iterations_time_msecs_avg = None

            if not one_logger.store_has_key('first_logged_train_iterations_finish_time'):
                one_logger.store_set(
                    'first_logged_train_iterations_finish_time',
                    get_timestamp_in_ms()
                )

            train_tokens = train_samples * seq_length

            e2e_metrics = {
                'first_logged_train_iterations_finish_time': \
                    one_logger.store_get('first_logged_train_iterations_finish_time'),
                'train_iterations_end': iteration,
                'train_samples_end': consumed_train_samples,
                'train_iterations': train_iterations,
                'train_samples': train_samples,
                'train_iterations_time_msecs_avg': train_iterations_time_msecs_avg,
                'validation_iterations_time_total': eval_duration,
                'validation_iterations_time_msecs_avg': validation_iterations_time_msecs_avg,
                'train_tokens': train_tokens,
                'train_iterations_time_total': train_duration,
                'last_logged_train_iterations_finish_time': get_timestamp_in_ms(),
            }

            if log_throughput:
                if train_duration:
                    train_throughput_per_gpu = total_flops_since_current_train_start / (train_duration * 10**12 * world_size)
                else:
                    train_throughput_per_gpu = 0.0

                train_throughput_per_gpu_max = one_logger.store_get('train_throughput_per_gpu_max')
                if throughput:
                    train_throughput_per_gpu_max = max(throughput, train_throughput_per_gpu_max)
                    one_logger.store_set('train_throughput_per_gpu_max', train_throughput_per_gpu_max)

                throughput_metrics = {
                    'train_tflop_end': float(num_floating_point_operations_so_far) / (10**12),
                    'train_tflop': float(total_flops_since_current_train_start) / (10**12),
                    'train_throughput_per_gpu': train_throughput_per_gpu,
                    'train_throughput_per_gpu_max': train_throughput_per_gpu_max,
                }
                e2e_metrics.update(throughput_metrics)

            # Tracking minimal train/validation iteration duration metrics
            # Minimal train iteration duration
            current_train_iterations_time_msecs_total = train_duration * 1000.0
            current_train_iteration = iteration
            prev_train_iterations_time_msecs_total = one_logger.store_get('train_iterations_time_msecs_total')
            tracked_train_iterations = one_logger.store_get('tracked_train_iterations')

            if current_train_iteration > tracked_train_iterations:
                train_iterations_time_msecs = (
                    (current_train_iterations_time_msecs_total - prev_train_iterations_time_msecs_total) /
                    (current_train_iteration - tracked_train_iterations)
                )

                if not one_logger.store_has_key('train_iterations_time_msecs_min'):
                    train_iterations_time_msecs_min = train_iterations_time_msecs
                else:
                    train_iterations_time_msecs_min = min(
                        one_logger.store_get('train_iterations_time_msecs_min'),
                        train_iterations_time_msecs
                    )
                one_logger.store_set('train_iterations_time_msecs_min', train_iterations_time_msecs_min)
                one_logger.store_set('train_iterations_time_msecs_total', current_train_iterations_time_msecs_total)
                one_logger.store_set('tracked_train_iterations', current_train_iteration)

                e2e_metrics.update({
                    'train_iterations_time_msecs_min': train_iterations_time_msecs_min
                })

            # Minimal validation iteration duration
            current_validation_iterations_time_msecs_total = eval_duration * 1000.0
            current_validation_iteration = eval_iterations
            prev_validation_iterations_time_msecs_total = \
                one_logger.store_get('validation_iterations_time_msecs_total')
            tracked_validation_iterations = one_logger.store_get('tracked_validation_iterations')

            if current_validation_iteration > tracked_validation_iterations:
                validation_iterations_time_msecs = (
                    (current_validation_iterations_time_msecs_total - prev_validation_iterations_time_msecs_total) /
                    (current_validation_iteration - tracked_validation_iterations)
                )

                # Cache minimal validation iteration duration
                if not one_logger.store_has_key('validation_iterations_time_msecs_min'):
                    validation_iterations_time_msecs_min = validation_iterations_time_msecs
                else:
                    validation_iterations_time_msecs_min = min(
                        one_logger.store_get('validation_iterations_time_msecs_min'),
                        validation_iterations_time_msecs
                    )
                one_logger.store_set('validation_iterations_time_msecs_min', validation_iterations_time_msecs_min)
                one_logger.store_set('validation_iterations_time_msecs_total', current_validation_iterations_time_msecs_total)
                one_logger.store_set('tracked_validation_iterations', current_validation_iteration)

                e2e_metrics.update({
                    'validation_iterations_time_msecs_min': validation_iterations_time_msecs_min
                })
            return e2e_metrics


def track_e2e_metrics(log_throughput=False, throughput=None):
    """Track E2E application metrics with one-logger

    NOTE: the function should be called after barrier call.

    Args:
        log_throughput (bool, optional): if log throughput or not. Defaults to False.
        throughput (int, optional): throughput value to log. Defaults to None.
    """
    one_logger = get_one_logger()

    if one_logger:
        with one_logger.get_context_manager():
            e2e_metrics = _produce_e2e_metrics(log_throughput, throughput)
            one_logger.log_metrics(e2e_metrics)


def on_save_checkpoint_start(async_save):
    """Function to be called before save-checkpoint start to generate productive metrics to log after ckpt succeeds.

    Args:
        async_save (bool): apply async checkpointing save

    Returns:
        dict: productive metrics to be stored to DB after ckpt succeeds
    """
    one_logger = get_one_logger()
    
    if one_logger:
        with one_logger.get_context_manager():
            # Unpack and assign local vars
            base_metrics = one_logger.store_get('get_e2e_base_metrics')()
            (iteration, train_duration, eval_duration, eval_iterations,
             total_flops_since_current_train_start, num_floating_point_operations_so_far,
             consumed_train_samples, world_size, seq_length) = base_metrics.values()

            save_checkpoint_count = one_logger.store_get('save_checkpoint_count') + 1
            one_logger.store_set('save_checkpoint_count', save_checkpoint_count)
            one_logger.log_metrics({
                'train_iterations_save_checkpoint_end': iteration,
                'save_checkpoint_count': save_checkpoint_count,
            })
            productive_metrics = {
                'train_tflop_productive_end': float(num_floating_point_operations_so_far) / (10**12),
                'train_iterations_productive_end': iteration,
                'train_samples_productive_end': consumed_train_samples,
                'train_iterations_time_total_productive': train_duration,
                'validation_iterations_time_total_productive': eval_duration,
            }
            if async_save:
                productive_metrics.update({
                    'save_checkpoint_async_count': save_checkpoint_count,
                })
            return productive_metrics

            
def on_pretrain_start():
    """ Function to be called at the start of pretrain function to track E2E meta data
    """
    args = get_args()
    one_logger = get_one_logger()

    if one_logger:
        with one_logger.get_context_manager():
            job_name = os.environ.get('SLURM_JOB_NAME', None)
            app_tag_run_name =  job_name if not args.app_tag_run_name else args.app_tag_run_name
            app_tag_run_version = args.app_tag_run_version
            one_logger.store_set('app_tag_run_name', app_tag_run_name)
            one_logger.store_set('app_tag_run_version', app_tag_run_version)
            one_logger.store_set('train_throughput_per_gpu_max', 0.0)

            one_logger.log_metrics({
                'train_iterations_warmup': 5,
                'data_parallel_size' : args.data_parallel_size,
                'context_parallel_size': args.context_parallel_size,
                'global_batch_size': args.global_batch_size,
                'micro_batch_size': args.micro_batch_size,
                'pipeline_model_parallel_size': args.pipeline_model_parallel_size,
                'tensor_model_parallel_size': args.tensor_model_parallel_size,
                'expert_model_parallel_size' : args.expert_model_parallel_size,
                'world_size': args.world_size,
                'model_seq_length': args.seq_length,
                'app_tag_run_name': app_tag_run_name,
                'app_tag_run_version': app_tag_run_version,
                'is_log_throughput_enabled': args.log_throughput,
                'app_run_type': 'training',
                'summary_data_schema_version': '1.0.0',
                'app_metrics_feature_tags': 'full',
                'one_logger_utils_version': _one_logger_utils_version,
            })

def track_config_flags(train_iters, skip_train, do_train, do_valid, do_test,
                           dataloader_type, retro_project_dir, retro_cyclic_train_iters):
    """Track flags about train/validation/test enablement

    Args:
        train_iters (int): target train iteration number
        skip_train (bool): flag to skip train iterations
        do_train (bool): flags to do train
        do_valid (bool): flags to do validation
        do_test (bool): flags to do test
        dataloader_type (str): dataloader type
        retro_project_dir (str): Retro project directory
        retro_cyclic_train_iters (int): iteration number for cyclic retro training
    """
    one_logger = get_one_logger()
    if one_logger:
        with one_logger.get_context_manager():
            # Update train_iters for cyclic loader
            if dataloader_type == 'cyclic' and retro_project_dir:
                assert retro_cyclic_train_iters is not None
                train_iters = retro_cyclic_train_iters
            # Track if training is enabled. Can only be done once args.do_train is assigned after dataloader is built.
            train_enabled = train_iters and (not skip_train) and do_train and train_iters > 0
            one_logger.log_metrics({
                'is_train_iterations_enabled': train_enabled,
                'is_validation_iterations_enabled': bool(do_valid),
                'is_test_iterations_enabled': bool(do_test),
            })

def on_save_checkpoint_success(productive_metrics, async_save):
    """Function to be called after checkpointing succeeds and checkpoint is persisted for storing productive metrics

    Args:
        productive_metrics (dict): productive related E2E metrics generated at the start of save checkpoint
        async_save (bool): apply async checkpointing save
    """
    one_logger = get_one_logger()

    if one_logger:
        with one_logger.get_context_manager():
            # Accumulate train_iterations_time_total_productive for current iteration
            prod_iteration = productive_metrics['train_iterations_productive_end']

            # Log start timestamp of first iteration that was successfully checkpointed
            if not one_logger.store_has_key('first_checkpoint_success'):
                app_train_loop_start_time = one_logger.store_get('app_train_loop_start_time')
                one_logger.store_set('first_checkpoint_success', True)
                one_logger.log_metrics({
                    'first_saved_train_iterations_start_time': app_train_loop_start_time
                })

            # Handle possible out-of-order async checkpoint callbacks
            need_update = True
            if one_logger.store_has_key('iters_prod_max'):
                need_update = prod_iteration > one_logger.store_get('iters_prod_max')

            if need_update:
                # Update cache
                one_logger.store_set('iters_prod_max', prod_iteration)

                if async_save:
                    save_checkpoint_sync_time_total_productive = \
                        one_logger.store_pop(f'save_checkpoint_sync_time_total_productive:{prod_iteration}')
                    last_successful_save_checkpoint_sync_finish_time = \
                        one_logger.store_pop(f'save_checkpoint_sync_finish_time:{prod_iteration}')
                    # Update productive metrics and log to DB
                    productive_metrics.update({
                        'save_checkpoint_sync_time_total_productive': save_checkpoint_sync_time_total_productive,
                        'last_successful_save_checkpoint_sync_finish_time': last_successful_save_checkpoint_sync_finish_time
                    })
                one_logger.log_metrics(productive_metrics)


def on_save_checkpoint_end(save_checkpoint_duration, current_iteration, async_save):
    """Function to be called after checkpointing ends
    
    Args:
        save_checkpoint_duration (float): duration of current save checkpoint process
        current_iteration (int): current train iteration step number
        async_save (bool): apply async checkpointing save
    """
    one_logger = get_one_logger()
    if one_logger:
        with one_logger.get_context_manager():
            save_checkpoint_sync_finish_time = get_timestamp_in_ms()

            # Track finish timestamp of the sync part of first successful save checkpoint
            if (one_logger.store_has_key('first_checkpoint_success') 
                    and not one_logger.store_has_key('first_successful_checkpoint_end')):
                one_logger.store_set('first_successful_checkpoint_end', True)
                one_logger.log_metrics({
                    'first_successful_save_checkpoint_sync_finish_time': save_checkpoint_sync_finish_time
                })

            save_checkpoint_sync_count = one_logger.store_get('save_checkpoint_count')

            # accumulate total sync checkpointing duration
            save_checkpoint_sync_time_total = \
                one_logger.store_get('save_checkpoint_sync_time_total') + save_checkpoint_duration
            one_logger.store_set('save_checkpoint_sync_time_total', save_checkpoint_sync_time_total)

            e2e_metrics = {}
            if async_save:
                # Cache total sync checkpointing duration
                one_logger.store_set(
                    f'save_checkpoint_sync_time_total_productive:{current_iteration}',
                    save_checkpoint_sync_time_total
                )
                # Cache finish time for current iteration
                one_logger.store_set(f'save_checkpoint_sync_finish_time:{current_iteration}',
                                     save_checkpoint_sync_finish_time)
            else:
                e2e_metrics.update({
                    # Track productive total time directly for sync ckpt
                    'save_checkpoint_sync_time_total_productive': save_checkpoint_sync_time_total,
                    'last_successful_save_checkpoint_sync_finish_time': save_checkpoint_sync_finish_time,
                })

            # Tracking min & max value sync checkpointing duration
            # For the first comparison
            if not one_logger.store_has_key('save_checkpoint_sync_time_max'):
                one_logger.store_set('save_checkpoint_sync_time_max', save_checkpoint_duration)
            if not one_logger.store_has_key('save_checkpoint_sync_time_min'):
                one_logger.store_set('save_checkpoint_sync_time_min', save_checkpoint_duration)

            save_checkpoint_sync_time_max = max(
                one_logger.store_get('save_checkpoint_sync_time_max'),
                save_checkpoint_duration
            )
            save_checkpoint_sync_time_min = min(
                one_logger.store_get('save_checkpoint_sync_time_min'),
                save_checkpoint_duration
            )
            one_logger.store_set('save_checkpoint_sync_time_max', save_checkpoint_sync_time_max)
            one_logger.store_set('save_checkpoint_sync_time_min', save_checkpoint_sync_time_min)
            e2e_metrics.update({
                'save_checkpoint_sync_count': save_checkpoint_sync_count,
                'save_checkpoint_sync_time_max': save_checkpoint_sync_time_max,
                'save_checkpoint_sync_time_min': save_checkpoint_sync_time_min,
                'save_checkpoint_sync_time_total': save_checkpoint_sync_time_total,
            })
            one_logger.log_metrics(e2e_metrics)


def track_app_tag(batch_size, world_size, seq_length):
    """Track app_tag and app_tag ID

    Args:
        batch_size (int): current batch size
        world_size (int): the number of processes of current job
        seq_length (int): current sequence length
    """
    # Track app tag & app tag ID
    one_logger = get_one_logger()
    if one_logger:
        with one_logger.get_context_manager():
            app_tag_run_name = one_logger.store_get('app_tag_run_name')
            app_tag_run_version = one_logger.store_get('app_tag_run_version')
            current_app_tag = (f'{app_tag_run_name}_{app_tag_run_version}_{batch_size}'
                            f'_{world_size}_{seq_length}')
            one_logger.log_app_tag(current_app_tag)


def finish():
    """Flush E2E metrics to remote server
    """
    one_logger = get_one_logger()
    if one_logger:
        with one_logger.get_context_manager():
            one_logger.finish()
