import tensorflow as tf
from tensorflow.keras import mixed_precision

def get_normalized_matrix(X):
    return X / (tf.norm(X,axis=(1,2),keepdims=True)+1e-10)

def get_normalized_vector(x):
    return x / (tf.norm(x,axis=(1),keepdims=True)+1e-10)

def matrix_spectral_norm(W,n_iter=1,eps=1e-8):
    d1,d2 = W.shape 
    u = tf.random.normal((d1,1))

    for _ in tf.range(n_iter):
        wu = tf.matmul(W,u,transpose_a=True)
        v = wu/(tf.norm(wu)+eps)

        wv = tf.matmul(W,v)
        u = wv/(tf.norm(wv)+eps)

    norm_value = tf.matmul(tf.matmul(u,W,transpose_a=True),v)    
    return norm_value

def set_precision(precision):
    if "mixed" in precision:
        policy = mixed_precision.Policy(precision)
        mixed_precision.set_global_policy(policy)
    else:
        tf.keras.backend.set_floatx(precision)
        policy = mixed_precision.Policy(precision)
        mixed_precision.set_global_policy(policy)

def calc_weightHessianEigen(X,y,mdl,primal=None,num_iter=2,batch_size=100,seed=None):
    primal = mdl.trainable_variables if primal is None else primal
    init_u = [tf.random.normal(elem.shape) for elem in primal]
    backward_list = [tf.zeros_like(elem) for elem in primal]
    for _ in range(num_iter):
        with tf.autodiff.ForwardAccumulator(primal,init_u) as acc:
            for start in range(0,len(X),batch_size):            
                end = start + batch_size
                with tf.GradientTape() as tape:
                    loss_value = mdl.loss(y[start:end],mdl(X[start:end],False)) # 스칼라
                backward = tape.gradient(loss_value,primal)
                for index in range(len(primal)):
                    backward_list[index] += backward[index]/batch_size
        hessian_prod = acc.jvp(backward_list)
        norm = [tf.norm(elem) for elem in hessian_prod]

        hessian_prod = [elem1/elem2 for elem1,elem2 in zip(hessian_prod,norm)]
        init_u = hessian_prod        
    gradient_norm = [tf.norm(x) for x in backward_list]   
    return hessian_prod, norm, gradient_norm

def get_singular_values(model,target_func=lambda x: "kernel" in x.name,topk=5):
    kernel_weights = {W.name:W for W in model.trainable_variables if target_func(W)}
    singular_values = {}
    for name,w in kernel_weights.items():
        w_dim = len(w.shape)
        assert w_dim in (2,3)
        if w_dim ==3:
            s = tf.linalg.svd(tf.squeeze(w,0),compute_uv=False)[:topk]
        else:
            s = tf.linalg.svd(w,compute_uv=False)[:topk]
        singular_values[name] = s
    return singular_values

def calc_hess_weight(model,train_data=None,valid_data=None,test_data=None, target_func=lambda W:True,num_iter=3,batch_size=100):
    result_dict = {}
    target_weights = [W for W in model.trainable_variables if target_func(W)]
    target_names = [W.name for W in target_weights]
    for name,data in {"train":train_data,"valid":valid_data,"test":test_data}.items():
        if data is None : continue
        result_dict[name] = {}
        _,eigenvalues,gradients = calc_weightHessianEigen(data[0],data[1],model,target_weights,num_iter=num_iter,batch_size=100)
        for tgt_name, grad_value, hess_value in zip(target_names,gradients,eigenvalues):
            result_dict[name][tgt_name] = (grad_value, hess_value)
    return result_dict

def kld_loss_func(p1,p2,eps=1e-10):
    return tf.reduce_mean(tf.reduce_sum(p1*(tf.math.log(p1+eps)-tf.math.log(p2+eps)),1))