import utils
from jax import vmap, lax # for auto-vectorizing functions
from jax import jacfwd, jacrev,hessian
from functools import partial # for use with vmap
from jax import jit, grad, random # for compiling functions for speedup
from jax.experimental import stax # neural network library
from jax.experimental.stax import Conv, Dense, Relu, LeakyRelu, Flatten, LogSoftmax # neural network layers
from jax.experimental.optimizers import l2_norm
from jax.nn.initializers import zeros
from jax.nn import leaky_relu
from jax.experimental.stax import elementwise
import matplotlib.pyplot as plt
from jax import numpy as np
import cvxpy as cp
import numpy as onp
from jax.interpreters import xla

from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

import pandas as pd
import scprep
import m_phate

from scipy.spatial.distance import cdist, pdist, squareform
from sklearn.neighbors import NearestNeighbors, DistanceMetric
import pickle 

import joblib
from joblib import Parallel, delayed
from joblib import dump, load

#from tqdm import tqdm
from tqdm.notebook import tqdm
import time
import itertools
import numpy.random as npr

import m_phate.train
import m_phate.data

"""
continuous surrogate for hamming
distance between z and y with parameter a
"""
def _cts_hamming(a,z,y):
    return np.divide(np.linalg.norm(np.abs(np.tanh(a*z) - np.tanh(a*y))), 2*z.shape[0])

"""
compute cts hamming distance
between elements in profiles 1 and 2
"""
def cts_hamming(a,profiles1,profiles2):
    hammings = [_cts_hamming(a,z,y) for (z,y) in zip(profiles1,profiles2)]
    #hammings = vmap(partial(_cts_hamming,a))(profiles1, profiles2)
    h = np.divide(np.sum(hammings),len(profiles1))
    return(h)

"""
base hamming regularizer
"""
def hamming_regularizer(params, inputs, targets, eps=0.3,a=8):
    p = np.where(random.bernoulli(rng,p=0.5,shape=inputs.shape)==1,eps,-eps)
    x_p1 = inputs + p
    x_p2 = inputs - p
    logits1, activations1 = net_walk(params,x_p1)
    logits2, activations2 = net_walk(params,x_p2)
    #profiles1 = [np.multiply(act, (act > 0)) for act in activations1]
    #profiles2 = [np.multiply(act, (act > 0)) for act in activations2]
    hammings = vmap(partial(cts_hamming,a))(activations1, activations2)
    reg = np.mean(hammings)
    return reg

"""
todo
"""
def sampled_lipschitz_regularizer(params, inputs, targets):
    eps = 0.1
    p = np.where(random.bernoulli(rng,p=0.5,shape=inputs.shape)==1,eps,-eps)
    x_p1 = inputs + p
    x_p2 = inputs - p
    #logits1, profiles1, activations1 = net_walk(params,x_p1)
    #logits2, profiles2, activations2 = net_walk(params,x_p2)
    logits1 = net_apply(params,x_p1)
    logits2 = net_apply(params,x_p2)
    reg = np.linalg.norm(logits1 - logits2) / np.linalg.norm(x_p1 - x_p2)
    return reg

"""
calculate pairwise euclidean distance
between all pair-rows of A and B
"""
def fastdiff(A,B):
    #return np.sum(np.abs(A[:, None, :] - B[None, :, :]), axis=-1)
    return np.sum((A[:, None, :] - B[None, :, :])**2, axis=-1)
    #return np.sqrt(((a[:, None] - b[:, :, None]) ** 2).sum(0))
    #return ((A[:, None] - B[:, :, None]) ** 2).sum(0)
    #return np.einsum('ij,ij->i',A,A)[:,None] + np.einsum('ij,ij->i',B,B) - \
    #                                                        2*np.dot(A,B.T)

"""
base pairwise stochastic manifold regularizer
regularize all input/output pairs for all layers
"""
def m_regularizer_pairwise(params, inputs, logits_ij, targets): 
    N_b = inputs.shape[0]

    #xijdiff = fastdiff(inputs) + np.eye(N_b)
    fijdiff = fastdiff(logits_ij,logits_ij)
    quot = np.divide(fijdiff, xijdiff)
    reg = np.sum(quot)
    reg = np.divide(reg,np.square(N_b))
    for i in range(1,a):
        xijdiff = fastdiff(acts[i], acts[i])
        nn = np.nan_to_num(np.divide(fijdiff, xijdiff))
        reg += np.sum(nn)
    reg = np.divide(reg, a * np.square(N_b))
    return reg

"""
base pairwise stochastic manifold regularizer
regularize all input/output pairs for all layers
"""
def l_regularizer_pairwise(params, inputs, logits_ij, targets): 
    N_b = inputs.shape[0]
    logits_ij, profiles1, acts = net_walk(params,inputs)
    a = len(acts)
    
    sigma = 100.0
    xijdiff = fastdiff(inputs,inputs)
    xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
    
    logits_ij = np.divide(logits_ij, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               logits_ij.shape[-1], axis=1))
    
    fijdiff = fastdiff(logits_ij,logits_ij)
    quot = np.multiply(fijdiff, xijdiff)
    reg = np.sum(quot)
    reg = np.divide(reg,np.square(N_b))
    for i in range(1,a):
        xijdiff = fastdiff(acts[i], acts[i])
        xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
        
        logits_ij = np.divide(logits_ij, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               logits_ij.shape[-1], axis=1))
        fijdiff = fastdiff(logits_ij,logits_ij)
        
        nn = np.multiply(fijdiff, xijdiff)
        
        reg += np.sum(nn)
    reg = np.divide(reg, a * np.square(N_b))
    return reg


"""
base stochastic laplacian masked regularizer
regularize only input and output logits
"""
def l_regularizer_normalized_masked(inputs, logits_ij, targets, sigma=100):
    N_b = inputs.shape[0]
    xijdiff = fastdiff(inputs,inputs)
    sigma = 100.0
    
    #mask = np.array(onp.fromfunction(lambda i, j: (targets[i] == targets[j]).all(), (N_b, N_b), dtype=int))
    mask = np.dot(targets, targets.T)
    xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
    xijdiff = np.multiply(xijdiff, mask)
    # normalize logits by degree
    logits_ij = np.divide(logits_ij, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               logits_ij.shape[-1], axis=1))
    fijdiff = fastdiff(logits_ij,logits_ij)
    
    nn = np.multiply(fijdiff, xijdiff)
    reg = np.divide(np.sum(nn),np.square(N_b))
        
    return reg  

def laplip_grad_reg(params, inputs,logits_ij,targets):
    N_b = inputs.shape[0]
    def gfun(p, inp, targ):
        return grad(loss)(p, np.expand_dims(inp,axis=0), targ)[0]
    #g = jacrev(net_apply)(params, inputs)
    g = vmap(partial(gfun,params))(inputs, targets)
    xijdiff = fastdiff(inputs,inputs)
    sigma = 100.0
    xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
    # normalize logits by degree
    g = np.divide(g, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               g[0].shape[0], axis=1))
    fijdiff = fastdiff(g,g)
    nn = np.multiply(fijdiff, xijdiff)
    reg = np.divide(np.sum(nn),np.square(N_b))
    return reg

def laplip_grad_reg_samp(params, inputs,logits_ij,targets):
    N_b = inputs.shape[0]
    p = np.where(random.bernoulli(rng,p=0.5,shape=inputs.shape)==1,eps,-eps)
    x_p1 = inputs + p
    x_p2 = inputs - p
    #logits1, profiles1, activations1 = net_walk(params,x_p1)
    #logits2, profiles2, activations2 = net_walk(params,x_p2)
    def gfun(p, inp, targ):
        return grad(loss)(p, np.expand_dims(inp,axis=0), targ)[0]
    #g = jacrev(net_apply)(params, inputs)
    g1 = vmap(partial(gfun,params))(x_p1, targets)
    g2 = vmap(partial(gfun,params))(x_p2, targets)
    reg = np.linalg.norm(g1 - g2) / np.linalg.norm(x_p1 - x_p2)
    """
    xijdiff = fastdiff(inputs,inputs)
    sigma = 100.0
    xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
    # normalize logits by degree
    g = np.divide(g, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               g[0].shape[0], axis=1))
    fijdiff = fastdiff(g,g)
    """
    #nn = np.multiply(fijdiff, xijdiff)
    #reg = np.divide(np.sum(nn),np.square(N_b))
    return reg

def lipschitz_gradient(params, inputs,logits_ij,targets):
    #assert len(inputs.shape) == 2
    N_b = inputs.shape[0]
    def gfun(p, inp, targ):
        #return grad(loss)(p, np.expand_dims(inp,axis=0), targ)[0]
        return grad(lo)((np.expand_dims(inp,axis=0), targ),p)[0].reshape((1,-1))
        #return grad(lo)((np.expand_dims(inp,axis=0), targ),p)[0]
    
    g = vmap(partial(gfun,params))(inputs, targets)
    #xijdiff = fastdiff(inputs,inputs) + np.eye(N_b)
    xijdiff = fastdiff(inputs.reshape(N_b,-1),inputs.reshape(N_b,-1)) + np.eye(N_b)

    fijdiff = fastdiff(g,g)
    quot = np.divide(fijdiff, xijdiff)
    quot = np.nan_to_num(quot, nan=0.0, posinf=0.0, neginf=0.0)
    reg = np.sum(quot)
    reg = np.divide(reg,np.square(N_b))
    return reg

def analytic_volume(params, inputs, logits_ij, targets):
    logits, activations = net_walk(params,inputs)
    #logits = logits_ij
    #activations = [np.ones((inputs.shape[0], 1024))]
    profiles = vmap(act2prof)(activations[0])
    polys = vmap(partial(get_pwl,params))(profiles)
    def _analytic_volume(inp, A, b):
        return np.linalg.det(A)
    reg = np.mean(vmap(_analytic_volume)(inputs, polys[0], polys[1]))
    return reg

def boost_margin(params, inputs, logits_ij, targets):
    def _m(l,t):
        return 1-np.dot(l/np.linalg.norm(l),t)
    
    return np.mean(vmap(_m)(logits_ij, targets))

def mmr_margin(params, inputs, logits_ij, targets):
    logits, activations = net_walk(params,inputs)
    
    profiles = vmap(act2prof)(activations[0])
    polys = vmap(partial(get_pwl,params))(profiles)
    V = polys[-2]
    def _d_d(lgts, lab, v):
        label_idx = np.argmax(lab)
        num = lgts[label_idx] * np.ones_like(lgts) - lgts
        def lip(fv_y,f_s, v_s):
            f_y, v_y = fv_y
            return (f_y - f_s) / (np.linalg.norm(v_y - v_s) + 1)
        
        d_D_i = vmap(partial(lip, (lgts[label_idx], v[label_idx])))(lgts, v)
        
        #v_s = np.tile(v[label_idx],(v.shape[0],1))
        #denom = np.linalg.norm(v_s - v,axis=-1)
        #d_D_i = np.divide(num,denom)
        #d_D_i = np.where(np.isnan(d_D_i), np.inf, d_D_i)
        #return np.maximum(0.0, 1.0 - np.sum(np.nan_to_num(d_D_i))/0.2)
        #d_D = np.amin(d_D_i)
        #return np.maximum(0.0,1.0 -  d_D/ 0.2)
        
        return np.sum(np.nan_to_num(d_D_i))

    reg = np.mean(vmap(_d_d)(logits, targets, V))
    return reg

def minimal_analytic_center(params, inputs, logits_ij, ac_args, targets):
    logits, activations = net_walk(params,inputs)
    
    profiles = vmap(act2prof)(activations[0])
    polys = vmap(partial(get_pwl,params))(profiles)
    
    @jit
    def _ac(inp, A, b, idxs):
        return(-np.sum(np.log(1 + (-np.dot(np.take(A,idxs,axis=0), inp)) + np.take(b,idxs))))
    
    rgs = []
    num_errs = 0
    for inp, A, b in zip(inputs, polys[0], polys[1]):
        A = A / np.linalg.norm(A, axis=1, keepdims=True)
        hi = np.arange(A.shape[0])
        A_np = onp.asarray(lax.stop_gradient(A))
        b_np = onp.asarray(lax.stop_gradient(b))
        inp_np = onp.asarray(lax.stop_gradient(inp))
        hi, err, _ = utils.compute_redun(A_np, b_np.reshape(b_np.shape[0],1), inp_np, ac_args, pl=False)
        if err is True:
            num_errs += 1
        rgs.append(_ac(inp, A, b, hi))
    
    reg = np.mean(np.array(rgs))
    
    print("errs / inps: ",num_errs / len(inputs))
    return reg

def analytic_center(params, inputs, logits_ij, targets):
    logits, activations = net_walk(params,inputs)
    
    profiles = vmap(act2prof)(activations[0])
    polys = vmap(partial(get_pwl,params))(profiles)
    #polys = vmap(partial(get_pwl,params))(inputs)
    
    #@jit
    #def _map_ac(inp, A, b):
    #    A = A / np.linalg.norm(A, axis=1, keepdims=True)
    #    hi = np.arange(A.shape[0])        
    #    return _ac(inp, A, b, hi)
    
    @jit
    def _ac(inp, A, b, idxs):
        #return(-np.sum(np.log(1 + (-np.dot(np.take(A,idxs,axis=0), inp)) + np.take(b,idxs))))
        return(-np.sum(np.log( np.maximum((-np.dot(np.take(A,idxs,axis=0), inp)) + np.take(b,idxs), 1e-8) )))
    
    #reg = np.mean(vmap(_map_ac)(inputs, polys[0], polys[1]))
    
    
    rgs = []
    num_errs = 0
    for inp, A, b in zip(inputs, polys[0], polys[1]):
        A = A / np.linalg.norm(A, axis=1, keepdims=True)
        hi = np.arange(A.shape[0])
        #A_np = onp.asarray(lax.stop_gradient(A))
        #b_np = onp.asarray(lax.stop_gradient(b))
        #inp_np = onp.asarray(lax.stop_gradient(inp))
        #hi, err, _ = utils.compute_redun(A_np, b_np.reshape(b_np.shape[0],1), inp_np, pl=False)
        #if err is True:
        #    num_errs += 1
        rgs.append(_ac(inp, A, b, hi))
    
    reg = np.mean(np.array(rgs))
    
    #print("errs / inps: ",num_errs / len(inputs))
    return reg

def angle_sum(params, inputs, logits_ij, targets, eps=0.3,):
    p = np.where(random.bernoulli(rng,p=0.5,shape=inputs.shape)==1,eps,-eps)
    x_p1 = inputs + p
    x_p2 = inputs - p
    logits1, activations1 = net_walk(params,x_p1)
    logits2, activations2 = net_walk(params,x_p2)
    
    poly1 = vmap(partial(get_pwl,params))(profiles1)
    poly2 = vmap(partial(get_pwl,params))(profiles2)
        
    def hp(x, poly):
        return np.dot(poly[-2].T, x) + poly[-1]
    
    def cossim(hp1, hp2):
        return  np.divide(np.dot(hp1, hp2), np.linalg.norm(hp1)*np.linalg.norm(hp2))
    
    hp1 = vmap(hp)(x_p1, poly1)
    hp2 = vmap(hp)(x_p2, poly2)
    
    cos = 1./vmap(cossim)(hp1,hp2)
    
    reg = np.mean(cos)
    return reg

def angle_sum_tst(params, inputs, logits_ij, targets, eps=0.3,static_argnums=(4,)):
    p = np.where(random.bernoulli(rng,p=0.5,shape=inputs.shape)==1,eps,-eps)
    x_p1 = inputs + p
    x_p2 = inputs - p
    logits1, profiles1, activations1 = net_walk(params,x_p1)
    logits2, profiles2, activations2 = net_walk(params,x_p2)
    p1 = [[np.asarray(prof,dtype=np.int32)] for prof in profiles2[0]]
    p2 = [[np.asarray(prof,dtype=np.int32)] for prof in profiles2[0]]
    
    polys1 = []
    for profile in p1:
        p = get_pwl(params,profile)
        polys1.append(p)

    polys2 = []
    for profile in p2:
        p = get_pwl(params,profile)
        polys2.append(p)
    
    #poly1 = vmap(partial(get_pwl,params))(profiles1)
    #poly2 = vmap(partial(get_pwl,params))(profiles2)
        
    #def hp(x, poly):
    #    return np.dot(poly[-2].T, x) + poly[-1]
    
    def cossim(p1,p2):
        #return  1 - np.divide(np.abs(np.dot(hp1, hp2)), np.linalg.norm(hp1)*np.linalg.norm(hp2))
        return  1 - np.divide(np.linalg.norm(p1[-2] - p2[-2]), np.linalg.norm(p1[-2])*np.linalg.norm(p2[-2])) + np.linalg.norm(p1[-1] - p2[-1])
    
    #hp1 = vmap(hp)(x_p1, poly1)
    #hp2 = vmap(hp)(x_p2, poly2)
    
    cos = vmap(cossim)(poly1,poly2)
    
    reg = np.mean(cos)
    return reg

def Q_rank(params, inputs, logits_ij, targets, eps=0.3,):
    logits, profiles, activations = net_walk(params,inputs)
    
    poly = vmap(partial(get_pwl,params))(profiles)
        
    def rnk(poly):
        return np.linalg.norm(poly[0],'fro')
        #return np.linalg.norm(np.linalg.norm(poly[0],2,axis=1))
    
    #hp1 = 1./vmap(rnk)(poly)
    hp1 = vmap(rnk)(poly)
        
    reg = np.mean(hp1)
    return reg

"""
base stochastic laplacian masked regularizer
regularize only input and output logits
"""
def m_regularizer_masked(inputs, logits_ij, targets):
    N_b = inputs.shape[0]
    xijdiff = fastdiff(inputs,inputs)
    
    #mask = np.array(onp.fromfunction(lambda i, j: (targets[i] == targets[j]).all(), (N_b, N_b), dtype=int))
    mask = np.dot(targets, targets.T)
    xijdiff = fastdiff(inputs,inputs) + np.eye(N_b)
    fijdiff = fastdiff(logits_ij,logits_ij)
    fijdiff = np.multiply(fijdiff, mask)
    quot = np.divide(fijdiff, xijdiff)

    reg = np.sum(quot)
    #reg = np.divide(reg,np.square(N_b))
    reg = np.divide(reg,np.square(np.sum(mask)))
    return reg

"""
base stochastic laplacian regularizer
regularize only input and output logits
"""
def l_regularizer_normalized(inputs, logits_ij, targets,sigma=100):
    N_b = inputs.shape[0]
    xijdiff = fastdiff(inputs,inputs)
    sigma = 100.0
    xijdiff = np.exp(-1*np.divide(xijdiff,2*np.square(sigma)))
    # normalize logits by degree
    logits_ij = np.divide(logits_ij, np.repeat(np.expand_dims(np.sum(xijdiff, axis=-1),axis=1), 
                                               logits_ij.shape[-1], axis=1))
    fijdiff = fastdiff(logits_ij,logits_ij)
    
    nn = np.multiply(fijdiff, xijdiff)
    reg = np.divide(np.sum(nn),np.square(N_b))
        
    return reg

"""
base stochastic laplacian regularizer
regularize only input and output logits
"""
def l_regularizer(inputs, logits_ij, targets,sigma=100):
    N_b = inputs.shape[0]
    xijdiff = fastdiff(inputs,inputs)
    sigma = 100.0
    xijdiff = np.exp(-np.divide(xijdiff,2*np.square(sigma)))
    fijdiff = fastdiff(logits_ij,logits_ij)
    nn = np.multiply(fijdiff, xijdiff)
    reg = np.divide(np.sum(nn),np.square(N_b))

    return reg

"""
base stochastic manifold regularizer
regularize only input and output logits
"""
def m_regularizer(inputs, logits_ij, targets):
    N_b = inputs.shape[0]

    xijdiff = fastdiff(inputs,inputs) + np.eye(N_b)
    fijdiff = fastdiff(logits_ij,logits_ij)
    quot = np.divide(fijdiff, xijdiff)
    reg = np.sum(quot)
    reg = np.divide(reg,np.square(N_b))
    return reg

"""
l2-norm
regularization
"""
def k_regularizer(params):
    return optimizers.l2_norm(params)

def l1_regularizer(params):
    return utils.l1_norm(params)

def get_min_distances(dists, k):
    """
    dists: bs x #neurons - tensor of distances to hyperplanes
    k: int - how many hyperplanes to take
    """
    #return -tf.nn.top_k(-dists, k=k).values
    return -lax.top_k(-dists, k)[0]


def zero_out_non_min_distances(dist, n_boundaries):
    n_units = dist.shape[1]
    dist = dist + 10**-5 * (random.normal(rng,(n_units,)) - 0.5)  # to break ties and don't count more than k zeros in CNNs
    min_dist_rb = get_min_distances(dist, n_boundaries)  # bs x n_boundaries

    #th1 = tf.expand_dims(tf.reduce_max(min_dist_rb, axis=1), 1)  # bs x 1  -  take the maximum distance over min-k
    #th1 = tf.tile(th1, [1, n_units])
    #th = tf.cast(tf.less_equal(dist, th1), tf.float32)  # only entries that are <= max over min-k are 1, else 0

    th1 = np.expand_dims(np.max(min_dist_rb, axis=1), 1)  # bs x 1  -  take the maximum distance over min-k
    th1 = np.tile(th1, [1, n_units])
    th = np.less_equal(dist, th1).astype(np.float32)  # only entries that are <= max over min-k are 1, else 0

    return th


def calc_v_fc(V_prev, W):
    """
    V_prev: bs x d x n_prev - previous V matrix
    W: n_prev x n_next - last weight matrix
    relu: bs x n_prev - matrix of relu switches (0s and 1s) which are fixed for a particular input
    """
    n_in = V_prev.shape[1]
    n_prev, n_next = W.shape

    V = np.reshape(V_prev, [-1, n_prev])  # bs*d x n_prev
    V = V @ W  # bs*d x n_next
    V = np.reshape(V, [-1, n_in, n_next])  # bs x d x n_next
    return V


def calc_v_conv(V_prev, w, stride, padding):
    """
    V_prev: bs x d x h_prev x w_prev x c_prev  - previous V matrix
    w: h_filter x w_filter x c_prev x c_next - next conv filter
    relu: bs x h_prev x w_prev x c_prev - tensor of relu switches (0s and 1s) which are fixed for a particular input
    """
    d, h_prev, w_prev, c_prev = int(V_prev.shape[1]), int(V_prev.shape[2]), int(V_prev.shape[3]), int(V_prev.shape[4])

    #V = tf.reshape(V_prev, [-1, h_prev, w_prev, c_prev])  # bs*d x h_prev x w_prev x c_prev
    #V = tf.nn.conv2d(V, w, strides=[1, stride, stride, 1], padding=padding)  # bs*d x h_next x w_next x c_next
    #V = tf.reshape(V, [-1, d, V.shape[1], V.shape[2], V.shape[3]])  # bs x d x h_next x w_next x c_next
    
    V = np.reshape(V_prev, [-1, h_prev, w_prev, c_prev])  # bs*d x h_prev x w_prev x c_prev
    V = lax.conv_general_dilated(V, w, window_strides=[stride, stride], padding=padding, dimension_numbers=('NHWC', 'HWIO', 'NHWC')) # bs*d x h_next x w_next x c_next
    V = np.reshape(V, [-1, d, V.shape[1], V.shape[2], V.shape[3]])  # bs x d x h_next x w_next x c_next

    return V

def mmr_cnn(params, x, y_true, n_h=200, n_rb=0.1, gamma_rb=0.2, gamma_db=0.2, bs=128, q=1):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for CNNs as a TensorFlow computational graph.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).
    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.LeNetSmall for details)
    x: input points (bs x image_height x image_width x image_n_channels)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    model: models.CNN object that contains a model with its weights, strides, padding, etc
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """

    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    W = [p[0] for p in params if len(p) == 2]
    n_rb = 0.08 * n_h 
    n_db = 10
    bs = x.shape[0]
    logits, z_list = net_walk(params,x)
    
    z_list.append(logits)

    # the padding and strides should be the same as in the forward pass conv
    #strides, padding = model.strides, model.padding
    
    strides = [2,2]
    padding = 'SAME'
    
    y_pred = z_list[-1]
    z_conv, z_fc, relus_conv, relus_fc, W_conv, W_fc = [], [], [], [], [], []
    for w, y in zip(W, z_list):  # Depending on the shape we form pre-activation values and their relu switches
        if len(y.shape) == 4:  # if conv layer
            z_conv.append(y)
            relu = np.greater(y, 0).astype(np.float32)
            relus_conv.append(np.expand_dims(relu, 1))
            W_conv.append(w)
        else:
            z_fc.append(y)
            relu = np.greater(y, 0).astype(np.float32)
            relus_fc.append(np.expand_dims(relu, 1))
            W_fc.append(w)
    h_in, w_in, c_in = int(x.shape[1]), int(x.shape[2]), int(x.shape[3])
    n_in = h_in * w_in * c_in
    n_out = y_true.shape[1]

    # z[0]: bs x h_next x w_next x n_next,  W[0]: h_filter x w_filter x n_prev x n_next
    w_matrix = np.reshape(W_conv[0], [-1, int(W_conv[0].shape[-1])])  # h_filter*w_filter*n_col x n_next
    denom = np.linalg.norm(w_matrix, axis=0, ord=q, keepdims=True)  # n_next
    dist_rb = np.abs(z_conv[0]) / denom  # bs x h_next x w_next x n_next
    dist_rb = np.reshape(dist_rb, [bs, int(z_conv[0].shape[1]*z_conv[0].shape[2]*z_conv[0].shape[3])])  # bs x h_next*w_next*n_next

    # We need to get the conv matrix. Instead of using loops to contruct such matrix, we can apply W[0] conv filter
    # to a reshaped identity matrix. Then we duplicate bs times the resulting tensor.
    identity_input_fm = np.reshape(np.eye(n_in, n_in), [1, n_in, h_in, w_in, c_in])
    #print(identity_input_fm.shape, W_conv[0].shape)
    V = calc_v_conv(identity_input_fm, W_conv[0], strides[0], padding)  # 1 x d x h_next x w_next x c_next
    V = np.tile(V, [bs, 1, 1, 1, 1])  # bs x d x h_next x w_next x c_next
    V = V * relus_conv[0]
    for i in range(1, len(z_conv)):
        V = calc_v_conv(V, W_conv[i], strides[i], padding)  # bs x d x h_next x w_next x c_next 
        V_stable = V + eps_num_stabil * np.less(np.abs(V), eps_num_stabil).astype(np.float32)  # note: +eps would also work
        new_dist_rb = np.abs(z_conv[i]) / np.linalg.norm(V_stable, axis=1, ord=q)  # bs x h_next x w_next x c_next
        new_dist_rb = np.reshape(new_dist_rb, [bs, z_conv[i].shape[1]*z_conv[i].shape[2]*z_conv[i].shape[3]])  # bs x h_next*w_next*c_next
        dist_rb = np.concatenate([dist_rb, new_dist_rb], 1)  # bs x sum(n_neurons[1:i])
        V = V * relus_conv[i]  # element-wise mult using broadcasting, result: bs x d x h_cur x w_cur x c_cur

    # Flattening after the last conv layer
    V = np.reshape(V, [bs, n_in, V.shape[2] * V.shape[3] * V.shape[4]])  # bs x d x h_prev*w_prev*c_prev

    for i in range(len(z_fc) - 1):  # the last layer requires special handling
        V = calc_v_fc(V, W_fc[i])  # bs x d x n_hs[i]
        V_stable = V + eps_num_stabil * np.less(np.abs(V), eps_num_stabil).astype(np.float32)
        new_dist_rb = np.abs(z_fc[i]) / np.linalg.norm(V_stable, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = np.concatenate([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus_fc[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = np.sum(th * np.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1)
    #rb_term = -np.sum(np.log(np.maximum(1e-8, dist_rb)),axis=1)

    
    # decision boundaries
    V = calc_v_fc(V, W_fc[-1])
    #y_true_diag = np.diag(y_true)
    N = y_true.shape[1]
    a = np.expand_dims(y_true, axis=1)
    y_true_diag = a*np.eye(N)
    
    LLK2 = V @ y_true_diag  # bs x d x K  @  bs x K x K  =  bs x d x K
    l = np.sum(LLK2, axis=2)  # bs x d
    l = np.tile(l, [1, n_out])  # bs x d x K
    l = np.reshape(l, [-1, n_out, n_in])  # bs x K x d
    V_argmax = np.transpose(l, [0, 2, 1])  # bs x d x K
    diff_v = np.abs(V - V_argmax)
    diff_v = diff_v + eps_num_stabil * np.less(diff_v, eps_num_stabil).astype(np.float32)
    dist_db_denominator = np.linalg.norm(diff_v, axis=1, ord=q)

    y_pred_diag = np.expand_dims(y_pred, 1)
    y_pred_correct = y_pred_diag @ y_true_diag  # bs x 1 x K  @  bs x K x K  =  bs x 1 x K
    y_pred_correct = np.sum(y_pred_correct, axis=2)  # bs x 1
    y_pred_correct = np.tile(y_pred_correct, [1, n_out])  # bs x 1 x K
    y_pred_correct = np.reshape(y_pred_correct, [-1, n_out, 1])  # bs x K x 1
    y_pred_correct = np.transpose(y_pred_correct, [0, 2, 1])  # bs x 1 x K
    dist_db_numerator = np.squeeze(y_pred_correct - y_pred_diag, 1)  # bs x K
    dist_db_numerator = dist_db_numerator + 100.0 * y_true  # bs x K

    dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db

    th = zero_out_non_min_distances(dist_db, n_db)
    db_term = np.sum(th * np.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1)
    
    #db_term = 0.0
    return rb_term, db_term

def mmr_fc(params, inputs, y_true, n_rb=0.1, n_hs=[1024], n_in=784, n_out=10, gamma_rb=0.1, gamma_db=0.1, bs=128, q=1):
#def mmr_fc(z_list, y_true, W, n_hs=[1024], n_in=784, n_out=10, gamma_rb=1.0, gamma_db=1.0, bs=128, q=2):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for fully-connected networks.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).
    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.MLP for details)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    W: list with all weight matrices
    n_in: total number of input pixels (e.g. 784 for MNIST)
    n_hs: list of number of hidden units for every hidden layer (e.g. [1024] for FC1)
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """
    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    W = [x[0] for x in params if len(x) == 2]
    n_rb = 0.08 * np.sum(n_hs) 
    n_db = n_out
    bs = inputs.shape[0]
    logits, z_list = net_walk(params,inputs)
    
    z_list.append(logits)
        
    n_hl = len(W) - 1  # number of hidden layers
    y_pred = z_list[-1]
    
    relus = []
    for i in range(n_hl):
        relu = np.greater(z_list[i], 0).astype(np.float32)
        relus.append(np.expand_dims(relu, 1))

    dist_rb = np.abs(z_list[0]) / np.linalg.norm(W[0], axis=0, ord=q)  # bs x n_hs[0]  due to broadcasting
    V = np.reshape(np.tile(W[0], [bs, 1]), [bs, n_in, n_hs[0]])  # bs x d x n1
    V = V * relus[0]  # element-wise mult using broadcasting, result: bs x d x n_cur
    for i in range(1, n_hl):
        V = calc_v_fc(V, W[i])  # bs x d x n_hs[i]
        new_dist_rb = np.abs(z_list[i]) / np.linalg.norm(V, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = np.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    #rb_term = -np.sum(np.log(np.maximum(1e-8, dist_rb)),axis=1)
    rb_term = np.sum(th * np.maximum(0.0, 1.0 - dist_rb / gamma_rb), axis=1)

    # decision boundaries
    V_last = calc_v_fc(V, W[-1])
    #y_true_diag = np.diag(y_true,k=1)
    N = y_true.shape[1]
    a = np.expand_dims(y_true, axis=1)
    y_true_diag = a*np.eye(N)
    
    LLK2 = V_last @ y_true_diag  # bs x d x K  @  bs x K x K  =  bs x d x K
    l = np.sum(LLK2, axis=2)  # bs x d
    l = np.tile(l, [1, n_out])  # bs x d x K
    l = np.reshape(l, [-1, n_out, n_in])  # bs x K x d
    V_argmax = np.transpose(l, [0, 2, 1])  # bs x d x K
    diff_v = np.abs(V_last - V_argmax)
    diff_v = diff_v + eps_num_stabil * np.less(diff_v, eps_num_stabil).astype(np.float32)
    dist_db_denominator = np.linalg.norm(diff_v, axis=1, ord=q)

    y_pred_diag = np.expand_dims(y_pred, 1)
    y_pred_correct = y_pred_diag @ y_true_diag  # bs x 1 x K  @  bs x K x K  =  bs x 1 x K
    y_pred_correct = np.sum(y_pred_correct, axis=2)  # bs x 1
    y_pred_correct = np.tile(y_pred_correct, [1, n_out])  # bs x 1 x K
    y_pred_correct = np.reshape(y_pred_correct, [-1, n_out, 1])  # bs x K x 1
    y_pred_correct = np.transpose(y_pred_correct, [0, 2, 1])  # bs x 1 x K
    dist_db_numerator = np.squeeze(y_pred_correct - y_pred_diag, 1)  # bs x K
    dist_db_numerator = dist_db_numerator + 100.0 * y_true  # bs x K

    dist_db = dist_db_numerator / dist_db_denominator + y_true * 2.0 * gamma_db

    th = zero_out_non_min_distances(dist_db, n_db)
    db_term = np.sum(th * np.maximum(0.0, 1.0 - dist_db / gamma_db), axis=1)
    
    #db_term = 0.0
    return rb_term, db_term

def acr_cnn(params, x, y_true, n_h=200, n_rb=0.1, gamma_rb=0.1, gamma_db=0.1, bs=128, q=1):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for CNNs as a TensorFlow computational graph.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).
    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.LeNetSmall for details)
    x: input points (bs x image_height x image_width x image_n_channels)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    model: models.CNN object that contains a model with its weights, strides, padding, etc
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """

    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    W = [p[0] for p in params if len(p) == 2]
    n_rb = 0.08 * n_h 
    n_db = 10
    bs = x.shape[0]
    logits, z_list = net_walk(params,x)
    
    z_list.append(logits)

    # the padding and strides should be the same as in the forward pass conv
    #strides, padding = model.strides, model.padding
    
    strides = [2,2]
    padding = 'SAME'
    
    y_pred = z_list[-1]
    z_conv, z_fc, relus_conv, relus_fc, W_conv, W_fc = [], [], [], [], [], []
    for w, y in zip(W, z_list):  # Depending on the shape we form pre-activation values and their relu switches
        if len(y.shape) == 4:  # if conv layer
            z_conv.append(y)
            relu = np.greater(y, 0).astype(np.float32)
            relus_conv.append(np.expand_dims(relu, 1))
            W_conv.append(w)
        else:
            z_fc.append(y)
            relu = np.greater(y, 0).astype(np.float32)
            relus_fc.append(np.expand_dims(relu, 1))
            W_fc.append(w)
    h_in, w_in, c_in = int(x.shape[1]), int(x.shape[2]), int(x.shape[3])
    n_in = h_in * w_in * c_in
    n_out = y_true.shape[1]

    # z[0]: bs x h_next x w_next x n_next,  W[0]: h_filter x w_filter x n_prev x n_next
    w_matrix = np.reshape(W_conv[0], [-1, int(W_conv[0].shape[-1])])  # h_filter*w_filter*n_col x n_next
    denom = np.linalg.norm(w_matrix, axis=0, ord=q, keepdims=True)  # n_next
    dist_rb = np.abs(z_conv[0]) / denom  # bs x h_next x w_next x n_next
    dist_rb = np.reshape(dist_rb, [bs, int(z_conv[0].shape[1]*z_conv[0].shape[2]*z_conv[0].shape[3])])  # bs x h_next*w_next*n_next

    # We need to get the conv matrix. Instead of using loops to contruct such matrix, we can apply W[0] conv filter
    # to a reshaped identity matrix. Then we duplicate bs times the resulting tensor.
    identity_input_fm = np.reshape(np.eye(n_in, n_in), [1, n_in, h_in, w_in, c_in])
    #print(identity_input_fm.shape, W_conv[0].shape)
    V = calc_v_conv(identity_input_fm, W_conv[0], strides[0], padding)  # 1 x d x h_next x w_next x c_next
    V = np.tile(V, [bs, 1, 1, 1, 1])  # bs x d x h_next x w_next x c_next
    V = V * relus_conv[0]
    for i in range(1, len(z_conv)):
        V = calc_v_conv(V, W_conv[i], strides[i], padding)  # bs x d x h_next x w_next x c_next 
        V_stable = V + eps_num_stabil * np.less(np.abs(V), eps_num_stabil).astype(np.float32)  # note: +eps would also work
        new_dist_rb = np.abs(z_conv[i]) / np.linalg.norm(V_stable, axis=1, ord=q)  # bs x h_next x w_next x c_next
        new_dist_rb = np.reshape(new_dist_rb, [bs, z_conv[i].shape[1]*z_conv[i].shape[2]*z_conv[i].shape[3]])  # bs x h_next*w_next*c_next
        dist_rb = np.concatenate([dist_rb, new_dist_rb], 1)  # bs x sum(n_neurons[1:i])
        V = V * relus_conv[i]  # element-wise mult using broadcasting, result: bs x d x h_cur x w_cur x c_cur

    # Flattening after the last conv layer
    V = np.reshape(V, [bs, n_in, V.shape[2] * V.shape[3] * V.shape[4]])  # bs x d x h_prev*w_prev*c_prev

    for i in range(len(z_fc) - 1):  # the last layer requires special handling
        V = calc_v_fc(V, W_fc[i])  # bs x d x n_hs[i]
        V_stable = V + eps_num_stabil * np.less(np.abs(V), eps_num_stabil).astype(np.float32)
        new_dist_rb = np.abs(z_fc[i]) / np.linalg.norm(V_stable, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = np.concatenate([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus_fc[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = -np.sum(np.log(np.maximum(1e-8, dist_rb)),axis=1)
    
    db_term = 0.0
    return rb_term, db_term

def acr_fc(params, inputs, y_true, n_rb, n_hs=[64], n_in=2, n_out=10, gamma_rb=0.1, gamma_db=0.1, bs=128, q=1):
#def mmr_fc(z_list, y_true, W, n_hs=[1024], n_in=784, n_out=10, gamma_rb=1.0, gamma_db=1.0, bs=128, q=2):
    """
    Batch-wise implementation of the Maximum Margin Regularizer for fully-connected networks.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).
    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.MLP for details)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    W: list with all weight matrices
    n_in: total number of input pixels (e.g. 784 for MNIST)
    n_hs: list of number of hidden units for every hidden layer (e.g. [1024] for FC1)
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """
    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    W = [x[0] for x in params if len(x) == 2]
    #n_rb = 0.08 * np.sum(n_hs) 
    #n_rb = 1.0 * np.sum(n_hs) 
    n_db = n_out
    bs = inputs.shape[0]
    logits, z_list = net_walk(params,inputs)
    
    z_list.append(logits)
        
    n_hl = len(W) - 1  # number of hidden layers
    y_pred = z_list[-1]
    
    relus = []
    for i in range(n_hl):
        relu = np.greater(z_list[i], 0).astype(np.float32)
        relus.append(np.expand_dims(relu, 1))

    dist_rb = np.abs(z_list[0]) / np.linalg.norm(W[0], axis=0, ord=q)  # bs x n_hs[0]  due to broadcasting
    V = np.reshape(np.tile(W[0], [bs, 1]), [bs, n_in, n_hs[0]])  # bs x d x n1
    V = V * relus[0]  # element-wise mult using broadcasting, result: bs x d x n_cur
    for i in range(1, n_hl):
        V = calc_v_fc(V, W[i])  # bs x d x n_hs[i]
        new_dist_rb = np.abs(z_list[i]) / np.linalg.norm(V, axis=1, ord=q)  # bs x n_hs[i]
        dist_rb = np.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = -np.sum(np.log(np.maximum(1e-8, dist_rb)),axis=1)

    db_term = 0.0
    return rb_term, db_term

def acr_fc_min(params, inputs, y_true, n_rb, n_hs=[64], n_in=2, n_out=10, gamma_rb=0.1, gamma_db=0.1, bs=128, q=1):
    """
    Batch-wise implementation of the ACR Regularizer for fully-connected networks.
    Note that it is differentiable, and thus can be directly added to the main objective (e.g. the cross-entropy loss).
    z_list: list with all tensors that correspond to preactivation feature maps
            (in particular, z_list[-1] are logits; see models.MLP for details)
    y_true: one-hot encoded ground truth labels (bs x n_classes)
    W: list with all weight matrices
    n_in: total number of input pixels (e.g. 784 for MNIST)
    n_hs: list of number of hidden units for every hidden layer (e.g. [1024] for FC1)
    n_rb: number of closest region boundaries to take
    n_db: number of closest decision boundaries to take
    gamma_rb: gamma for region boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              certify robustness in)
    gamma_db: gamma for decision boundaries (approx. corresponds to the radius of the Lp-ball that we want to
              be robust in)
    bs: batch size
    q: q-norm which is the dual norm to the p-norm that we aim to be robust at (e.g. if p=np.inf, q=1)
    """
    eps_num_stabil = 1e-5  # epsilon for numerical stability in the denominator of the distances
    W = [x[0] for x in params if len(x) == 2]
    n_rb = 0.08 * np.sum(n_hs) 
    n_db = n_out
    bs = inputs.shape[0]
    logits, z_list = net_walk(params,inputs)
    
    z_list.append(logits)
        
    n_hl = len(W) - 1  # number of hidden layers
    y_pred = z_list[-1]
    
    relus = []
    for i in range(n_hl):
        relu = np.greater(z_list[i], 0).astype(np.float32)
        relus.append(np.expand_dims(relu, 1))

    dist_rb = np.abs(z_list[0]) / np.linalg.norm(W[0], axis=0, ord=q)  # bs x n_hs[0]  due to broadcasting
    V = np.reshape(np.tile(W[0], [bs, 1]), [bs, n_in, n_hs[0]])  # bs x d x n1
    V = V * relus[0]  # element-wise mult using broadcasting, result: bs x d x n_cur
    for i in range(1, n_hl):
        V = calc_v_fc(V, W[i])  # bs x d x n_hs[i]
        #new_dist_rb = np.abs(z_list[i]) / np.linalg.norm(V, axis=1, ord=q)  # bs x n_hs[i]
        new_dist_rb = np.abs(z_list[i])
        dist_rb = np.concat([dist_rb, new_dist_rb], 1)  # bs x sum(n_hs[1:i])
        V = V * relus[i]  # element-wise mult using broadcasting, result: bs x d x n_cur

    th = zero_out_non_min_distances(dist_rb, n_rb)
    rb_term = -np.sum(th * np.log(np.maximum(1e-8, dist_rb)),axis=1)

    db_term = 0.0
    return rb_term, db_term