import tensorflow as tf
import numpy as np
from util.construct_NN import spectral_normalization

# learning rate decay schedule
def decay_learning_rate(lr, schedule, param_dict):
    if schedule == 'rational':
        maxiter = param_dict['epochs']
        it = param_dict['epoch']
        lr.assign(lr*(1-it/maxiter))
    elif schedule == 'step':
        KE_P = param_dict['KE_P']
        KE_P_lim = 1e-4
        KE_P_alpha_lim = 0.001
        if KE_P <= KE_P_lim + 2*KE_P_alpha_lim:
            lr.assign(lr/2)
    return lr
    
    
def sgd(W, b, nu, dW, db, dnu, lr_NN, NN_par, descent=True, calc_dW_norm=False):
    dW_norm = 0
    if descent == False:
        lr_NN = - lr_NN
    
    for l in range(len(W)):
        if calc_dW_norm == True:
            dW_norm = max(dW_norm, tf.norm(dW[l]))
        
        W[l].assign(W[l] - lr_NN*dW[l])
        if NN_par['constraint'] == 'hard' and NN_par['L'] != None: # spectral normalization
            spectral_normalization(W[l], NN_par['L']**(1/len(W)))
        if db[l] != None: # fnn
            b[l].assign(b[l] - lr_NN*db[l])
            
    if dnu != None:
        nu.assign(nu - lr_NN*dnu)
            
    return W, b, nu, dW_norm


def adam_update(grad, iter, m, v, beta1=0.9, beta2=0.999, eps=1e-8):
    grad = grad.numpy()
    if iter == 0:
        m, v = np.zeros_like(grad), np.zeros_like(grad)
        
    # reweight m, v
    m = beta1*m + (1-beta1)*grad
    v = beta2*v + (1-beta2)*grad**2
    
    m_hat = m/(1-beta1**(iter+1))
    v_hat = v/(1-beta2**(iter+1))
    
    grad_hat = m_hat/(np.sqrt(v_hat)+eps)
    
    return grad_hat, m, v
 
def adam(W, b, nu, dW, db, dnu, m_W, v_W, m_b, v_b, m_nu, v_nu, lr_NN, NN_par, iter, descent=True, calc_dW_norm=False):
    dW_norm = 0
    if descent == False:
        lr_NN = - lr_NN
    
    for l in range(len(W)):
        if calc_dW_norm == True:
            dW_norm = max(dW_norm, tf.norm(dW[l]))
        dW_hat, m_W[l], v_W[l] = adam_update(dW[l], iter, m_W[l], v_W[l])
        W[l].assign(W[l] - lr_NN*dW_hat)
        
        if NN_par['constraint'] == 'hard' and NN_par['L'] != None: # spectral normalization
            spectral_normalization(W[l], NN_par['L']**(1/len(W)))
            
        if db[l] != None: # fnn
            db_hat, m_b[l], v_b[l] = adam_update(db[l], iter, m_b[l], v_b[l])
            b[l].assign(b[l] - lr_NN*db_hat)
    if dnu != None:
        dnu_hat, m_nu, v_nu = adam_update(dnu, iter, m_nu, v_nu)
        nu.assign(nu - lr_NN*dnu_hat)
    return W, b, nu, dW_norm, m_W, v_W, m_b, v_b, m_nu, v_nu
