# coding: utf-8
import argparse
import os
import sys
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
try:
    from apex import amp
except:
    print('Failed to import apex. You can still train with --precision {float|double}.')

from mup.coord_check import get_coord_data, plot_coord_data
from mup import MuAdam, MuSGD, get_shapes, make_base_shapes, set_base_shapes

import tokenizer
import model as mdl


###############################################################################
# Training code
###############################################################################

# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i, bptt):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

def batchloader(train_data, bptt):
    for _, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        yield get_batch(train_data, i, bptt)

def batchify(data, bsz, device):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)
    
def setprec(t, precision):
    if precision == 'half':
        # do nothing since this is handled by AMP
        return t
    elif precision == 'float':
        return t.float()
    elif precision == 'double':
        return t.double()
    else:
        raise ValueError(f'invalid precision string {args.precision}')

def coord_check(mup, lr, optimizer, batch_size, nsteps, nseeds, data_dir, args, plotdir='', legend=False):

    corpus = tokenizer.Corpus(data_dir)
    ntokens = len(corpus.dictionary)

    def gen(w, standparam=False):
        import model as _model
        def f():
            model = _model.TransformerModel(args, ntokens, ninp=w, nhead=args.nhead, nhid=w*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
                                            tied=args.tied, bias=args.bias, encoder_var=args.init_var, 
                                            decoder_var=args.init_var, standparam=standparam).to(args.device)
            model = setprec(model, args.precision)
            if standparam:
                set_base_shapes(model, None)
            else:
                assert args.load_base_shapes, 'load_base_shapes needs to be nonempty'
                set_base_shapes(model, args.load_base_shapes)
            return model
        return f

    optimizer = optimizer.replace('mu', '')
    # These widths are way too high for a measly 40GB GPU to handle
    # widths = 2**np.arange(7, 14 if optimizer=='sgd' else 12)
    widths = 2**np.arange(2, 10 if optimizer=='sgd' else 10)
    models = {w: gen(w, standparam=not mup) for w in widths}

    
    train_data = batchify(corpus.train, batch_size, device=args.device)
    df = get_coord_data(models, batchloader(train_data, args.bptt), mup=mup, lr=lr, optimizer=optimizer, flatten_output=True, nseeds=nseeds, nsteps=nsteps, lossfn='nll')

    prm = 'μP' if mup else 'SP'
    return plot_coord_data(df, legend=legend,
        save_to=os.path.join(plotdir, f'{prm.lower()}_trsfmr_{optimizer}_coord.png'),
        suptitle=f'{prm} Transformer {optimizer} lr={lr} nseeds={nseeds}',
        face_color='xkcd:light grey' if not mup else None)

def find_gpu_with_most_memory():
    import subprocess
    # https://discuss.pytorch.org/t/access-gpu-memory-usage-in-pytorch/3192/4
    print('finding GPU with most memory...')

    # Run nvidia-smi and parse the output
    cmd = "nvidia-smi --query-gpu=index,memory.free --format=csv,noheader"
    output = subprocess.check_output(cmd.split()).decode().strip().split("\n")
    gpu_memory = [int(x.split(",")[1].split()[0]) for x in output]
    largest_gpu, largest_memory = np.argmax(gpu_memory), np.max(gpu_memory)
    print("GPU {} has {} MB memory free".format(largest_gpu, largest_memory))
    return largest_gpu


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description=
    '''
    PyTorch Wikitext-2 Transformer Language Model, with μP.

    To train a μP model, one needs to first specify the base shapes. To save base shapes info, run, for example,

        python main.py --d_model 256 --save_base_shapes width256.bsh

    To train using MuAdam, run

        python main.py --d_model 256 --load_base_shapes width256.bsh --cuda --optimizer muadam

    To perform coord check, run

        python main.py --load_base_shapes width256.bsh --optimizer sgd --lr 0.5 --cuda --coord_check

        python main.py --load_base_shapes width256.bsh --optimizer adam --lr 0.01 --cuda --coord_check

    If you don't specify a base shape file, then you are using standard parametrization

        python main.py --d_model 256 --cuda --optimizer muadam

    Note that models of different depths need separate `.bsh` files.
    ''', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--data', type=str, default='./data/wikitext-103',
                        help='location of the data corpus')
    parser.add_argument('--max_data', type=int, default=300000000,
                            help='maximum number of tokens in the dataset')                        
    parser.add_argument('--bias', action='store_true',
                        help='use bias')
    parser.add_argument('--save_base_shapes', type=str, default='',
                        help='file location to save base shapes at')
    parser.add_argument('--load_base_shapes', type=str, default='',
                        help='file location to load base shapes from')
    parser.add_argument('--d_model', type=int, default=256,
                        help='width of the model')
    parser.add_argument('--ffn_ratio', type=int, default=1,
                        help='the ratio of d_ffn to d_model')
    parser.add_argument('--nlayers', type=int, default=2,
                        help='number of layers')
    parser.add_argument('--nhead', type=int, default=2,
                        help='the number of heads in the encoder/decoder of the transformer model')
    parser.add_argument('--lr', type=float, default=0.5,
                        help='initial learning rate')
    parser.add_argument('--lr_rescale', action='store_true',
                            help='rescale lr with \sqrt N')                        
    parser.add_argument('--momentum', type=float, default=0,
                        help='momentum')
    parser.add_argument('--output_mult', type=float, default=1,
                        help='output is multiplied by sqrt(output_mult/d_model)')
    parser.add_argument('--input_mult', type=float, default=1,
                        help='input is multiplied by sqrt(input_mult*d_model)')
    parser.add_argument('--attn_mult', type=float, default=1,
                        help='attn is multiplied by sqrt(attn_mult)/head_dim')
    parser.add_argument('--optimizer', default='musgd', choices=['sgd', 'musgd', 'adam', 'muadam'])
    parser.add_argument('--init_var', type=float, default=1,
                        help='weights are initialized with variance init_var/ninp')
    parser.add_argument('--clip', type=float, default=0.25,
                        help='gradient clipping')
    parser.add_argument('--epochs', type=int, default=100,
                        help='upper epoch limit')
    parser.add_argument('--batch_size', type=int, default=32, metavar='N',
                        help='batch size')
    parser.add_argument('--bptt', type=int, default=35,
                        help='sequence length')
    parser.add_argument('--dropout', type=float, default=0.0,
                        help='dropout applied to layers (0 = no dropout)')
    parser.add_argument('--tied', action='store_true',
                        help='tie the word embedding and softmax weights')
    parser.add_argument('--seed', type=int, default=1111,
                        help='random seed')
    parser.add_argument('--cuda', action='store_true',
                        help='use CUDA')
    parser.add_argument('--precision', type=str, default='float',
                        help='float | double | half')
    parser.add_argument('--log_interval', type=int, default=200, metavar='N',
                        help='report interval')
    parser.add_argument('--save_dir', type=str, default='./results',
                        help='path to save the final model')
    parser.add_argument('--resume_dir', type=str, default='./results',
                        help='path to resume training')
    parser.add_argument('--log_dir', type=str, default='./results',
                        help='path to save logs')
    parser.add_argument('--coord_check', action='store_true',
                        help='test μ parametrization is correctly implemented by collecting statistics on coordinate distributions for a few steps of training.')
    parser.add_argument('--coord_check_nsteps', type=int, default=3,
                        help='Do coord check with this many steps.')
    parser.add_argument('--coord_check_nseeds', type=int, default=3,
                        help='number of seeds for testing correctness of μ parametrization')
    parser.add_argument('--num_ens', type=int, default=1,
                        help='number of random inits to ensemble over')
    parser.add_argument('--subdivide_epoch', type=int, default=100,
                        help='number of sub-epochs to divide each epoch into')
    parser.add_argument('--eval_point', action='store_true',
                        help='Evaluate on a held out validation point')

    args = parser.parse_args()
    
    print(args)
    dataset_name = args.data.split('/')[-1]
    
    save_str = f'mutrans_{dataset_name}_E={args.num_ens}_d={args.d_model}_nheads={args.nhead}_nlayers={args.nlayers}_ffnrat={args.ffn_ratio}_lr={args.lr}_om={args.output_mult}_drp={args.dropout}_opt={args.optimizer}_bs={args.batch_size}_bptt={args.bptt}'
    if args.load_base_shapes:
        base_dmodel = int(args.load_base_shapes.split('width=')[1].split('_')[0])
        save_str += f'_bsh={base_dmodel}'

    if args.max_data != 300000000:
        save_str += f'_maxdata={args.max_data}'
    if args.lr_rescale:
        save_str += '_lrrescale'

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    dev = "cpu"
    if torch.cuda.is_available():
        print(f"WE HAVE CUDA :) with {torch.cuda.device_count()} devices")
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")
        else:
            max_gpu = find_gpu_with_most_memory()
            dev = f"cuda:{max_gpu}"
            torch.cuda.empty_cache()
    else: 
        print("NO CUDA :(")
    device = args.device = torch.device(dev)
    E = args.num_ens
    print(device)
    
    # Load data
    corpus = tokenizer.Corpus(args.data)

    # Starting from sequential data, batchify arranges the dataset into columns.
    # For instance, with the alphabet as the sequence and batch size 4, we'd get
    # ┌ a g m s ┐
    # │ b h n t │
    # │ c i o u │
    # │ d j p v │
    # │ e k q w │
    # └ f l r x ┘.
    # These columns are treated as independent by the model, which means that the
    # dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
    # batch processing.

    eval_batch_size = 10
    train_data = batchify(corpus.train[:args.max_data], args.batch_size, device)
    val_data = batchify(corpus.valid, eval_batch_size, device)
    test_data = batchify(corpus.test, eval_batch_size, device)
    ntokens = len(corpus.dictionary)
        
    if args.data=='./data/wikitext-103':
        tokens_in_epoch = (train_data.size(0) - 1) // args.subdivide_epoch
    else:
        tokens_in_epoch = train_data.size(0) - 1

    def evaluate(models, data_source):
        # Turn on evaluation mode which disables dropout.
        for e in range(E):
            models[e].eval()
        
        total_losses = [0.0 for e in range(E)]
        ntokens = len(corpus.dictionary)
        ens_loss = 0.0
        with torch.no_grad():
            for i in range(0, data_source.size(0) - 1, args.bptt):
                data, targets = get_batch(data_source, i, args.bptt)
                out1 = models[0](data)
                out1 = out1.view(-1, ntokens)
                ens_out = torch.zeros(out1.shape).to(device)
                for e in range(E):
                    output = models[e](data)
                    output = output.view(-1, ntokens) 
                    ens_out += 1.0/E * output
                    total_losses[e] += len(data) * criterion(output, targets).item() / (len(data_source) - 1)
                ens_loss += len(data) * criterion(ens_out, targets).item() / (len(data_source) - 1)
                del ens_out
        return total_losses, ens_loss

    def eval_on_point(models, data_source):
        for e in range(E):
            models[e].eval()
        ntokens = len(corpus.dictionary)
        pt_estimates = torch.zeros(((E, val_data.size(0)//args.bptt + 1)))
        ens_pt_estimates = torch.zeros((val_data.size(0)//args.bptt + 1))
        with torch.no_grad():
            for i in range(0, data_source.size(0) - 1, args.bptt):
                data, targets = get_batch(data_source, i, args.bptt)
                out1 = models[0](data)
                out1 = out1.view(-1, ntokens)
                ens_out = torch.zeros(out1.shape).to(device)
                token_idx = targets[-1].item()
                for e in range(E):
                    output = models[e](data)
                    output = output.view(-1, ntokens) 
                    ens_out += 1.0/E * output
                    pt_estimates[e, i//args.bptt] = output[-1, token_idx].item()
                
                ens_pt_estimates[i//args.bptt] = ens_out[-1, token_idx].item()
                
                del ens_out
        return pt_estimates, ens_pt_estimates

    def train(models, optimizers, epoch, tokens_in_epoch=tokens_in_epoch):
        
        first_losses = [0.0 for e in range(E)]
        
        starting_index = tokens_in_epoch * (epoch - 1)

        for e in range(E):
            # Turn on training mode which enables dropout.
            models[e].train()
            
        total_losses = [0. for e in range(E)]
        epoch_losses = [0. for e in range(E)]
        ens_loss = 0.0
        start_time = time.time()
        ntokens = len(corpus.dictionary)
        first_loss = None

        for batch, i in enumerate(range(starting_index, starting_index+tokens_in_epoch, args.bptt)):
            data, targets = get_batch(train_data, i, args.bptt)
            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            output = models[0](data)
            output = output.view(-1, ntokens)
            ens_out = torch.zeros(output.shape).to(device)
            
            for e in range(E):
                optimizer = optimizers[e]
                model = models[e]
                optimizer.zero_grad()

                output = model(data)
                output = output.view(-1, ntokens)
                ens_out += 1/E * output
                
                loss = criterion(output, targets)
                if torch.isnan(loss):
                    exit(0)
                if args.precision == 'half':
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if args.clip > 0:
                    # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
                    if args.precision == 'half':
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.clip)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

                optimizer.step()

                total_losses[e] += loss.item()
                epoch_losses[e] += len(data) * loss.item()/ tokens_in_epoch

                if batch % args.log_interval == 0 and batch > 0:
                    cur_loss = total_losses[e] / args.log_interval
                    elapsed = time.time() - start_time
                    print('e = {:3d} | epoch {:3d} | {:5d}/{:5d} batches | lr {:02.5f} | ms/batch {:5.2f} | '
                            'loss {:5.2f} | ppl {:8.2f}'.format(e,
                        epoch, batch, tokens_in_epoch // args.bptt, lr,
                        elapsed * 1000 / args.log_interval, cur_loss, np.exp(cur_loss)))
                    total_losses[e] = 0.0
                    start_time = time.time()
                    if first_loss is None:
                        first_loss = cur_loss
                first_losses[e] = first_loss
            
            ens_loss += len(data) * criterion(ens_out, targets).item() / tokens_in_epoch
            del ens_out
            
            # for each e append epoch_loss
        return epoch_losses, first_losses, ens_loss

    
    # Set learning rate
    lr = args.lr
    if args.lr_rescale:
        lr = args.lr * 64.0 / np.sqrt(args.d_model)
        print(f'lr rescaled by {64.0/np.sqrt(args.d_model)}')
    
    if args.coord_check:
        print('testing parametrization')
        import os
        os.makedirs('coord_checks', exist_ok=True)
        plotdir = 'coord_checks'
        coord_check(mup=True, lr=lr, optimizer=args.optimizer, batch_size=args.batch_size, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, data_dir=args.data, args=args, plotdir=plotdir, legend=False)
        coord_check(mup=False, lr=lr, optimizer=args.optimizer, batch_size=args.batch_size, nsteps=args.coord_check_nsteps, nseeds=args.coord_check_nseeds, data_dir=args.data, args=args, plotdir=plotdir, legend=False)
        import sys; sys.exit()

    # Build the models
    models = [ ]
    for e in range(E):
        torch.manual_seed(e)
        models += [ mdl.TransformerModel(args, ntokens, ninp=args.d_model, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio, nlayers=args.nlayers, dropout=args.dropout,
                                    tied=args.tied, bias=args.bias, encoder_var=args.init_var, 
                                    decoder_var=args.init_var, standparam=args.load_base_shapes=='')]
    
    model = models[0]
    if args.save_base_shapes:
        print(f'saving base shapes at {args.save_base_shapes}')
        base_shapes = get_shapes(model)
        delta_shapes = get_shapes(
            # just need to change whatever dimension(s) we are scaling
            mdl.TransformerModel(args, ntokens, ninp=args.d_model*2, nhead=args.nhead, nhid=args.d_model*args.ffn_ratio*2,
                                    nlayers=args.nlayers, dropout=args.dropout,
                                    tied=args.tied, bias=args.bias, encoder_var=args.init_var, 
                                    decoder_var=args.init_var, standparam=args.load_base_shapes=='')
        )
        make_base_shapes(base_shapes, delta_shapes, savefile=args.save_base_shapes)
        print('done and exit')
        import sys; sys.exit()
    if args.load_base_shapes:
        print(f'loading base shapes from {args.load_base_shapes}')
        for e in range(E):
            set_base_shapes(models[e], args.load_base_shapes)
        print('done')
    else:
        print(f'using own shapes')
        for e in range(E):
            set_base_shapes(models[e], None)
        print('done')

    
    for e in range(E):
        models[e] = models[e].to(device)
        models[e] = setprec(models[e], args.precision)

    #criterion = nn.NLLLoss()
    criterion = nn.CrossEntropyLoss()
    
    if args.save_dir is not None:
        os.makedirs(args.save_dir, exist_ok=True)

    # Loop over epochs.
    best_val_loss = float('inf')


    if args.optimizer == 'sgd':
        optimizers =  [ optim.SGD(models[e].parameters(), lr=lr, momentum=args.momentum) for e in range(E)]
    elif args.optimizer == 'musgd':
        optimizers = [ MuSGD(models[e].parameters(), lr=lr, momentum=args.momentum) for e in range(E) ]
    elif args.optimizer == 'adam':
        optimizers = [ optim.Adam(models[e].parameters(), lr=lr) for e in range(E)] 
    elif args.optimizer == 'muadam':
        optimizers = [ MuAdam(models[e].parameters(), lr=lr) for e in range(E) ]
    else:
        raise ValueError()

    # half-precision black magic
    if args.precision == 'half':
        for e, (model, optimizer) in enumerate(zip(models, optimizers)):
            models[e], optimizers[e] = amp.initialize(
                model,
                optimizer,
                opt_level='O1',
                min_loss_scale=0.0001,
                verbosity=0
                )

    logs = [ [] for e in range(E) ]
    ens_log = []
    point_log = [ [] for e in range(E) ]
    ens_point_log = []
    start_epoch = 0
    if args.resume_dir and os.path.exists(os.path.join(args.resume_dir, save_str+'_model_last.pt')):
        checkpoint_log = torch.load(os.path.join(args.resume_dir, save_str+'_log_last.pt'))
        checkpoint_model = torch.load(os.path.join(args.resume_dir, save_str+'_model_last.pt'))
        for e in range(E):
            models[e].load_state_dict(checkpoint_model['models'][e])
            optimizers[e].load_state_dict(checkpoint_model['optimizers'][e])
            logs[e] = checkpoint_log['logs'][e]
        ens_log = checkpoint_log['ens_log']
            
        if args.precision == 'half':
            amp.load_state_dict(checkpoint_model['amp'])
        start_epoch = checkpoint_log['epoch']
        ens_loss = checkpoint_log['ens_loss']
        #best_val_loss = checkpoint['best_val_loss']
        

    # At any point you can hit Ctrl + C to break out of training early.
    try:
        for epoch in range(start_epoch+1, args.epochs+1):
            epoch_start_time = time.time()
            
            val_losses, val_ens_loss = evaluate(models, val_data) 
            val_losses_mean = 0.0
            for e in range(E):
                val_losses_mean += 1/E * val_losses[e]
            if args.eval_point:
                pt_estimates, ens_point_estimates = eval_on_point(models, val_data)
            train_losses, first_losses, ens_loss = train(models, optimizers, epoch)
            # train_losses is a list of E lists           
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | raw valid loss {:5.2f} | '
                    'raw valid ppl {:8.2f} | ens valid loss {:5.2f} | '
                    'ens valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                            val_losses_mean, np.exp(val_losses_mean), val_ens_loss, np.exp(val_ens_loss)))
            print('-' * 89)
            sys.stdout.flush()
            for e in range(E):
                logs[e].append(dict(
                    epoch=epoch,
                    train_loss=train_losses[e],
                    val_loss=val_losses[e],
                    first_loss=first_losses[e]
                ))
                if args.eval_point:
                    point_log[e].append(dict(
                        epoch=epoch,
                        pt_est=pt_estimates[e]
                    ))
            
            ens_log.append(dict(ens_loss=ens_loss, val_ens_loss=val_ens_loss))
            if args.eval_point:
                ens_point_log.append(dict(epoch=epoch, pt_est=ens_point_estimates))
            # Save the model if the validation loss is the best we've seen so far.
            if args.save_dir is not None:
                if val_ens_loss < best_val_loss:
                    checkpoint_model = {
                        'models': [models[e].state_dict() for e in range(E)],
                        'optimizers': [optimizers[e].state_dict() for e in range(E)],
                        'epoch': epoch
                    }
                    checkpoint_log = {
                        'epoch': epoch,
                        'logs': [logs[e] for e in range(E)],
                        'ens_log': ens_log
                    }
                    if args.eval_point:
                        checkpoint_point = {
                            'epoch': epoch,
                            'point_logs': [point_log[e] for e in range(E)],
                            'ens_point_log': ens_point_log
                        }
                        with open(os.path.join(args.save_dir, save_str+'_point_best.pt'), 'wb') as f:
                            torch.save(checkpoint_point, f)
                    
                    if args.precision == 'half':
                        checkpoint_model['amp'] = amp.state_dict(),
                    with open(os.path.join(args.save_dir, save_str+'_model_best.pt'), 'wb') as f:
                        torch.save(checkpoint_model, f)
                    with open(os.path.join(args.save_dir, save_str+'_log_best.pt'), 'wb') as f:
                        torch.save(checkpoint_log, f)
                    
                    best_val_loss = val_ens_loss
                else:
                    checkpoint_model = {
                        'models': [models[e].state_dict() for e in range(E)],
                        'optimizers': [optimizers[e].state_dict() for e in range(E)],
                        'epoch': epoch
                    }
                    checkpoint_log = {
                        'epoch': epoch,
                        'logs': [logs[e] for e in range(E)],
                        'ens_log': ens_log
                    }
                    if args.eval_point:
                        checkpoint_point = {
                            'epoch': epoch,
                            'point_logs': [point_log[e] for e in range(E)],
                            'ens_point_log': ens_point_log
                        }
                    if args.precision == 'half':
                        checkpoint_model['amp'] = amp.state_dict()
                with open(os.path.join(args.save_dir, save_str+'_model_last.pt'), 'wb') as f:
                    torch.save(checkpoint_model, f)
                with open(os.path.join(args.save_dir, save_str+'_log_last.pt'), 'wb') as f:
                    torch.save(checkpoint_log, f)
                if args.eval_point:
                    with open(os.path.join(args.save_dir, save_str+'_point_last.pt'), 'wb') as f:
                        torch.save(checkpoint_point, f)

    except KeyboardInterrupt:
        print('-' * 89)
        print('Exiting from training early')

    # Load the best saved model.
    if args.save_dir is not None:
        with open(os.path.join(args.save_dir, save_str+'_model_best.pt'), 'rb') as f:
            checkpoint_model = torch.load(f)
            for e in range(E):
                models[e].load_state_dict(checkpoint_model['models'][e])
                optimizers[e].load_state_dict(checkpoint_model['optimizers'][e])
            if args.precision == 'half':
                amp.load_state_dict(checkpoint_model['amp'][0])
        # Run on test data.
        test_losses, ens_test_loss = evaluate(models, test_data)
        print('=' * 89)
        print('| End of training | ens test loss {:5.2f} | ens test ppl {:8.2f}'.format(
            ens_test_loss, np.exp(ens_test_loss)))
        print('=' * 89)
        # logs.append(dict(
        #     epoch='-1',
        #     test_loss=test_loss
        # ))


    with open(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv'), 'w') as f:
        logdf = pd.DataFrame(logs)
        print(os.path.join(os.path.expanduser(args.log_dir), 'logs.tsv'))
        f.write(logdf.to_csv(sep='\t', float_format='%.4f'))
