import numpy as np
import tensorflow as tf

def spectral_norm(pWeight, pIteration=1):
    """Function to use spectral normalization on a weight tensor.

    Args:
        pWeight (float tensor): Weight tensor.
        pIteration (int): Number of iterations used.
    Returns:
        float tensor: Normalized weight tensor.
    """

    wShape = pWeight.shape.as_list()
    w = tf.reshape(pWeight, [-1, wShape[-1]])

    weightName = pWeight.name.split('/')[-1]
    weightName = weightName.split(':')[0]
    u = tf.get_variable(weightName+"_u", [1, wShape[-1]], 
        initializer=tf.random_normal_initializer(), 
        trainable=False)

    uHat = u
    vHat = None
    for i in range(pIteration):
        
        vHat = tf.nn.l2_normalize(
            tf.matmul(uHat, tf.transpose(w))) 
        uHat = tf.nn.l2_normalize(
            tf.matmul(vHat, w))

    uHat = tf.stop_gradient(uHat)
    vHat = tf.stop_gradient(vHat)

    sigma = tf.matmul(tf.matmul(vHat, w), tf.transpose(uHat))

    with tf.control_dependencies([u.assign(uHat)]):
        wNorm = w / sigma
        wNorm = tf.reshape(wNorm, wShape)

    return wNorm
