# coding: utf-8

# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import functools
import itertools
import logging
import math
import os
import shutil
import sys
import time
import warnings

import dllogger
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
try:
    from apex import amp
except ModuleNotFoundError:
    warnings.warn('APEX AMP is unavailable')

from torch.nn.parallel import DistributedDataParallel

import lamb
import utils
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.data_parallel import BalancedDataParallel
from utils.exp_utils import AverageMeter
from utils.exp_utils import TimeoutHandler
from utils.exp_utils import benchmark
from utils.exp_utils import create_exp_dir
from utils.exp_utils import l2_promote
from utils.exp_utils import log_env_info
from utils.exp_utils import register_ignoring_timeout_handler

sys.path.append("../..")
from Core.group import group_model
from Core.optimizer import ProxSGD, ProxAdamW, RMDA, RAMDA, AdamW
from Core.scheduler import multistep_param_scheduler

def parse_args():
    parent_parser = argparse.ArgumentParser(
        description='PyTorch Transformer-XL Language Model',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False,
        )

    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
    cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)

    cfg_parser.add_argument('--config', default='default')
    cfg_parser.add_argument('--config_file', default=None)

    config_args, _ = cfg_parser.parse_known_args()

    if config_args.config is not None and config_args.config_file is not None:
        with open(config_args.config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
    else:
        config = {}

    general = parser.add_argument_group('general setup')
    general.add_argument('--work_dir', default='LM-TFM', type=str,
                         help='Directory for the results')
    general.add_argument('--append_dataset', action='store_true',
                         help='Automatically append dataset name to work_dir')
    general.add_argument('--append_time', action='store_true',
                         help='Automatically append current time to work_dir')
    general.add_argument('--cuda', action='store_true',
                         help='Run training on a GPU using CUDA')
    general.add_argument('--fp16', action='store_true',
                         help='Run training in fp16/mixed precision')
    general.add_argument('--restart', type=str, default='',
                         help='Restart training from the saved checkpoint')
    general.add_argument('--debug', action='store_true',
                         help='Run in debug mode (do not create exp dir)')
    general.add_argument('--log_all_ranks', action='store_true',
                         help='Enable logging from all distributed ranks')
    general.add_argument('--dllog_file', type=str, default='train_log.json',
                         help='Name of the DLLogger output file')
    general.add_argument('--txtlog_file', type=str, default='train_log.log',
                         help='Name of the txt log file')
    general.add_argument('--save_all', action='store_true',
                         help='Save all checkpoints')
    general.add_argument('--no_env', action='store_true',
                         help='Do not print info on execution env')
    general.add_argument('--no_eval', action='store_true',
                         help='Disable model evaluation')
    general.add_argument('--no_test', action='store_true',
                         help='Disable model evaluation on test data')
    general.add_argument('--log_interval', type=int, default=10,
                         help='Report interval')
    general.add_argument('--target_throughput', type=float, default=None,
                         help='Target training throughput (for benchmarking)')
    general.add_argument('--target_perplexity', type=float, default=None,
                         help='Target validation perplexity (for benchmarking)')
    general.add_argument('--apex_amp_opt_level', type=str, default='O2',
                         choices=['O0', 'O1', 'O2', 'O3'],
                         help='Optimization level for apex amp')
    general.add_argument('--amp', choices=['apex', 'pytorch'], default='apex',
                         help='Implementation of automatic mixed precision')
    general.add_argument('--affinity', type=str,
                         default='socket_unique_interleaved',
                         choices=['socket', 'single', 'single_unique',
                                  'socket_unique_interleaved',
                                  'socket_unique_continuous',
                                  'disabled'],
                         help='type of CPU affinity')

    dataset = parser.add_argument_group('dataset setup')
    dataset.add_argument('--data', type=str, default='../data/wikitext-103',
                         help='Location of the data corpus')
    dataset.add_argument('--dataset', type=str, default='wt103',
                         choices=['wt103', 'lm1b', 'enwik8', 'text8'],
                         help='Dataset name')
    dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
                         help='Type of vocabulary')

    model = parser.add_argument_group('model setup')
    model.add_argument('--n_layer', type=int, default=16,
                       help='Number of total layers')
    model.add_argument('--n_head', type=int, default=8,
                       help='Number of heads')
    model.add_argument('--d_head', type=int, default=64,
                       help='Head dimension')
    model.add_argument('--d_embed', type=int, default=-1,
                       help='Embedding dimension')
    model.add_argument('--d_model', type=int, default=512,
                       help='Model dimension')
    model.add_argument('--d_inner', type=int, default=2048,
                       help='Inner dimension in feedforward layer')
    model.add_argument('--dropout', type=float, default=0.1,
                       help='Global dropout rate')
    model.add_argument('--dropatt', type=float, default=0.0,
                       help='Attention probability dropout rate')
    model.add_argument('--pre_lnorm', action='store_true',
                       help='Apply LayerNorm to the input instead of the output')
    model.add_argument('--attn_type', type=int, default=0,
                       help='Attention type. 0 for ours, 1 for Shaw et al,'
                       '2 for Vaswani et al, 3 for Al Rfou et al.')
    model.add_argument('--not_tied', action='store_true',
                       help='Do not tie the word embedding and softmax weights')
    model.add_argument('--clamp_len', type=int, default=-1,
                       help='Use the same pos embeddings after clamp_len')
    model.add_argument('--adaptive', action='store_true',
                       help='Use adaptive softmax')
    model.add_argument('--div_val', type=int, default=1,
                       help='Dividend value for adaptive input and softmax')
    model.add_argument('--sample_softmax', type=int, default=-1,
                       help='Number of samples in sampled softmax')
    model.add_argument('--init', default='normal', type=str,
                       help='Parameter initializer to use')
    model.add_argument('--emb_init', default='normal', type=str,
                       help='Parameter initializer to use')
    model.add_argument('--init_range', type=float, default=0.1,
                       help='Parameters initialized by U(-init_range, init_range)')
    model.add_argument('--emb_init_range', type=float, default=0.01,
                       help='Parameters initialized by U(-init_range, init_range)')
    model.add_argument('--init_std', type=float, default=0.02,
                       help='Parameters initialized by N(0, init_std)')
    model.add_argument('--proj_init_std', type=float, default=0.01,
                       help='Parameters initialized by N(0, init_std)')

    opt = parser.add_argument_group('optimizer setup')
    opt.add_argument('--optim', default='jitlamb', type=str,
                     choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb', 'RAMDA', 'RMDA', 'ProxAdamW', 'ProxSGD'],
                     help='Optimizer to use')
    opt.add_argument('--lr', type=float, default=0.01,
                     help='Initial learning rate')
    opt.add_argument('--mom', type=float, default=0.0,
                     help='Momentum for sgd')
    opt.add_argument('--scheduler', default='cosine', type=str,
                     choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant', 'stage_wise'],
                     help='LR scheduler to use')
    opt.add_argument('--max_step_scheduler', type=int, default=None,
                     help='Max number of training steps for LR scheduler')
    opt.add_argument('--warmup_step', type=int, default=1000,
                     help='Number of iterations for LR warmup')
    opt.add_argument('--decay_rate', type=float, default=0.5,
                     help='Decay factor when ReduceLROnPlateau is used')
    opt.add_argument('--lr_min', type=float, default=0.0,
                     help='Minimum learning rate during annealing')
    opt.add_argument('--clip', type=float, default=0.25,
                     help='Gradient clipping')
    opt.add_argument('--weight_decay', type=float, default=0.0,
                     help='Weight decay for adam|lamb')
    opt.add_argument('--lambda_', type=float, default=0.0,
                     help='regularization weight for RAMDA|RMDA|ProxAdamW|ProxSGD')
    opt.add_argument('--clip_nonemb', action='store_true',
                     help='Only clip the gradient of non-embedding params')
    opt.add_argument('--patience', type=int, default=0,
                     help='Patience')
    opt.add_argument('--eta_min', type=float, default=0.001,
                     help='Min learning rate for cosine scheduler')
    opt.add_argument('--milestones', type=int, nargs='+', default=[10000, 20000, 30000],
                     help='Milestones for stage-wise LR scheduler')
    opt.add_argument('--gamma', type=float, default=1e-1,
                     help='Decay factor when stage-wise LR scheduler is used')

    training = parser.add_argument_group('training setup')
    training.add_argument('--max_step', type=int, default=40000,
                          help='Max number of training steps')
    training.add_argument('--batch_size', type=int, default=256,
                          help='Global batch size')
    training.add_argument('--local_batch_size', type=int, default=None,
                          help='Local (per-device) batch size, this setting \
                          overrides global --batch_size and sets batch_size \
                          to local_batch_size * world_size')
    training.add_argument('--batch_chunk', type=int, default=1,
                          help='Split batch into chunks and train with '
                          'gradient accumulation')
    training.add_argument('--roll', action='store_true',
                          help='Enable random shifts within each data stream')
    training.add_argument('--tgt_len', type=int, default=192,
                          help='Number of tokens to predict')
    training.add_argument('--ext_len', type=int, default=0,
                          help='Length of the extended context')
    training.add_argument('--mem_len', type=int, default=192,
                          help='Length of the retained previous heads')
    training.add_argument('--seed', type=int, default=1111,
                          help='Random seed')
    training.add_argument('--multi_gpu', default=None, type=str,
                          choices=['ddp', 'dp'],
                          help='Use multiple GPU')
    training.add_argument('--gpu0_bsz', type=int, default=-1,
                          help='Batch size on gpu 0 (for "dp" backend)')
    training.add_argument('--same_length', action='store_true',
                          help='Use the same attn length for all tokens')
    training.add_argument('--varlen', action='store_true',
                          help='Use variable length')
    training.add_argument('--swap_mem', action='store_true',
                          help='Swap memory tensors to cpu')

    val = parser.add_argument_group('validation setup')
    val.add_argument('--eval_tgt_len', type=int, default=192,
                     help='Number of tokens to predict for evaluation')
    val.add_argument('--eval_batch_size', type=int, default=16,
                     help='Eval batch size')
    val.add_argument('--eval_max_steps', type=int, default=-1,
                     help='Max eval steps')
    val.add_argument('--eval_interval', type=int, default=5000,
                     help='Evaluation interval')

    dist = parser.add_argument_group('distributed setup')
    dist.add_argument('--local_rank',  type=int,
                      default=os.getenv('LOCAL_RANK', 0),
                      help='Used for multi-process training.')

    parser.set_defaults(**config)
    args, _ = parser.parse_known_args()

    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    if args.ext_len < 0:
        raise RuntimeError('Extended context length must be non-negative')

    if args.mem_len == 0:
        if args.eval_tgt_len > args.ext_len + args.tgt_len:
            raise RuntimeError('eval_tgt_len should be <= tgt_len + ext_len; '
                               f'eval_tgt_len: {args.eval_tgt_len}, '
                               f'tgt_len: {args.tgt_len}, '
                               f'ext_len: {args.ext_len}')
    else:
        if args.eval_tgt_len > args.mem_len + args.tgt_len:
            raise RuntimeError('eval_tgt_len should be <= tgt_len + mem_len; '
                               f'eval_tgt_len: {args.eval_tgt_len}, '
                               f'tgt_len: {args.tgt_len}, '
                               f'mem_len: {args.mem_len}')

    if args.batch_size % args.batch_chunk != 0:
        raise RuntimeError('Batch size needs to be divisible by batch chunk')

    if (
        args.local_batch_size is not None
        and args.local_batch_size % args.batch_chunk != 0
    ):
        raise RuntimeError('Local batch size needs to be divisible by '
                           'batch chunk')

    if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
        raise RuntimeError(
            'APEX AMP unavailable, install APEX or switch to pytorch AMP'
        )

    return args


def save_checkpoint(args, model, mems, model_config, optimizer, scheduler,
                    scaler, vocab, epoch, batch, last_iter, train_step,
                    best_val_loss, is_best, work_dir, device):
    if args.fp16:
        if args.amp == 'pytorch':
            amp_state = scaler.state_dict()
        elif args.amp == 'apex':
            amp_state = amp.state_dict()
    else:
        amp_state = None

    memory = [
        utils.distributed.all_gather_tensors(mem, device) for mem in mems
    ]

    if args.scheduler == 'cosine' or args.scheduler == 'inv_sqrt' or args.scheduler == 'dev_perf':
        state = {
            'args': args,
            'model_config': model_config,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'rng_states': utils.exp_utils.get_default_rng_states(device),
            'memory': memory,
            'vocab': vocab,
            'amp_state': amp_state,
            'epoch': epoch,
            'batch': batch,
            'last_iter': last_iter,
            'train_step': train_step,
            'best_val_loss': best_val_loss,
            }
    elif args.scheduler == 'stage_wise' or args.scheduler == 'constant':
        state = {
            'args': args,
            'model_config': model_config,
            'model_state': model.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'rng_states': utils.exp_utils.get_default_rng_states(device),
            'memory': memory,
            'vocab': vocab,
            'amp_state': amp_state,
            'epoch': epoch,
            'batch': batch,
            'last_iter': last_iter,
            'train_step': train_step,
            'best_val_loss': best_val_loss,
            }   

    last_chkpt_fname = 'checkpoint_last.pt'

    with utils.distributed.sync_workers() as rank:
        last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
        if rank == 0:
            # always save last checkpoint
            logging.info(f'Saving checkpoint to {last_chkpt_path}')
            torch.save(state, last_chkpt_path)

            # save best checkpoint if better than previous best
            if is_best:
                best_chkpt_fname = 'checkpoint_best.pt'
                best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
                logging.info(f'Saving checkpoint to {best_chkpt_path}')
                shutil.copy(last_chkpt_path, best_chkpt_path)

            # save every checkpoint if save_all is true
            if args.save_all:
                step_chkpt_fname = f'checkpoint_{train_step}.pt'
                step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
                logging.info(f'Saving checkpoint to {step_chkpt_path}')
                shutil.copy(last_chkpt_path, step_chkpt_path)


def load_checkpoint(path):
    if os.path.isdir(path):
        path = os.path.join(path, 'checkpoint_last.pt')

    dst = f'cuda:{torch.cuda.current_device()}'
    logging.info(f'Loading checkpoint from {path}')
    checkpoint = torch.load(path, map_location=dst)
    return checkpoint


def init_weight(weight, args):
    if args.init == 'uniform':
        nn.init.uniform_(weight, -args.init_range, args.init_range)
    elif args.init == 'normal':
        nn.init.normal_(weight, 0.0, args.init_std)


def init_bias(bias):
    nn.init.constant_(bias, 0.0)


def weights_init(m, args):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            init_weight(m.weight, args)
        if hasattr(m, 'bias') and m.bias is not None:
            init_bias(m.bias)
    elif classname.find('AdaptiveEmbedding') != -1:
        if hasattr(m, 'emb_projs'):
            for i in range(len(m.emb_projs)):
                if m.emb_projs[i] is not None:
                    nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
    elif classname.find('Embedding') != -1:
        if hasattr(m, 'weight'):
            init_weight(m.weight, args)
    elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
        if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
            init_weight(m.cluster_weight, args)
        if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
            init_bias(m.cluster_bias)
        if hasattr(m, 'out_projs'):
            for i in range(len(m.out_projs)):
                if m.out_projs[i] is not None:
                    nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
        if hasattr(m, 'out_layers_weights'):
            for i in range(len(m.out_layers_weights)):
                if m.out_layers_weights[i] is not None:
                    init_weight(m.out_layers_weights[i], args)
    elif classname.find('LayerNorm') != -1:
        if hasattr(m, 'weight'):
            nn.init.normal_(m.weight, 1.0, args.init_std)
        if hasattr(m, 'bias') and m.bias is not None:
            init_bias(m.bias)
    elif classname.find('TransformerLM') != -1:
        if hasattr(m, 'r_emb'):
            init_weight(m.r_emb, args)
        if hasattr(m, 'r_w_bias'):
            init_weight(m.r_w_bias, args)
        if hasattr(m, 'r_r_bias'):
            init_weight(m.r_r_bias, args)
        if hasattr(m, 'r_bias'):
            init_bias(m.r_bias)


def update_dropout(m, args):
    classname = m.__class__.__name__
    if classname.find('Dropout') != -1:
        if hasattr(m, 'p'):
            m.p = args.dropout


def update_dropatt(m, args):
    if hasattr(m, 'dropatt'):
        m.dropatt.p = args.dropatt


def evaluate(eval_iter, model, args):
    # Turn on evaluation mode which disables dropout.
    model.eval()

    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if args.mem_len == 0:
        model.reset_length(tgt_len=args.eval_tgt_len,
                           ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
                           mem_len=args.mem_len
                           )
    else:
        model.reset_length(tgt_len=args.eval_tgt_len,
                           ext_len=args.ext_len,
                           mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
                           )

    # Evaluation
    # unstructured sparsity            
    nonzero = 0.0
    num_el = 0.0
    for p in model.parameters():
        nonzero += p.count_nonzero().item()
        num_el += p.numel()
    unstructured_sparsity = 1.0-(nonzero/num_el)
    
    optimizer_grouped_parameters = group_model(model=model, name="Transformer-XL", lambda_=args.lambda_)
    
    # (input-wise) structured sparsity  
    num_nonsparse_groups = 0.0
    num_groups = 0.0
    for group in optimizer_grouped_parameters:
        dim = group["dim"]
        for p in group["params"]:
            if dim == (0):
                num_nonsparse_groups += p.count_nonzero(dim=dim).count_nonzero().item()
                num_groups += p.shape[1]
    structured_sparsity = 1.0-(num_nonsparse_groups/num_groups)
        
    # weighted (input-wise) structured sparsity
    weighted_num_nonsparse_groups = 0.0
    weighted_num_groups = 0.0
    for group in optimizer_grouped_parameters:
        dim = group["dim"]
        for p in group["params"]:
            if dim == (0):
                group_size = p.shape[0]
                weighted_num_nonsparse_groups += p.count_nonzero(dim=dim).count_nonzero().item()*group_size
                weighted_num_groups += p.shape[1]*group_size
    weighted_structured_sparsity = 1.0-(weighted_num_nonsparse_groups/weighted_num_groups)
    
    total_len, total_loss = 0, 0.
    with torch.no_grad():
        mems = None
        for i, (data, target, seq_len, warm) in enumerate(eval_iter):
            if args.eval_max_steps > 0 and i >= args.eval_max_steps:
                break
            enable_autocast = args.fp16 and args.amp == 'pytorch'
            with torch.cuda.amp.autocast(enable_autocast):
                loss, mems = model(data, target, mems)
                loss = loss.float().mean().type_as(loss)
            if warm:
                # assert (mems is None) or mems.size(1) == model.mem_len
                total_loss += seq_len * loss.item()
                total_len += seq_len

    # Switch back to the training mode
    model.reset_length(tgt_len=args.tgt_len,
                       ext_len=args.ext_len,
                       mem_len=args.mem_len
                       )
    model.train()

    return unstructured_sparsity, structured_sparsity, weighted_structured_sparsity, total_loss / total_len


def train_iteration(model, i, mems, data_chunks, target_chunks, scaler,
                    optimizer, device, delay_unscale, args):
    cpu = torch.device('cpu')
    data_i = data_chunks[i].contiguous()
    target_i = target_chunks[i].contiguous()

    if args.swap_mem and mems[i] is not None:
        mems[i] = mems[i].to(device, non_blocking=True)

    enable_autocast = args.fp16 and args.amp == 'pytorch'
    with torch.cuda.amp.autocast(enable_autocast):
        loss, mems[i] = model(data_i, target_i, mems[i])
        loss = loss.float().mean().type_as(loss) / args.batch_chunk

    if args.swap_mem and mems[i] is not None:
        mems[i] = mems[i].to(cpu, non_blocking=True)

    if args.fp16:
        if args.amp == 'pytorch':
            scaler.scale(loss).backward()
        elif args.amp == 'apex':
            with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
    else:
        loss.backward()

    train_loss = loss.float().item()
    return train_loss


def train(tr_iter, va_iter, model, para_model, mems, model_config, optimizer,
          optimizer_sparse, scheduler, scheduler_sparse, scaler, vocab, epoch,
          last_batch, last_iter, train_step, best_val_loss, meters,
          timeout_handler, device, args, valid_perplexities,
          unstructured_sparsities, structured_sparsities, weighted_structured_sparsities):
    # Turn on training mode which enables dropout.
    model.train()

    train_loss = 0
    cur_loss = float('inf')
    target_tokens = 0
    log_step = 0
    utils.distributed.barrier()
    log_start_time = time.time()

    if args.varlen:
        train_iter = tr_iter.get_varlen_iter(start=last_iter)
    else:
        train_iter = tr_iter.get_fixlen_iter(start=last_iter)

    for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
        log_step += 1
        target_tokens += target.numel()

        for param in model.parameters():
            param.grad = None

        data_chunks = torch.chunk(data, args.batch_chunk, 1)
        target_chunks = torch.chunk(target, args.batch_chunk, 1)

        for i in range(args.batch_chunk):
            if i < args.batch_chunk - 1 and isinstance(para_model, DistributedDataParallel):
                with para_model.no_sync():
                    train_loss_chunk = train_iteration(
                        para_model, i, mems, data_chunks, target_chunks, scaler,
                        optimizer, device, True, args
                    )
            else:
                train_loss_chunk = train_iteration(
                    para_model, i, mems, data_chunks, target_chunks, scaler,
                    optimizer, device, False, args
                )

            train_loss += train_loss_chunk

        if args.fp16:
            if args.amp == 'pytorch':
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            elif args.amp == 'apex':
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

        if args.fp16 and args.amp == 'pytorch':
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()
            if optimizer_sparse:
                optimizer_sparse.step()

        # step-wise learning rate annealing
        train_step += 1
        if args.scheduler in ['cosine', 'constant', 'dev_perf']:
            # linear warmup stage
            if train_step < args.warmup_step:
                curr_lr = args.lr * train_step / args.warmup_step
                optimizer.param_groups[0]['lr'] = curr_lr
                if optimizer_sparse:
                    optimizer_sparse.param_groups[0]['lr'] = curr_lr * 2
            else:
                if args.scheduler == 'cosine':
                    scheduler.step(train_step - args.warmup_step)
                    if scheduler_sparse:
                        scheduler_sparse.step(train_step - args.warmup_step)
        elif args.scheduler == 'inv_sqrt':
            scheduler.step(train_step)
            if scheduler_sparse:
                scheduler_sparse.step(train_step)
        elif args.scheduler == 'stage_wise':
            scheduler.momentum_step(optimizer=optimizer, epoch=train_step)
            scheduler.step(optimizer=optimizer, epoch=train_step)

        if train_step % args.log_interval == 0:
            cur_loss = train_loss / log_step
            cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
            train_loss = 0

            utils.distributed.barrier()
            current_time = time.time()
            elapsed = current_time - log_start_time
            avg_elapsed = elapsed / log_step
            avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
            log_start_time = current_time
            log_step = 0

            lr = optimizer.param_groups[0]['lr']
            if args.optim == "ProxAdamW" or args.optim.lower() == "adam" or args.optim == "lamb" or args.optim == "jitlamb":
                momentum = optimizer.state_dict()['param_groups'][0]['betas']
            else:
                momentum = optimizer.state_dict()['param_groups'][0]['momentum']
            throughput = target_tokens / elapsed
            throughput = utils.distributed.all_reduce_item(throughput, op='sum')
            meters['train_throughput'].update(throughput, elapsed)
            target_tokens = 0

            log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} | momentum {}' \
                '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
                    epoch,
                    train_step,
                    batch,
                    tr_iter.n_batch,
                    lr,
                    momentum,
                    avg_elapsed * 1000,
                    throughput,
                    cur_loss,
                    )

            dllogger_data = {
                'epoch': epoch,
                'train_batch': batch+1,
                'lr': lr,
                'momentum': momentum,
                'train_time/batch': avg_elapsed * 1000,
                'train_throughput': throughput,
                'train_loss': cur_loss,
                }

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
                dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
            else:
                log_str += ' | ppl {:9.2f}'.format(math.exp(cur_loss))
                dllogger_data['train_perplexity'] = math.exp(cur_loss)

            logging.info(log_str)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)

        do_periodic_eval = train_step % args.eval_interval == 0
        is_final_step = train_step == args.max_step
        interrupted = timeout_handler.interrupted

        if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
            utils.distributed.barrier()
            eval_start_time = time.time()

            unstructured_sparsity, structured_sparsity, weighted_structured_sparsity, val_loss = evaluate(va_iter, model, args)
            val_loss = utils.distributed.all_reduce_item(val_loss, op='mean')

            utils.distributed.barrier()
            eval_elapsed = time.time() - eval_start_time

            logging.info('-' * 100)
            log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                      '| valid loss {:5.2f} | unstructured sparsity {:5.4f} ' \
                      '| structured sparsity {:5.4f} | weighted structured sparsity {:5.4f}'.format(
                          train_step // args.eval_interval,
                          train_step,
                          eval_elapsed,
                          val_loss,
                          unstructured_sparsity,
                          structured_sparsity,
                          weighted_structured_sparsity
                          )

            dllogger_data = {
                'valid_elapsed': eval_elapsed,
                'valid_loss': val_loss,
                'unstructured_sparsity': unstructured_sparsity,
                'structured_sparsity': structured_sparsity,
                'weighted_structured_sparsity': weighted_structured_sparsity,
                }

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(val_loss / math.log(2))
                dllogger_data['valid_bits_per_character'] = val_loss / math.log(2)
            else:
                log_str += ' | valid ppl {:9.3f}'.format(math.exp(val_loss))
                dllogger_data['valid_perplexity'] = math.exp(val_loss)
            logging.info(log_str)
            logging.info('-' * 100)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)

            last_iter = tr_iter.last_iter

            # Check if the validation loss is the best we've seen so far.
            is_best = False
            if not best_val_loss or val_loss < best_val_loss:
                best_val_loss = val_loss
                is_best = True

            if not args.debug:
                save_checkpoint(args, model, mems, model_config, optimizer,
                                scheduler, scaler, vocab, epoch, batch,
                                last_iter, train_step, best_val_loss, is_best,
                                args.work_dir, device)

            # dev-performance based learning rate annealing
            if args.scheduler == 'dev_perf':
                scheduler.step(val_loss)
                if scheduler_sparse:
                    scheduler_sparse.step(val_loss)

            # subtract eval time from timers for training
            utils.distributed.barrier()
            log_start_time += time.time() - eval_start_time

            valid_perplexities.append(math.exp(val_loss))
            unstructured_sparsities.append(unstructured_sparsity)
            structured_sparsities.append(structured_sparsity)
            weighted_structured_sparsities.append(weighted_structured_sparsity)
            
            f = open(args.work_dir+"/{}_Transformer-XL_on_WikiText-103_presentation_{}.txt".format(args.optim, args.seed), 'w+')

            f.write("final valid perplexity: {}".format(valid_perplexities[-1])+'\n')
            f.write("final unstructured sparsity: {}".format(unstructured_sparsities[-1])+'\n')
            f.write("final structured sparsity: {}".format(structured_sparsities[-1])+'\n')
            f.write("final weighted structured sparsity: {}".format(weighted_structured_sparsities[-1])+'\n')
    
            f.write("\n")  
        
            f.write("lr: {}".format(args.lr)+'\n')
            if args.optim.lower() != "adam" and args.optim != "sgd" and args.optim != "lamb" and args.optim != "jitlamb":
                f.write("lambda_: {}".format(args.lambda_)+'\n')
            else:
                f.write("weight decay: {}".format(args.weight_decay)+'\n')
        
            f.write("\n")

            for i, r in enumerate(zip(unstructured_sparsities, structured_sparsities, weighted_structured_sparsities, valid_perplexities)):
                f.write("step:{:<8d}\tunstructured sparsity:{:<20.15f}\tstructured sparsity:{:<20.15f}\tweighted structured sparsity:{:<20.15f}\tvalid perplexity:{:<20.15f}".format((i+1)*args.eval_interval, r[0], r[1], r[2], r[3])+'\n')

            f.close()
            
            f = open(args.work_dir+"/{}_Transformer-XL_on_WikiText-103_forplotting_{}.txt".format(args.optim, args.seed), 'w+')

            f.write('unstructured sparsity\n')
            for i, unstructured_sparsity in enumerate(unstructured_sparsities):
                f.write("step {}: {}".format((i+1)*args.eval_interval, unstructured_sparsity)+'\n')
            f.write('structured sparsity\n')
            for i, structured_sparsity in enumerate(structured_sparsities):
                f.write("step {}: {}".format((i+1)*args.eval_interval, structured_sparsity)+'\n')
            f.write('weighted structured sparsity\n')
            for i, weighted_structured_sparsity in enumerate(weighted_structured_sparsities):
                f.write("step {}: {}".format((i+1)*args.eval_interval, weighted_structured_sparsity)+'\n')    
            f.write('valid perplexity\n')
            for i, valid_perplexity in enumerate(valid_perplexities):
                f.write("step {}: {}".format((i+1)*args.eval_interval, valid_perplexity)+'\n')
                
            f.close()

        if interrupted:
            logging.info(f'Received SIGTERM, exiting')
            sys.exit(0)

        if is_final_step:
            break
    return train_step, best_val_loss, cur_loss


def main():
    args = parse_args()
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(
            args.local_rank,
            nproc_per_node,
            args.affinity
        )
        print(f'{args.local_rank}: thread affinity: {affinity}')

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)
    
    args.work_dir = args.work_dir+"/"+args.optim+"_"+str(args.seed)

    args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
                                                        args.dataset,
                                                        args.append_dataset,
                                                        args.append_time,
                                                        )

    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['train.py', 'mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = "{}_Transformer-XL_on_WikiText-103_log_{}.log".format(args.optim, args.seed)
        
    dllog_file = "{}_Transformer-XL_on_WikiText-103_log_{}.json".format(args.optim, args.seed)
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
                                  filename=log_file,
                                  )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')
        if args.batch_size % args.batch_chunk != 0:
            raise RuntimeError('Batch size needs to be divisible by '
                               'batch chunk')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    dllogger.metadata('train_throughput', {'unit': 'tokens/s'})
    dllogger.metadata('train_elapsed', {'unit': 'min'})
    dllogger.metadata('valid_elapsed', {'unit': 'min'})
    dllogger.metadata('train_perplexity', {'unit': None})
    dllogger.metadata('valid_perplexity', {'unit': None})
    dllogger.metadata('train_loss', {'unit': None})
    dllogger.metadata('valid_loss', {'unit': None})

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
                                  device=device, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
        }

    model = MemTransformerLM(**model_config)

    model.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    model.word_emb.apply(functools.partial(weights_init, args=args))

    args.n_all_param = sum([p.nelement() for p in model.parameters()])
    args.n_nonemb_param = sum([p.nelement() for p in model.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(), lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params, lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(), lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        optimizer = lamb.JITLamb(model.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim == 'RAMDA' or args.optim == 'RMDA' or args.optim == 'ProxAdamW' or args.optim == 'ProxSGD':
        optimizer_sparse = None
        optimizer_grouped_parameters = group_model(model=model, name="Transformer-XL", lambda_=args.lambda_)
        if args.optim == "RAMDA":
            optimizer = RAMDA(optimizer_grouped_parameters,
                              lr=args.lr)
        elif args.optim == "RMDA":
            optimizer = RMDA(optimizer_grouped_parameters,
                             lr=args.lr)
        elif args.optim == "ProxAdamW":
            optimizer = ProxAdamW(optimizer_grouped_parameters,
                                  lr=args.lr)
        elif args.optim == "ProxSGD":
            optimizer = ProxSGD(optimizer_grouped_parameters,
                                lr=args.lr)


    model = model.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
                )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_model = DistributedDataParallel(model,
                                             device_ids=[args.local_rank],
                                             output_device=args.local_rank,
                                             broadcast_buffers=False,
                                             find_unused_parameters=True,
                                             )
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                              model, dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                optimizer_sparse, max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(
                optimizer_sparse,
                lr_lambda=lr_lambda
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=args.decay_rate, patience=args.patience,
            min_lr=args.lr_min,
            )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse, factor=args.decay_rate, patience=args.patience,
                min_lr=args.lr_min,
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'stage_wise':
        scheduler = multistep_param_scheduler(name=args.optim, optimizer=optimizer, milestones=args.milestones, gamma=args.gamma)
        scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass
        

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None
    cur_loss = float('inf')

    train_mems = [None for _ in range(args.batch_chunk)]

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            if args.scheduler != 'stage_wise' and args.scheduler != 'constant':
                scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            utils.exp_utils.set_default_rng_states(
                checkpoint['rng_states'], device
            )

            train_mems = [
                checkpoint['memory'][i][utils.distributed.get_rank()]
                for i in range(args.batch_chunk)
            ]

            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(f'Loaded checkpoint after {train_step} steps, but '
                             f'this run was scheduled for a total of '
                             f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.
    valid_perplexities = []
    unstructured_sparsities = []
    structured_sparsities = []
    weighted_structured_sparsities = []

    utils.distributed.barrier()
    start_time = time.time()

    with TimeoutHandler() as timeout_handler:
        try:
            for epoch in itertools.count(start=start_epoch):
                if args.roll:
                    tr_iter.roll(seed=args.seed + epoch)
                train_step, best_val_loss, cur_loss = train(
                    tr_iter, va_iter, model, para_model, train_mems,
                    model_config, optimizer, optimizer_sparse, scheduler,
                    scheduler_sparse, scaler, vocab, epoch, last_batch,
                    last_iter, train_step, best_val_loss, meters,
                    timeout_handler, device, args, 
                    valid_perplexities, unstructured_sparsities, 
                    structured_sparsities, weighted_structured_sparsities
                    )

                last_batch = 0
                last_iter = 0

                if train_step == args.max_step:
                    logging.info('-' * 100)
                    logging.info('End of training')
                    break
        except KeyboardInterrupt:
            logging.info('-' * 100)
            logging.info('Exiting from training early')
    utils.distributed.barrier()
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'checkpoint_best.pt')
    if (
        not args.debug
        and not args.no_test
        and not args.no_eval
        and os.path.exists(test_path)
    ):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path)
        model.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        utils.distributed.barrier()
        test_start_time = time.time()

        unstructured_sparsity, structured_sparsity, weighted_structured_sparsity, test_loss = evaluate(te_iter, model, args)
        test_loss = utils.distributed.all_reduce_item(test_loss, 'mean')

        utils.distributed.barrier()
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
                test_elapsed, test_loss, test_loss / math.log(2)))
        else:
            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test ppl {:9.3f}'.format(
                test_elapsed, test_loss, math.exp(test_loss)))
        logging.info('=' * 100)

        summary.update({
            'test_elapsed': test_elapsed,
            'test_loss': test_loss,
            })

        if args.dataset in ['enwik8', 'text8']:
            summary['test_bits_per_character'] = test_loss / math.log(2)
        else:
            summary['test_perplexity'] = math.exp(test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        best_val_perplexity = math.exp(best_val_loss)
    else:
        best_val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'train_loss': cur_loss,
        'valid_loss': best_val_loss,
        'valid_perplexity': best_val_perplexity,
        })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=best_val_perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['train_throughput'].avg
        )
    if not passed:
        sys.exit(1)


if __name__ == "__main__":
    # Disable profiling executor
    try:
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)
    except AttributeError:
        pass

    # Before we do anything with models, we want to ensure that we get fp16
    # execution of torch.einsum in APEX AMP.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
    # Note that running `--apex_amp_opt_level O2` will remove the need for this
    # code, but it is still valid.
    if 'apex' in sys.modules:
        amp.register_half_function(torch, 'einsum')

    main()
