import argparse
import time
import math
import numpy as np
import torch
import torch.nn as nn
import os
import data
import model
import pickle
from utils import batchify, get_batch, repackage_hidden
from utils import getOptimizer
# from adabound import AdaBound
# from AdaBelief import AdaBelief
# from yogi import Yogi
# from MSVAG import MSVAG
# from RAdam import RAdam
# from AdamW import AdamW
# from fromage import Fromage

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank RNN/LSTM Language Model')
parser.add_argument('--data', type=str, default='./data/penn/',
                    help='location of the data corpus')
parser.add_argument('--model', type=str, default='LSTM',
                    help='type of recurrent net (LSTM, QRNN, GRU)')
parser.add_argument('--emsize', type=int, default=400,
                    help='size of word embeddings')
parser.add_argument('--nhid', type=int, default=1150,
                    help='number of hidden units per layer')
parser.add_argument('--nlayers', type=int, default=3,
                    help='number of layers')
parser.add_argument('--lr', type=float, default=30,
                    help='initial learning rate')
parser.add_argument('--clip', type=float, default=0.25,
                    help='gradient clipping')
parser.add_argument('--epochs', type=int, default=200,
                    help='upper epoch limit')
parser.add_argument('--batch_size', type=int, default=20, metavar='N',
                    help='batch size')
parser.add_argument('--bptt', type=int, default=70,
                    help='sequence length')
parser.add_argument('--dropout', type=float, default=0.4,
                    help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--dropouth', type=float, default=0.3,
                    help='dropout for rnn layers (0 = no dropout)')
parser.add_argument('--dropouti', type=float, default=0.65,
                    help='dropout for input embedding layers (0 = no dropout)')
parser.add_argument('--dropoute', type=float, default=0.1,
                    help='dropout to remove words from embedding layer (0 = no dropout)')
parser.add_argument('--wdrop', type=float, default=0.5,
                    help='amount of weight dropout to apply to the RNN hidden to hidden matrix')
parser.add_argument('--seed', type=int, default=1111,
                    help='random seed')
parser.add_argument('--nonmono', type=int, default=5,
                    help='random seed')
parser.add_argument('--cuda', action='store_true',
                    help='use CUDA')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
                    help='report interval')
randomhash = ''.join(str(time.time()).split('.'))
parser.add_argument('--save', type=str,  default=randomhash+'.pt',
                    help='path to save the final model')
parser.add_argument('--rnnalpha', type=float, default=2,
                    help='alpha L2 regularization on RNN activation (alpha = 0 means no regularization)')
parser.add_argument('--rnnbeta', type=float, default=1,
                    help='beta slowness regularization applied on RNN activiation (beta = 0 means no regularization)')
parser.add_argument('--weight_decay', type=float, default=1.2e-6,
                    help='weight decay applied to all weights')
parser.add_argument('--resume', type=str,  default='',
                    help='path of model to resume')
parser.add_argument('--eps_sqrt', type=float,  default=1e-8)
parser.add_argument('--when', nargs="+", type=int, default=[100,150],
                    help='When (which epochs) to divide the learning rate by 10 - accepts multiple')
parser.add_argument('--run', type=int, default=0,
                    help='Number of runs')
parser.add_argument('--momentum', default=0, type=float, metavar='M',help='momentum')
parser.add_argument('--beta1', default=0.9, type=float, help='Adam coefficients beta_1')
parser.add_argument('--beta2', default=0.999, type=float, help='Adam coefficients beta_2')
parser.add_argument('--eps', default=1e-8, type=float, help='eps for var adam')
parser.add_argument('--lr-gamma', default=0.1, type=float, help='learning rate')
parser.add_argument('--final_lr', default=0.1, type=float,help='final learning rate of AdaBound')
parser.add_argument('--gamma', default=1e-3, type=float,help='convergence speed term of AdaBound')
parser.add_argument('--beta', default=1.0, type=float,help='beta (default: 1.0)')
parser.add_argument('--alpha', default=1.0, type=float,help='alpha (default: 1.0)')
parser.add_argument('--damp', default=1e-2, type=float,help='damp (default: 1e-2)')
parser.add_argument('--damp1', default=1e-2, type=float,help='damp (default: 1e-2)')
parser.add_argument('--damp2', default=1e-2, type=float,help='damp (default: 1e-2)')
parser.add_argument('--period', default=1, type=int,help='period (default: 1)')
parser.add_argument('--hist_length', default=10, type=int,help='hist-length (default: 10)')
parser.add_argument('--pullback_momentum',default='reset',type=str,
                    help='pullback_momentum for lookahead(default: reset)')
parser.add_argument('--optim', default='sgdm', type=str, help='optimizer',
                        choices=['sgdm', 'adam', 'adadelta','adamw', 'adabelief','adabound','amsgrad','lookahead',
                                 'adasam','oaar','aenr','padasam','cr','pca'])

args = parser.parse_args()

args.save = args.save + '-niter-{}'.format(args.epochs) + '-optimizer-{}'.format(args.optim) + '-nlayers{}'.format(args.nlayers) + \
            '-lr{}'.format(args.lr) + '-clip-{}'.format(args.clip) +'-eps{}'.format(args.eps) \
            +'-epsqrt{}'.format(args.eps_sqrt) + '-betas-{}-{}'.format(args.beta1, args.beta2) + '-run{}'.format(args.run) + '-wdecay{}-when-{}'.format(args.weight_decay, args.when)

args.tied = True

results = {'trainloss':[],'trainppl':[],
            'valloss':[],'valppl':[],'testppl':[],'testloss':[]}

# if not os.path.exists('curve'):
#     os.mkdir('curve')

# Set the random seed manually for reproducibility.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    if not args.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    else:
        torch.cuda.manual_seed(args.seed)

###############################################################################
# Load data
###############################################################################

def model_save(fn):
    with open(fn, 'wb') as f:
        torch.save([model, criterion, optimizer], f)

def model_load(fn):
    global model, criterion, optimizer
    with open(fn, 'rb') as f:
        model, criterion, optimizer = torch.load(f)

import os
import hashlib
fn = 'corpus.{}.data'.format(hashlib.md5(args.data.encode()).hexdigest())
if os.path.exists(fn):
    print('Loading cached dataset...')
    corpus = torch.load(fn)
else:
    print('Producing dataset...')
    corpus = data.Corpus(args.data)
    torch.save(corpus, fn)

eval_batch_size = 10
test_batch_size = 1
train_data = batchify(corpus.train, args.batch_size, args)
val_data = batchify(corpus.valid, eval_batch_size, args)
test_data = batchify(corpus.test, test_batch_size, args)

###############################################################################
# Build the model
###############################################################################

from splitcross import SplitCrossEntropyLoss
criterion = None

ntokens = len(corpus.dictionary)
model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, args.nlayers, args.dropout, args.dropouth, args.dropouti, args.dropoute, args.wdrop, args.tied)
###
if args.resume:
    print('Resuming model ...')
    model_load(args.resume)
    optimizer.param_groups[0]['lr'] = args.lr
    model.dropouti, model.dropouth, model.dropout, args.dropoute = args.dropouti, args.dropouth, args.dropout, args.dropoute
    if args.wdrop:
        from weight_drop import WeightDrop
        for rnn in model.rnns:
            if type(rnn) == WeightDrop: rnn.dropout = args.wdrop
            elif rnn.zoneout > 0: rnn.zoneout = args.wdrop
###
if not criterion:
    splits = []
    if ntokens > 500000:
        # One Billion
        # This produces fairly even matrix mults for the buckets:
        # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422
        splits = [4200, 35000, 180000]
    elif ntokens > 75000:
        # WikiText-103
        splits = [2800, 20000, 76000]
    print('Using', splits)
    criterion = SplitCrossEntropyLoss(args.emsize, splits=splits, verbose=False)
###
if args.cuda:
    model = model.cuda()
    criterion = criterion.cuda()
###
params = list(model.parameters()) + list(criterion.parameters())
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Args:', args)
print('Model total parameters:', total_params)

###############################################################################
# Training code
###############################################################################

def evaluate(data_source, batch_size=10):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    if args.model == 'QRNN': model.reset()
    total_loss = 0
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, args.bptt):
        data, targets = get_batch(data_source, i, args, evaluation=True)
        output, hidden = model(data, hidden)
        total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss.item() / len(data_source)


def train():
    # Turn on training mode which enables dropout.
    if args.model == 'QRNN': model.reset()
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(args.batch_size)
    batch, i = 0, 0
    indices = np.random.permutation(train_data.size(0)-1-1)
    sum_loss = 0
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, args.bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt
        model.train()
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.

        def closure():
            torch.set_grad_enabled(True)
            hiddenx = repackage_hidden(hidden)
            optimizer.zero_grad()

            #output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
            output, hiddenx, rnn_hs, dropped_rnn_hs = model(data, hiddenx, return_h=True)
            raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets)

            loss = raw_loss
            # Activiation Regularization
            if args.rnnalpha: loss = loss + sum(args.rnnalpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
            # Temporal Activation Regularization (slowness)
            if args.rnnbeta: loss = loss + sum(args.rnnbeta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
            loss.backward()

            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            if args.clip: torch.nn.utils.clip_grad_norm_(params, args.clip)
            torch.set_grad_enabled(False)
            return loss,raw_loss,hiddenx

        torch.set_grad_enabled(False)
        loss,raw_loss,hidden = optimizer.step(closure)
        torch.set_grad_enabled(True)

        total_loss += raw_loss.data
        sum_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2

        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                epoch, batch, len(train_data) // args.bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss), cur_loss / math.log(2)))
            total_loss = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len
    #return math.exp(cur_loss)
    return sum_loss / batch

# Loop over epochs.
lr = args.lr
best_val_loss = []
stored_loss = 100000000

# At any point you can hit Ctrl + C to break out of training early.
try:
    #optimizer = None
    # Ensure the optimizer is optimizing params, which includes both the model's weights as well as the criterion's weight (i.e. Adaptive Softmax)
    optimizer = getOptimizer(args, params)
    # if args.optimizer == 'sgd':
    #     optimizer = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
    # if args.optimizer == 'adam':
    #     optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay=args.wdecay)
    # if args.optimizer == 'fromage':
    #     optimizer = Fromage(params, lr=args.lr)
    # if args.optimizer == 'adamw':
    #     optimizer = AdamW(params, lr=args.lr, weight_decay=args.wdecay)
    # if args.optimizer == 'radam':
    #     optimizer = RAdam(params, lr=args.lr, weight_decay=args.wdecay)
    # if args.optimizer.lower() == 'adabelief':
    #     optimizer = AdaBelief(params, lr=args.lr, weight_decay=args.wdecay,
    #                          eps=args.eps, betas=(args.beta1, args.beta2))
    # if args.optimizer == 'adabound':
    #     optimizer = AdaBound(params, lr=args.lr, weight_decay=args.wdecay, final_lr=30, gamma=1e-3)
    # if args.optimizer == 'amsbound':
    #     optimizer = AdaBound(params, lr=args.lr, weight_decay=args.wdecay, final_lr=30, gamma=1e-3, amsbound=True)
    # elif args.optimizer == 'yogi':
    #     optimizer =  Yogi(params, args.lr, betas=(args.beta1, args.beta2),
    #                       weight_decay=args.wdecay)
    # elif args.optimizer == 'msvag':
    #     optimizer = MSVAG(params, args.lr, betas=(args.beta1, args.beta2),
    #                       weight_decay=args.wdecay)
    # if args.optimizer == 'aar':
    #     sgd = torch.optim.SGD(params, lr=args.lr, weight_decay=args.wdecay)
    #     optimizer = torch.optim.AAR(sgd,period=1,hist_length=10)
    
    print('t0' in optimizer.param_groups[0])

    for epoch in range(1, args.epochs+1):
        epoch_start_time = time.time()
        train_loss = train()
        results['trainloss'].append(float(train_loss))
        results['trainppl'].append(math.exp(float(train_loss)))
        if 't0' in optimizer.param_groups[0]:
            #tmp = {}
            #for prm in model.parameters():
            #    tmp[prm] = prm.data.clone()
            #    if 'ax' in optimizer.state[prm]:
            #        prm.data = optimizer.state[prm]['ax'].clone()

            val_loss2 = evaluate(val_data)
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
                    epoch, (time.time() - epoch_start_time), val_loss2, math.exp(val_loss2), val_loss2 / math.log(2)))
            print('-' * 89)

            if val_loss2 < stored_loss:
                model_save(args.save)
                print('Saving Averaged!')
                stored_loss = val_loss2

            #for prm in model.parameters():
            #    prm.data = tmp[prm].clone()

            #val_losses.append(math.exp(val_loss2))
            results['valloss'].append(float(val_loss2))
            results['valppl'].append(math.exp(float(val_loss2)))
        else:
            val_loss = evaluate(val_data, eval_batch_size)
            print('-' * 89)
            print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
                'valid ppl {:8.2f} | valid bpc {:8.3f}'.format(
              epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss), val_loss / math.log(2)))
            print('-' * 89)

            if val_loss < stored_loss:
                model_save(args.save)
                print('Saving model (new best validation)')
                stored_loss = val_loss

            if args.optim == 'sgd' and 't0' not in optimizer.param_groups[0] and (len(best_val_loss)>args.nonmono and val_loss > min(best_val_loss[:-args.nonmono])):
                print('Switching to ASGD')
                optimizer = torch.optim.ASGD(model.parameters(), lr=args.lr, t0=0, lambd=0., weight_decay=args.wdecay)


            best_val_loss.append(val_loss)

            #val_losses.append(math.exp(val_loss))
            results['valloss'].append(float(val_loss))
            results['valppl'].append(math.exp(float(val_loss)))

        if epoch in args.when:
            print('Saving model before learning rate decreased')
            model_save('{}.e{}'.format(args.save, epoch))
            print('Dividing learning rate by 10')
            if args.optim in ['lookahead', 'adasam', 'oaar', 'aenr']:
                for param_group in optimizer.optimizer.param_groups:
                    param_group['lr'] /= 10.
                if args.optim in ['adasam','oaar','aenr']:
                    optimizer.alpha = optimizer.alpha * 0.06
                    optimizer.beta = optimizer.beta * 0.06
            else:
                for param_group in optimizer.param_groups:
                    param_group['lr'] /= 10.
                if args.optim in ['cr']:
                    group = optimizer.param_groups[0]
                    group['mix'] /= 10.

        # results['trainloss'].append(train_loss)
        # results['trainppl'].append(math.exp(train_loss))
        # results['valloss'].append(val_loss)
        # results['valppl'].append(math.exp(val_loss))
        pickle.dump(results, open('output.ser', 'wb'))
        #print(train_losses)
        # torch.save({'train_loss': train_losses, 'test_loss': val_losses},
        #            os.path.join('curve', args.save))

except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

# Load the best saved model.
model_load(args.save)

# Run on test data.
test_loss = evaluate(test_data, test_batch_size)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f} | test bpc {:8.3f}'.format(
    test_loss, math.exp(test_loss), test_loss / math.log(2)))
print('=' * 89)

results['testloss'].append(float(test_loss))
results['testppl'].append(math.exp(float(test_loss)))
pickle.dump(results, open('output.ser', 'wb'))
