from __future__ import print_function
import argparse
import time
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math
import random
import os
import datetime
import sys
import warnings
warnings.filterwarnings("ignore")
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

sys.path.append("./")
sys.path.append("../integer_discrete_flows")
sys.path.append("../EinsumNetworks/src")

from prep_idf import prep_idf
from optimization.training import train, evaluate
import models.Model as Model

from EinsumNetwork import Graph, EinsumNetwork

from training import my_train_multi_gpus, my_evaluate_multi_gpus


def main(num_epochs = 4000):
    train_loader, val_loader, test_loader, args = prep_idf()
    
    #### Hyperparameters ####
    num_repetitions = args.num_repetitions
    num_sums = args.num_sums
    num_input_distributions = args.num_input_distributions
    online_em_stepsize = args.online_em_stepsize
    #########################
    
    if args.log_file_name == "none":
        log_file_name = "logs/multigpus_pcflow_{}_{}_{}.txt".format(num_repetitions, num_sums, num_input_distributions)
    else:
        log_file_name = "logs/" + args.log_file_name
    
    gpu_idx = args.gpu_idx
    
    #### Generate flow model ####
    
    args.num_prior_leaf_nodes = num_input_distributions * num_repetitions * 1
    
    flow_model = Model.Model(args)
    args.flow_device = torch.device("cuda:1")
    args.einets_device = torch.device("cuda:0")
    args.cpu = torch.device("cpu")
    flow_model.set_temperature(args.temperature)
    flow_model.enable_hard_round(args.hard_round)
    
    ## data dependend initialization on CPU ##
    for batch_idx, (data, _) in enumerate(train_loader):
        pz, z, pys, ys, ldj = flow_model.forward_only(data)
        print(z.size(), ys[0].size(), ys[1].size())
        break
        
    ## Move model to GPU ##
    flow_model = flow_model.to(args.flow_device)
    
    # Test
    for batch_idx, (data, _) in enumerate(train_loader):
        pz, z, pys, ys, ldj = flow_model.forward_only(data.to(args.flow_device))
        print(z.size(), ys[0].size(), ys[1].size())
        break
    
    ## Optimizer ##
    def lr_lambda(epoch):
        return min(1., (epoch+1) / args.warmup) * np.power(args.lr_decay, epoch)
    optimizer = optim.Adamax(flow_model.parameters(), lr=args.learning_rate, eps=1.e-7)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)
    
    #### Generate Einsum Networks ####
    
    num_vars = [8 * 8 * 48, 32 * 32 * 6, 16 * 16 * 12]
    shapes = [(48, 8, 8), (6, 32, 32), (12, 16, 16)]
    einsum_models = []
    for i in range(3):
        # graph = Graph.random_binary_trees(num_var = num_vars[i], depth = int(np.ceil(np.log(num_vars[i]))), 
        #                                   num_repetitions = num_repetitions)
        # graph = Graph.poon_domingos_structure(shape = shapes[i], axes = [0,1], delta = [8])
        graph = Graph.closeness_binary_trees(num_var = num_vars[i], shape = shapes[i], 
                                             depth = int(np.ceil(np.log(num_vars[i]))), 
                                             num_repetitions = num_repetitions)
        
        einet_args = EinsumNetwork.Args(num_var = num_vars[i], num_dims=1, num_classes=1, 
                                        num_input_distributions=num_input_distributions, 
                                        num_sums=num_sums,
                                        exponential_family=EinsumNetwork.LogisticArray,
                                        exponential_family_args={'scale_min': 1e-6, 'scale_max': 1.0},
                                        online_em_frequency=1,
                                        online_em_stepsize=online_em_stepsize,
                                        uniform_params = False)
    
        einet = EinsumNetwork.EinsumNetwork(graph, einet_args)
        einet.initialize()
        einet.to(args.einets_device)
        
        print(einet)
        print(sum(p.numel() for p in einet.parameters()))
        
        einsum_models.append(einet)
        
    #### Train ####
    
    best_train_bpd, best_val_bpd, best_test_bpd = np.inf, np.inf, np.inf
    for epoch in range(1, num_epochs + 1):
        t_start = time.time()
        
        '''if epoch <= args.warmup:
            for einet in einsum_models:
                einet.em_set_hyperparams(1, online_em_stepsize * epoch / args.warmup)
        elif epoch == args.warmup + 1:
            for einet in einsum_models:
                einet.em_set_hyperparams(1, online_em_stepsize)'''
        
        tr_loss, tr_bpd = my_train_multi_gpus(epoch, flow_model, einsum_models, train_loader, optimizer, args)
        with open(log_file_name, "a+") as f:
            f.write('====> Epoch: {:3d} Average train loss: {:.4f} Average bpd: {:.3f}\n'.format(epoch, tr_loss, tr_bpd))
        
        scheduler.step()
        if tr_bpd < best_train_bpd:
            best_train_bpd = tr_bpd
        print('One training epoch took %.2f seconds' % (time.time()-t_start))
        
        if epoch % 2 == 0:
            val_bpd = my_evaluate_multi_gpus(epoch, flow_model, einsum_models, val_loader, args)
            test_bpd = my_evaluate_multi_gpus(epoch, flow_model, einsum_models, test_loader, args)
            with open(log_file_name, "a+") as f:
                f.write('====> [eval] Epoch: {:3d} Average bpd: {:.3f}\n'.format(epoch, val_bpd))
                f.write('====> [test] Epoch: {:3d} Average bpd: {:.3f}\n'.format(epoch, test_bpd))
            
            if val_bpd < best_val_bpd:
                best_val_bpd = val_bpd
                best_test_bpd = test_bpd
                
            print("Best val_bpd: {}".format(best_val_bpd))
            print("Best test_bpd: {}".format(best_test_bpd))
            with open(log_file_name, "a+") as f:
                f.write("Best val_bpd: {}\n".format(best_val_bpd))
                f.write("Best test_bpd: {}\n".format(best_test_bpd))
        
        
if __name__ == "__main__":
    main()