# cvxnn_utils.py

import numpy as np
import cvxpy as cp

def relu(x):
    return np.maximum(0,x)
def drelu(x):
    return x>=0
def cvx_nn_gen_mask(X, sample_num = 200, seed = 0):
    ## Finite approximation of all possible sign patterns
    N, d = X.shape

    np.random.seed(seed)
    dmat = drelu(X@np.random.randn(d,sample_num)).T
    dmat = (np.unique(dmat,axis=0))

    return dmat

def cvx_nn_max_margin(X,y,dmat,verbose=True):
    N, d = X.shape
    p = dmat.shape[0]
    w_p = cp.Variable([p,d])
    w_m = cp.Variable([p,d])
    output = 0
    for i in range(p):
        output+=cp.multiply(X,dmat[i,:].reshape([-1,1]))@(w_p[i,:]-w_m[i,:])
#     regularization = 0
#     for i in range(p):
#         regularization+=cp.norm(w_p[i])+cp.norm(w_m[i])
    regularization = cp.sum(cp.norm(w_p,axis=1)+cp.norm(w_m,axis=1))
    constraints = []
    I = np.eye(3)
    for i in range(p):
        d_sub = dmat[i,:].reshape([-1,1])
        constraints.append(cp.multiply(X,2*d_sub-1)@w_p[i,:]>=0)
        constraints.append(cp.multiply(X,2*d_sub-1)@w_m[i,:]>=0)
    constraints.append(cp.multiply(y,output)>=1)
    p_star = cp.Problem(cp.Minimize(regularization),constraints).solve(verbose=verbose)
    return p_star, (w_p.value,w_m.value)

def cvx_nn_max_margin_dual(X,y,dmat,verbose=True):
    N, d = X.shape
    p = dmat.shape[0]
    z_p = cp.Variable([p,N])
    z_m = cp.Variable([p,N])
    lbd = cp.Variable(N)

    constraints = [z_p>=0, z_m>=0,cp.multiply(y,lbd)>=0]

    for i in range(p):
        d_sub = dmat[i,:]
        constraints.append(cp.norm(X.T@cp.multiply(d_sub,lbd)-X.T@cp.multiply(2*d_sub-1,z_p[i,:]))<=1)
        constraints.append(cp.norm(-X.T@cp.multiply(d_sub,lbd)-X.T@cp.multiply(2*d_sub-1,z_m[i,:]))<=1)
    p_star = cp.Problem(cp.Maximize(cp.sum(cp.multiply(y,lbd))),constraints).solve(verbose=verbose)

    return p_star, (lbd.value, z_p.value, z_m.value)





