#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""

import argparse
import logging
import math
import os
import sys
from typing import Dict, Optional, Any, List, Tuple, Callable

import numpy as np
import torch
import torch.distributed as dist
import functools
from fairseq import (
    checkpoint_utils,
    options,
    quantization_utils,
    tasks,
    utils,
)
from fairseq.data import iterators, data_utils
from fairseq.data.plasma_utils import PlasmaStore
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap, utils as distributed_utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.trainer import Trainer
from omegaconf import DictConfig, OmegaConf

from fairseq.utils import print_r0

logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level=os.environ.get("LOGLEVEL", "WARN").upper(),
    stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.train")


def main(cfg: FairseqConfig) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)

    utils.import_user_module(cfg.common)

    if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    if cfg.common.log_file is not None:
        handler = logging.FileHandler(filename=cfg.common.log_file)
        logger.addHandler(handler)

    np.random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)

    # Print nvidia smi stats
    logger.info(metrics.get_nvidia_smi_gpu_memory_stats_str())

    # Print args
    logger.info(cfg)

    if cfg.checkpoint.write_checkpoints_asynchronously:
        try:
            import iopath  # noqa: F401
        except ImportError:
            logging.exception(
                "Asynchronous checkpoint writing is specified but iopath is "
                "not installed: `pip install iopath`"
            )
            return

    # Setup task, e.g., translation, language modeling, etc.
    task = tasks.setup_task(cfg.task)

    assert cfg.criterion, "Please specify criterion to train a model"
    if utils.is_moe(cfg.model) and getattr(cfg.model, "moe_expert_count", 0) < distributed_utils.get_global_world_size():
        assert cfg.distributed_training.ddp_backend == 'fully_sharded', 'num_experts < num_gpus only supported by FSDP'

    # Build model and criterion
    if cfg.distributed_training.ddp_backend == "fully_sharded":
        #if cfg.distributed_training.use_sharded_state: assert cfg.checkpoint.no_save_optimizer_state, f'--use-sharded-state requires --no-save-optimizer-state'
        extra = {
            "is_moe": utils.is_moe(cfg.model),
            "use_sharded_state": cfg.distributed_training.use_sharded_state,
        }

        with fsdp_enable_wrap(cfg.distributed_training, **extra):
            model = fsdp_wrap(task.build_model(cfg.model))
    else:
        model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)

    def is_expert_param(p):
        return getattr(p, "expert", False) or getattr(p, "base_expert", False)

    logger.debug(model)
    logger.debug("task: {}".format(task.__class__.__name__))
    logger.debug("model: {}".format(model.__class__.__name__))
    logger.debug("criterion: {}".format(criterion.__class__.__name__))
    logger.debug(
        "num. non-expert model params: {:,} (num. trained: {:,})".format(
            sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if not is_expert_param(p)),
            sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if not is_expert_param(p) and p.requires_grad),
        )
    )
    logger.debug(
        "num. expert model params: {:,} (num. trained: {:,})".format(
            sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if is_expert_param(p)),
            sum(getattr(p, "_orig_size", p).numel() for p in model.parameters() if is_expert_param(p) and p.requires_grad),
        )
    )
    logger.debug(metrics.get_nvidia_smi_gpu_memory_stats_str())

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    # We load the valid dataset AFTER building the model
    data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
    if cfg.dataset.combine_valid_subsets:
        task.load_dataset("valid", combine=True, epoch=1)
    else:
        for valid_sub_split in cfg.dataset.valid_subset.split(","):
            task.load_dataset(valid_sub_split, combine=False, epoch=1)

    if cfg.task._name in ["multilingual_language_modeling", "translation_multi_simple_epoch"]:
        valid_subsets = task.args.valid_subset.split(",")
    else:
        valid_subsets = cfg.dataset.valid_subset.split(",")

    # (optionally) Configure quantization
    if cfg.common.quantization_config_path is not None:
        quantizer = quantization_utils.Quantizer(
            config_path=cfg.common.quantization_config_path,
            max_epoch=cfg.optimization.max_epoch,
            max_update=cfg.optimization.max_update,
        )
    else:
        quantizer = None

    # Build trainer
    if cfg.common.model_parallel_size == 1:
        trainer = Trainer(cfg, task, model, criterion, quantizer)
    else:
        trainer = MegatronTrainer(cfg, task, model, criterion)
    logger.debug(
        "training on {} devices (GPUs/TPUs)".format(
            cfg.distributed_training.distributed_world_size
        )
    )
    logger.debug(
        "max tokens per GPU = {} and batch size per GPU = {}".format(
            cfg.dataset.max_tokens,
            cfg.dataset.batch_size,
        )
    )
    logger.debug(metrics.get_nvidia_smi_gpu_memory_stats_str())

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache=task.has_sharded_data("train"),
    )
    if extra_state:
        if 'current_k' in extra_state:
            task.current_k = extra_state['current_k'] 
            task.prev_best_val_loss = extra_state['prev_best_val_loss']
            task.best_step = extra_state['best_step']
            logger.info(f'[dea] load topk from ckpt: {task.current_k}')

    max_epoch = cfg.optimization.max_epoch or math.inf
    max_update = cfg.optimization.max_update or math.inf
    lr = trainer.get_lr()
    train_meter = meters.StopwatchMeter()
    train_meter.start()

    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr})"
            )
            break

        """Train the model for one epoch and return validation losses."""
        # Initialize data iterator
        with metrics.aggregate("train"):
            itr = epoch_itr.next_epoch_itr(
                fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
                shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
            )
            update_freq = (
                cfg.optimization.update_freq[epoch_itr.epoch - 1]
                if epoch_itr.epoch <= len(cfg.optimization.update_freq)
                else cfg.optimization.update_freq[-1]
            )

            itr = iterators.GroupedIterator(itr, update_freq)
            if cfg.common.tpu:
                itr = utils.tpu_data_loader(itr)
            progress = progress_bar.progress_bar(
                itr,
                log_format=cfg.common.log_format,
                log_file=cfg.common.log_file,
                log_interval=cfg.common.log_interval,
                epoch=epoch_itr.epoch,
                tensorboard_logdir=(
                    cfg.common.tensorboard_logdir
                    if distributed_utils.is_master(cfg.distributed_training)
                    else None
                ),
                default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
                wandb_project=(
                    cfg.common.wandb_project
                    if distributed_utils.is_master(cfg.distributed_training)
                    else None
                ),
                wandb_run_name=os.environ.get(
                    "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
                ),
                azureml_logging=(
                    cfg.common.azureml_logging
                    if distributed_utils.is_master(cfg.distributed_training)
                    else False
                ),
            )
            progress.update_config(_flatten_config(cfg))

            trainer.begin_epoch(epoch_itr.epoch)
            
            logger.info("Start iterating over samples")
            
            # cnt = 0 # test rerun, validate 2 times then rerun
            while True:
                for i, samples in enumerate(progress):
                    with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
                        "train_step-%d" % i
                    ):
                        log_output = trainer.train_step(samples)

                    if log_output is not None:  # not OOM, overflow, ...
                        # log mid-epoch stats
                        num_updates = trainer.get_num_updates()
                        if num_updates % cfg.common.log_interval == 0:
                            stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
                            progress.log(stats, tag="train_inner", step=num_updates)
                            # reset mid-epoch stats after each log interval
                            # the end-of-epoch stats will still be preserved
                            metrics.reset_meters("train_inner")
                    
                    end_of_epoch = not itr.has_next()
                    
                    num_updates = trainer.get_num_updates()
                    should_stop = False

                    if num_updates >= max_update:
                        should_stop = True
                        logger.info(
                            f"Stopping training due to "
                            f"num_updates: {num_updates} >= max_update: {max_update}"
                        )

                    training_time_hours = trainer.cumulative_training_time() / (60 * 60)
                    if (
                        cfg.optimization.stop_time_hours > 0
                        and training_time_hours > cfg.optimization.stop_time_hours
                    ):
                        should_stop = True
                        logger.info(
                            f"Stopping training due to "
                            f"cumulative_training_time: {training_time_hours} > "
                            f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
                        )

                    # [dea]: do_save = do_validate.
                    do_validate = (
                        (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
                        or (
                            cfg.dataset.validate_interval_updates > 0
                            and num_updates > 0
                            and num_updates % cfg.dataset.validate_interval_updates == 0
                        )
                    ) and not cfg.dataset.disable_validation
                    
                    if do_validate:
                        # cnt += 1 # test rerun, validate 2 times then rerun
                        # valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
                        '''
                        ori validate starts here
                        '''
                        if cfg.dataset.fixed_validation_seed is not None:
                            # set fixed seed for every validation
                            utils.set_torch_seed(cfg.dataset.fixed_validation_seed)

                        trainer.begin_valid_epoch(epoch_itr.epoch)
                        valid_losses = []
                        recompile = False # [dea]: topk+1 not improving, rerun this interval with ori topk 
                        
                        if not end_of_epoch:
                            logger.info(f'[dea] before validation topk {task.current_k}')

                        for subset_idx, subset in enumerate(valid_subsets):
                            logger.debug('begin validation on "{}" subset on rank {}'.format(
                                subset, distributed_utils.get_global_rank()))

                            # Initialize data iterator
                            valid_itr = trainer.get_valid_iterator(subset).next_epoch_itr(
                                shuffle=False, set_dataset_epoch=False  # use a fixed valid set
                            )
                            if cfg.common.tpu:
                                valid_itr = utils.tpu_data_loader(valid_itr)

                            logger.debug('got valid iterator on "{}" subset on rank {}'.format(
                                    subset,
                                    distributed_utils.get_global_rank()
                                )
                            )

                            valid_progress = progress_bar.progress_bar(
                                valid_itr,
                                log_format=cfg.common.log_format,
                                log_interval=cfg.common.log_interval,
                                epoch=epoch_itr.epoch,
                                prefix=f"valid on '{subset}' subset",
                                tensorboard_logdir=(
                                    cfg.common.tensorboard_logdir
                                    if distributed_utils.is_master(cfg.distributed_training)
                                    else None
                                ),
                                default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
                                wandb_project=(
                                    cfg.common.wandb_project
                                    if distributed_utils.is_master(cfg.distributed_training)
                                    else None
                                ),
                                wandb_run_name=os.environ.get(
                                    "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
                                ),
                            )

                            logger.info('Begin looping over validation "{}" subset with length "{}"'.format(subset, len(valid_progress)))

                            # create a new root metrics aggregator so validation metrics
                            # don't pollute other aggregators (e.g., train meters)
                            with metrics.aggregate(new_root=True) as agg:
                                trainer.unwrap_model()
                                for i, sample in enumerate(valid_progress):
                                    if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
                                        break
                                    trainer.valid_step(sample)
                            
                            # log validation stats (orginal code)
                            # [dea] track best metr on avg validsets
                            # stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
                            stats = agg.get_smoothed_values()
                            stats["num_updates"] = trainer.get_num_updates()
                            if hasattr(checkpoint_utils.save_checkpoint, "best"): # True after second validation
                                key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
                                best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
                                stats[key] = best_function(
                                    checkpoint_utils.save_checkpoint.best,
                                    stats[cfg.checkpoint.best_checkpoint_metric],
                                )
                                if stats[cfg.checkpoint.best_checkpoint_metric] <= checkpoint_utils.save_checkpoint.best:
                                    stats['best_num_updates'] = trainer.get_num_updates()
                            stats['top_k'] = task.current_k[subset]
                            valid_progress.print(stats, tag=subset, step=trainer.get_num_updates())

                            # [dea] 
                            valid_loss = stats[cfg.checkpoint.best_checkpoint_metric] # add ppl to val_loss
                            valid_losses.append(valid_loss)

                            if hasattr(task, "post_validate"):
                                task.post_validate(trainer.get_model(), stats, agg)
                            
                            if end_of_epoch: # [dea] stop dea checking at end of epoch, avoid insufficient rerun
                                logger.info(f'[dea] end of epoch, validate without dea checking')
                                continue

                            logger.info(f'[dea] check {subset} valid loss {valid_loss}, prev best {task.prev_best_val_loss[subset]}, decreased? {valid_loss <= task.prev_best_val_loss[subset]}')
                            logger.info(f'[dea] {subset} stop_changing: {task.stop_changing[subset]}, not_improving_flag: {task.not_improving_flag[subset]}')
                            if valid_loss <= task.prev_best_val_loss[subset]:
                                # current topk improve valid loss
                                task.prev_best_val_loss[subset] = valid_loss
                                task.best_step[subset] = trainer.get_num_updates()
                                task.not_improving_flag[subset] = False
                            elif trainer.get_num_updates() - task.best_step[subset] >= 2*cfg.dataset.validate_interval_updates and not task.stop_changing[subset] \
                                and task.current_k[subset] < cfg.model.moe_expert_count :
                                '''
                                change topk and rerun this validate_interval_updates
                                '''
                                if task.not_improving_flag[subset]:
                                    # (2nd time) topk+1 still dont improve, revert to topk, and avoid 
                                    task.current_k[subset] -= 1
                                    task.stop_changing[subset] = True # test stop this
                                else:
                                    # (1st time) topk dont improve, topk+1
                                    task.current_k[subset] += 1
                                task.not_improving_flag[subset] = True
                                recompile = True

                        if not end_of_epoch:
                            logger.info(f'[dea] after validation topk {task.current_k}')
                            
                        if should_stop:
                            break
                        
                        if recompile:
                            logger.info('[dea] recompile is True, break training progress')
                            break

                        # If recompile is False, Save checkpoint
                        checkpoint_utils.save_checkpoint(
                            cfg.checkpoint, trainer, epoch_itr, sum(valid_losses)/len(valid_losses), training_finished=should_stop,
                            async_callback_fn=functools.partial(post_checkpoint_callback, cfg) if cfg.checkpoint.s3_upload_path else None,
                            dynamic_dicts_list = [task.current_k, task.best_step, task.prev_best_val_loss]
                        )
                        trainer.reset_dummy_batch(epoch_itr.first_batch)
                
                '''
                if recompile is True, 
                    reload last saved itr&progress, continue to rerun loop (while True, for i, samples in enumerate(progress):)
                if recompile is False,
                    inner loop exit normaly, break the while loop
                '''
                if recompile:
                    logger.info('[dea] recompile is True, revert training progress')
                    _, revert_num_updates, epoch_itr = checkpoint_utils.revert_checkpoint(
                        cfg.checkpoint,
                        trainer,
                        # don't cache epoch iterators for sharded datasets
                        disable_iterator_cache=task.has_sharded_data("train"),
                    )
                    itr = epoch_itr.next_epoch_itr(
                        fix_batches_to_gpus=cfg.distributed_training.fix_batches_to_gpus,
                        shuffle=(epoch_itr.next_epoch_idx > cfg.dataset.curriculum),
                    )
                    itr = iterators.GroupedIterator(itr, update_freq)
                    progress = progress_bar.progress_bar(
                        itr,
                        log_format=cfg.common.log_format,
                        log_file=cfg.common.log_file,
                        log_interval=cfg.common.log_interval,
                        epoch=epoch_itr.epoch,
                        default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
                        wandb_project=(
                            cfg.common.wandb_project
                            if distributed_utils.is_master(cfg.distributed_training)
                            else None
                        ),
                        wandb_run_name=os.environ.get(
                            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
                        ),
                    )
                    progress.update_config(_flatten_config(cfg))
                    max_update_ = max_update
                    max_update += cfg.dataset.validate_interval_updates
                    logger.info(f'[dea] extend max updates from {max_update_} to {max_update}')
                    continue
                else:
                    break

            # log end-of-epoch stats
            logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
            stats = get_training_stats(metrics.get_smoothed_values("train"))
            progress.print(stats, tag="train", step=num_updates)

            # reset epoch-level meters
            metrics.reset_meters("train")
        
        if should_stop:
            break
        # [dea] use avg validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, sum(valid_losses)/len(valid_losses))

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch
            load_dataset=task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasets
            disable_iterator_cache=task.has_sharded_data("train"),
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    if cfg.checkpoint.write_checkpoints_asynchronously:
        logger.info(
            "ioPath PathManager waiting for all asynchronous checkpoint "
            "writes to finish."
        )
        PathManager.async_close()
        logger.info("ioPath PathManager finished waiting.")


def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
    # skip check if no validation was done in the current epoch
    if valid_loss is None:
        return False
    if cfg.checkpoint.patience <= 0:
        return False

    def is_better(a, b):
        return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b

    prev_best = getattr(should_stop_early, "best", None)
    if prev_best is None or is_better(valid_loss, prev_best):
        should_stop_early.best = valid_loss
        should_stop_early.num_runs = 0
        return False
    else:
        should_stop_early.num_runs += 1
        if should_stop_early.num_runs >= cfg.checkpoint.patience:
            logger.info(
                "early stop since valid performance hasn't improved for last {} runs".format(
                    cfg.checkpoint.patience
                )
            )
            return True
        else:
            return False


def _flatten_config(cfg: DictConfig):
    config = OmegaConf.to_container(cfg)
    # remove any legacy Namespaces and replace with a single "args"
    namespace = None
    for k, v in list(config.items()):
        if isinstance(v, argparse.Namespace):
            namespace = v
            del config[k]
    if namespace is not None:
        config["args"] = vars(namespace)
    return config


def post_checkpoint_callback(cfg, filename):
    if cfg.checkpoint.s3_upload_path is not None:
        try:
            # PathManager only supports writing to S3, but this function call
            # can be replaced with other APIs for copying checkpoints.
            PathManager.copy_from_local(
                filename,
                os.path.join(cfg.checkpoint.s3_upload_path, os.path.basename(filename)),
                overwrite=True,
            )
        except (FileNotFoundError, AssertionError) as e:
            logger.info(f'could not upload {filename}: {e}')

def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
    stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
    return stats

def display_token_expert(trainer, k=50):
    src_token_to_expert=trainer.model.src_token_to_expert
    tgt_token_to_expert=trainer.model.tgt_token_to_expert
    src_token_to_expert=distributed_utils.all_reduce(src_token_to_expert, group=None, op='sum')
    tgt_token_to_expert=distributed_utils.all_reduce(tgt_token_to_expert, group=None, op='sum')

    src_token_to_expert=src_token_to_expert/src_token_to_expert.max()
    tgt_token_to_expert=tgt_token_to_expert/tgt_token_to_expert.max()

    src_token_to_expert=src_token_to_expert/src_token_to_expert.sum(dim=-1, keepdim=True)
    tgt_token_to_expert=tgt_token_to_expert/tgt_token_to_expert.sum(dim=-1, keepdim=True)
    
    _, topk_src_tokens=src_token_to_expert.topk(k=k, dim=1, )
    _, topk_tgt_tokens=src_token_to_expert.topk(k=k, dim=1, )

    src_dict=trainer.task.src_dict
    tgt_dict=trainer.task.tgt_dict

    enc_layer_num=len(src_token_to_expert)
    dec_layer_num=len(tgt_token_to_expert)
    num_experts=src_token_to_expert.shape[-1]

    for i in range(enc_layer_num):
        utils.print_r0('encoder layer:{}'.format(i)+'-'*40)
        for j in range(num_experts):
            utils.print_r0('encoder expert {}:{}'.format(j,src_dict.string(topk_src_tokens[i,:,j])))

    for i in range(dec_layer_num):
        utils.print_r0('decoder layer:{}'.format(i)+'-'*40)
        for j in range(num_experts):
            utils.print_r0('decoder expert {}:{}'.format(j,tgt_dict.string(topk_tgt_tokens[i,:,j])))


def get_valid_stats(
    cfg: DictConfig,
    trainer: Trainer,
    stats: Dict[str, Any],
) -> Dict[str, Any]:
    stats["num_updates"] = trainer.get_num_updates()
    print_r0(hasattr(checkpoint_utils.save_checkpoint, "best"))
    torch.distributed.barrier()
    if hasattr(checkpoint_utils.save_checkpoint, "best"):
        key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
        best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
        stats[key] = best_function(
            checkpoint_utils.save_checkpoint.best,
            stats[cfg.checkpoint.best_checkpoint_metric],
        )
        if stats[cfg.checkpoint.best_checkpoint_metric] <= checkpoint_utils.save_checkpoint.best:
            stats['best_num_updates'] = trainer.get_num_updates()
    return stats


def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)

    cfg = convert_namespace_to_omegaconf(args)

    if cfg.common.use_plasma_view:
        server = PlasmaStore(path=cfg.common.plasma_path)
        logger.info(f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}")

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)

    # if cfg.common.use_plasma_view:
    #     server.server.kill()


if __name__ == "__main__":
    cli_main()
