# Copyright IRT Antoine de Saint Exupéry et Université Paul Sabatier Toulouse III - All
# rights reserved. DEEL is a research program operated by IVADO, IRT Saint Exupéry,
# CRIAQ and ANITI - https://www.deel.ai/
# =====================================================================================
"""
This module contains computation function, for Bjorck and spectral
normalization. This is done for internal use only.
"""
import tensorflow as tf
from tensorflow.keras import backend as K
import numpy as np
from tensorflow_riemopt.manifolds import StiefelCayley
import tensorflow_riemopt as riemopt
from .utils import padding_circular, transposeKernel, zero_upscale2D
import sys
DEFAULT_NITER_BJORCK = 15
DEFAULT_NITER_SPECTRAL = 5
DEFAULT_NITER_SPECTRAL_INIT = 10
DEFAULT_BETA_BJORCK = 0.25
DEFAULT_EPS_SPECTRAL = 1e-3
DEFAULT_EPS_BJORCK = 5e-2
SWAP_MEMORY = True
STOP_GRAD_SPECTRAL = False
GRAD_PASSTHROUGH_BJORCK = False

def set_swap_memory(value: bool):
    """
    Set the global SWAP_MEMORY to values. This function must be called before
    constructing the model (first call of `reshaped_kernel_orthogonalization`) in
    order to be accounted.
    Args:
        value: boolean that will be used as the swap_memory parameter in while loops
            in spectral and bjorck algorithms.
    """
    global SWAP_MEMORY
    SWAP_MEMORY = value


def set_stop_grad_spectral(value: bool):
    """
    Set the global STOP_GRAD_SPECTRAL to values. This function must be called before
    constructing the model (first call of `reshaped_kernel_orthogonalization`) in
    order to be accounted.
    Args:
        value: boolean, when set to True, disable back-propagation through the power
            iteration algorithm. The back-propagation will account how updates affects
            the maximum singular value but not how it affects the largest singular
            vector. When set to False, back-propagate through the while loop.
    """
    global STOP_GRAD_SPECTRAL
    STOP_GRAD_SPECTRAL = value


def set_grad_passthrough_bjorck(value: bool):
    """
    Set the global GRAD_PASSTHROUGH_BJORCK to value. This function must be called before
    constructing the model (first call of `reshaped_kernel_orthogonalization`) in
    order to be accounted.
    Args:
        value: boolean, when set to True, only back-propagate through a single bjorck
            iteration. In other words it will act with the usual variable number of
            iterations during forward but will have bjorck_niter=1 during backward.
            This allows to save time and memory. When set to False, back-propagate
            through the bjorck loop.
    """
    global GRAD_PASSTHROUGH_BJORCK
    GRAD_PASSTHROUGH_BJORCK = value


@tf.function
def get_operator_norm(conv, w, h):
    """

    Args:
        conv:
        inp_shape:
        beta: if beta is 1 we project tatally

    Returns:

    """
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    conv_shape = conv.get_shape().as_list()
    padding = tf.constant(
        [
            [0, 0],
            [0, 0],
            [0, w - conv_shape[0]],
            [0, h - conv_shape[1]],
        ]
    )
    conv_tr_padded = tf.pad(conv_tr, padding)
    # apply FFT
    transform_coeff = tf.math.real(tf.signal.fft2d(conv_tr_padded))
    D = tf.linalg.svd(tf.transpose(transform_coeff, perm=[2, 3, 0, 1]), compute_uv = False)
    # D = tf.linalg.eigvals(tf.transpose(transform_coeff, perm = [2, 3, 0, 1]))
    norm = tf.reduce_max(D)
    return norm

@tf.function
def get_operator_stats(conv, w, h):
    """

    Args:
        conv:
        inp_shape:
        beta: if beta is 1 we project tatally

    Returns:

    """
    conv_tr = tf.cast(tf.transpose(conv, perm=[2, 3, 0, 1]), tf.complex64)
    conv_shape = conv.get_shape().as_list()
    padding = tf.constant(
        [
            [0, 0],
            [0, 0],
            [0, w - conv_shape[0]],
            [0, h - conv_shape[1]],
        ]
    )
    conv_tr_padded = tf.pad(conv_tr, padding)
    # apply FFT
    transform_coeff = tf.math.real(tf.signal.fft2d(conv_tr_padded))
    D = tf.linalg.svd(tf.transpose(transform_coeff, perm=[2, 3, 0, 1]), compute_uv = False)
    # D = tf.linalg.eigvals(tf.transpose(transform_coeff, perm = [2, 3, 0, 1]))
    norm = tf.reduce_max(D)
    
    return tf.reduce_max(D), tf.reduce_min(D), tf.reduce_mean(D)
def get_operator_stats_np(conv, input_shape):
    print(input_shape[1:3])
    kernel_n = conv.astype(dtype="float32")
    input_size = kernel_n.shape[2]
    transforms = np.fft.fft2(kernel_n, input_shape[1:3], axes=[0, 1])
    svd = np.linalg.svd(transforms, compute_uv=False)
    SVmin = np.min(svd)
    SVmax = np.max(svd)
    SVmean = np.mean(svd)
    return SVmax, SVmax, SVmean

def aol_conv2d_rescale(kernel_parameters):
    """
    Takes a convolutional parameter kernel as an input,
    and returns the rescaled version of it
    that guarantes the convolutions to be 1-Lipschitz
    (with respect to the L2 norm).
    """
    channel_rescaling_values = get_aol_conv2d_rescale(kernel_parameters)
    rescaled_kernel_weights \
        = kernel_parameters * channel_rescaling_values[None, None, :, None]
    return rescaled_kernel_weights


def get_aol_conv2d_rescale(kernel_parameters, epsilon=1e-6):
    w = kernel_parameters  # shape: [ks1, ks2, nrof_in, nrof_out]

    # For each i and j, we want to convolve kernel[:, :, i, :]
    # with kernel[:, :, j, :] in order to calculate the bound.
    # We can do this for all i and j in parallel by using the
    # batch dimension and the output dimension of a standard
    # implementation of a convolution. (tf.nn.conv2d)

    w_input_dimension_as_batch = tf.transpose(w, [2, 0, 1, 3])
    w_input_dimension_as_output = tf.transpose(w, [0, 1, 3, 2])

    # Padding needed to guarantee to pick up any positions with kernel overlap:
    p1 = w.shape[0] - 1  # kernel_size1 - 1
    p2 = w.shape[1] - 1  # kernel_size2 - 1

    v = tf.nn.conv2d(
        input=w_input_dimension_as_batch,
        filters=w_input_dimension_as_output,
        strides=[1, 1, 1, 1],
        padding=[[0, 0], [p1, p1], [p2, p2], [0, 0]]
    )  # shape [nrof_in, 2*ks1-1, 2*ks2-1, nrof_in]

    # Sum the absolute value of v over one of the input
    # channel dimension (axis 3),
    # as well as over the spatial dimensions (axis 1 and 2):
    lipschitz_bounds_squared = tf.reduce_sum(tf.abs(v),
                                             axis=(1, 2, 3))  # shape [nrof_in]
    rescaling_factors = (lipschitz_bounds_squared + epsilon) ** (-1 / 2)
    return rescaling_factors


@tf.function
def stiefel_project(kernel):

    manifold = riemopt.manifolds.StiefelCayley()
    W_shape = kernel.shape
    if W_shape[0] >W_shape[1]:
        return manifold.projx(kernel)
    kernel =tf.transpose(manifold.projx(tf.transpose(kernel)))


    return kernel

@tf.function
def reshaped_kernel_orthogonalization_dense_no_grad(
    kernel,
    u,
    adjustment_coef,
    niter_spectral=DEFAULT_NITER_SPECTRAL,
    niter_bjorck=DEFAULT_NITER_BJORCK,
    beta=DEFAULT_BETA_BJORCK,stop_gradient = False
):


    W_bar, u, sigma = spectral_normalization_dense(kernel, u, niter=niter_spectral)
    if stop_gradient:
        sigma = tf.stop_gradient(sigma)
    W_bar = W_bar/sigma
    if niter_bjorck>0:
        W_bar = bjorck_normalization(W_bar, niter=2, beta=beta)
        W_bar = tf.stop_gradient(bjorck_normalization(W_bar, niter=niter_bjorck-2, beta=beta))

    #W_bar = stiefel_project(kernel)
    W_bar = W_bar * adjustment_coef
    return W_bar, u, sigma


@tf.function
def reshaped_kernel_orthogonalization_dense(
    kernel,
    u,
    adjustment_coef,
    niter_spectral=DEFAULT_NITER_SPECTRAL,
    niter_bjorck=DEFAULT_NITER_BJORCK,
    beta=DEFAULT_BETA_BJORCK,stop_gradient = False
):


    W_bar, u, sigma = spectral_normalization_dense(kernel, u, niter=niter_spectral)
    if stop_gradient:
        sigma = tf.stop_gradient(sigma)
    W_bar = W_bar/sigma
    if niter_bjorck>0:
        W_bar = bjorck_normalization(W_bar, niter=niter_bjorck, beta=beta)

    #W_bar = stiefel_project(kernel)
    W_bar = W_bar * adjustment_coef
    return W_bar, u, sigma

@tf.function
def spectral_normalization_dense(kernel, u, niter=DEFAULT_NITER_SPECTRAL):
    """
    Normalize the kernel to have it's max eigenvalue == 1.

    Args:
        kernel: the kernel to normalize
        u: initialization for the max eigen vector
        niter: number of iteration

    Returns:
        the normalized kernel w_bar, it's shape, the maximum eigen vector, and the
        maximum eigen value

    """
    _u, _v = _power_iteration(kernel, u, niter=niter)
    # Calculate Sigma
    sigma = _v @ kernel
    sigma = sigma @ tf.transpose(_u)
    # normalize it
    #tf.print("shape",kernel.shape,"sigma ",sigma,output_stream=sys.stdout)
    W_bar = kernel
    return W_bar, _u, sigma

def reshaped_kernel_orthogonalization(
    kernel,
    u,
    adjustment_coef,
    niter_spectral=DEFAULT_NITER_SPECTRAL,
    niter_bjorck=DEFAULT_NITER_BJORCK,
    beta=DEFAULT_BETA_BJORCK,
):
    """
    Perform reshaped kernel orthogonalization (RKO) to the kernel given as input. It
    apply the power method to find the largest singular value and apply the Bjorck
    algorithm to the rescaled kernel. This greatly improve the stability and and
    speed convergence of the bjorck algorithm.

    Args:
        kernel: the kernel to orthogonalize
        u: the vector used to do the power iteration method
        adjustment_coef: the adjustment coefficient as used in convolution
        niter_spectral: number of iteration to do in spectral algorithm
        niter_bjorck: iteration used for bjorck algorithm
        beta: the beta used in the bjorck algorithm

    Returns: the orthogonalized kernel, the new u, and sigma which is the largest
        singular value

    """

    W_bar, u, sigma = spectral_normalization(kernel, u, niter=niter_spectral)
    if niter_bjorck>0:
        W_bar = bjorck_normalization(W_bar, niter=niter_bjorck, beta=beta)

    #W_bar = stiefel_project(kernel)
    W_bar = W_bar * adjustment_coef
    W_bar = K.reshape(W_bar, kernel.shape)
    return W_bar, u, sigma

@tf.function 
def reshaped_depth_kernel_orthogonalization(
    kernel,
    niter_bjorck=7):
  shape = tf.shape(kernel)
  kernel = tf.reshape(kernel,(shape[0],shape[1],shape[2]))
  kernel = tf.transpose(kernel,perm = [2,0,1])
  kernel = bjork_multi_tf(kernel, niter = niter_bjorck)
  kernel = tf.transpose(kernel,perm = [1,2,0])
  kernel = tf.reshape(kernel,shape)
  return kernel

@tf.function  
def bjork_multi_tf(w, niter=20,beta = 0.5):
  i = tf.constant(0)
  def cond(w, i):
        return i < niter

  # define the loop body
  def body(w, i):
      old_w = w
      w = (1 + beta) * w - beta * tf.linalg.matmul(w, tf.linalg.matmul(tf.transpose(w,perm =[0,2,1]),w))
      return w, i+1
  w, i = tf.while_loop(cond,
                        body, (w, i ),
                        parallel_iterations=1,
                        maximum_iterations=20)

  return w

@tf.function 
def bjorck_1(w,eps=DEFAULT_EPS_BJORCK, niter=DEFAULT_NITER_BJORCK, beta=DEFAULT_BETA_BJORCK):
    i = tf.constant(0)
   
    def cond(w, i):
        return i < niter

    # define the loop body
    def body(w, i):
        old_w = w
        w = (1 + beta) * w - beta *(( w @ tf.transpose(w)) @ w)
        return w, i+1


    w, i = tf.while_loop(cond,
                         body, (w, i),
                         parallel_iterations=1,
                         maximum_iterations=20)

    #tf.print("shape",w.shape,'bjork',tf.linalg.norm(w - old_w),tf.linalg.norm(w),output_stream=sys.stdout )
    return w

@tf.function 
def bjorck_2(w,eps=DEFAULT_EPS_BJORCK, niter=DEFAULT_NITER_BJORCK, beta=DEFAULT_BETA_BJORCK):
    i = tf.constant(0)
   
    def cond(w, i):
        return i < niter

    # define the loop body
    def body(w, i):
        old_w = w
        w = (1 + beta) * w - beta *( w @ (tf.transpose(w) @ w))
        return w, i+1


    w, i = tf.while_loop(cond,
                         body, (w, i ),
                         parallel_iterations=1,
                         maximum_iterations=20)

    #tf.print("shape",w.shape,'bjork',tf.linalg.norm(w - old_w),tf.linalg.norm(w),output_stream=sys.stdout )
    return w

def _wwtw(w):
    if w.shape[0] > w.shape[1]:
        return w @ (tf.transpose(w) @ w)
    else:
        return (w @ tf.transpose(w)) @ w


def bjorck_normalization_old(
    w, eps=DEFAULT_EPS_BJORCK, beta=DEFAULT_BETA_BJORCK, niter=DEFAULT_NITER_BJORCK
):
    """
    apply Bjorck normalization on w.
    Args:
        w (tf.Tensor): weight to normalize, in order to work properly, we must have
            max_eigenval(w) ~= 1
        eps (float): epsilon stopping criterion: norm(wt - wt-1) must be less than eps
        beta (float): beta used in each iteration, must be in the interval ]0, 0.5]
        maxiter (int): maximum number of iterations for the algorithm
    Returns:
        tf.Tensor: the orthonormal weights
    """
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the first action done in the loop overwrite it.
    old_w = w
    w = (1 + beta) * w - beta * _wwtw(w)
    # define the loop condition

    def cond(w, old_w):
        return tf.linalg.norm(w - old_w) >= eps

    # define the loop body
    def body(w, old_w):
        old_w = w
        w = (1 + beta) * w - beta * _wwtw(w)
        return w, old_w

    # apply the loop
    def while_func(w, old_w):
        return tf.while_loop(
        cond,
        body,
        (w, old_w),
        parallel_iterations=30,
        maximum_iterations=niter,
        swap_memory=SWAP_MEMORY,
        )
    if GRAD_PASSTHROUGH_BJORCK:
        w, old_w = tf.grad_pass_through(while_func)(w, old_w)
    else:
        w, old_w = while_func(w, old_w)
    return w

@tf.function
def bjorck_normalization(w,eps=DEFAULT_EPS_BJORCK, niter=DEFAULT_NITER_BJORCK, beta=DEFAULT_BETA_BJORCK):
    """
    apply Bjorck normalization on w.
    Args:
        w: weight to normalize, in order to work properly, we must have
            max_eigenval(w) ~= 1
        eps: epsilon stopping criterion: norm(wt - wt-1) must be less than eps
        beta: beta used in each iteration, must be in the interval ]0, 0.5]
    Returns:
        the orthonormal weights
    """


    shape = tf.shape(w)

    w =tf.cond(tf.less(shape[0], shape[1]),
               true_fn = lambda: bjorck_1(w,eps=eps, niter=niter, beta=beta),
               false_fn =lambda: bjorck_2(w,eps=eps, niter=niter, beta=beta))
    #tf.print("shape",w.shape,'bjork',tf.linalg.norm(w - old_w),tf.linalg.norm(w),output_stream=sys.stdout )
    return w

@tf.function
def bjorck_normalization_opti(w,eps=DEFAULT_EPS_BJORCK, niter=DEFAULT_NITER_BJORCK, beta=DEFAULT_BETA_BJORCK):
    """
    apply Bjorck normalization on w.
    Args:
        w: weight to normalize, in order to work properly, we must have
            max_eigenval(w) ~= 1
        eps: epsilon stopping criterion: norm(wt - wt-1) must be less than eps
        beta: beta used in each iteration, must be in the interval ]0, 0.5]
    Returns:
        the orthonormal weights
    """
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the first action done in the loop overwrite it.
    old_w = 10 * w
    # define the loop condition

    def cond(w, old_w):
        return tf.linalg.norm(w - old_w) >= eps

    # define the loop body
    def body(w, old_w):
        old_w = w
        w = (1 + beta) * w - beta * w @ tf.transpose(w) @ w
        return w, old_w

    # apply the loop
    w, old_w = tf.while_loop(
        cond, body, (w, old_w), parallel_iterations=1, maximum_iterations=20
    )
    #tf.print("shape",w.shape,'bjork',tf.linalg.norm(w - old_w),tf.linalg.norm(w),output_stream=sys.stdout )
    return w

@tf.function
def _power_iteration_(w, u, niter=DEFAULT_NITER_SPECTRAL, eps=DEFAULT_EPS_SPECTRAL):
    """
    Internal function that performs the power iteration algorithm.
    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps
    Returns:
         u and v corresponding to the maximum eigenvalue
    """
    # build _u and _v
    _u = u
    _v = tf.math.l2_normalize(_u @ tf.transpose(w))
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the firt action done in the loop overwrite it.
    # define the loop condition

    i = tf.constant(0)
    # define the loop condition

    def cond(_u, _v, i):
        return i < niter

    # define the loop body
    def body(_u, _v, i):
        _old_u = _u
        _v = tf.math.l2_normalize(_u @ tf.transpose(w))
        _u = tf.math.l2_normalize(_v @ w)
        return _u, _v, i+1

    # apply the loop
    _u, _v, _old_u = tf.while_loop(
        cond, body, (_u, _v, i), parallel_iterations=1, maximum_iterations=20
    )
    return _u, _v
    #return tf.cond(tf.greater(tf.linalg.norm(_u - _old_u), eps),
    #               lambda: (_u,tf.constant(1.2,shape=_v.shape)), lambda: (_u, _v))

@tf.function
def _power_iteration_opti(w, u, niter=DEFAULT_NITER_SPECTRAL, eps=DEFAULT_EPS_SPECTRAL):
    """
    Internal function that performs the power iteration algorithm.
    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps
    Returns:
         u and v corresponding to the maximum eigenvalue
    """
    # build _u and _v
    _u = u
    _v = tf.math.l2_normalize(_u @ tf.transpose(w))
    # create a fake old_w that does'nt pass the loop condition
    # it won't affect computation as the firt action done in the loop overwrite it.
    _old_u = 10 * _u
    # define the loop condition

    def cond(_u, _v, old_u):
        return tf.linalg.norm(_u - old_u) >= eps

    # define the loop body
    def body(_u, _v, _old_u):
        _old_u = _u
        _v = tf.math.l2_normalize(_u @ tf.transpose(w))
        _u = tf.math.l2_normalize(_v @ w)
        return _u, _v, _old_u

    # apply the loop
    _u, _v, _old_u = tf.while_loop(
        cond, body, (_u, _v, _old_u), parallel_iterations=1, maximum_iterations=20
    )
    return _u, _v
    #return tf.cond(tf.greater(tf.linalg.norm(_u - _old_u), eps),
    #               lambda: (_u,tf.constant(1.2,shape=_v.shape)), lambda: (_u, _v))


@tf.function
def spectral_normalization_old(kernel, u, niter=DEFAULT_NITER_SPECTRAL):
    """
    Normalize the kernel to have it's max eigenvalue == 1.

    Args:
        kernel: the kernel to normalize
        u: initialization for the max eigen vector
        niter: number of iteration

    Returns:
        the normalized kernel w_bar, it's shape, the maximum eigen vector, and the
        maximum eigen value

    """
    W_shape = kernel.shape
    if u is None:
        niter *= 2  # if u was not known increase number of iterations
        u = tf.ones(shape=tuple([1, W_shape[-1]]))
    # Flatten the Tensor
    W_reshaped = tf.reshape(kernel, [-1, W_shape[-1]])
    _u, _v = _power_iteration(W_reshaped, u, niter=niter)
    # Calculate Sigma
    sigma = _v @ W_reshaped
    sigma = sigma @ tf.transpose(_u)
    # normalize it
    #tf.print("shape",kernel.shape,"sigma ",sigma,output_stream=sys.stdout)
    W_bar = W_reshaped / sigma
    return W_bar, _u, sigma


def _power_iteration(w, u, eps=DEFAULT_EPS_SPECTRAL, niter=DEFAULT_NITER_SPECTRAL):
    """
    Internal function that performs the power iteration algorithm.
    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        eps: epsilon stopping criterion: norm(ut - ut-1) must be less than eps
        niter: maximum number of iterations for the algorithm
    Returns:
         u and v corresponding to the maximum eigenvalue
    """
    # build _u and _v (_v is size of _u@tf.transpose(w), will be set on the first body
    # iteration)
    if u is None:
        u = tf.linalg.l2_normalize(
            tf.random.uniform(
                shape=(1, w.shape[-1]), minval=0.0, maxval=1.0, dtype=w.dtype
            )
        )
    _u = u
    _v = tf.zeros((1,) + (w.shape[0],), dtype=w.dtype)

    # create a fake old_w that doesn't pass the loop condition
    # it won't affect computation as the first action done in the loop overwrite it.
    _old_u = 10 * _u

    # define the loop condition
    def cond(_u, _v, old_u):
        return tf.linalg.norm(_u - old_u) >= eps

    # define the loop body
    def body(_u, _v, _old_u):
        _old_u = _u
        _v = tf.math.l2_normalize(_u @ tf.transpose(w))
        _u = tf.math.l2_normalize(_v @ w)
        return _u, _v, _old_u

    # apply the loop
    _u, _v, _old_u = tf.while_loop(
        cond,
        body,
        (_u, _v, _old_u),
        parallel_iterations=30,
        maximum_iterations=niter,
        swap_memory=SWAP_MEMORY,
    )
    if STOP_GRAD_SPECTRAL:
        _u = tf.stop_gradient(_u)
        _v = tf.stop_gradient(_v)
    return _u, _v


def spectral_normalization(
    kernel, u, eps=DEFAULT_EPS_SPECTRAL, niter=DEFAULT_NITER_SPECTRAL
):
    """
    Normalize the kernel to have it's max eigenvalue == 1.
    Args:
        kernel (tf.Tensor): the kernel to normalize, assuming a 2D kernel
        u (tf.Tensor): initialization for the max eigen vector
        eps (float): epsilon stopping criterion: norm(ut - ut-1) must be less than eps
        maxiter (int): maximum number of iterations for the algorithm
    Returns:
        the normalized kernel w_bar, the maximum eigen vector, and the maximum singular
            value.
    """
    _u, _v = _power_iteration(kernel, u, eps, niter)
    # compute Sigma
    sigma = _v @ kernel
    sigma = sigma @ tf.transpose(_u)
    # normalize it
    # we assume that in the worst case we converged to sigma + eps (as u and v are
    # normalized after each iteration)
    # in order to be sure that operator norm of W_bar is strictly less than one we
    # use sigma + eps, which ensure stability of the bjorck even when beta=0.5
    W_bar = kernel / (sigma + eps)
    return W_bar, _u,  sigma

@tf.function
def _power_iteration_conv(w, u, stride = 1.0, conv_first = True, w_pad=None,h_pad = None,niter=DEFAULT_NITER_SPECTRAL, bigConstant=-1):
    """
    Internal function that performs the power iteration algorithm.

    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        niter: number of iteration, must be greater than 0

    Returns:
         u and v corresponding to the maximum eigenvalue

    """
    def iter_f(u):
        u=u/tf.norm(u)
        if w_pad is None:
            padType = 'SAME'
        else:
            padType='VALID'

        if conv_first:
            u_pad=padding_circular(u,w_pad,h_pad)
            v= tf.nn.conv2d(u_pad,w,padding=padType,strides=(1,stride,stride,1))
            v1 = zero_upscale2D(v,(stride,stride))
            v1=padding_circular(v1,w_pad,h_pad)
            wAdj=transposeKernel(w,True)
            unew=tf.nn.conv2d(v1,wAdj,padding=padType,strides=1)
        else:
            u1 = zero_upscale2D(u,(stride,stride))
            u_pad=padding_circular(u1,w_pad,h_pad)
            wAdj=transposeKernel(w,True)
            v=tf.nn.conv2d(u_pad,wAdj,padding=padType,strides=1)
            v1=padding_circular(v,w_pad,h_pad)
            unew= tf.nn.conv2d(v1,w,padding=padType,strides=(1,stride,stride,1))
        if bigConstant> 0:
            unew = bigConstant*u-unew
        return unew,v

    _u = u
    for i in range(niter):
        _u,_v = iter_f(_u)
    return _u, _v

@tf.function
def _power_iteration_conv_test(w, u, stride = 1.0, conv_first = True, w_pad=None,h_pad = None, niter=DEFAULT_NITER_SPECTRAL, bigConstant=-1):
    """
    Internal function that performs the power iteration algorithm.

    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        niter: number of iteration, must be greater than 0

    Returns:
         u and v corresponding to the maximum eigenvalue

    """
    def cond(u,v_tmp, i):
        return i < niter
    def body(u,v_tmp,i):
        u=u/tf.norm(u)
        if w_pad is None:
            padType = 'SAME'
        else:
            padType='VALID'

        if conv_first:
            u_pad=padding_circular(u,w_pad,h_pad)
            v= tf.nn.conv2d(u_pad,w,padding=padType,strides=(1,stride,stride,1))
            v1 = zero_upscale2D(v,(stride,stride))
            v1=padding_circular(v1,w_pad,h_pad)
            wAdj=transposeKernel(w,True)
            unew=tf.nn.conv2d(v1,wAdj,padding=padType,strides=1)
        else:
            u1 = zero_upscale2D(u,(stride,stride))
            u_pad=padding_circular(u1,w_pad,h_pad)
            wAdj=transposeKernel(w,True)
            v=tf.nn.conv2d(u_pad,wAdj,padding=padType,strides=1)
            v1=padding_circular(v,w_pad,h_pad)
            unew= tf.nn.conv2d(v1,w,padding=padType,strides=(1,stride,stride,1))
        if bigConstant> 0:
            unew = bigConstant*u-unew
        return unew,v,i+1
    i=tf.constant(0)
    _u = u
    _,_,N_in,N_out = w.shape
    k1,k2,U_in,U_out = u.shape
    #tf.print(N_out)
    if conv_first:
        _v=tf.zeros((k1,k2+3*(stride-1),U_in+3*(stride-1),N_out))
    else :
        _v=tf.zeros((k1,k2+3*(stride-1),U_in+3*(stride-1),N_in))
    _u,_v,i = tf.while_loop(
        cond, body, (_u,_v,i), parallel_iterations=1, maximum_iterations=20
    )

    #for i in range(niter):
    #    _u,_v = iter_f(_u,None)
    return _u, _v


@tf.function
def _power_iteration_conv_optim(w, u, stride = 1.0, conv_first = True, cPad=None, niter=DEFAULT_NITER_SPECTRAL, bigConstant=-1, eps=DEFAULT_EPS_SPECTRAL):
    """
    Internal function that performs the power iteration algorithm.

    Args:
        w: weights matrix that we want to find eigen vector
        u: initialization of the eigen vector
        niter: number of iteration, must be greater than 0

    Returns:
         u and v corresponding to the maximum eigenvalue

    """


    def cond(_u,  _v, old_u):
        return tf.linalg.norm(_u - old_u) >= eps

    def iter_f(_u, _v, _old_u):
        _old_u = _u

        _u=_u/tf.norm(_u)
        if cPad is None:
            padType = 'SAME'
        else:
            padType='VALID'

        if conv_first:
            u_pad=padding_circular(_u,cPad)
            _v= tf.nn.conv2d(u_pad,w,padding=padType,strides=(1,stride,stride,1))
            v1 = zero_upscale2D(_v,(stride,stride))
            v1=padding_circular(v1,cPad)
            wAdj=transposeKernel(w,True)
            unew=tf.nn.conv2d(v1,wAdj,padding=padType,strides=1)
        else:
            u1 = zero_upscale2D(_u,(stride,stride))
            u_pad=padding_circular(u1,cPad)
            wAdj=transposeKernel(w,True)
            _v=tf.nn.conv2d(u_pad,wAdj,padding=padType,strides=1)
            v1=padding_circular(_v,cPad)
            unew= tf.nn.conv2d(v1,w,padding=padType,strides=(1,stride,stride,1))
        if bigConstant> 0:
            unew = bigConstant*_u-unew
        return unew,_v,_old_u

    _u = u
    _,_,N_in,N_out = w.shape
    k1,k2,U_in,U_out = u.shape
    #tf.print(N_out)
    if conv_first:
        _v=tf.zeros((k1,k2+3*(stride-1),U_in+3*(stride-1),N_out))
    else :
        _v=tf.zeros((k1,k2+3*(stride-1),U_in+3*(stride-1),N_in))
    #tf.print("_v start",_v.shape)
    _old_u = 10 * _u
    #for i in range(niter):
    #    _u,_v,_old_u = iter_f(_u, _v, _old_u)
    _u, _v, _old_u = tf.while_loop(
        cond, iter_f, (_u, _v, _old_u), parallel_iterations=1, maximum_iterations=20
    )
    #if tf.linalg.norm(_u - old_u) >= eps :
    #    return _u,
    #tf.print( "sigma",w.shape,_u.shape,_v.shape, tf.norm(_v))
    #_v = tf.norm(_v)
    #return tf.cond(tf.greater(tf.linalg.norm(_u - _old_u), eps),
    #               lambda: (_u,tf.constant(1.2,shape=_v.shape)),
    #               lambda: (_u, _v))
    return _u, _v


@tf.function
def spectral_normalization_conv(kernel, u=None, stride = 1.0, conv_first = True,  w_pad=None,h_pad = None, niter=DEFAULT_NITER_SPECTRAL):
    """
    Normalize the convolution kernel to have it's max eigenvalue == 1.

    Args:
        kernel: the convolution kernel to normalize
        u: initialization for the max eigen matrix
        stride: stride parameter of convolutuions
        conv_first: RO or CO case stride^2*C<M
        cPad: Circular padding (k//2,k//2)
        niter: number of iteration

    Returns:
        the normalized kernel w_bar, it's shape, the maximum eigen vector, and the
        maximum eigen value

    """
    
    if u is None:
        W_shape = kernel.shape
        niter *= 2  # if u was not known increase number of iterations
        u = K.random_normal(shape=tuple([1, W_shape[-1]]))
    # Flatten the Tensor
    #W_reshaped = K.reshape(kernel, [-1, W_shape[-1]])
    if niter <= 0:
        return kernel, u, 1.0
    _u, _v = _power_iteration_conv(kernel, u, stride = stride, conv_first = conv_first, w_pad=w_pad,h_pad=h_pad, niter=niter)
    # Calculate Sigma
    sigma = tf.norm(_v)
    W_bar = kernel
    return W_bar, _u, sigma


#@tf.function
def spectral_normalization_r(kernel, u, k_shape,k_flat,niter=DEFAULT_NITER_SPECTRAL):
    """
    Normalize the kernel to have it's max eigenvalue == 1.

    Args:
        kernel: the kernel to normalize
        u: initialization for the max eigen vector
        niter: number of iteration

    Returns:
        the normalized kernel w_bar, it's shape, the maximum eigen vector, and the
        maximum eigen value

    """
    #W_shape = kernel.shape
    if u is None:
        niter *= 2  # if u was not known increase number of iterations
        u = tf.ones(shape=k_flat)
    _u, _v = _power_iteration(kernel, u, niter=niter)
    # Calculate Sigma
    sigma = _v @ kernel
    sigma = sigma @ tf.transpose(_u)
    # normalize it
    #tf.print("shape",kernel.shape,"sigma ",sigma,output_stream=sys.stdout)
    return  sigma

@tf.function
def bjork_normalization_conv(kernel,k_shape,k_flat, u=None, stride = 1.0, conv_first = True, cPad=None, niter=DEFAULT_NITER_SPECTRAL):
    wbar, _, sigma =  reshaped_kernel_orthogonalization(kernel, None,1,niter_spectral=3,niter_bjorck=5,beta=DEFAULT_BETA_BJORCK,
    )
    _u, _v = _power_iteration_conv(wbar, u, stride = stride, conv_first = conv_first, cPad=cPad, niter=niter)
    # Calculate Sigma
    sigma = tf.norm(_v)
    W_bar = wbar
    return W_bar, _u, sigma

def get_spectral_input_shape(kernel, stride):
        (R0,R,C,M) = kernel.shape
        cPad=[int(R0/2),int(R/2)]

        ##Compute minimal N
        r = R//2
        if r<1:
            N=5
        else:
            N = 4*r+1
            if stride>1:
                N = int(0.5+N/stride)
        #FM 01 N = 6*N
        #FM 01 print("test N = 2*N")
        #FM 01 print("self.niter_spectral",self.niter_spectral)
        if C*stride**2>M:
            spectral_input_shape = (N,N,M)
            RO_case = True
        else:
            spectral_input_shape = (stride*N,stride*N,C)
            RO_case = False
       # print("conv_first ",not self.RO_case)
        #FM 01 print("self.spectral_input_shape ",self.spectral_input_shape)
        usize = np.prod(spectral_input_shape)
        return cPad[0],cPad[1], RO_case, spectral_input_shape,usize
        #FM 01 print("self.usize ",self.usize)


def compute_sigma(kernel, stride = 1, u = None, niter=DEFAULT_NITER_SPECTRAL):

    W_shape = kernel.shape
    w_pad,h_pad, RO_case, spectral_input_shape,usize = get_spectral_input_shape(kernel, stride)
    if u is None:
        u =tf.random.normal(shape=(1,)+spectral_input_shape)

    _u, _v = _power_iteration_conv(kernel, u, stride = stride, conv_first = not RO_case , w_pad=w_pad,h_pad = h_pad, niter=20)
    #tf.print(_u)
    # Calculate Sigma
    sigma = tf.norm(_v)
    return sigma


def get_spectral_u(kernel,stride = 1):
        (R0,R,C,M) = kernel.shape
        cPad=[int(R0/2),int(R/2)]
        r = R//2
        #print(r)
        if r<1:
            N=5
        else:
            N = 4*r+1
            if stride >1:
                N = int(0.5+N/stride)
        #FM 01 N = 6*N
        #FM 01 print("test N = 2*N")
        #FM 01 print("self.niter_spectral",self.niter_spectral)
        if C*stride**2>M:
            spectral_input_shape = (N,N,M)
            #print("false")
        else:
            spectral_input_shape = (stride*N,stride*N,C)
            #print("True")

        return  tf.Variable(tf.random_normal_initializer(mean=0.0, stddev=1)(shape=(1,)+spectral_input_shape,dtype = tf.float32)),cPad

def spectral_conv_value(kernel, stride = 1.0, conv_first = False, niter=DEFAULT_NITER_SPECTRAL):
    """
    Normalize the convolution kernel to have it's max eigenvalue == 1.

    Args:
        kernel: the convolution kernel to normalize
        u: initialization for the max eigen matrix
        stride: stride parameter of convolutuions
        conv_first: RO or CO case stride^2*C<M
        cPad: Circular padding (k//2,k//2)
        niter: number of iteration

    Returns:
        the normalized kernel w_bar, it's shape, the maximum eigen vector, and the
        maximum eigen value

    """
    '''W_shape = kernel.shape
    if u is None:
        niter *= 2  # if u was not known increase number of iterations
        u = K.random_normal(shape=tuple([1, W_shape[-1]]))
    # Flatten the Tensor
    W_reshaped = K.reshape(kernel, [-1, W_shape[-1]])'''
    #print(niter)
    #if niter <= 0:
    #    return kernel, u, 1.0

    u,cPad = get_spectral_u(kernel)
    #print(u.shape)
    _u, sigma = _power_iteration_conv(kernel, u, stride = stride, conv_first = conv_first, cPad=cPad, niter=niter)
    # Calculate Sigma


    return  tf.norm(sigma)
