import tensorflow as tf
from tensorflow.keras.regularizers import Regularizer, L1, L2, L1L2
        
class SpectralNormRegularizer(Regularizer):
    
    def __init__(self, dim,reg_lambda=1.0,num_iter=2, p=1.0):
        self.dim = int(dim)
        self.reg_lambda= float(reg_lambda)
        self.num_iter = int(num_iter)
        self.p = float(p)
        
        self.v = tf.Variable(
            tf.random.normal((self.dim,1),0.0,0.05),
            trainable=False
        )        

    def __call__(self,W):
        # kernel이 (dim_in,dim_out)으로 들어옴
        for _ in range(self.num_iter):
            u = tf.matmul(W,self.v,transpose_a=True)
            u /= tf.norm(u)
            
            v = tf.matmul(W,u,transpose_a=False)
            v /= tf.norm(v)
            self.v.assign(v)
        sigma = tf.matmul(tf.matmul(v,W,transpose_a=True),u,transpose_a=False)
        sigma = tf.squeeze(sigma)
        return tf.math.pow(sigma,self.p)*self.reg_lambda
    
    def get_config(self):
        config_dict = {"dim":self.dim, "reg_lambda":self.reg_lambda, "num_iter":self.num_iter,"p":self.p}
        return config_dict
    
    
class IntervalRegularizer(Regularizer):
    
    def __init__(self,l1_lambda=1e-4,l2_lambda=1e-4):
        self.l1_lambda = tf.Variable(l1_lambda,False,name="lambda_l1")
        self.l2_lambda = tf.Variable(l2_lambda,False,name="lambda_l2")
        
    def __call__(self,W,Cl,Cr,B):
        center_diff = Cl - Cr

        l1_sum = tf.reduce_mean(tf.math.abs(W)+tf.math.abs(center_diff)+tf.math.abs(B))
        l2_sum = tf.reduce_mean(tf.math.square(W)+tf.math.square(center_diff)+tf.math.square(B))

        intv_reg = self.l1_lambda*l1_sum + self.l2_lambda*l2_sum
        return intv_reg