from __future__ import print_function
import argparse
import time
import torch
import torch.utils.data
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import random
import os
import datetime
import sys

sys.path.append("./")
sys.path.append("../integer_discrete_flows")
sys.path.append("../EinsumNetworks")

from prep_idf import prep_idf
from optimization.training import train, evaluate
import models.Model as Model


def main(gpu_idx = 1, num_epochs = 2000):
    train_loader, val_loader, test_loader, args = prep_idf()
    
    ## Generate model ##
    model = Model.Model(args)
    args.device = torch.device("cuda")
    model.set_temperature(args.temperature)
    model.enable_hard_round(args.hard_round)
    
    ## data dependend initialization on CPU ##
    for batch_idx, (data, _) in enumerate(train_loader):
        model(data)
        break
        
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model, dim=0)
        
    ## Move model to GPU ##
    model.to(args.device)
    
    ## Optimizer ##
    def lr_lambda(epoch):
        return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
    optimizer = optim.Adamax(model.parameters(), lr=args.learning_rate, eps=1.e-7)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
    
    ## Train ##
    best_train_bpd, best_val_bpd = np.inf, np.inf
    for epoch in range(1, num_epochs + 1):
        t_start = time.time()
        tr_loss, tr_bpd = train(epoch, train_loader, model, optimizer, args)
        scheduler.step()
        if np.mean(tr_bpd) < best_train_bpd:
            best_train_bpd = np.mean(tr_bpd)
        print('One training epoch took %.2f seconds' % (time.time()-t_start))

        if epoch < 5 or epoch % args.evaluate_interval_epochs == 0:
            v_loss, v_bpd = evaluate(
                train_loader, val_loader, model, model, args,
                epoch=epoch, file = None)
            
            if np.mean(v_bpd) < best_val_bpd:
                torch.save(model.module, 'data/a.model')
                torch.save(optimizer, 'data/a.optimizer')
            
                best_val_bpd = np.mean(v_bpd)
                print("Generate latent codes for train")
                gen_latent_code(model, train_loader, args, file_name = "data/train_zs.npz")
                print("Generate latent codes for validation")
                gen_latent_code(model, val_loader, args, file_name = "data/valid_zs.npz")
                print("Generate latent codes for test")
                gen_latent_code(model, test_loader, args, file_name = "data/test_zs.npz")

            print('(BEST: train bpd {:.4f}, validation bpd {:.4f})\n'.format(
                best_train_bpd, best_val_bpd))

            if math.isnan(v_loss):
                raise ValueError('NaN encountered!')
                

def gen_latent_code(model, data_loader, args, file_name = None):
    
    def log_min_exp(a, b, epsilon=1e-8):
        y = a + torch.log(1 - torch.exp(b - a) + epsilon)
        return y
    
    def log_discretized_logistic(x, mean, logscale, inverse_bin_width):
        scale = torch.exp(logscale)
        logp = log_min_exp(
            F.logsigmoid((x + 0.5 / inverse_bin_width - mean) / scale),
            F.logsigmoid((x - 0.5 / inverse_bin_width - mean) / scale))
        return logp
    
    with torch.no_grad():
        zs = torch.zeros([0, 48, 4, 4])
        ys1 = torch.zeros([0, 6, 16, 16])
        ys2 = torch.zeros([0, 12, 8, 8])
        pzs = [torch.zeros([48, 4, 4]), torch.zeros([48, 4, 4])]
        pys1 = [torch.zeros([0, 6, 16, 16]), torch.zeros([0, 6, 16, 16])]
        pys2 = [torch.zeros([0, 12, 8, 8]), torch.zeros([0, 12, 8, 8])]
        
        for data, _ in data_loader:
            data = data.to(args.device)
            data = data.view(-1, *args.input_size)
            loss, batch_bpd, bpd_per_prior, pz, z, pys, ys, ldj = model(data)
            
            zs = torch.cat((zs, z.detach().cpu().clone()), dim = 0)
            ys1 = torch.cat((ys1, ys[0].detach().cpu().clone()), dim = 0)
            ys2 = torch.cat((ys2, ys[1].detach().cpu().clone()), dim = 0)
            
            pzs[0] = pz[0][0, :, :, :].detach().cpu().clone()
            pzs[1] = pz[1][0, :, :, :].detach().cpu().clone()
            pys1[0] = torch.cat((pys1[0], pys[0][0].detach().cpu().clone()), dim = 0)
            pys1[1] = torch.cat((pys1[1], pys[0][1].detach().cpu().clone()), dim = 0)
            pys2[0] = torch.cat((pys2[0], pys[1][0].detach().cpu().clone()), dim = 0)
            pys2[1] = torch.cat((pys2[1], pys[1][1].detach().cpu().clone()), dim = 0)
            
        log_pzs = log_discretized_logistic(zs, pzs[0].unsqueeze(0), pzs[1].unsqueeze(0), 256.0).mean(dim = 0).sum()
        log_pys1 = log_discretized_logistic(ys1, pys1[0], pys1[1], 256.0).mean(dim = 0).sum()
        log_pys2 = log_discretized_logistic(ys2, pys2[0], pys2[1], 256.0).mean(dim = 0).sum()
        bpd_pzs = -log_pzs * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        bpd_pys1 = -log_pys1 * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        bpd_pys2 = -log_pys2 * np.log(np.e) / np.log(2.0) / 32 / 32 / 3
        print(f"log_pzs {log_pzs}; log_pys1 {log_pys1}; log_pys2 {log_pys2}")
        print(f"bpd_pzs {bpd_pzs}; bpd_pys1 {bpd_pys1}; bpd_pys2 {bpd_pys2}")
        
    if file_name is not None:
        np.savez(file_name, 
                 zs = zs.numpy(), ys1 = ys1.numpy(), ys2 = ys2.numpy(),
                 pzs_0 = pzs[0].numpy(), pzs_1 = pzs[1].numpy(),
                 pys1_0 = pys1[0].numpy(), pys1_1 = pys1[1].numpy(),
                 pys2_0 = pys2[0].numpy(), pys2_1 = pys2[1].numpy())
                

if __name__ == "__main__":
    main()
