import argparse
import datetime
import sys
import json
from collections import defaultdict
from pathlib import Path
from tempfile import mkdtemp

import numpy as np
import torch
import torch.nn as nn
from torch import optim

import models
import objectives_together
from utils import Logger, Timer, save_model, save_vars, unpack_data


# TensorBoard
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser(description='Multi-Modal VAEs')
parser.add_argument('--experiment', type=str, default='', metavar='E',
                    help='experiment name')
parser.add_argument('--model', type=str, default='cubIS_conv', metavar='M',
                    choices=[s[4:] for s in dir(models) if 'VAE_' in s],
                    help='model name (default: mnist_svhn)')
parser.add_argument('--obj', type=str, default='dreg', metavar='O',
                    choices=['elbo_naive','elbo', 'iwae', 'dreg'],
                    help='objective to use (default: elbo)')
parser.add_argument('--K', type=int, default=10, metavar='K',
                    help='number of particles to use for iwae/dreg (default: 10)')
parser.add_argument('--looser', action='store_true', default=True,
                    help='use the looser version of IWAE/DREG')
parser.add_argument('--llik_scaling', type=float, default=0.,
                    help='likelihood scaling for cub images/svhn modality when running in'
                         'multimodal setting, set as 0 to use default value')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                    help='batch size for data (default: 256)')
parser.add_argument('--epochs', type=int, default=50, metavar='E',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--latent-dim-w', type=int, default=16, metavar='L',
                    help='latent dimensionality (default: 20)')
parser.add_argument('--latent-dim-u', type=int, default=48, metavar='L',
                    help='latent dimensionality (default: 20)')
parser.add_argument('--pre-trained', type=str, default="",
                    help='path to pre-trained model (train from scratch if empty)')
parser.add_argument('--learn-prior', action='store_true', default=False,
                    help='learn model prior parameters for w')
parser.add_argument('--logp', action='store_true', default=False,
                    help='estimate tight marginal likelihood on completion')
parser.add_argument('--print-freq', type=int, default=100, metavar='f',
                    help='frequency with which to print stats (default: 0)')
parser.add_argument('--no-analytics', action='store_true', default=False,
                    help='disable plotting analytics')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='disable CUDA use')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--w-from-prior', type=str, default='single', metavar='W',
                    choices=['single', 'natural', 'natural-nograd'],
                    help='w for recontruction options (default: single)')
parser.add_argument('--generate-multiple-ws', action='store_true', default=False,
                    help='if True w the model generates samples from same u and only resampling w')
parser.add_argument('--learn-prior-w-sent', action='store_true', default=True,
                    help='learn model prior parameters for w')
parser.add_argument('--learn-prior-w-img', action='store_true', default=True,
                    help='learn model prior parameters for w')
parser.add_argument('--beta', type=float, default=1.0, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--recon-option', type=str, default='jointprior', metavar='O',
                    choices=['natural','stdprior', 'jointprior', 'tunedsingleprior', 'None'],
                    help='cross-recon-option to use')
parser.add_argument('--recon-option-factor', type=float, default=1.0, metavar='F',
                    help='cross-recon-option-factor to use')
parser.add_argument('--llik_scaling_sent', type=float, default=5.0,
                    help='likelihood scaling factor sentences')
#parser.add_argument('--batchnorm_momentum', type=float, default=-1,
#                    help='likelihood scaling factor sentences')

# args
args = parser.parse_args()


# random seed
# https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
np.random.seed(args.seed)

# load args from disk if pretrained model path is given
pretrained_path = ""
if args.pre_trained:
    pretrained_path = args.pre_trained
    args = torch.load(args.pre_trained + '/args.rar')

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")

modelC = getattr(models, 'VAE_{}'.format(args.model))
model = modelC(args).to(device)

if pretrained_path:
    print('Loading model {} from {}'.format(model.modelName, pretrained_path))
    model.load_state_dict(torch.load(pretrained_path + '/model.rar'))
    model._pz_params = model._pz_params

if not args.experiment:
    args.experiment = model.modelName

# set up run path
if hasattr(model, 'vaes'):
    runId = str(args.latent_dim_w) + '_' + str(args.latent_dim_u) + '_' + str(args.beta) + '_' + str(args.seed) + \
            '_' + str(model.vaes[0].llik_scaling) + '_' + str(model.vaes[1].llik_scaling) + "_" +\
            datetime.datetime.now().isoformat()
else:
    runId = str(args.latent_dim_w) + '_' + str(args.latent_dim_u) + '_' + str(args.beta) + '_' + str(args.seed) + \
            '_' + str(model.llik_scaling) + '_' + datetime.datetime.now().isoformat()

experiment_dir = Path('../experiments/' + args.experiment)
experiment_dir.mkdir(parents=True, exist_ok=True)
runPath = mkdtemp(prefix=runId, dir=str(experiment_dir))
sys.stdout = Logger('{}/run.log'.format(runPath))
print('Expt:', runPath)
print('RunID:', runId)

# save args to run
with open('{}/args.json'.format(runPath), 'w') as fp:
    json.dump(args.__dict__, fp)
# -- also save object because we want to recover these for other things
torch.save(args, '{}/args.rar'.format(runPath))

# TensorBoard
tensorboard_log_dir = "../runs/"+args.experiment+"/"+runId
writer = SummaryWriter(log_dir=tensorboard_log_dir)

objectives = objectives_together

# preparation for training
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()),
                       lr=1e-3, amsgrad=True)


train_loader, test_loader = model.getDataLoaders(args.batch_size, device=device)
#n_iter_test = len(test_loader)
objective = getattr(objectives,
                    ('m_' if hasattr(model, 'vaes') else '')
                    + args.obj
                    + ('_looser' if (args.looser and args.obj != 'elbo') else ''))
t_objective = getattr(objectives, ('m_' if hasattr(model, 'vaes') else '') + 'iwae'+ ('_test' if hasattr(model, 'vaes') else '') )


#if args.batchnorm_momentum == -1:
#    batchnorm_momentum_arg = None
#else:
#    batchnorm_momentum_arg = args.batchnorm_momentum

#def set_momentum_batchnorm(m):
#    if type(m) == nn.BatchNorm1d:
#        m.momentum=batchnorm_momentum_arg

#model.apply(set_momentum_batchnorm)

#def avoid_tracking_runningstats_batchnorm(m):
#    if type(m) == nn.BatchNorm1d:
#        m.track_running_stats=False
#def enable_tracking_runningstats_batchnorm(m):
#    if type(m) == nn.BatchNorm1d:
#        m.track_running_stats=True

def batchnorm_to_train_mode(m):         # BatchNorm unstable for certain runs. It is some times required to set it at train, and play around with it at test time.
    if type(m) == nn.BatchNorm2d:       # Will be fixed in future implementations.
        m.train()
def batchnorm_to_test_mode(m):
    if type(m) == nn.BatchNorm2d:
        m.test()

def train(epoch, agg):
    model.train()
    b_loss = 0
    for i, dataT in enumerate(train_loader):
        data = unpack_data(dataT, device=device)
        optimizer.zero_grad()
        loss = -objective(model, data, K=args.K)
        loss.backward()
        optimizer.step()
        b_loss += loss.item()
        #with torch.no_grad():
        #    pass
        #with torch.no_grad():
        #    print('two magical lines')
        #    model.reconstruct(data, runPath, 0)
        if args.print_freq > 0 and i % args.print_freq == 0:
            print("iteration {:04d}: loss: {:6.3f}".format(i, loss.item() / args.batch_size))
    epoch_loss = b_loss / len(train_loader.dataset)
    writer.add_scalar("Loss/train", epoch_loss, epoch)
    agg['train_loss'].append(epoch_loss)
    print('====> Epoch: {:03d} Train loss: {:.4f}'.format(epoch, agg['train_loss'][-1]))


def test(epoch, agg):
    model.eval()
    model.apply(batchnorm_to_train_mode)
    b_loss = 0
    with torch.no_grad():
        for i, dataT in enumerate(test_loader):
            data = unpack_data(dataT, device=device)
            loss = -t_objective(model, data, K=args.K)
            b_loss += loss.item()
            if i < 10  and epoch % 1 == 0:
                # model.apply(enable_tracking_runningstats_batchnorm))
                model.reconstruct(data, runPath, epoch)
                if hasattr(model, 'vaes'):
                    model.reconstruct_options(data, runPath, epoch, args.recon_option, args.recon_option_factor)
                    model.cross_generate(data, runPath, epoch, 0, 1, args.recon_option_factor)
                    model.cross_generate(data, runPath, epoch, 1, 0, args.recon_option_factor)
                else:
                    pass
            if i > 10 and i < 20 and epoch % 1 == 0:
                model.eval()
                model.reconstruct(data, runPath, epoch)
                if hasattr(model, 'vaes'):
                    model.reconstruct_options(data, runPath, epoch, args.recon_option, args.recon_option_factor)
                    model.cross_generate(data, runPath, epoch, 0, 1, args.recon_option_factor)
                    model.cross_generate(data, runPath, epoch, 1, 0, args.recon_option_factor)
                else:
                    pass
                # model.apply(avoid_tracking_runningstats_batchnorm
                if not args.no_analytics:
                    model.analyse(data, runPath, epoch)
    epoch_loss = b_loss / len(test_loader.dataset)
    writer.add_scalar("Loss/test", epoch_loss, epoch)
    agg['test_loss'].append(epoch_loss)
    print('====>             Test loss: {:.4f}'.format(agg['test_loss'][-1]))



def estimate_log_marginal(K):
    """Compute an IWAE estimate of the log-marginal likelihood of test data."""
    model.eval()
    marginal_loglik = 0
    with torch.no_grad():
        for dataT in test_loader:
            data = unpack_data(dataT, device=device)
            marginal_loglik += -t_objective(model, data, K).item()

    marginal_loglik /= len(test_loader.dataset)
    writer.add_scalar("Marginal_llik", marginal_loglik)
    print('Marginal Log Likelihood (IWAE, K = {}): {:.4f}'.format(K, marginal_loglik))

if __name__ == '__main__':
    with Timer('MM-VAE') as t:
        agg = defaultdict(list)
        for epoch in range(1, args.epochs + 1):
            train(epoch, agg)
            test(epoch, agg)
            if epoch % 1 == 0:
                model.generate(runPath, epoch)
                save_model(model, runPath + '/model_'+str(epoch)+'.rar')
                save_vars(agg, runPath + '/losses_'+str(epoch)+'.rar')
            if args.generate_multiple_ws:
                model.generate_multiple_ws(runPath, epoch)
        writer.flush()
        writer.close()