import torch
import numpy as np
from time import time
t0_script = time()
torch.set_num_threads(1)
torch.set_default_dtype(torch.float64)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import sys
sys.path.append('../../../package')
from vaeflow import *
from planar import *

# Check available device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device {device} \n')

#------------------------------------------------------ parameter settings 
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--replicate',        type=int)
parser.add_argument('--num_layers_lst',   type=int, nargs='+')
parser.add_argument('--num_channels',     type=int, default=[16,32], nargs='+')
parser.add_argument('--latent_dim',       type=int, default=20)
parser.add_argument('--reparam')

parser.add_argument('--IS_size',          type=int,   default=500)
parser.add_argument('--batch_size',       type=int,   default=250)
parser.add_argument('--num_epochs',       type=int,   default=500)
parser.add_argument('--anneal_epochs',    type=int,   default=100)
parser.add_argument('--lr_init',          type=float, default=0.001)
parser.add_argument('--lr_factor',        type=float, default=0.75)
parser.add_argument('--lr_patience',      type=int  , default=10)
parser.add_argument('--lr_min',           type=float, default=1e-5)

parser.add_argument('--seed_init',        type=int,   default=235711131719)
parser.add_argument('--seed_draw',        type=int,   default=31415926)
parser.add_argument('--seed_model',       type=int,   default=11235813)
parser.add_argument('--seed_dataloader',  type=int,   default=42)

args = parser.parse_args()
replicate        = args.replicate
num_layers_lst   = args.num_layers_lst
num_channels     = args.num_channels
latent_dim       = args.latent_dim
reparam          = args.reparam

IS_size          = args.IS_size
batch_size       = args.batch_size
num_epochs       = args.num_epochs
anneal_epochs    = args.anneal_epochs
lr_init          = args.lr_init
lr_factor        = args.lr_factor
lr_patience      = args.lr_patience
lr_min           = args.lr_min

seed_init        = args.seed_init
seed_draw        = args.seed_draw
seed_model       = args.seed_model
seed_dataloader  = args.seed_dataloader

print(f'replicate        = {replicate}')
print(f'num_layers_lst   = {num_layers_lst}')
print(f'num_channels     = {num_channels}')
print(f'latent_dim       = {latent_dim}')
print(f'reparam          = {reparam}')

print(f'IS_size          = {IS_size}')
print(f'batch_size       = {batch_size}')
print(f'num_epochs       = {num_epochs}')
print(f'anneal_epochs    = {anneal_epochs}')
print(f'lr_init          = {lr_init}')
print(f'lr_factor        = {lr_factor}')
print(f'lr_patience      = {lr_patience}')
print(f'lr_min           = {lr_min}')

print(f'seed_init        = {seed_init}')
print(f'seed_draw        = {seed_draw}')
print(f'seed_model       = {seed_model}')
print(f'seed_dataloader  = {seed_dataloader} \n')

#------------------------------------------------------ load data
from torchvision import datasets
from torchvision.transforms import Compose, ToTensor
from torch.utils.data import DataLoader

# Download MNIST dataset in local system
#   - If dataset is already downloaded, it is not downloaded again.
train_data = datasets.MNIST(
    root = '../../../data',
    train = True, 
    download=True,
    transform=Compose([
        ToTensor(), 
        ToBinary()
    ])
)
test_data = datasets.MNIST(
    root = '../../../data', 
    train = False, 
    download=True,
    transform=Compose([
        ToTensor(), 
        ToBinary()
    ])
)

#------------------------------------------------------ define and train models
best_model     = {}
test_loss_hist = torch.zeros(len(num_layers_lst), num_epochs+1, 3)
best_test_loss = torch.zeros(len(num_layers_lst))
best_epoch     = torch.zeros(len(num_layers_lst))
evaluate       = torch.zeros(len(num_layers_lst), 3)

for i_num_layers, num_layers in enumerate(num_layers_lst):
    print(f'---- num_layers = {num_layers} ----')
    # Normalizing Flow
    flow = PlanarFlow(latent_dim, num_layers, reparam, False, 
                      replicate, seed_init)
    
    # VAE with Flow
    torch.manual_seed(seed_model)
    model = VAEFlow(latent_dim, num_channels, flow, seed_draw).to(device)

    # define dataloader
    g_loader = torch.Generator()
    g_loader.manual_seed(seed_dataloader)
    test_loader  = DataLoader(test_data,  batch_size=batch_size)
    train_loader = DataLoader(train_data, batch_size=batch_size, 
                              shuffle=True, generator=g_loader)

    # define optimization algorithm
    optimizer = torch.optim.Adam(model.parameters(), lr=lr_init)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        factor=lr_factor, 
        patience=lr_patience, 
        min_lr=lr_min, verbose=True)

    # training and testing
    anneal = True
    train_test = model.train_test(num_epochs, 
                                  optimizer, scheduler, 
                                  train_loader, test_loader, 
                                  anneal, anneal_epochs, verbose=False)
    test_loss_hist[i_num_layers] = train_test['test_loss_hist']
    best_test_loss[i_num_layers] = train_test['best_test_loss']
    best_epoch[i_num_layers]     = train_test['best_epoch']

    # revert to the best model
    model.load_state_dict(train_test['best_model_state'])
    best_model[f'K{num_layers}'] = model.state_dict()

    # estimate -ELBO, -logp(x), and KL[ q(z|x) || p(z|x) ]
    est = model.evaluate(test_data, IS_size, reduction='mean')
    evaluate[i_num_layers] = est     # tensor (3,)
    print(f"Estimates are {est} at epoch {train_test['best_epoch']} \n")
    
#------------------------------------------------------------ save 
import pickle
result = {'test_loss_hist':test_loss_hist, 
          'best_test_loss':best_test_loss, 
          'best_epoch':best_epoch,
          'evaluate':evaluate}
file = f'pickle/r{replicate}-{latent_dim}d-{num_layers_lst}-{reparam}_.pickle'
with open(file, 'wb') as f:
    pickle.dump(result, f, pickle.HIGHEST_PROTOCOL)
    

PATH = f'pt/r{replicate}-{latent_dim}d-{num_layers_lst}-{reparam}_.pt'
torch.save(best_model, PATH)


print(f'Total time for running this script' \
      f'  {round((time()-t0_script)/3600, 1)} hrs')

