import abc

import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import (
    Layer,
    Dense,
    Conv2D,
    Lambda,
    AveragePooling2D,
    GlobalAveragePooling2D,
)

from .constraints import SpectralConstraint,FrobeniusConstraint,SpectralConvConstraint,SpectralStiefelConstraint
from .initializers import SpectralInitializer,FrobenusInitializer,StiefelInitializer
from .normalizers import (
    DEFAULT_NITER_BJORCK,
    DEFAULT_NITER_SPECTRAL,
    DEFAULT_NITER_SPECTRAL_INIT,
    reshaped_kernel_orthogonalization,
    reshaped_kernel_orthogonalization_dense,
    reshaped_depth_kernel_orthogonalization,
    bjork_normalization_conv,
    spectral_normalization_conv,
    spectral_normalization,
    DEFAULT_BETA_BJORCK,
)
from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras.initializers import GlorotUniform
from .regularizers import Lorth2D,LorthRegularizer
from tensorflow.keras.layers import DepthwiseConv2D

from tensorflow.keras.utils import register_keras_serializable
from tensorflow.keras import backend as K
import tensorflow as tf
from deel.lip.layers import LipschitzLayer, Condensable
import numpy as np
from tensorflow_riemopt.variable import assign_to_manifold
from tensorflow_riemopt.manifolds import StiefelCayley,StiefelEuclidean,StiefelCanonical
def padding_circular(x, cPad):
    if cPad is None:
        return x
    w_pad, h_pad = cPad
    if w_pad > 0:
        x = tf.concat((x[:, -w_pad:, :, :], x, x[:, :w_pad, :, :]), axis=1)
    if h_pad > 0:
        x = tf.concat((x[:, :, -h_pad:, :], x, x[:, :, :h_pad, :]), axis=2)
    return x


DEFAULT_ALPHA = 0.99
DEFAULT_INIT = False
DEFAULT_AVG = False
DEFAULT_CENTER = False
DEFAULT_STOP_GRADIENT = False

def set_init(value: bool):
    global DEFAULT_INIT
    DEFAULT_INIT = value
def set_avg(value: bool):
    global DEFAULT_AVG
    DEFAULT_AVG = value
def set_center(value: bool):
    global DEFAULT_CENTER
    DEFAULT_CENTER = value
def set_alpha(value: bool):
    global DEFAULT_ALPHA
    DEFAULT_ALPHA = value
def set_stop_gradient(value: bool):
    global DEFAULT_STOP_GRADIENT
    DEFAULT_STOP_GRADIENT = value

@register_keras_serializable("dlt", "BiasLayer")
class BiasLayer(tf.keras.layers.Layer):
    def __init__(self,pixelwise=False, channelwise=True, *args, **kwargs):
        super(BiasLayer, self).__init__(*args, **kwargs)
        self.pixelwise = pixelwise
        self.channelwise = channelwise
 
    def build(self, input_shape):
        if len(input_shape) == 4:
            if self.channelwise and self.pixelwise:
                shape=(1,) + input_shape[1:]
            elif self.pixelwise:
                shape=(1, input_shape[1], input_shape[2], 1)
            elif self.channelwise:
                shape=(1, 1, 1, input_shape[-1])
            else:
                shape=(1, 1, 1, 1)
        else:
            shape=(1,) * (len(input_shape) - 1) + (input_shape[-1],)
        self.bias = self.add_weight('bias',
                                    shape=shape,
                                    initializer='zeros',
                                    trainable=False)
        self.built = True
    def call(self, x):
        return x + self.bias


@register_keras_serializable("dlt", "LayerCentering")
class LayerCentering(tf.keras.layers.Layer,LipschitzLayer):

    def __init__(self, **kwargs):
        self.axes = [1, 2]
        self.center = True
        super().__init__(**kwargs)

    def _compute_lip_coef(self, input_shape=None):
        return 1.0  # this layer don't require a corrective factor
    def build(self, input_shape):
        super(LayerCentering, self).build(input_shape)
        if self.center :
            self.beta = self.add_weight(
            shape=(1, 1, 1, input_shape[-1]),
            initializer="zeros",
            name="beta",
            trainable=True,
            dtype=self.dtype,
            )
        self.built = True
       
   
    @tf.function
    def call(self, inputs, training=True, **kwargs):
       
        current_means = tf.reduce_mean(inputs, axis=self.axes, keepdims=True)
        x = inputs - current_means    
        if self.center:
            x = x + self.beta 
        return x
    


    def get_config(self):
        config = {
            
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

@register_keras_serializable("dlt", "BatchCentering")
class BatchCentering(tf.keras.layers.Layer,LipschitzLayer,Condensable):

    def __init__(self, pixelwise=False, channelwise=True, alpha=DEFAULT_ALPHA, stop_grad = DEFAULT_STOP_GRADIENT,center = DEFAULT_CENTER,init_first =DEFAULT_INIT, avg = DEFAULT_AVG,  rescale_grad = False, freeze = False, **kwargs):
        self.pixelwise = pixelwise
        self.channelwise = channelwise
        self.axes = None
        self.freeze = freeze
        self.alpha = alpha
        self.init_first = init_first
        self.avg = avg
        self.rescale_grad = rescale_grad
        self.epsilon = 0.001
        self.center = center
        self.stop_grad = stop_grad
        super().__init__(**kwargs)

    def condense(self):
        pass
    def _compute_lip_coef(self, input_shape=None):
        return 1.0  # this layer don't require a corrective factor
    def build(self, input_shape):
        super(BatchCentering, self).build(input_shape)
        if len(input_shape) == 4:
            if self.channelwise and self.pixelwise:
                shape=(1,) + input_shape[1:]
                self.axes = None
            elif self.pixelwise:
                shape=(1, input_shape[1], input_shape[2], 1)
                self.axes = [0, -1]
            elif self.channelwise:
                shape=(1, 1, 1, input_shape[-1])
                self.axes = [0, 1, 2]
            # elif self.channelwise:
            #     shape=(1, input_shape[1], input_shape[2], 1)
            #     self.axes = [0, -1]
            # elif self.pixelwise:
            #     shape=(1, 1, 1, input_shape[-1])
            #     self.axes = [0, 1, 2]
            else:
                shape=(1, 1, 1, 1)
                self.axes = [0, 1, 2, 3]
        else:
            shape=(1,) * (len(input_shape) - 1) + (input_shape[-1],)
            self.axes = range(len(input_shape) - 1)
        self.moving_mean = self.add_weight(
            shape=shape,
            initializer="zeros",
            name="moving_mean",
            trainable=False,
            dtype=self.dtype,
        )
        if self.center :
            self.beta = self.add_weight(
            shape=shape,
            initializer="zeros",
            name="beta",
            trainable=True,
            dtype=self.dtype,
            )
        if self.init_first :
            self.nb = self.add_weight(
                shape=(1),
                initializer="zeros",
                name="nb",
                trainable=False,
                dtype=self.dtype,
            )
        self.built = True
       
   
    @tf.function
    def call(self, inputs, training=True, **kwargs):
        if training and not self.freeze:

            # update moving mean
            current_means = tf.reduce_mean(inputs, axis=self.axes, keepdims=True)
            updated_means = self.alpha * self.moving_mean + (1.0 - self.alpha) * current_means
            if self.init_first:
                if self.nb ==2:
                    tf.print("assign ",tf.reduce_mean(tf.abs(current_means)))
                    self.moving_mean.assign(current_means)
                self.nb.assign(self.nb+1)
            else :
                self.moving_mean.assign(tf.cast(updated_means,self.moving_mean.dtype))
            # while the moving mean is updated, use only current mean
            # as is would scale the gradient with a factor of (1-alpha)
            if self.avg:
                x = inputs - updated_means
            else :
                if self.stop_grad:
                    current_means = tf.stop_gradient(current_means)
                x = inputs - current_means    
            if self.center:
                x = x + self.beta 
            return x
        else:
            # use the stored value at inference
            x = inputs - self.moving_mean
            if self.center:
                x = x + self.beta 
            return x

    def vanilla_export(self):
        bias = -self.moving_mean
        if self.center:
            bias = bias + self.beta
        
        layer = BiasLayer(pixelwise=self.pixelwise, channelwise=self.channelwise)
        layer.build(self.input_shape)
        layer.bias.assign(bias)

        return layer
    def get_config(self):
        config = {
            "pixelwise": self.pixelwise,
            "channelwise": self.channelwise,
            "alpha": self.alpha,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))


@register_keras_serializable("deel-lip", "StiefelConv")
class StiefelConv(Conv2D, LipschitzLayer, Condensable):   
    def __init__(
        self,
        filters,
        kernel_size,
        strides=(1, 1),
        padding="same",
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        kernel_initializer=StiefelInitializer(),
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        k_coef_lip=1.0,
        conv_first = False,
        normconstaint = False,
        **kwargs
    ):
        self.normconstaint = normconstaint
        if normconstaint :
            self.norm_constraint =SpectralStiefelConstraint(strides = strides,conv_first = conv_first)
            kernel_constraint = self.norm_constraint
        regulLipConv = None
        if not conv_first:
            regulLipConv = LorthRegularizer(kernel_shape=None,stride=strides[0],lambdaLorth=1.,flag_deconv=False)
        super(StiefelConv, self).__init__(
            filters = filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
            )
        self.conv_first = conv_first
        self._kwargs = kwargs
        self.set_klip_factor(k_coef_lip)
        
        self.built = False
    def build(self, input_shape):
        #shape = c_in
        super(StiefelConv, self).build(input_shape)
        self._init_lip_coef(input_shape)
        self.conv_shape =self.kernel.shape
        n = self.kernel_size[0]*self.kernel_size[1]*self.conv_shape[2]
        p = self.filters
        self.transpose = False
        if n <p :
            n,p = p,n
            self.transpose = True
        self.kernel = self.add_weight(
                shape=(n, p),  
                name="kernel",
                trainable=True,
                initializer=self.kernel_initializer,
                constraint = self.kernel_constraint,
                dtype=self.dtype,
            )
        self.sig = 1/self._get_coef()
        if self.normconstaint :
            self.sig = self.add_weight(
                    shape=tuple([1, 1]),  # maximum spectral  value
                    name="sigma",
                    trainable=False,
                    initializer=tf.constant_initializer(value=1),
                    dtype=self.dtype,
                )
        
            assign_to_manifold(self.kernel, StiefelCayley())
            self.sig.assign([[1/self._get_coef()]])
            self.norm_constraint.set_shape(self.conv_shape,self.transpose)
            self.norm_constraint.set_sigma(self.sig)
            self.norm_constraint.__call__(self.kernel)
        if self.kernel_regularizer is not None:
            self.kernel_regularizer.set_kernel_shape(self.conv_shape,self.transpose)
        print(self.conv_shape,self.kernel.shape,n,p,self.transpose,self.conv_first,self.sig)
        
        self.built = True    
    def _compute_lip_coef(self, input_shape=None):
        # According to the file lipschitz_CNN.pdf
        if self.padding == "valid":
            return float(self.strides[0])/float(self.kernel_size[0])
        stride = np.prod(self.strides)
        k1 = self.kernel_size[0]
        k1_div2 = (k1 - 1) / 2
        k2 = self.kernel_size[1]
        k2_div2 = (k2 - 1) / 2
        if self.data_format == "channels_last":
            h = input_shape[-3]
            w = input_shape[-2]
        elif self.data_format == "channels_first":
            h = input_shape[-2]
            w = input_shape[-1]
        else:
            raise RuntimeError("data_format not understood: " % self.data_format)
        if stride == 1:
            coefLip = np.sqrt(
                (w * h)
                / (
                    (k1 * h - k1_div2 * (k1_div2 + 1))
                    * (k2 * w - k2_div2 * (k2_div2 + 1))
                )
            )
        else:
            sn1 = self.strides[0]
            sn2 = self.strides[1]
            ho = np.floor(h / sn1)
            wo = np.floor(w / sn2)
            alphabar1 = np.floor(k1_div2 / sn1)
            alphabar2 = np.floor(k2_div2 / sn2)
            betabar1 = k1_div2 - alphabar1 * sn1
            betabar2 = k2_div2 - alphabar2 * sn2
            zl1 = (alphabar1 * sn1 + 2 * betabar1) * (alphabar1 + 1) / 2
            zl2 = (alphabar2 * sn2 + 2 * betabar2) * (alphabar2 + 1) / 2
            gamma1 = h - 1 - sn1 * np.ceil((h - 1 - k1_div2) / sn1)
            gamma2 = w - 1 - sn2 * np.ceil((w - 1 - k2_div2) / sn2)
            alphah1 = np.floor(gamma1 / sn1)
            alphaw2 = np.floor(gamma2 / sn2)
            zr1 = (alphah1 + 1) * (k1_div2 - gamma1 + sn1 * alphah1 / 2.0)
            zr2 = (alphaw2 + 1) * (k2_div2 - gamma2 + sn2 * alphaw2 / 2.0)
            coefLip = np.sqrt((h * w) / ((k1 * ho - zl1 - zr1) * (k2 * wo - zl2 - zr2)))
        #return 1
        return coefLip

    def call(self, x, training=True):
        wbar = self.kernel / self.sig
        if self.transpose:
            wbar = tf.transpose(wbar)
            
        wbar =  tf.reshape(wbar, self.conv_shape)
        
        
        
        
        outputs = K.conv2d(
            x,
            wbar,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
        )
        
        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
        
        if self.activation is not None:
            return self.activation(outputs)
        
        return outputs
    def condense(self):
        pass

    def vanilla_export(self):
        self._kwargs["name"] = self.name
        layer = Conv2D(
            filters=self.filters,
            kernel_size=self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            activation=self.activation,
            use_bias=self.use_bias,
            kernel_initializer="glorot_uniform",
            bias_initializer="zeros",
            **self._kwargs
        )
        layer.build(self.input_shape)
        wbar = self.kernel / self.sig
        if self.transpose:
            wbar = tf.transpose(wbar)
            
        wbar =  tf.reshape(wbar, self.conv_shape)
        layer.kernel.assign(wbar)
           
        if self.use_bias:
            layer.bias.assign(self.bias)
        return layer    
class Identity(Layer):
    def __init__(self, k_coef_lip):
        super(Identity, self).__init__()
        self.grad_func = gradient_scaler(k_coef_lip)
        self.k_coef_lip = k_coef_lip

    def call(self, inputs):
        outputs =self.grad_func(inputs)
        return outputs



@register_keras_serializable("dlt", "SpectralDepthwiseConv2D")
class SpectralDepthwiseConv2D(DepthwiseConv2D, LipschitzLayer, Condensable):
    def __init__(
        self,
        kernel_size,
        strides=(1, 1),
        padding="valid",
        depth_multiplier=1,
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        depthwise_initializer="glorot_uniform",
        bias_initializer="zeros",
        depthwise_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        depthwise_constraint=None,
        bias_constraint=None,
        k_coef_lip=1.0,
        in_graph=True,
        fft=True,
        stop_gradient=True,
        **kwargs
    ):
       
        self.in_graph = in_graph
        self.fft = fft
        self.stop_gradient = stop_gradient
        if not self.in_graph :
            self.stop_gradient =False
        if not in_graph:
            depthwise_constraint = self.normalize_kernel
        if not (
            (dilation_rate == (1, 1))
            or (dilation_rate == [1, 1])
            or (dilation_rate == 1)
        ):
            raise RuntimeError("SpectralDepthwiseConv2D does not support dilation rate")
        if depth_multiplier != 1:
            raise RuntimeError(
                "SpectralDepthwiseConv2D does not support depth multiplier"
            )
        
        super(SpectralDepthwiseConv2D, self).__init__(
            kernel_size,
            strides=strides,
            padding=padding,
            depth_multiplier=depth_multiplier,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            depthwise_initializer=depthwise_initializer,
            bias_initializer=bias_initializer,
            depthwise_regularizer=depthwise_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            depthwise_constraint=depthwise_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )
        self._kwargs = kwargs
        self.set_klip_factor(k_coef_lip)
        self.u = None
        #print("depth", use_bias,fft)

    def build(self, input_shape):
        super(SpectralDepthwiseConv2D, self).build(input_shape)
        self._init_lip_coef(input_shape)
        if self.in_graph:
            self.wbar = tf.Variable(self.depthwise_kernel.read_value(), trainable=False)
        self.built = True

    def _compute_lip_coef(self, input_shape=None):
        if self.fft :
            return 1
        if self.padding in ["valid"]:
            stride = np.sqrt(self.strides[0] * self.strides[1])
            ksize = np.sqrt(self.kernel_size[0] * self.kernel_size[1])
            return stride / ksize
        stride = np.prod(self.strides)
        k1 = self.kernel_size[0]
        k1_div2 = (k1 - 1) / 2
        k2 = self.kernel_size[1]
        k2_div2 = (k2 - 1) / 2
        if self.data_format == "channels_last":
            h = input_shape[-3]
            w = input_shape[-2]
        elif self.data_format == "channels_first":
            h = input_shape[-2]
            w = input_shape[-1]
        else:
            raise RuntimeError("data_format not understood: " % self.data_format)
        if stride == 1:
            coefLip = np.sqrt(
                (w * h)
                / (
                    (k1 * h - k1_div2 * (k1_div2 + 1))
                    * (k2 * w - k2_div2 * (k2_div2 + 1))
                )
            )
        else:
            sn1 = self.strides[0]
            sn2 = self.strides[1]
            coefLip = np.sqrt(1.0 / (np.ceil(k1 / sn1) * np.ceil(k2 / sn2)))
        return coefLip
        


    @tf.function
    def normalize_kernel(self, kernel):
        if self.fft :
            conv_tr = tf.cast(tf.transpose(kernel, perm=[2, 3, 0, 1]), tf.complex64)
            conv_shape = kernel.get_shape().as_list()
            pad_1 = self._build_input_shape[1] - conv_shape[0]
            pad_2 = self._build_input_shape[2] - conv_shape[1]
            padding = tf.constant(
                [
                    [0, 0],
                    [0, 0],
                    [pad_1//2, pad_1//2],
                    [pad_2//2, pad_2//2],
                ]
            )
            conv_tr_padded = tf.pad(conv_tr, padding)
            # apply FFT
            transform_coeff = tf.abs(tf.signal.fft2d(conv_tr_padded))
            #transform_coeff =  tf.abs(tf.math.real(tf.signal.fft2d(conv_tr_padded)))
            l = tf.reduce_max(transform_coeff, axis=[-2, -1], keepdims=True)
            if self.stop_gradient:
                l = tf.stop_gradient(l)
            return (kernel / tf.transpose(l, perm=[2, 3, 0, 1]))*self._get_coef()
        
        wbar = reshaped_depth_kernel_orthogonalization(kernel,niter_bjorck = 9 )
        return wbar*self._get_coef()

    @tf.function
    def call(self, x, training=True):
        if self.in_graph:
            if training:
                wbar = self.normalize_kernel(self.depthwise_kernel)
                self.wbar.assign(wbar)
            else:
                wbar = self.wbar
        else:
            wbar = self.depthwise_kernel
        #x = self.pad(x)
        outputs = K.depthwise_conv2d(
            x,
            wbar ,
            strides=self.strides,
            padding=self.padding,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
        )
        if self.use_bias:
            outputs = K.bias_add(outputs, self.bias, data_format=self.data_format)
        if self.activation is not None:
            return self.activation(outputs)
        return outputs

    def get_config(self):
        config = {
            "k_coef_lip": self.k_coef_lip,
        }
        base_config = super(SpectralDepthwiseConv2D, self).get_config()
        config = dict(list(base_config.items()) + list(config.items()))
        config["padding"] = self.old_padding
        return config

    def condense(self):
        if not self.in_graph:
            wbar = self.normalize_kernel(self.depthwise_kernel)
            self.depthwise_kernel.assign(wbar)

    def vanilla_export(self):
        self._kwargs["name"] = self.name
        layer = DepthwiseConv2D(
            self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            depth_multiplier=self.depth_multiplier,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            activation=self.activation,
            use_bias=self.use_bias,
            depthwise_initializer=self.depthwise_initializer,
            bias_initializer=self.bias_initializer,
            depthwise_regularizer=self.depthwise_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            depthwise_constraint=self.depthwise_constraint,
            bias_constraint=self.bias_constraint,
            **self._kwargs
        )
        layer.build(self.input_shape)
        layer.depthwise_kernel.assign(self.wbar)
        if self.use_bias:
            layer.bias.assign(self.bias)
        return layer


@register_keras_serializable("dlt", "OrthoDepthwiseConv2D")
class OrthoDepthwiseConv2D(DepthwiseConv2D, LipschitzLayer, Condensable):
    def __init__(
        self,
        kernel_size,
        strides=(1, 1),
        padding="valid",
        depth_multiplier=1,
        data_format=None,
        dilation_rate=(1, 1),
        activation=None,
        use_bias=True,
        depthwise_initializer="glorot_uniform",
        bias_initializer="zeros",
        depthwise_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        depthwise_constraint=None,
        bias_constraint=None,
        **kwargs
    ):
       
        super(OrthoDepthwiseConv2D, self).__init__(
            kernel_size,
            strides=strides,
            padding=padding,
            depth_multiplier=depth_multiplier,
            data_format=data_format,
            dilation_rate=dilation_rate,
            activation=activation,
            use_bias=use_bias,
            depthwise_initializer=depthwise_initializer,
            bias_initializer=bias_initializer,
            depthwise_regularizer=depthwise_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            depthwise_constraint=depthwise_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )
        self._kwargs = kwargs
    def _compute_lip_coef(self, input_shape=None):
        return 1
    def build(self, input_shape):
        super(OrthoDepthwiseConv2D, self).build(input_shape)
        shape_k =self.depthwise_kernel.shape
        nb = shape_k[0]*shape_k[1]
        c_in = shape_k[2]
        ortho_kernel = tf.random.uniform([c_in, nb],dtype=tf.float64,minval=-1,maxval=1)

        ortho_max = tf.reduce_max(ortho_kernel,axis = 1,keepdims = True)
        ortho_kernel = tf.cast(tf.where(ortho_kernel==ortho_max,ortho_kernel/ortho_max,0),tf.float32)
        ortho_kernel =tf.reshape(ortho_kernel,(c_in,shape_k[1],shape_k[1],1))
        
        ortho_kernel = tf.transpose(ortho_kernel, perm=[1,2,0,3])
        self.depthwise_kernel = self.add_weight(
                shape=shape_k,  
                name="depthwise_kernel",
                trainable=False,
                initializer=self.kernel_initializer,
                constraint = self.kernel_constraint,
                dtype=self.dtype,
            )
        self.depthwise_kernel.assign(ortho_kernel)
        tf.print(self.depthwise_kernel.shape)
        #self.depthwise_kernel.trainable = False
        self.built = True

 
    def get_config(self):
        config = {
        }
        base_config = super(OrthoDepthwiseConv2D, self).get_config()
        config = dict(list(base_config.items()) + list(config.items()))
        config["padding"] = self.old_padding
        return config

    def condense(self):
        pass

    def vanilla_export(self):
        self._kwargs["name"] = self.name
        layer = DepthwiseConv2D(
            self.kernel_size,
            strides=self.strides,
            padding=self.padding,
            depth_multiplier=self.depth_multiplier,
            data_format=self.data_format,
            dilation_rate=self.dilation_rate,
            activation=self.activation,
            use_bias=self.use_bias,
            depthwise_initializer=self.depthwise_initializer,
            bias_initializer=self.bias_initializer,
            depthwise_regularizer=self.depthwise_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            depthwise_constraint=self.depthwise_constraint,
            bias_constraint=self.bias_constraint,
            **self._kwargs
        )
        layer.build(self.input_shape)
        layer.depthwise_kernel.assign(self.depthwise_kernel)
        if self.use_bias:
            layer.bias.assign(self.bias)
        return layer

@tf.function
def cayley_norm(kernel):
    
    W_shape = kernel.shape
    W_reshaped = tf.reshape(kernel, [-1, W_shape[-1]])
    n = W_reshaped.shape[0]
    C = tf.linalg.band_part(W_reshaped, 0, -1)
    D = C-tf.transpose(C)
    id = tf.eye(n)
    inv =tf.linalg.inv (id - D)
    W_bar = K.reshape((D+id)@inv, kernel.shape)
    return W_bar
    
