#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ----------------------------------------
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # avoid tensorflow warning
import tensorflow as tf
import numpy as np
import re

# input parameters --------------------------------------
from util.input_args import input_params
p = input_params()
param = vars(p)

if p.alpha:    
    par = [p.alpha]
    p.exptype = '%s=%05.2f-%s' % (p.f, p.alpha, p.Gamma)
else: 
    par = []
    p.exptype = '%s-%s' % (p.f, p.Gamma)
if p.L == None:
    p.expname = '%s_%s' % (p.exptype, 'inf')
else:
    p.expname = '%s_%.2f' % (p.exptype, p.L)   

# Data generation ----------------------------------------
from util.generate_data import generate_data
p, X_, Y_, X_label, Y_label = generate_data(p)
       
if p.dataset in ['BreastCancer',]:
    Q = tf.constant(X_/10.0, dtype=tf.float32) # constant
    P = tf.Variable(Y_/10.0, dtype=tf.float32) # variable
else:
    Q = tf.constant(X_, dtype=tf.float32) # constant
    P = tf.Variable(Y_, dtype=tf.float32) # variable

if p.N_conditions >1:
    Q_label = tf.constant(X_label, dtype=tf.float32)
    P_label = tf.constant(Y_label, dtype=tf.float32)
else:
    Q_label, P_label = None, None  

data_par = {'P_label': P_label, 'Q_label': Q_label, 'mb_size_P': p.mb_size_P, 'mb_size_Q': p.mb_size_Q, 'N_samples_P': p.N_samples_P, 'N_samples_Q': p.N_samples_Q, }

# Discriminator construction using Neural Network -----------------
from util.construct_NN import check_nn_topology, initialize_NN, model

N_fnn_layers, N_cnn_layers, p.activation_ftn = check_nn_topology(p.NN_model, p.N_fnn_layers, p.N_cnn_layers, p.N_dim, p.activation_ftn)

NN_par = {'NN_model':p.NN_model, 'activation_ftn':p.activation_ftn, 'N_dim': p.N_dim, 'N_cnn_layers':N_cnn_layers, 'N_fnn_layers':N_fnn_layers, 'N_conditions': p.N_conditions, 'constraint': p.constraint, 'L': p.L}

W, b = initialize_NN(NN_par)
phi = model(NN_par) 

# Loss & Loss_first_variation ------------------------------
from util.losses import divergence, divergence_large, first_variation, gradient_penalty
loss_par = {'f': p.f, 'formulation': p.formulation, 'par': par, 'reverse': p.reverse, }

# scalar optimal value optimization for f-divergence
nu = tf.Variable(0.0, dtype=tf.float32)
        
# Train -------------------------------------------------
from util.train_NN import decay_learning_rate, sgd, adam
lr_NN = tf.Variable(p.lr_NN, trainable=False)
lr_P = tf.Variable(p.lr_P, trainable=False)
lr_Ps = []

from util.evaluate_metric import calc_fid, calc_ke, calc_grad_phi
trajectories = []
vectorfields = []
divergences = []
KE_Ps = []
FIDs = []


if p.N_dim == 1:
    xx = np.linspace(-10, 10, 300)
    xx = tf.constant(np.reshape(xx, (-1,1)), dtype=tf.float32)
    phis = []
elif p.N_dim == 2:#'2D' in p.dataset:
    xx = np.linspace(-10, 10, 40)
    yy = np.linspace(-10, 10, 40)
    XX, YY = np.meshgrid(xx, yy)
    xx = np.concatenate((np.reshape(XX, (-1,1)), np.reshape(YY, (-1,1))), axis=1)
    xx = tf.constant(xx, dtype=tf.float32)
    phis = []

import time 
t0 = time.time()

for it in range(1, p.epochs+1): # Loop for updating particles P
    penalty = gradient_penalty(phi, P, Q, W, b, NN_par, data_par, p.lamda)
    if p.mb_size_P>=p.N_samples_P or p.mb_size_Q>=p.N_samples_Q:
        loss = divergence(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
    else:
        loss = divergence_large(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
    current_loss = loss.numpy()
    
    for in_it in range(p.epochs_nn): # Loop for training NN discriminator phi*
        with tf.GradientTape(watch_accessed_variables=False) as tape:
            tape.watch([W, b, nu])
            
            penalty = gradient_penalty(phi, P, Q, W, b, NN_par, data_par, p.lamda)
            if p.mb_size_P>=p.N_samples_P or p.mb_size_Q>=p.N_samples_Q:
                loss = divergence(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty
            else:
                loss = divergence_large(phi, nu, P, Q, W, b, NN_par, loss_par, data_par) + penalty

        dW, db, dnu = tape.gradient(loss, [W,b,nu])
        
        if p.optimizer == 'sgd':
            W, b, nu, dW_norm = sgd(W, b, nu, dW, db, dnu, lr_NN, NN_par, descent=False, calc_dW_norm=True) # update phi(W,b)
        elif p.optimizer == 'adam':
            if in_it == 0:
                m_W, v_W, m_b, v_b, m_nu, v_nu = [0]*len(W), [0]*len(W), [0]*len(W), [0]*len(W), [0], [0]
            W, b, nu, dW_norm, m_W, v_W, m_b, v_b, m_nu, v_nu = adam(W, b, nu, dW, db, dnu, m_W, v_W, m_b, v_b, m_nu, v_nu, lr_NN, NN_par, in_it, descent=False, calc_dW_norm=True) # update phi(W,b)
            

    with tf.GradientTape(watch_accessed_variables=False) as tape:
        tape.watch(P)
        loss_first_variation = first_variation(phi, P, W, b, NN_par,loss_par, data_par)
    dP = tape.gradient(loss_first_variation, P)
    P.assign(P - lr_P*dP)   # update P
    
    # save results
    divergences.append(current_loss)
    KE_P = calc_ke(dP, p.N_samples_P)
    #grad_phi = calc_grad_phi(dP, p.N_samples_P)
    #print("grad", grad_phi)
    KE_Ps.append(KE_P)    
    
    if p.epochs<=100 or it%p.save_iter == 0:
        if p.dataset in ['BreastCancer',]:
            trajectories.append(P.numpy()*10)
        else:
            trajectories.append(P.numpy())
        if np.prod(p.N_dim) < 500:
            vectorfields.append(dP.numpy())
        elif np.prod(p.N_dim) >= 784:  # image data
            if 'MNIST' in p.dataset:
                fid_model_name = 'autoencoder_mnist'
            elif 'CIFAR10' in p.dataset:
                fid_model_name = 'autoencoder_cifar10'
            FIDs.append( calc_fid(pred=P.numpy(), real=Q.numpy(), model_name=fid_model_name) )
            
    # adjust learning rates
    lr_Ps.append(lr_P.numpy())
    if it>=100:
        lr_P = decay_learning_rate(lr_P, p.lr_P_decay, {'epochs': p.epochs-100, 'epoch': it-100, 'KE_P': KE_P})
    
    # display intermediate results
    if it % (p.epochs/10) == 0:
        display_msg = 'iter %6d: loss = %.10f, grad_norm of W = %.2f, kinetic energy of P = %.10f, learning rate for P = %.6f' % (it, current_loss, dW_norm, KE_P, lr_P.numpy())
        if len(FIDs) > 0 :
            display_msg = display_msg + ', FID = %.3f' % FIDs[-1]   
        print(display_msg)
        
        if p.N_dim == 1 or p.N_dim == 2:
            zz = phi(xx,None, W,b,NN_par).numpy()
            zz = np.reshape(zz, -1)
            phis.append(zz)
            

total_time = time.time() - t0
print(f'total time {total_time:.3f}s')

# Save result ------------------------------------------------------
import pickle
if not os.path.exists(p.dataset):
    os.makedirs(p.dataset)

if '1D' in p.dataset:
    X_ = np.concatenate((X_, np.zeros(shape=X_.shape)), axis=1)
    Y_ = np.concatenate((Y_, np.zeros(shape=Y_.shape)), axis=1)
    
    trajectories = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in trajectories]
    vectorfields = [np.concatenate((x, np.zeros(shape=x.shape)), axis=1) for x in vectorfields]
        
param.update({'X_': X_, 'Y_': Y_, 'lr_Ps':lr_Ps})

p.expname = p.expname+'_%04d_%04d_%02d_%s' % (p.N_samples_Q, p.N_samples_P, p.random_seed, p.exp_no)        
filename = p.dataset+'/%s.pickle' % (p.expname)
    
result = {'trajectories': trajectories, 'vectorfields': vectorfields, 'divergences': divergences, 'KE_Ps': KE_Ps, 'FIDs': FIDs,}

if p.dataset in ['BreastCancer',]:
    np.savetxt("gene_expression_example/GPL570/"+p2.dataset+'/output_norm_dataset_dim_%d.csv' % p.N_dim, trajectories[-1], delimiter=",")
        
## Save trained data
with open(filename,"wb") as fw:
    pickle.dump([param, result] , fw)
print("Results saved at:", filename)

# plot -----------------------------------------------------------
if p.plot_result == True:
    plot_velocity = True
    if len(vectorfields)==0 or '1D' in p.dataset or '2D' in p.dataset:
        plot_velocity = False
    
    quantile = True
    
    if 'MNIST' in filename or 'CIFAR10' in filename:
        from plot_result import plot_losses, plot_trajectories_img, plot_trained_img, images_to_animation, plot_tiled_images
        epochs = 0
        iter_nos = None
        
        images_to_animation(trajectories=None, N_samples_P=None, dt=None, physical_time=True, pick_samples = None, epochs=epochs, save_gif=True, filename = filename)
        plot_trajectories_img(X_ = None,Y_=None, trajectories = None, dt = None, pick_samples=None, epochs=epochs, iter_nos = iter_nos, physical_time=True, filename=filename)
        plot_trained_img(X_ = None, trajectories = None, pick_samples=None, epochs=0, filename=filename)
        
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='FIDs', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        
        if 'all' in filename: # conditional gpa
            plot_tiled_images(print_multiplier=10, last_traj=None, last_digit=None, epochs = 0, data = None, data_label=None, filename=filename)
            
    ## 2D embedded in high dimensions examples 
    elif 'submnfld' in filename:
        from plot_result import plot_losses, plot_trajectories, plot_speeds, plot_orth_axes_saturation, plot_initial_data
        iter_nos = None
        exp_alias_ = None
        track_velocity = True
        iscolor = True
        quantile = True
        
        plot_initial_data(proj_axes = [5,6], x_lim = [None,None],y_lim = [None,None], filename=filename)
        plot_trajectories(trajectories=None, dt=None, X_=None, Y_=None, r_param=None, vectorfields = [], proj_axes = [5,6], pick_samples =None, epochs = 0, iter_nos = None, physical_time=True, save_iter = 1, track_velocity=track_velocity, arrow_scale = 1, iscolor=iscolor, quantile=quantile, exp_alias_ = exp_alias_, x_lim = [None,None],y_lim = [None,None],  filename = filename)
        plot_orth_axes_saturation(N_dim = None, Y_ = None, trajectories = None, save_iter = 1,dt = None, proj_axes=[5,6], epochs=0, iter_nos=None, physical_time = True, filename = filename) 
        plot_speeds(vectorfields =  None, N_dim = None, save_iter=1, dt=None, plot_scale='semilogy', proj_axes = [5,6], physical_time=True, epochs=0, filename = filename)
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
    
    ## low dimensional example
    elif filename != "" :  
        from plot_result import plot_trajectories, plot_losses, plot_initial_data
        iter_nos = None
        exp_alias_ = None
        track_velocity = True
        iscolor = True
        quantile = True
        if 'student_t' in p.dataset:
            x_lim = [-30, 30]
            y_lim = [-30, 30]
        else:
            x_lim = [None, None]
            y_lim = [None, None]
    
        plot_initial_data(proj_axes = [0,1], x_lim = x_lim, y_lim = y_lim, filename=filename)
        plot_output_target(proj_axes = [0,1], x_lim = x_lim,y_lim = y_lim, filename=filename)
        plot_trajectories(trajectories=None, dt=None, X_=None, Y_=None, r_param=None, vectorfields = [], proj_axes = [0,1], pick_samples =None, epochs = 0, iter_nos = None, physical_time=True, save_iter = 1, track_velocity=track_velocity, arrow_scale = 1, iscolor=iscolor, quantile=quantile, exp_alias_ = exp_alias_, x_lim = x_lim, y_lim = y_lim,  filename = filename)
        plot_losses(loss_type='divergences', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
        plot_losses(loss_type='KE_Ps', loss_state=None, plot_scale='semilogy', fitting_line=False, save_iter = 1, dt = None, iter_nos = None, exp_alias_=None, epochs=0, ylims=None, physical_time=True, filename=filename)
            
        if '1D' in p.dataset:
            #from plot_result import plot_velocities_1D
            #plot_velocities_1D(filename, plot_scale='semilogy', physical_time=True, epochs=0)
            import matplotlib.pyplot as plt
            for i, yy in enumerate(phis):
                plt.plot(xx, yy, label='t=%.2f' %((i+1)*p.epochs/10*p.lr_P))
            plt.legend()
            plt.title(r'$\phi_n, n=10*k$')
            f = filename.split('.pickle')
            plt.savefig(f[0]+"-phis.png")
            plt.show()
            
                
        elif '2D' in p.dataset:
            #from plot_result import plot_velocities_2D
            #plot_velocities_2D(filename, plot_scale='semilogy', physical_time=True, epochs=0)
            
            import matplotlib.pyplot as plt
            from mpl_toolkits.mplot3d import Axes3D
            from mpl_toolkits.mplot3d import Axes3D
            fig = plt.figure()
            ax = Axes3D(juufig)
            ZZ = np.reshape(phis[-1],(40,40))
            ax.plot_surface(XX, YY, ZZ)
            #for i, zz in enumerate(phis):
            #    Axes3D.plot_surface(XX, YY, zz, label=i)
            #plt.legend()
            ax.set_title(r'$\phi$')
            f = filename.split('.pickle')
            plt.savefig(f[0]+"-phis.png")
            plt.show()
