import math
import time
import argparse


import torch
import torch.cuda
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR

import model
import corpus
import custom_optimizers as OP

parser = argparse.ArgumentParser(description='PyTorch PennTreeBank LSTM LM')
parser.add_argument('--data', type=str, default='./input',
                    help='location of the data corpus')
parser.add_argument('--bptt', type=int, default=35)
parser.add_argument('--epochs', type=int, default=39)
parser.add_argument('--clip', type=float, default=5)
parser.add_argument('--dropout', type=float, default=0.5)

parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--embed_size', type=int, default=650)
parser.add_argument('--num_hidden', type=int, default=650)
parser.add_argument('--batch_size', type=int, default=20)

parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--optimizer', type=str, default='adam')

parser.add_argument('--beta1', type=float, default=0.9,
                    help='Beta1 Hyperparam for SAdam')
parser.add_argument('--gamma', type=float, default=0.9,
                    help='Gamma Hyperparam for SAdam')
args = parser.parse_args()

torch.manual_seed(1111)


def get_optimizer(params, name, lr, convex=False,
                  decay=0, beta_1=0.9, gamma=0.9):

    assert name in ["adam", "adamnc", "sadam", "amsgrad", "sgd",
                    "scrms", "scadagrad", "ogd"], "Unknown Optimization"

    optimizers = {
        "adam": torch.optim.Adam(params, lr=lr, weight_decay=decay),
        "sgd": torch.optim.SGD(params, lr=lr),
        "amsgrad": torch.optim.Adam(
                        params, lr=lr, amsgrad=True, weight_decay=decay),
        "scrms": OP.SC_RMSprop(
                        params, lr=lr, weight_decay=decay, convex=convex),
        "scadagrad": OP.SC_Adagrad(
                        params, lr=lr, weight_decay=decay, convex=convex),
        "ogd": OP.SC_SGD(params, convex=convex, lr=lr, weight_decay=decay),
        "sadam": OP.SAdam(
                        params, lr=lr, weight_decay=decay,
                        beta_1=beta_1, gamma=gamma)
    }
    return optimizers[name]


def batchify(data, batch_size):
    # Work out how cleanly we can divide the dataset into batch_size parts.
    nbatch = data.size(0) // batch_size
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * batch_size)
    # Evenly divide the data across the batch_size batches.
    data = data.view(batch_size, -1).t().contiguous()
    return data


# Training code
def repackage_hidden(h):
    # detach
    return tuple(v.detach() for v in h)


def get_batch(source, i):
    # source: size(total_len//batch_size, batch_size)
    seq_len = min(args.bptt, len(source) - 1 - i)
    data = source[i:i+seq_len].clone().detach()
    target = source[i+1:i+1+seq_len].clone().detach().view(-1)
    return data, target


def evaluate_epoch(data_source, device):
    # Turn on evaluation mode which disables dropout.
    with torch.no_grad():
        model.eval()

        total_loss = 0
        ntokens = len(corpus.dictionary)

        hidden = model.init_hidden(eval_batch_size)
        for i in range(0, data_source.size(0) - 1, args.bptt):
            data, targets = get_batch(data_source, i)
            data, targets = data.to(device), targets.to(device)
            output, hidden = model(data, hidden)

            total_loss += len(data) * loss_criterion(output, targets).data
            # hidden = repackage_hidden(hidden)
        return total_loss / len(data_source)


def training_epoch(optimizer, device):
    # choose a optimizer

    model.train()
    total_loss = 0
    start_time = time.time()
    hidden = model.init_hidden(args.batch_size)
    # train_data size(batchcnt, batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, args.bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was
        # If we didn't, the model would try backpropagating to the start
        hidden = repackage_hidden(hidden)
        data, targets = data.to(device), targets.to(device)
        output, hidden = model(data, hidden)
        loss = loss_criterion(output, targets)
        optimizer.zero_grad()
        loss.backward()

        # `clip_grad_norm` helps prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        total_loss += loss.data

        if batch % interval == 0 and batch > 0:
            cur_loss = total_loss / interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | ms/batch {:5.2f} | '
                  'loss {:5.4f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // args.bptt,
                    elapsed * 1000 / interval, cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()


if __name__ == "__main__":
    # Load data
    interval = 200
    best_val_loss = 1e12
    eval_batch_size = 10

    corpus = corpus.Corpus(args.data)
    ntokens = len(corpus.dictionary)

    test_data = batchify(corpus.test, eval_batch_size)
    val_data = batchify(corpus.valid, eval_batch_size)
    train_data = batchify(corpus.train, args.batch_size)

    # Build the model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.RNNModel(ntokens, args.embed_size,
                           args.num_hidden, args.num_layers, args.dropout)

    model.to(device)
    print(model)
    loss_criterion = nn.CrossEntropyLoss()

    # Path at which the model is to be saved
    model_save_path = args.optimizer + str(args.num_layers) + '.pt'

    # Load the appropriate optimizer
    optimizer = get_optimizer(
                    list(model.parameters()), args.optimizer,
                    args.lr, args.beta1, args.gamma)
    scheduler = StepLR(optimizer, step_size=4, gamma=0.5)

    # Training Loop
    for epoch in range(1, args.epochs+1):
        epoch_start = time.time()
        # Training and validation step
        training_epoch(optimizer, device)
        val_loss = evaluate_epoch(val_data, device)
        scheduler.step()

        print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.4f} | '
              'valid ppl {:8.2f}'.format(
                epoch, (time.time() - epoch_start),
                val_loss, math.exp(val_loss)))

        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            with open(model_save_path, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss

    # Load the best saved model.
    with open(model_save_path, 'rb') as f:
        model = torch.load(f)

    # Evaluation on Test Data
    test_loss = evaluate_epoch(test_data, device)
    print('| End of training | test loss {:5.4f} | test ppl {:8.2f}'.format(
        test_loss, math.exp(test_loss)))
