def our_reg_loss(p, W, b, x, y, mu, sigma, t, lambda_reg=0.01):
    def characteristic_function_normal(mu, sigma,t):
        return jnp.exp(1j * mu * t - 0.5 * (sigma ** 2) * (t ** 2))

    def characteristic_function_bernoulli(p_values, t):
        '''
        Vectorized variant
        '''
        #Compute the denominator once for efficiency
        denom = jnp.sqrt(jnp.sum(p_values * (1 - p_values) ** 2))

        def phi_lhs(prob):
            return (1 - prob) * jnp.exp((-1j * t * prob) / denom)
        def phi_rhs(prob):
            return prob * jnp.exp((1j * t * (1 - prob)) / denom)

        # Use vmap to vectorize the operation over p_values
        phi_values_lhs = jax.vmap(phi_lhs)(p_values)
        phi_values_rhs = jax.vmap(phi_rhs)(p_values)
        phi_values = phi_values_lhs + phi_values_rhs

        phi = jnp.prod(phi_values_lhs,axis=0) + jnp.prod(phi_values_rhs, axis=0)

        return phi


    def lossfx(model_charc_fx,mu,sigma,t,p):
        norm_charc = characteristic_function_normal(mu,sigma,t)# charc fx shape of (t,)
            # Compute the difference
        diff = jnp.abs(model_charc_fx - norm_charc)
        reg_loss = jnp.where(
        jnp.equal(p, -1),
        jnp.max(diff),  # Infinity norm
        jnp.sum(diff ** p) ** (1 / p)  # Lp norm
        )
        return reg_loss



    #Charc Regularisation Term
    bernouli_ps = jax.nn.softmax(jnp.dot(x,W) + b) # bernouli_ps shape of (batchsz,number of ps) , ie. (batchsize,10)
    #calculate the charc function of each bernouli RV
    model_charc_fx = jax.vmap(characteristic_function_bernoulli, in_axes=(0, None))(bernouli_ps, t) #size of (batchsize, t)

    #calculate reg loss
    batched_loss_helper_function = jax.vmap(lossfx, in_axes=(0, None, None, None,None))#helperfunction to overload properly
    R_loss = batched_loss_helper_function(model_charc_fx,mu,sigma,t,p)
    R_loss_batch = jnp.mean(R_loss)
    reg_loss_ours = lambda_reg * R_loss_batch

    return reg_loss_ours

def reg_loss_lp(p, W, lambda_reg):
    reg_loss = jnp.where(
        jnp.equal(p, -1),
        jnp.max(jnp.abs(W)),  # L-infinity norm
        jnp.sum(jnp.abs(W) ** p) ** (1 / p)  # Lp norm
    )
    return lambda_reg *reg_loss
