
import argparse
import functools
import itertools
import logging
import math
import os
import shutil
import sys
import time
import warnings
import json
import torch.nn.functional as F
import wandb

import dllogger
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import yaml
try:
    from apex import amp
except ModuleNotFoundError:
    warnings.warn('APEX AMP is unavailable')

from torch.nn.parallel import DistributedDataParallel

import lamb
import utils
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.data_parallel import BalancedDataParallel
from utils.exp_utils import AverageMeter
from utils.exp_utils import TimeoutHandler
from utils.exp_utils import benchmark
from utils.exp_utils import create_exp_dir
from utils.exp_utils import l2_promote
from utils.exp_utils import log_env_info
from utils.exp_utils import register_ignoring_timeout_handler


def parse_args():
    parent_parser = argparse.ArgumentParser(
        description='PyTorch Transformer-XL Language Model',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        add_help=False,
        )
    parser = argparse.ArgumentParser(parents=[parent_parser], add_help=True)
    cfg_parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False)

    cfg_parser.add_argument('--config', default='default')
    cfg_parser.add_argument('--config_file', default=None)

    config_args, _ = cfg_parser.parse_known_args()

    if config_args.config is not None and config_args.config_file is not None:
        with open(config_args.config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)[config_args.config]['train']
    else:
        config = {}

    general = parser.add_argument_group('general setup')
    # LoT
    general.add_argument('--teacher', type=str, default='wt103_base')
    general.add_argument('--student', type=str, default='')
    general.add_argument('--alpha', type=float, default=1.0)
    general.add_argument('--T', type=float, default=1.5)
    general.add_argument('--exp_name', type=str, default='TransformerXL')
    general.add_argument('--start_step', type=int, default=1000)

    general.add_argument('--work_dir', default='LM-TFM', type=str,
                         help='Directory for the results')
    general.add_argument('--append_dataset', action='store_true',
                         help='Automatically append dataset name to work_dir')
    general.add_argument('--append_time', action='store_true',
                         help='Automatically append current time to work_dir')
    general.add_argument('--cuda', action='store_true',
                         help='Run training on a GPU using CUDA')
    general.add_argument('--fp16', action='store_true',
                         help='Run training in fp16/mixed precision')
    general.add_argument('--restart', type=str, default='',
                         help='Restart training from the saved checkpoint')
    general.add_argument('--debug', action='store_true',
                         help='Run in debug mode (do not create exp dir)')
    general.add_argument('--log_all_ranks', action='store_true',
                         help='Enable logging from all distributed ranks')
    general.add_argument('--dllog_file', type=str, default='train_log.json',
                         help='Name of the DLLogger output file')
    general.add_argument('--txtlog_file', type=str, default='train_log.log',
                         help='Name of the txt log file')
    general.add_argument('--save_all', action='store_true',
                         help='Save all checkpoints')
    general.add_argument('--no_env', action='store_true',
                         help='Do not print info on execution env')
    general.add_argument('--no_eval', action='store_true',
                         help='Disable model evaluation')
    general.add_argument('--no_test', action='store_true',
                         help='Disable model evaluation on test data')
    general.add_argument('--log_interval', type=int, default=10,
                         help='Report interval')
    general.add_argument('--target_throughput', type=float, default=None,
                         help='Target training throughput (for benchmarking)')
    general.add_argument('--target_perplexity', type=float, default=None,
                         help='Target validation perplexity (for benchmarking)')
    general.add_argument('--apex_amp_opt_level', type=str, default='O2',
                         choices=['O0', 'O1', 'O2', 'O3'],
                         help='Optimization level for apex amp')
    general.add_argument('--amp', choices=['apex', 'pytorch'], default='pytorch',
                         help='Implementation of automatic mixed precision')
    general.add_argument('--affinity', type=str,
                         default='socket_unique_interleaved',
                         choices=['socket', 'single', 'single_unique',
                                  'socket_unique_interleaved',
                                  'socket_unique_continuous',
                                  'disabled'],
                         help='type of CPU affinity')

    dataset = parser.add_argument_group('dataset setup')
    dataset.add_argument('--data', type=str, default='../data/wikitext-103',
                         help='Location of the data corpus')
    dataset.add_argument('--dataset', type=str, default='wt103',
                         choices=['wt103', 'lm1b', 'enwik8', 'text8'],
                         help='Dataset name')
    dataset.add_argument('--vocab', type=str, default='word', choices=['word', 'bpe'],
                         help='Type of vocabulary')

    model = parser.add_argument_group('model setup')
    model.add_argument('--n_layer', type=int, default=16,
                       help='Number of total layers')
    model.add_argument('--n_head', type=int, default=8,
                       help='Number of heads')
    model.add_argument('--d_head', type=int, default=64,
                       help='Head dimension')
    model.add_argument('--d_embed', type=int, default=-1,
                       help='Embedding dimension')
    model.add_argument('--d_model', type=int, default=512,
                       help='Model dimension')
    model.add_argument('--d_inner', type=int, default=2048,
                       help='Inner dimension in feedforward layer')
    model.add_argument('--dropout', type=float, default=0.1,
                       help='Global dropout rate')
    model.add_argument('--dropatt', type=float, default=0.0,
                       help='Attention probability dropout rate')
    model.add_argument('--pre_lnorm', action='store_true',
                       help='Apply LayerNorm to the input instead of the output')
    model.add_argument('--attn_type', type=int, default=0,
                       help='Attention type. 0 for ours, 1 for Shaw et al,'
                       '2 for Vaswani et al, 3 for Al Rfou et al.')
    model.add_argument('--not_tied', action='store_true',
                       help='Do not tie the word embedding and softmax weights')
    model.add_argument('--clamp_len', type=int, default=-1,
                       help='Use the same pos embeddings after clamp_len')
    model.add_argument('--adaptive', action='store_true',
                       help='Use adaptive softmax')
    model.add_argument('--div_val', type=int, default=1,
                       help='Dividend value for adaptive input and softmax')
    model.add_argument('--sample_softmax', type=int, default=-1,
                       help='Number of samples in sampled softmax')
    model.add_argument('--init', default='normal', type=str,
                       help='Parameter initializer to use')
    model.add_argument('--emb_init', default='normal', type=str,
                       help='Parameter initializer to use')
    model.add_argument('--init_range', type=float, default=0.1,
                       help='Parameters initialized by U(-init_range, init_range)')
    model.add_argument('--emb_init_range', type=float, default=0.01,
                       help='Parameters initialized by U(-init_range, init_range)')
    model.add_argument('--init_std', type=float, default=0.02,
                       help='Parameters initialized by N(0, init_std)')
    model.add_argument('--proj_init_std', type=float, default=0.01,
                       help='Parameters initialized by N(0, init_std)')

    opt = parser.add_argument_group('optimizer setup')
    opt.add_argument('--optim', default='jitlamb', type=str,
                     choices=['adam', 'sgd', 'adagrad', 'lamb', 'jitlamb'],
                     help='Optimizer to use')
    opt.add_argument('--lr', type=float, default=0.01,
                     help='Initial learning rate')
    opt.add_argument('--mom', type=float, default=0.0,
                     help='Momentum for sgd')
    opt.add_argument('--scheduler', default='cosine', type=str,
                     choices=['cosine', 'inv_sqrt', 'dev_perf', 'constant'],
                     help='LR scheduler to use')
    opt.add_argument('--max_step_scheduler', type=int, default=None,
                     help='Max number of training steps for LR scheduler')
    opt.add_argument('--warmup_step', type=int, default=1000,
                     help='Number of iterations for LR warmup')
    opt.add_argument('--decay_rate', type=float, default=0.5,
                     help='Decay factor when ReduceLROnPlateau is used')
    opt.add_argument('--lr_min', type=float, default=0.0,
                     help='Minimum learning rate during annealing')
    opt.add_argument('--clip', type=float, default=0.25,
                     help='Gradient clipping')
    opt.add_argument('--weight_decay', type=float, default=0.0,
                     help='Weight decay for adam|lamb')
    opt.add_argument('--clip_nonemb', action='store_true',
                     help='Only clip the gradient of non-embedding params')
    opt.add_argument('--patience', type=int, default=0,
                     help='Patience')
    opt.add_argument('--eta_min', type=float, default=0.001,
                     help='Min learning rate for cosine scheduler')

    training = parser.add_argument_group('training setup')
    training.add_argument('--max_step', type=int, default=40000,
                          help='Max number of training steps')
    training.add_argument('--batch_size', type=int, default=256,
                          help='Global batch size')
    training.add_argument('--local_batch_size', type=int, default=None,
                          help='Local (per-device) batch size, this setting \
                          overrides global --batch_size and sets batch_size \
                          to local_batch_size * world_size')
    training.add_argument('--batch_chunk', type=int, default=1,
                          help='Split batch into chunks and train with '
                          'gradient accumulation')
    training.add_argument('--roll', action='store_true',
                          help='Enable random shifts within each data stream')
    training.add_argument('--tgt_len', type=int, default=192,
                          help='Number of tokens to predict')
    training.add_argument('--ext_len', type=int, default=0,
                          help='Length of the extended context')
    training.add_argument('--mem_len', type=int, default=192,
                          help='Length of the retained previous heads')
    training.add_argument('--seed', type=int, default=1111,
                          help='Random seed')
    training.add_argument('--multi_gpu', default='ddp', type=str,
                          choices=['ddp', 'dp'],
                          help='Use multiple GPU')
    training.add_argument('--gpu0_bsz', type=int, default=-1,
                          help='Batch size on gpu 0 (for "dp" backend)')
    training.add_argument('--same_length', action='store_true',
                          help='Use the same attn length for all tokens')
    training.add_argument('--varlen', action='store_true',
                          help='Use variable length')
    training.add_argument('--swap_mem', action='store_true',
                          help='Swap memory tensors to cpu')

    val = parser.add_argument_group('validation setup')
    val.add_argument('--eval_tgt_len', type=int, default=192,
                     help='Number of tokens to predict for evaluation')
    val.add_argument('--eval_batch_size', type=int, default=16,
                     help='Eval batch size')
    val.add_argument('--eval_max_steps', type=int, default=-1,
                     help='Max eval steps')
    val.add_argument('--eval_interval', type=int, default=5000,
                     help='Evaluation interval')

    dist = parser.add_argument_group('distributed setup')
    dist.add_argument('--local_rank',  type=int,
                      default=os.getenv('LOCAL_RANK', 0),
                      help='Used for multi-process training.')

    parser.set_defaults(**config)
    args, _ = parser.parse_known_args()

    args.tied = not args.not_tied

    if args.d_embed < 0:
        args.d_embed = args.d_model

    if args.ext_len < 0:
        raise RuntimeError('Extended context length must be non-negative')

    if args.mem_len == 0:
        if args.eval_tgt_len > args.ext_len + args.tgt_len:
            raise RuntimeError('eval_tgt_len should be <= tgt_len + ext_len; '
                               f'eval_tgt_len: {args.eval_tgt_len}, '
                               f'tgt_len: {args.tgt_len}, '
                               f'ext_len: {args.ext_len}')
    else:
        if args.eval_tgt_len > args.mem_len + args.tgt_len:
            raise RuntimeError('eval_tgt_len should be <= tgt_len + mem_len; '
                               f'eval_tgt_len: {args.eval_tgt_len}, '
                               f'tgt_len: {args.tgt_len}, '
                               f'mem_len: {args.mem_len}')

    if args.batch_size % args.batch_chunk != 0:
        raise RuntimeError('Batch size needs to be divisible by batch chunk')

    if (
        args.local_batch_size is not None
        and args.local_batch_size % args.batch_chunk != 0
    ):
        raise RuntimeError('Local batch size needs to be divisible by '
                           'batch chunk')

    if args.fp16 and args.amp == 'apex' and 'apex' not in sys.modules:
        raise RuntimeError(
            'APEX AMP unavailable, install APEX or switch to pytorch AMP'
        )

    return args


def save_checkpoint(args, model, mems, model_config, optimizer, scheduler,
                    scaler, vocab, epoch, batch, last_iter, train_step,
                    best_val_loss, is_best, work_dir, device, type='teacher'):
    if args.fp16:
        if args.amp == 'pytorch':
            amp_state = scaler.state_dict()
        elif args.amp == 'apex':
            amp_state = amp.state_dict()
    else:
        amp_state = None

    memory = [
        utils.distributed.all_gather_tensors(mem, device) for mem in mems
    ]

    state = {
        'args': args,
        'model_config': model_config,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'scheduler_state': scheduler.state_dict(),
        'rng_states': utils.exp_utils.get_default_rng_states(device),
        'memory': memory,
        'vocab': vocab,
        'amp_state': amp_state,
        'epoch': epoch,
        'batch': batch,
        'last_iter': last_iter,
        'train_step': train_step,
        'best_val_loss': best_val_loss,
        }

    last_chkpt_fname = type+'checkpoint_last.pt'

    with utils.distributed.sync_workers() as rank:
        last_chkpt_path = os.path.join(work_dir, last_chkpt_fname)
        if rank == 0:
            # always save last checkpoint
            logging.info(f'Saving checkpoint to {last_chkpt_path}')
            torch.save(state, last_chkpt_path)

            # save best checkpoint if better than previous best
            if is_best:
                best_chkpt_fname = type+'checkpoint_best.pt'
                best_chkpt_path = os.path.join(work_dir, best_chkpt_fname)
                logging.info(f'Saving checkpoint to {best_chkpt_path}')
                shutil.copy(last_chkpt_path, best_chkpt_path)

            # save every checkpoint if save_all is true
            if args.save_all:
                step_chkpt_fname = type+f'checkpoint_{train_step}.pt'
                step_chkpt_path = os.path.join(work_dir, step_chkpt_fname)
                logging.info(f'Saving checkpoint to {step_chkpt_path}')
                shutil.copy(last_chkpt_path, step_chkpt_path)


def load_checkpoint(path, type):
    if os.path.isdir(path):
        path = os.path.join(path, type+'checkpoint_last.pt')

    dst = f'cuda:{torch.cuda.current_device()}'
    logging.info(f'Loading checkpoint from {path}')
    checkpoint = torch.load(path, map_location=dst)
    return checkpoint


def init_weight(weight, args):
    if args.init == 'uniform':
        nn.init.uniform_(weight, -args.init_range, args.init_range)
    elif args.init == 'normal':
        nn.init.normal_(weight, 0.0, args.init_std)


def init_bias(bias):
    nn.init.constant_(bias, 0.0)


def weights_init(m, args):
    classname = m.__class__.__name__
    if classname.find('Linear') != -1:
        if hasattr(m, 'weight') and m.weight is not None:
            init_weight(m.weight, args)
        if hasattr(m, 'bias') and m.bias is not None:
            init_bias(m.bias)
    elif classname.find('AdaptiveEmbedding') != -1:
        if hasattr(m, 'emb_projs'):
            for i in range(len(m.emb_projs)):
                if m.emb_projs[i] is not None:
                    nn.init.normal_(m.emb_projs[i], 0.0, args.proj_init_std)
    elif classname.find('Embedding') != -1:
        if hasattr(m, 'weight'):
            init_weight(m.weight, args)
    elif classname.find('ProjectedAdaptiveLogSoftmax') != -1:
        if hasattr(m, 'cluster_weight') and m.cluster_weight is not None:
            init_weight(m.cluster_weight, args)
        if hasattr(m, 'cluster_bias') and m.cluster_bias is not None:
            init_bias(m.cluster_bias)
        if hasattr(m, 'out_projs'):
            for i in range(len(m.out_projs)):
                if m.out_projs[i] is not None:
                    nn.init.normal_(m.out_projs[i], 0.0, args.proj_init_std)
        if hasattr(m, 'out_layers_weights'):
            for i in range(len(m.out_layers_weights)):
                if m.out_layers_weights[i] is not None:
                    init_weight(m.out_layers_weights[i], args)
    elif classname.find('LayerNorm') != -1:
        if hasattr(m, 'weight'):
            nn.init.normal_(m.weight, 1.0, args.init_std)
        if hasattr(m, 'bias') and m.bias is not None:
            init_bias(m.bias)
    elif classname.find('TransformerLM') != -1:
        if hasattr(m, 'r_emb'):
            init_weight(m.r_emb, args)
        if hasattr(m, 'r_w_bias'):
            init_weight(m.r_w_bias, args)
        if hasattr(m, 'r_r_bias'):
            init_weight(m.r_r_bias, args)
        if hasattr(m, 'r_bias'):
            init_bias(m.r_bias)


def update_dropout(m, args):
    classname = m.__class__.__name__
    if classname.find('Dropout') != -1:
        if hasattr(m, 'p'):
            m.p = args.dropout


def update_dropatt(m, args):
    if hasattr(m, 'dropatt'):
        m.dropatt.p = args.dropatt


def kl_div_logits(p, q, T):
    loss_func = nn.KLDivLoss(reduction = 'batchmean', log_target=True)
    loss = loss_func(F.log_softmax(p/T, dim=-1), F.log_softmax(q/T, dim=-1)) * T * T
    return loss


def evaluate(eval_iter, model, args):
    # Turn on evaluation mode which disables dropout.
    model.eval()

    # If the model does not use memory at all, make the ext_len longer.
    # Otherwise, make the mem_len longer and keep the ext_len the same.
    if args.mem_len == 0:
        model.reset_length(tgt_len=args.eval_tgt_len,
                           ext_len=args.ext_len + args.tgt_len - args.eval_tgt_len,
                           mem_len=args.mem_len
                           )
    else:
        model.reset_length(tgt_len=args.eval_tgt_len,
                           ext_len=args.ext_len,
                           mem_len=args.mem_len + args.tgt_len - args.eval_tgt_len,
                           )

    # Evaluation
    total_len, total_loss = 0, 0.
    with torch.no_grad():
        mems = None
        for i, (data, target, seq_len, warm) in enumerate(eval_iter):
            if args.eval_max_steps > 0 and i >= args.eval_max_steps:
                break
            enable_autocast = args.fp16 and args.amp == 'pytorch'
            with torch.cuda.amp.autocast(enable_autocast):
                loss, mems, logit = model(data, target, mems)
                loss = loss.float().mean().type_as(loss)
            if warm:
                # assert (mems is None) or mems.size(1) == model.mem_len
                total_loss += seq_len * loss.item()
                total_len += seq_len

    # Switch back to the training mode
    model.reset_length(tgt_len=args.tgt_len,
                       ext_len=args.ext_len,
                       mem_len=args.mem_len
                       )
    model.train()

    return total_loss / total_len


def train_iteration(teacher, student, i, teacher_mems, student_mems, data_chunks, target_chunks, scaler,
                    teacher_optimizer, student_optimizer, device, delay_unscale, args):
    cpu = torch.device('cpu')
    data_i = data_chunks[i].contiguous()
    target_i = target_chunks[i].contiguous()

    if args.swap_mem and teacher_mems[i] is not None:
        teacher_mems[i] = teacher_mems[i].to(device, non_blocking=True)
        if student:
            student_mems[i] = student_mems[i].to(device, non_blocking=True)

    enable_autocast = args.fp16 and args.amp == 'pytorch'
    with torch.cuda.amp.autocast(enable_autocast):
        teacher_ce_loss, teacher_mems[i], teacher_logit = teacher(data_i, target_i, teacher_mems[i])
        teacher_ce_loss = teacher_ce_loss.float().mean().type_as(teacher_ce_loss) / args.batch_chunk
        teacher_lot_loss = 0
        teacher_loss = teacher_ce_loss + teacher_lot_loss
        if student:
            student_ce_loss, student_mems[i], student_logit = student(data_i, target_i, student_mems[i])
            student_ce_loss = student_ce_loss.float().mean().type_as(student_ce_loss) / args.batch_chunk
            
            if args.train_step > args.warmup_step and args.train_step > args.start_step:
                args.alpha = args.original_alpha * (0.5 + args.train_step / args.max_step)
                teacher_lot_loss = args.alpha*kl_div_logits(teacher_logit, student_logit.detach(), args.T)
                student_lot_loss = args.alpha*kl_div_logits(student_logit, teacher_logit.detach(), args.T)
                teacher_loss = teacher_ce_loss + teacher_lot_loss
                student_loss = student_ce_loss + student_lot_loss
            else:
                student_loss = student_ce_loss

    if args.swap_mem and teacher_mems[i] is not None:
        teacher_mems[i] = teacher_mems[i].to(cpu, non_blocking=True)
        if student:
            student_mems[i] = student_mems[i].to(cpu, non_blocking=True)

    if args.fp16:
        if args.amp == 'pytorch':
            scaler.scale(teacher_loss).backward()
            if student:
                scaler.scale(student_loss).backward()
        elif args.amp == 'apex':
            with amp.scale_loss(teacher_loss, teacher_optimizer, delay_unscale=delay_unscale) as scaled_loss:
                scaled_loss.backward()
    else:
        teacher_loss.backward()

    teacher_loss = teacher_loss.float().item()
    return teacher_loss, teacher_ce_loss, teacher_lot_loss


def train(tr_iter, va_iter, teacher, para_teacher, student, para_student, teacher_mems, student_mems, model_config, teacher_optimizer,
          student_optimizer, teacher_scheduler, student_scheduler, scaler, vocab, epoch,
          last_batch, last_iter, train_step, best_val_loss, meters,
          timeout_handler, device, args):
    # Turn on training mode which enables dropout.
    teacher.train()
    if student:
        student.train()

    train_loss = 0
    ce_loss = 0
    lot_loss = 0
    cur_loss = float('inf')
    target_tokens = 0
    log_step = 0
    utils.distributed.barrier()
    log_start_time = time.time()

    if args.varlen:
        train_iter = tr_iter.get_varlen_iter(start=last_iter)
    else:
        train_iter = tr_iter.get_fixlen_iter(start=last_iter)

    for batch, (data, target, seq_len, _) in enumerate(train_iter, start=last_batch+1):
        log_step += 1
        target_tokens += target.numel()

        for param in teacher.parameters():
            param.grad = None
        if student:
            for param in student.parameters():
                param.grad = None

        data_chunks = torch.chunk(data, args.batch_chunk, 1)
        target_chunks = torch.chunk(target, args.batch_chunk, 1)

        for i in range(args.batch_chunk):
            if i < args.batch_chunk - 1 and isinstance(para_teacher, DistributedDataParallel):
                with para_teacher.no_sync():
                    train_loss_chunk, ce_loss_chunk, lot_loss_chunk = train_iteration(
                        para_teacher, para_student, i, teacher_mems, student_mems, data_chunks, target_chunks, scaler,
                        teacher_optimizer, student_optimizer, device, True, args
                    )
            else:
                train_loss_chunk, ce_loss_chunk, lot_loss_chunk = train_iteration(
                    para_teacher, para_student, i, teacher_mems, student_mems, data_chunks, target_chunks, scaler,
                    teacher_optimizer, student_optimizer, device, False, args
                )

            train_loss += train_loss_chunk
            ce_loss += ce_loss_chunk
            lot_loss += lot_loss_chunk

        if args.fp16:
            if args.amp == 'pytorch':
                scaler.unscale_(teacher_optimizer)
                torch.nn.utils.clip_grad_norm_(teacher.parameters(), args.clip)
                if student:
                    scaler.unscale_(student_optimizer)
                    torch.nn.utils.clip_grad_norm_(student.parameters(), args.clip)
            elif args.amp == 'apex':
                torch.nn.utils.clip_grad_norm_(amp.master_params(teacher_optimizer), args.clip)
        else:
            torch.nn.utils.clip_grad_norm_(teacher.parameters(), args.clip)

        if args.fp16 and args.amp == 'pytorch':
            scaler.step(teacher_optimizer)
            if student:
                scaler.step(student_optimizer)
            scaler.update()
        else:
            teacher_optimizer.step()

        # step-wise learning rate annealing
        train_step += 1
        args.train_step = train_step
        if args.scheduler in ['cosine', 'constant', 'dev_perf']:
            # linear warmup stage
            if train_step < args.warmup_step:
                curr_lr = args.lr * train_step / args.warmup_step
                teacher_optimizer.param_groups[0]['lr'] = curr_lr
                if student:
                    student_optimizer.param_groups[0]['lr'] = curr_lr
            else:
                if args.scheduler == 'cosine':
                    teacher_scheduler.step(train_step - args.warmup_step)
                    if student:
                        student_scheduler.step(train_step - args.warmup_step)

        elif args.scheduler == 'inv_sqrt':
            teacher_scheduler.step(train_step)

        if train_step % args.log_interval == 0:
            cur_loss = train_loss / log_step
            cur_ce_loss = ce_loss / log_step
            cur_lot_loss = lot_loss / log_step
            cur_loss = utils.distributed.all_reduce_item(cur_loss, op='mean')
            cur_ce_loss = utils.distributed.all_reduce_item(cur_ce_loss, op='mean')
            cur_lot_loss = utils.distributed.all_reduce_item(cur_lot_loss, op='mean')
            train_loss = 0
            ce_loss = 0
            lot_loss = 0

            utils.distributed.barrier()
            current_time = time.time()
            elapsed = current_time - log_start_time
            avg_elapsed = elapsed / log_step
            avg_elapsed = utils.distributed.all_reduce_item(avg_elapsed, op='max')
            log_start_time = current_time
            log_step = 0

            lr = teacher_optimizer.param_groups[0]['lr']
            throughput = target_tokens / elapsed
            throughput = utils.distributed.all_reduce_item(throughput, op='sum')
            meters['train_throughput'].update(throughput, elapsed)
            target_tokens = 0

            log_str = '| epoch {:3d} step {:>8d} | batches {:>6d} / {:d} | lr {:.3e} ' \
                '| ms/batch {:5.1f} | tok/s {:7.0f} | loss {:5.2f}'.format(
                    epoch,
                    train_step,
                    batch,
                    tr_iter.n_batch,
                    lr,
                    avg_elapsed * 1000,
                    throughput,
                    cur_loss,
                    )

            dllogger_data = {
                'epoch': epoch,
                'train_batch': batch+1,
                'lr': lr,
                'train_time/batch': avg_elapsed * 1000,
                'train_throughput': throughput,
                'train_loss': cur_loss,
                }
            if args.local_rank == 0:
                wandb.log({'teacher_lr': lr, 'teacher_train_loss': cur_loss, \
                'teacher_ce_loss': cur_ce_loss, 'teacher_lot_loss': cur_lot_loss, \
                'teacher_train_ppl': math.exp(cur_ce_loss)}, step=train_step)

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(cur_loss / math.log(2))
                dllogger_data['train_bits_per_character'] = cur_loss / math.log(2)
            else:
                log_str += ' | ppl {:9.2f}'.format(math.exp(cur_ce_loss))
                dllogger_data['train_perplexity'] = math.exp(cur_ce_loss)

            logging.info(log_str)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)

        do_periodic_eval = train_step % args.eval_interval == 0
        is_final_step = train_step == args.max_step
        interrupted = timeout_handler.interrupted

        if (do_periodic_eval or is_final_step or interrupted) and not args.no_eval:
            utils.distributed.barrier()
            eval_start_time = time.time()

            teacher_val_loss = evaluate(va_iter, teacher, args)
            teacher_val_loss = utils.distributed.all_reduce_item(teacher_val_loss, op='mean')
            if student:
                student_val_loss = evaluate(va_iter, student, args)
                student_val_loss = utils.distributed.all_reduce_item(student_val_loss, op='mean')

            utils.distributed.barrier()
            eval_elapsed = time.time() - eval_start_time

            logging.info('-' * 100)
            log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
                      '| valid loss {:5.2f}'.format(
                          train_step // args.eval_interval,
                          train_step,
                          eval_elapsed,
                          teacher_val_loss,
                          )

            dllogger_data = {
                'valid_elapsed': eval_elapsed,
                'valid_loss': teacher_val_loss,
                }

            if args.dataset in ['enwik8', 'text8']:
                log_str += ' | bpc {:9.5f}'.format(teacher_val_loss / math.log(2))
                dllogger_data['valid_bits_per_character'] = teacher_val_loss / math.log(2)
            else:
                log_str += ' | teacher valid ppl {:9.3f}'.format(math.exp(teacher_val_loss))
                dllogger_data['teacher_valid_perplexity'] = math.exp(teacher_val_loss)
                if student:
                    log_str += ' | student valid ppl {:9.3f}'.format(math.exp(student_val_loss))
                    dllogger_data['student_valid_perplexity'] = math.exp(student_val_loss)
            logging.info(log_str)
            logging.info('-' * 100)
            dllogger.log(step=tuple([train_step]), data=dllogger_data)
            if args.local_rank == 0:
                wandb.log({'teacher_valid_ppl': math.exp(teacher_val_loss)}, step=train_step)
                if student:
                    wandb.log({'student_valid_ppl': math.exp(student_val_loss)}, step=train_step)
            
            last_iter = tr_iter.last_iter

            # Check if the validation loss is the best we've seen so far.
            is_best = False
            if not best_val_loss or teacher_val_loss < best_val_loss:
                best_val_loss = teacher_val_loss
                is_best = True

            if not args.debug:
                save_checkpoint(args, teacher, teacher_mems, model_config, teacher_optimizer,
                                teacher_scheduler, scaler, vocab, epoch, batch,
                                last_iter, train_step, best_val_loss, is_best,
                                args.work_dir, device, type='teacher')
                if student:
                    save_checkpoint(args, student, student_mems, model_config, student_optimizer,
                                student_scheduler, scaler, vocab, epoch, batch,
                                last_iter, train_step, student_val_loss, is_best,
                                args.work_dir, device, type='student')

            # dev-performance based learning rate annealing
            if args.scheduler == 'dev_perf':
                teacher_scheduler.step(teacher_val_loss)
                if student:
                    student_scheduler.step(student_val_loss)

            # subtract eval time from timers for training
            utils.distributed.barrier()
            log_start_time += time.time() - eval_start_time

        if interrupted:
            logging.info(f'Received SIGTERM, exiting')
            sys.exit(0)

        if is_final_step:
            break
    return train_step, best_val_loss, cur_loss


def main():
    args = parse_args()
    args.original_alpha = args.alpha
    if args.local_rank == 0:
        print(json.dumps(vars(args), indent=4))
    if args.affinity != 'disabled':
        nproc_per_node = torch.cuda.device_count()
        affinity = utils.gpu_affinity.set_affinity(
            args.local_rank,
            nproc_per_node,
            args.affinity
        )
        print(f'{args.local_rank}: thread affinity: {affinity}')

    # Initialize device and distributed backend
    torch.cuda.set_device(args.local_rank)
    l2_promote()
    device = torch.device('cuda' if args.cuda else 'cpu')
    utils.distributed.init_distributed(args.cuda)

    args.work_dir = utils.exp_utils.build_work_dir_name(args.work_dir,
                                                        args.dataset,
                                                        args.append_dataset,
                                                        args.append_time,
                                                        )
    with utils.distributed.sync_workers() as rank:
        if rank == 0:
            create_exp_dir(args.work_dir,
                           scripts_to_save=['transformer_xl/pytorch/train_lot.py', 'transformer_xl/pytorch/mem_transformer.py'],
                           debug=args.debug)

    # Setup logging
    if args.log_all_ranks:
        log_file = f'train_log_rank_{utils.distributed.get_rank()}.log'
    else:
        log_file = args.txtlog_file
    dllog_file = args.dllog_file
    log_file = os.path.join(args.work_dir, log_file)
    dllog_file = os.path.join(args.work_dir, dllog_file)

    if args.debug:
        log_file = os.devnull
        dllog_file = os.devnull

    utils.exp_utils.setup_logging(log_all_ranks=args.log_all_ranks,
                                  filename=log_file,
                                  )
    utils.exp_utils.setup_dllogger(enabled=True, filename=dllog_file)

    if args.local_batch_size is not None:
        world_size = utils.distributed.get_world_size()
        args.batch_size = world_size * args.local_batch_size
        logging.info(f'--local_batch_size was set, adjusting global batch size'
                     f' to {args.batch_size} (local_batch_size * world_size)')
        if args.batch_size % args.batch_chunk != 0:
            raise RuntimeError('Batch size needs to be divisible by '
                               'batch chunk')

    logging.info(args)
    dllogger.log(step='PARAMETER', data=vars(args))

    dllogger.metadata('train_throughput', {'unit': 'tokens/s'})
    dllogger.metadata('train_elapsed', {'unit': 'min'})
    dllogger.metadata('valid_elapsed', {'unit': 'min'})
    dllogger.metadata('train_perplexity', {'unit': None})
    dllogger.metadata('valid_perplexity', {'unit': None})
    dllogger.metadata('train_loss', {'unit': None})
    dllogger.metadata('valid_loss', {'unit': None})

    logging.info(f'world size: {utils.distributed.get_world_size()}')

    if not args.no_env:
        log_env_info()

    register_ignoring_timeout_handler()

    if args.local_rank == 0:
        wandb_username=os.environ.get('WANDB_USER_NAME')
        wandb_key=os.environ.get('WANDB_API_KEY')    
        wandb.login(key=wandb_key)
        wandb.init(project='LoT_LM_TransformerXL_'+args.dataset, entity=wandb_username, name=args.exp_name)


    # Set the random seed manually for reproducibility.
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    ###########################################################################
    # Load data
    ###########################################################################
    corpus = get_lm_corpus(args.data, args.dataset, args.vocab)
    ntokens = len(corpus.vocab)
    vocab = corpus.vocab
    args.n_token = ntokens

    if args.mem_len == 0:
        eval_mem_len = 0
    else:
        eval_mem_len = args.mem_len + args.tgt_len - args.eval_tgt_len

    tr_iter = corpus.get_iterator('train', args.batch_size, args.tgt_len,
                                  device=device, ext_len=args.ext_len)
    va_iter = corpus.get_iterator('valid', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.eval_batch_size,
                                  args.eval_tgt_len, device=device,
                                  mem_len=eval_mem_len, ext_len=args.ext_len)

    # adaptive softmax / embedding
    cutoffs, tie_projs = [], [False]
    if args.adaptive:
        assert args.dataset in ['wt103', 'lm1b']
        if args.dataset == 'wt103':
            cutoffs = [19997, 39997, 199997]
            tie_projs += [True] * len(cutoffs)
        elif args.dataset == 'lm1b':
            cutoffs = [59997, 99997, 639997]
            tie_projs += [False] * len(cutoffs)

    ###########################################################################
    # Build the model
    ###########################################################################
    model_config = {
        'n_token': ntokens,
        'n_layer': args.n_layer,
        'n_head': args.n_head,
        'd_model': args.d_model,
        'd_head': args.d_head,
        'd_inner': args.d_inner,
        'dropout': args.dropout,
        'dropatt': args.dropatt,
        'dtype': None,
        'tie_weight': args.tied,
        'd_embed': args.d_embed,
        'div_val': args.div_val,
        'tie_projs': tie_projs,
        'pre_lnorm': args.pre_lnorm,
        'tgt_len': args.tgt_len,
        'ext_len': args.ext_len,
        'mem_len': args.mem_len,
        'cutoffs': cutoffs,
        'same_length': args.same_length,
        'attn_type': args.attn_type,
        'clamp_len': args.clamp_len,
        'sample_softmax': args.sample_softmax,
        }

    teacher = MemTransformerLM(**model_config)
    teacher.apply(functools.partial(weights_init, args=args))
    # ensure embedding init is not overridden by out_layer in case of weight sharing
    teacher.word_emb.apply(functools.partial(weights_init, args=args))
    if args.student:
        student = MemTransformerLM(**model_config)
        student.apply(functools.partial(weights_init, args=args))
        student.word_emb.apply(functools.partial(weights_init, args=args))
    else:
        student = None

    args.n_all_param = sum([p.nelement() for p in teacher.parameters()])
    args.n_nonemb_param = sum([p.nelement() for p in teacher.layers.parameters()])

    # optimizer
    if args.optim.lower() == 'sgd':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SGD(sparse_params, lr=args.lr * 2)
            optimizer = optim.SGD(dense_params, lr=args.lr, momentum=args.mom)
        else:
            optimizer = optim.SGD(model.parameters(), lr=args.lr,
                                  momentum=args.mom)
            optimizer_sparse = None
    elif args.optim.lower() == 'adam':
        if args.sample_softmax > 0:
            dense_params, sparse_params = [], []
            for param in model.parameters():
                if param.size() == model.word_emb.weight.size():
                    sparse_params.append(param)
                else:
                    dense_params.append(param)
            optimizer_sparse = optim.SparseAdam(sparse_params, lr=args.lr)
            optimizer = optim.Adam(dense_params, lr=args.lr,
                                   weight_decay=args.weight_decay)
        else:
            optimizer = optim.Adam(model.parameters(), lr=args.lr,
                                   weight_decay=args.weight_decay)
            optimizer_sparse = None
    elif args.optim.lower() == 'adagrad':
        optimizer = optim.Adagrad(model.parameters(), lr=args.lr)
        optimizer_sparse = None
    elif args.optim.lower() == 'lamb':
        optimizer = lamb.Lamb(model.parameters(), lr=args.lr,
                              weight_decay=args.weight_decay)
        optimizer_sparse = None
    elif args.optim.lower() == 'jitlamb':
        teacher_optimizer = lamb.JITLamb(teacher.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
        teacher_optimizer_sparse = None
        if student:
            student_optimizer = lamb.JITLamb(student.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)
            student_optimizer_sparse = None
        else:
            student_optimizer = None
            student_optimizer_sparse = None

    teacher = teacher.to(device)
    if student:
        student = student.to(device)

    scaler = None
    if args.fp16:
        if args.amp == 'pytorch':
            scaler = torch.cuda.amp.GradScaler()
        elif args.amp == 'apex':
            model, optimizer = amp.initialize(
                model,
                optimizer,
                opt_level=args.apex_amp_opt_level,
                )

    if args.multi_gpu == 'ddp' and torch.distributed.is_initialized():
        para_teacher = DistributedDataParallel(teacher,
                                             device_ids=[args.local_rank],
                                             output_device=args.local_rank,
                                             broadcast_buffers=False,
                                             find_unused_parameters=True,
                                             )
        if student:
            para_student = DistributedDataParallel(student,
                                             device_ids=[args.local_rank],
                                             output_device=args.local_rank,
                                             broadcast_buffers=False,
                                             find_unused_parameters=True,
                                             )
        else:
            para_student = None
            
    elif args.multi_gpu == 'dp':
        if args.gpu0_bsz >= 0:
            para_model = BalancedDataParallel(args.gpu0_bsz // args.batch_chunk,
                                              model, dim=1).to(device)
        else:
            para_model = nn.DataParallel(model, dim=1).to(device)
    else:
        para_model = model

    # scheduler
    if args.scheduler == 'cosine':
        if args.max_step_scheduler:
            max_step = args.max_step_scheduler
        else:
            max_step = args.max_step

        teacher_scheduler = optim.lr_scheduler.CosineAnnealingLR(
            teacher_optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
        if student:
            student_scheduler = optim.lr_scheduler.CosineAnnealingLR(
                student_optimizer, max_step - args.warmup_step, eta_min=args.eta_min)
        else:
            student_scheduler = None
        if args.sample_softmax > 0 and teacher_optimizer_sparse is not None:
            teacher_scheduler_sparse = optim.lr_scheduler.CosineAnnealingLR(
                teacher_optimizer_sparse, max_step - args.warmup_step,
                eta_min=args.eta_min)
        else:
            teacher_scheduler_sparse = None
    elif args.scheduler == 'inv_sqrt':
        # originally used for Transformer (in Attention is all you need)
        def lr_lambda(step):
            # return a multiplier instead of a learning rate
            if step == 0 and args.warmup_step == 0:
                return 1.
            else:
                return 1. / (step ** 0.5) if step > args.warmup_step \
                    else step / (args.warmup_step ** 1.5)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.LambdaLR(
                optimizer_sparse,
                lr_lambda=lr_lambda
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'dev_perf':
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, factor=args.decay_rate, patience=args.patience,
            min_lr=args.lr_min,
            )
        if args.sample_softmax > 0 and optimizer_sparse is not None:
            scheduler_sparse = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer_sparse, factor=args.decay_rate, patience=args.patience,
                min_lr=args.lr_min,
                )
        else:
            scheduler_sparse = None
    elif args.scheduler == 'constant':
        pass

    logging.info('=' * 100)
    for k, v in args.__dict__.items():
        logging.info('    - {} : {}'.format(k, v))
    logging.info('=' * 100)
    logging.info('#params = {}'.format(args.n_all_param))
    logging.info('#non emb params = {}'.format(args.n_nonemb_param))

    train_step = 0
    start_epoch = 1
    last_batch = 0
    last_iter = 0
    best_val_loss = None
    cur_loss = float('inf')
    args.train_step = train_step

    teacher_mems = [None for _ in range(args.batch_chunk)]
    if student:
        student_mems = [None for _ in range(args.batch_chunk)]
    else:
        student_mems = None

    if args.restart:
        try:
            checkpoint = load_checkpoint(args.restart)
            model.load_state_dict(checkpoint['model_state'])
            optimizer.load_state_dict(checkpoint['optimizer_state'])
            scheduler.load_state_dict(checkpoint['scheduler_state'])
            if args.fp16:
                if args.amp == 'pytorch':
                    scaler.load_state_dict(checkpoint['amp_state'])
                elif args.amp == 'apex':
                    amp.load_state_dict(checkpoint['amp_state'])
            utils.exp_utils.set_default_rng_states(
                checkpoint['rng_states'], device
            )

            train_mems = [
                checkpoint['memory'][i][utils.distributed.get_rank()]
                for i in range(args.batch_chunk)
            ]

            train_step = checkpoint['train_step']
            start_epoch = checkpoint['epoch']
            last_batch = checkpoint['batch']
            last_iter = checkpoint['last_iter']
            best_val_loss = checkpoint['best_val_loss']

            if train_step >= args.max_step:
                logging.info(f'Loaded checkpoint after {train_step} steps, but '
                             f'this run was scheduled for a total of '
                             f'{args.max_step} steps, exiting')
                sys.exit(1)

            model.apply(functools.partial(update_dropout, args=args))
            model.apply(functools.partial(update_dropatt, args=args))
        except FileNotFoundError:
            logging.info(f'Could not load checkpoint from {args.restart}, '
                         f'starting training from random init')

    meters = {}
    warmup = args.mem_len // args.tgt_len + 2
    meters['train_throughput'] = AverageMeter(warmup=warmup)
    ###########################################################################
    # Train
    ###########################################################################
    # Loop over epochs.
    # At any point you can hit Ctrl + C to break out of training early.

    utils.distributed.barrier()
    start_time = time.time()

    with TimeoutHandler() as timeout_handler:
        try:
            for epoch in itertools.count(start=start_epoch):
                if args.roll:
                    tr_iter.roll(seed=args.seed + epoch)
                train_step, best_val_loss, cur_loss = train(
                    tr_iter, va_iter, teacher, para_teacher, student, para_student, teacher_mems, student_mems,
                    model_config, teacher_optimizer, student_optimizer, teacher_scheduler,
                    student_scheduler, scaler, vocab, epoch, last_batch,
                    last_iter, train_step, best_val_loss, meters,
                    timeout_handler, device, args
                    )

                last_batch = 0
                last_iter = 0

                if train_step == args.max_step:
                    logging.info('-' * 100)
                    logging.info('End of training')
                    break
        except KeyboardInterrupt:
            logging.info('-' * 100)
            logging.info('Exiting from training early')
    utils.distributed.barrier()
    elapsed = time.time() - start_time

    ###########################################################################
    # Test
    ###########################################################################
    summary = {}
    test_path = os.path.join(args.work_dir, 'teachercheckpoint_best.pt')
    if (
        not args.debug
        and not args.no_test
        and not args.no_eval
        and os.path.exists(test_path)
    ):
        # Load the best saved model.
        checkpoint = load_checkpoint(test_path, type='teacher')
        teacher.load_state_dict(checkpoint['model_state'])
        if student:
            test_path = os.path.join(args.work_dir, 'studentcheckpoint_best.pt')
            checkpoint = load_checkpoint(test_path, type='student')
            student.load_state_dict(checkpoint['model_state'])

        # Run on test data.
        utils.distributed.barrier()
        test_start_time = time.time()

        teacher_test_loss = evaluate(te_iter, teacher, args)
        teacher_test_loss = utils.distributed.all_reduce_item(teacher_test_loss, 'mean')
        if student:
            student_test_loss = evaluate(te_iter, student, args)
            student_test_loss = utils.distributed.all_reduce_item(student_test_loss, 'mean')

        utils.distributed.barrier()
        test_elapsed = time.time() - test_start_time

        logging.info('=' * 100)
        if args.dataset in ['enwik8', 'text8']:
            logging.info('| End of training | test time: {:5.2f}s | test loss {:5.2f} | test bpc {:9.5f}'.format(
                test_elapsed, teacher_test_loss, teacher_test_loss / math.log(2)))
        else:
            logging.info('| End of training | test time: {:5.2f}s | teacher test loss {:5.2f} | teacher test ppl {:9.3f}'.format(
                test_elapsed, teacher_test_loss, math.exp(teacher_test_loss)))
            if student:
                logging.info('| End of training | student test loss {:5.2f} | student test ppl {:9.3f}'.format(
                    student_test_loss, math.exp(student_test_loss)))
        logging.info('=' * 100)

        if args.local_rank == 0:
            wandb.log({'teacher_test_ppl': math.exp(teacher_test_loss)},step=train_step)
            if student:
                wandb.log({'student_test_ppl': math.exp(student_test_loss)},step=train_step)
            
        summary.update({
            'test_elapsed': test_elapsed,
            'teacher_test_loss': teacher_test_loss,
            })

        if args.dataset in ['enwik8', 'text8']:
            summary['teacher_test_bits_per_character'] = teacher_test_loss / math.log(2)
        else:
            summary['teacher_test_perplexity'] = math.exp(teacher_test_loss)
            if student:
                summary['student_test_perplexity'] = math.exp(student_test_loss)

    logging.info(f'Training time: {(elapsed / 60):.2f} minutes')
    logging.info(f'Training throughput: {meters["train_throughput"].avg:.2f} tok/s')

    if best_val_loss:
        best_val_perplexity = math.exp(best_val_loss)
    else:
        best_val_perplexity = None

    summary.update({
        'train_throughput': meters['train_throughput'].avg,
        'train_elapsed': elapsed / 60,
        'train_loss': cur_loss,
        'valid_loss': best_val_loss,
        'valid_perplexity': best_val_perplexity,
        })
    dllogger.log(step=tuple(), data=summary)

    passed = benchmark(
        target_perplexity=args.target_perplexity,
        test_perplexity=best_val_perplexity,
        target_throughput=args.target_throughput,
        test_throughput=meters['train_throughput'].avg
        )
    if not passed:
        sys.exit(1)


if __name__ == "__main__":
    # Disable profiling executor
    try:
        torch._C._jit_set_profiling_executor(False)
        torch._C._jit_set_profiling_mode(False)
    except AttributeError:
        pass

    # Before we do anything with models, we want to ensure that we get fp16
    # execution of torch.einsum in APEX AMP.
    # Otherwise it'll default to "promote" mode, and we'll get fp32 operations.
    # Note that running `--apex_amp_opt_level O2` will remove the need for this
    # code, but it is still valid.
    if 'apex' in sys.modules:
        amp.register_half_function(torch, 'einsum')

    main()
