import tensorflow as tf
import numpy as np
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras.layers import (
    Layer,
    Dense,
    Conv2D,
    InputLayer,
    AveragePooling2D,
    GlobalAveragePooling2D,
    BatchNormalization,
    Lambda,
    
    Maximum,
    Concatenate,
    ReLU
)
from tensorflow.keras import backend as K

from deel.lip.extra_layers import SpectralDepthwiseConv2D
from deel.lip.normalizers import compute_sigma, get_operator_norm,get_operator_stats,get_operator_stats_np
from deel.lip.regularizers import LorthRegularizer
from deel.lip.layers import SpectralConv2D, SpectralDense, FrobeniusDense, ScaledAveragePooling2D, ScaledL2NormPooling2D, InvertibleDownSampling,ScaledGlobalAveragePooling2D
from deel.lip.model import Model

def analyze_kernel(kernel, strides, w, h,input_shape, verbose = False):
    v = tf.convert_to_tensor(kernel.shape)
    reg = LorthRegularizer(kernel_shape=kernel.shape,stride=strides[0] ,lambdaLorth=1.,flag_deconv=False)
    mat = kernel.numpy()
    kernel_size =  mat.shape[0]
    pad = (kernel_size)-1
    mat = mat.reshape([-1, mat.shape[-1]])
    _,s,_=np.linalg.svd(mat)
    coeff = float(strides[0])/float(kernel_size)
    s_max = s.max()
    s_min = s.min()
    s_mean = s.mean()

    s_conv_max_nostride, s_conv_min, s_conv_mean  = get_operator_stats(kernel,kernel.shape[1],kernel.shape[2])
    s_conv_max = compute_sigma(kernel, stride = strides[0],niter = 2)
    info_dict = {"s_max" : s_max,
                 "s_min" : s_min,
                 "s_mean" : s_mean,
                 "coeff" : coeff,
                 "s_conv_max" : s_conv_max,
                 "s_conv_min" : s_conv_min,
                 "s_conv_mean" : s_conv_mean,
                }
    if verbose :
        print(f"singular dense : {s_min:0.2f} <= {s_mean:0.2f} <= {s_max:0.2f} coeff : {coeff:0.2f}")
        print(f"singular conv  : {s_conv_min:0.2f} <= {s_conv_mean:0.2f} <= {s_conv_max:0.2f} [{s_conv_max_nostride:0.2f}] orth {reg(kernel)}")
    return info_dict

def analyse_network(model):
    net_infos = {}
    for w in model.layers:
        if type(w) ==Conv2D:     
            print("conv2D shape ",w.kernel.shape, "strides :", w.strides[0])
            net_infos[w.name] = analyze_kernel(w.kernel, w.strides, w.input_shape[1]-2, w.input_shape[2]-2,w.input.shape, verbose = True)
            print("     *****       ")
        if type(w) == SpectralConv2D:
            print("spectral shape ",w.kernel.shape, "strides :", w.strides[0], "sigma", w.sig.numpy())
            net_infos[w.name] = analyze_kernel(w.wbar, w.strides, w.input_shape[1]-2, w.input_shape[2]-2, w.input.shape, verbose = True)
            print("     *****       ")

        if  type(w) == FrobeniusDense:
            mat = w.kernel.numpy()
            print(np.linalg.norm(mat[:,1]))
        