#!/usr/bin/env python3
import argparse
import time

from torch import cuda
from torch.nn.init import xavier_uniform_

from data import Dataset
from trunc_latent import StochasticEmbeddingLM
from utils import *

parser = argparse.ArgumentParser()

# Data path options
parser.add_argument('--train_file', default='data/ptb-train.pkl')
parser.add_argument('--val_file', default='data/ptb-val.pkl')
parser.add_argument('--save_path', default='latent-embed.pt', help='where to save the model')
parser.add_argument('--load_checkpoint', default=None, help='where to load the model')
# Model options
parser.add_argument('--z_dim', default=64, type=int, help='latent dimension')
parser.add_argument('--h_dim', default=512, type=int, help='hidden dim for variational LSTM')
parser.add_argument('--prior_layers', default=3, type=int, help='prior layers')
parser.add_argument('--prior_made_upscale', default=20, type=int, help='prior made upscale')
parser.add_argument('--word_drop', default=0.05, type=float, help='word dropout')
parser.add_argument('--dropout', default=0.1, type=float, help='dropout')
parser.add_argument('--unk_latent', action='store_true', help='estimate unk latent')
parser.add_argument('--argmax_flow', action='store_true', help='use argmax flow')
parser.add_argument('--no-truncation', action='store_true', help='use CatNF')
parser.add_argument('--conditional_word_prior', action='store_true', help='word priors are conditional on context')
# Optimization options
parser.add_argument('--num_epochs', default=10, type=int, help='number of training epochs')
parser.add_argument('--lr', default=0.001, type=float, help='starting learning rate')
parser.add_argument('--max_grad_norm', default=3, type=float, help='gradient clipping parameter')
parser.add_argument('--weight_decay', default=1e-6, type=float, help='gradient clipping parameter')
parser.add_argument('--max_length', default=200, type=float, help='max sentence length cutoff start')
parser.add_argument('--len_incr', default=1, type=int, help='increment max length each epoch')
parser.add_argument('--final_max_length', default=200, type=int, help='final max length cutoff')
parser.add_argument('--beta1', default=0.75, type=float, help='beta1 for adam')
parser.add_argument('--beta2', default=0.999, type=float, help='beta2 for adam')
parser.add_argument('--gpu', default=0, type=int, help='which gpu to use')
parser.add_argument('--seed', default=3435, type=int, help='random seed')
parser.add_argument('--print_every', type=int, default=1000, help='print stats after N batches')
parser.add_argument('--test-only', action='store_true', help='Only run eval')


def main(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    train_data = Dataset(args.train_file)
    val_data = Dataset(args.val_file)
    train_sents = train_data.batch_size.sum()
    vocab_size = int(train_data.vocab_size)
    max_len = max(val_data.sents.size(1), train_data.sents.size(1))
    print('Train: %d sents / %d batches, Val: %d sents / %d batches' %
          (train_data.sents.size(0), len(train_data), val_data.sents.size(0), len(val_data)))
    print('Vocab size: %d, Max Sent Len: %d' % (vocab_size, max_len))
    print('Save Path', args.save_path)
    cuda.set_device(args.gpu)
    if not args.test_only:
        model = StochasticEmbeddingLM(
            vocab=vocab_size,
            unk_idx=train_data.word2idx['<unk>'], pad_idx=-1,
            h_dim=args.h_dim, z_dim=args.z_dim,
            argmax_flow=args.argmax_flow,
            prior_layers=args.prior_layers,
            prior_made_upscale=args.prior_made_upscale,
            prior_word_drop=args.word_drop,
            prior_dropout=args.dropout,
            conditional_word_prior=args.conditional_word_prior,
            unk_encoder=args.unk_latent,
            truncated=not args.no_truncation
        )

        print("model architecture")
        if args.load_checkpoint is not None:
            checkpoint = torch.load(args.load_checkpoint)
            model = checkpoint['model'].cuda()
            train_data.word2idx = checkpoint['word2idx']
            train_data.idx2word = checkpoint['idx2word']

        print(model)
        model.train()
        model.cuda()
        optimizer = torch.optim.Adam(
            model.parameters(), lr=args.lr,
            betas=(args.beta1, args.beta2),
            weight_decay=args.weight_decay
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min', factor=0.5, patience=1,
            min_lr=5e-5,
            verbose=True
        )

        epoch = 0
        print('--------------------------------')
        print('Checking validation perf...')
        best_val_ppl, val_f1 = eval(val_data, model)
        print('--------------------------------')

        while epoch < args.num_epochs:
            start_time = time.time()
            epoch += 1
            print('Starting epoch %d' % epoch)
            train_nll = 0.
            train_kl = 0.
            num_sents = 0.
            num_words = 0.
            b = 0
            for i in np.random.permutation(len(train_data)):
                b += 1
                sents, length, batch_size, _, gold_spans, gold_binary_trees, _ = train_data[i]
                if length == 1:
                    continue
                length = length + 1
                sents = sents.cuda()
                nll, kl = model(sents)
                loss = (nll + kl).mean()
                if epoch > 1:
                    if loss < 10. or torch.isfinite(loss):
                        loss.backward()
                    else:
                        print("SKIP")
                else:
                    loss.backward()
                train_nll += nll.sum().item()
                train_kl += kl.sum().item()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                optimizer.zero_grad()
                num_sents += batch_size
                num_words += batch_size * length  # we implicitly generate </s> so we explicitly count it

                if b % args.print_every == 0:
                    param_norm = sum([p.norm() ** 2 for p in model.parameters()]).item() ** 0.5
                    gparam_norm = sum([p.grad.norm() ** 2 for p in model.parameters()
                                       if p.grad is not None]).item() ** 0.5
                    log_str = 'Epoch: %d, Batch: %d/%d, |Param|: %.6f, |GParam|: %.2f,  LR: %.4f, ' + \
                              'ReconPPL: %.2f, KL: %.4f, PPLBound: %.2f, ValPPL: %.2f, ' + \
                              'Throughput: %.2f examples/sec'
                    print(log_str %
                          (epoch, b, len(train_data), param_norm, gparam_norm,
                           optimizer.param_groups[0]['lr'],
                           np.exp(train_nll / num_words), train_kl / num_sents,
                           np.exp((train_nll + train_kl) / num_words), best_val_ppl,
                           num_sents / (time.time() - start_time)))
                    train_nll = 0.
                    train_kl = 0.
                    num_sents = 0.
                    num_words = 0.
                    start_time = time.time()

            args.max_length = min(args.final_max_length, args.max_length + args.len_incr)
            print('--------------------------------')
            print('Checking validation perf...')
            val_ppl, val_f1 = eval(val_data, model)
            print('--------------------------------')
            if val_ppl < best_val_ppl:
                best_val_ppl = val_ppl
                checkpoint = {
                    'args': args.__dict__,
                    'model': model.cpu(),
                    'word2idx': train_data.word2idx,
                    'idx2word': train_data.idx2word
                }
                print('Saving checkpoint to %s' % args.save_path)
                torch.save(checkpoint, args.save_path)
                model.cuda()

            scheduler.step(val_ppl)
    checkpoint = torch.load(args.save_path)
    model = checkpoint['model'].cuda()
    train_data.word2idx = checkpoint['word2idx']
    train_data.idx2word = checkpoint['idx2word']
    print(model)
    print('--------------------------------')
    print('Checking validation perf...')
    val_ppl, val_f1 = eval(val_data, model, k=1000)
    print('--------------------------------')



def eval(data, model, k=5, chunk_size=5):
    chunk_size = min(k, chunk_size)
    model.eval()
    num_sents = 0
    num_words = 0
    total_nll = 0.
    total_kl = 0.
    total_log_p_x = 0.
    with torch.no_grad():
        for i in range(len(data)):
            sents, length, batch_size, _, gold_spans, gold_binary_trees, other_data = data[i]
            if length == 1:
                continue
            length = length + 1
            sents = sents.cuda()
            nll, kl = model(sents)
            total_nll += nll.sum().item()
            total_kl += kl.sum().item()
            _x = sents[:, None].expand(-1, chunk_size, -1).flatten(0, 1)
            results = []
            for i in range(k // chunk_size):
                nll, kl = model(_x)
                # while (~torch.isfinite(nll + kl)).any():
                #     nll, kl = model(_x)
                results.append(-(nll.sum(1) + kl.sum(1)).view(sents.size(0), chunk_size))

            log_ratios = torch.cat(results, dim=1)
            assert log_ratios.size(1) == k
            log_k = torch.log(torch.tensor(k, dtype=torch.float, device=sents.device))
            log_p_x = torch.logsumexp(log_ratios, dim=1) - log_k
            total_log_p_x += log_p_x.sum().item()

            num_sents += batch_size
            num_words += batch_size * length

    recon_nll = total_nll / num_words
    kl = total_kl / num_words
    nll = -total_log_p_x / num_words
    ppl_elbo = np.exp(nll)
    print('NLL: %.4f, ReconNLL: %.4f, KL: %.4f, PPL (Upper Bound): %.2f' %
          (nll, recon_nll, kl, ppl_elbo))
    model.train()
    return ppl_elbo, 0.  # sent_f1*100


if __name__ == '__main__':
    args = parser.parse_args()
    main(args)
