#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ----------------------------------------
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # avoid tensorflow warning
import tensorflow as tf
import numpy as np
import re

# input parameters --------------------------------------
from util.input_args import input_params
p = input_params()
param = vars(p)

if p.alpha:    
    par = [p.alpha]
    p.exptype = '%s=%05.2f-%s' % (p.f, p.alpha, p.Gamma)
else: 
    par = []
    p.exptype = '%s-%s' % (p.f, p.Gamma)
if p.L == None:
    p.expname = '%s_%s' % (p.exptype, 'inf')
else:
    p.expname = '%s_%.2f' % (p.exptype, p.L)   

# Data generation ----------------------------------------
from util.generate_data import generate_data
p, X_, Y_, X_label, Y_label = generate_data(p)
N_dim = p.N_dim
if 'ae' in p.dataset:
    N_dim = p.N_latent_dim # NOTE! N_dim != p.N_dim
    '''
    from util.autoencoder import load_autoencoder, preprocess_input, postprocess_output
    
    data_name = p.dataset.split('_ae')
    
    ae_filepath = 'util/saved_model/%s_deep_%d' % (data_name[0], N_dim)
    print(ae_filepath)
    encoder = load_autoencoder(ae_filepath, encoder=True)
    decoder = load_autoencoder(ae_filepath, encoder=False)
    
    _, _, X__ = encoder.predict(preprocess_input(X_, vae=True))
    if p.sample_latent == False:
        _, _, Y__ = encoder.predict(preprocess_input(Y_, vae=True))
    else:
        Y__ = Y_
        Y_ = postprocess_output(decoder.predict(Y_), p.N_dim)
    
    '''
    from util.autoencoder import load_autoencoder, preprocess_input, postprocess_output
    
    data_name = p.dataset.split('_ae')
    
    ae_filepath = 'util/saved_model/%s_deep_%d' % (data_name[0], N_dim)
    encoder = load_autoencoder(ae_filepath, encoder=True)
    decoder = load_autoencoder(ae_filepath, encoder=False)
    
    X__ = encoder.predict(preprocess_input(X_))
    if p.sample_latent == False:
        Y__ = encoder.predict(preprocess_input(Y_))
    else:
        Y__ = Y_
        Y_ = postprocess_output(decoder.predict(Y_), p.N_dim)
    
Q = tf.constant(X__, dtype=tf.float32) # constant
P = tf.Variable(Y__, dtype=tf.float32) # variable

if p.N_conditions >1:
    Q_label = tf.constant(X_label, dtype=tf.float32)
    P_label = tf.constant(Y_label, dtype=tf.float32)
else:
    Q_label, P_label = None, None
    
data_par = {'P_label': P_label, 'Q_label': Q_label, 'mb_size_P': p.mb_size_P, 'mb_size_Q': p.mb_size_Q, 'N_samples_P': p.N_samples_P, 'N_samples_Q': p.N_samples_Q, }

# Discriminator construction using Neural Network -----------------
from util.construct_NN import check_nn_topology, initialize_NN, model

N_fnn_layers, N_cnn_layers, p.activation_ftn = check_nn_topology(p.NN_model, p.N_fnn_layers, p.N_cnn_layers, N_dim, p.activation_ftn)

NN_par = {'NN_model':p.NN_model, 'activation_ftn':p.activation_ftn, 'N_dim': p.N_dim, 'N_cnn_layers':N_cnn_layers, 'N_fnn_layers':N_fnn_layers, 'N_conditions': p.N_conditions, 'constraint': p.constraint, 'L': p.L}

W, b = initialize_NN(NN_par)
phi = model(NN_par) 

# Loss & Loss_first_variation ------------------------------
from util.losses import divergence, divergence_large, first_variation, gradient_penalty
loss_par = {'f': p.f, 'formulation': p.formulation, 'par': par, 'reverse': p.reverse, }

# scalar optimal value optimization for f-divergence
nu = tf.Variable(0.0, dtype=tf.float32)

# Train -------------------------------------------------
from util.train_NN import decay_learning_rate, sgd, adam
lr_NN = tf.Variable(p.lr_NN, trainable=False)
lr_P = tf.Variable(p.lr_P, trainable=False)
lr_Ps = []

from util.evaluate_metric import calc_fid, calc_ke
trajectories = []
vectorfields = []
divergences = []
KE_Ps = []
FIDs = []


import time 
t0 = time.time()
for it in range(1, p.epochs+1): # Loop for updating particles P
    penalty = gradient_penalty(phi, P, Q, W, b, NN_par, data_par, p.lamda)
    if p.mb_size_P>=p.N_samples_P or p.mb_size_Q>=p.N_samples_Q:
        loss = divergence(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
    else:
        loss = divergence_large(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
    current_loss = loss.numpy()
    for in_it in range(p.epochs_nn): # Loop for training NN discriminator phi*
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch([W, b, nu])
            penalty = gradient_penalty(phi, P, Q, W, b, NN_par, data_par, p.lamda)
            if p.mb_size_P>=p.N_samples_P or p.mb_size_Q>=p.N_samples_Q:
                loss = divergence(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
            else:
                loss = divergence_large(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
        dW, db, dnu = tape.gradient(loss, [W,b,nu])
        
        if p.optimizer == 'sgd':
            W, b, nu, dW_norm = sgd(W, b, nu, dW, db, dnu, lr_NN, NN_par, descent=False, calc_dW_norm=True) # update phi(W,b)
        elif p.optimizer == 'adam':
            if in_it == 0:
                m_W, v_W, m_b, v_b, m_nu, v_nu = [0]*len(W), [0]*len(W), [0]*len(W), [0]*len(W), [0], [0]
            W, b, nu, dW_norm, m_W, v_W, m_b, v_b, m_nu, v_nu = adam(W, b, nu, dW, db, dnu, m_W, v_W, m_b, v_b, m_nu, v_nu, lr_NN, NN_par, in_it, descent=False, calc_dW_norm=True) # update phi(W,b)

    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(P)
        loss_first_variation = first_variation(phi, P, W, b, NN_par,loss_par, data_par)
    dP = tape.gradient(loss_first_variation, P)
    P.assign(P - lr_P*dP) # update P
    
    # save results
    divergences.append(current_loss)
    KE_P = calc_ke(dP, p.N_samples_P)
    KE_Ps.append(KE_P)    
    
    if p.epochs<=100 or it%p.save_iter == 0:
        # reconstruct real space variable
        P_ = postprocess_output(decoder.predict(P.numpy()), p.N_dim)
        trajectories.append(P_)
        if np.prod(p.N_dim) < 100 and 'ae' not in p.dataset:
            vectorfields.append(dP.numpy())
        if np.prod(p.N_dim) >= 64:  # image data
            if 'MNIST' in p.dataset:
                fid_model_name = 'autoencoder_mnist'
            elif 'CIFAR10' in p.dataset:
                fid_model_name = 'autoencoder_cifar10'
            # reconstruct real space variable
            Q_ = postprocess_output(decoder.predict(Q.numpy()), p.N_dim)
            FIDs.append( calc_fid(pred=P_, real=Q_, model_name=fid_model_name) )
            
    # adjust learning rates
    lr_Ps.append(lr_P.numpy())
    if it>=100:
        lr_P = decay_learning_rate(lr_P, p.lr_P_decay, {'epochs': p.epochs-100, 'epoch': it-100, 'KE_P': KE_P})
    
    # display intermediate results
    if it % (p.epochs/10) == 0:
        display_msg = 'iter %6d: loss = %.10f, grad_norm of W = %.2f, kinetic energy of P = %.10f, learning rate for P = %.6f' % (it, current_loss, dW_norm, KE_P, lr_P.numpy())
        if len(FIDs) > 0 :
            display_msg = display_msg + ', FID = %.3f' % FIDs[-1]   
        print(display_msg)
            

total_time = time.time() - t0
print(f'total time {total_time:.3f}s')

# Save result ------------------------------------------------------
import pickle
if not os.path.exists(p.dataset):
    os.makedirs(p.dataset)

if '1D' in p.dataset:
    X_ = np.concatenate((X_, np.zeros(shape=X_.shape)), axis=1)
    Y_ = np.concatenate((Y_, np.zeros(shape=Y_.shape)), axis=1)
    
    trajectories = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in trajectories]
    vectorfields = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in vectorfields]
        
param.update({'X_': X_, 'Y_': Y_, 'lr_Ps':lr_Ps})

p.expname = p.expname+'_%04d_%04d_%02d_%s' % (p.N_samples_Q, p.N_samples_P, p.random_seed, p.exp_no)        
filename = p.dataset+'/%s.pickle' % (p.expname)
    
result = {'trajectories': trajectories, 'vectorfields': vectorfields, 'divergences': divergences, 'KE_Ps': KE_Ps, 'FIDs': FIDs, }
        
## Save trained data
with open(filename,"wb") as fw:
    pickle.dump([param, result] , fw)
print("Results saved at:", filename)

# plot -----------------------------------------------------------
if p.plot_result == True:
    plot_velocity = True
    if len(vectorfields)==0 or '1D' in p.dataset or '2D' in p.dataset:
        plot_velocity = False
    
    quantile = True
    
    if 'MNIST' in filename or 'CIFAR10' in filename:
        from plot_result import plot_losses, plot_trajectories_img, plot_trained_img, images_to_animation, plot_tiled_images
        epochs = 0
        iter_nos = None
        
        images_to_animation(trajectories=None, N_samples_P=None, dt=None, physical_time=True, pick_samples = None, epochs=epochs, save_gif=True, filename = filename)
        plot_trajectories_img(X_ = None,Y_=None, trajectories = None, dt = None, pick_samples=None, epochs=epochs, iter_nos = iter_nos, physical_time=True, filename=filename)
        plot_trained_img(X_ = None, trajectories = None, pick_samples=None, epochs=0, filename=filename)
        
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='FIDs', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        
        if 'all' in filename: # conditional gpa
            plot_tiled_images(print_multiplier=10, last_traj=None, last_digit=None, epochs = 0, data = None, data_label=None, filename=filename)
            
    ## 2D embedded in high dimensions examples 
    elif 'submnfld' in filename:
        from plot_result import plot_losses, plot_trajectories, plot_speeds, plot_orth_axes_saturation, plot_initial_data
        iter_nos = None
        exp_alias_ = None
        track_velocity = True
        iscolor = True
        quantile = True
        
        plot_initial_data(proj_axes = [5,6], x_lim = [None,None],y_lim = [None,None], filename=filename)
        plot_trajectories(trajectories=None, dt=None, X_=None, Y_=None, r_param=None, vectorfields = [], proj_axes = [5,6], pick_samples =None, epochs = 0, iter_nos = None, physical_time=True, save_iter = 1, track_velocity=track_velocity, arrow_scale = 1, iscolor=iscolor, quantile=quantile, exp_alias_ = exp_alias_, x_lim = [None,None],y_lim = [None,None],  filename = filename)
        plot_orth_axes_saturation(N_dim = None, Y_ = None, trajectories = None, save_iter = 1,dt = None, proj_axes=[5,6], epochs=0, iter_nos=None, physical_time = True, filename = filename) 
        plot_speeds(vectorfields =  None, N_dim = None, save_iter=1, dt=None, plot_scale='semilogy', proj_axes = [5,6], physical_time=True, epochs=0, filename = filename)
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
    
    ## low dimensional example
    elif filename != "" :  
        from plot_result import plot_trajectories, plot_losses, plot_initial_data
        iter_nos = None
        exp_alias_ = None
        track_velocity = True
        iscolor = True
        quantile = True
    
        plot_initial_data(proj_axes = [0,1], x_lim = [None,None],y_lim = [None,None], filename=filename)
        plot_trajectories(trajectories=None, dt=None, X_=None, Y_=None, r_param=None, vectorfields = [], proj_axes = [0,1], pick_samples =None, epochs = 0, iter_nos = None, physical_time=True, save_iter = 1, track_velocity=track_velocity, arrow_scale = 1, iscolor=iscolor, quantile=quantile, exp_alias_ = exp_alias_, x_lim = [None,None],y_lim = [None,None],  filename = filename)
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
