import scipy as sp
import numpy as np
import torch
# import proxop
from projop.lp_balls import *
from projop.utils import *

def prox_linfnorm (v, lambda1):
    return v - lambda1 * l1_ball_vec(v/lambda1, 1)

def prox_l1norm (v, lambda1):
    return v - lambda1 * linf_ball_vec(v/lambda1, 1)

def prox_l2norm (v, lambda1):
    v_l2 = torch.norm(v)
    if v_l2 >= lambda1:
        return v - lambda1*v/v_l2
    else:
        return torch.zeros_like(v)

def prox_spectral (A, phi_name, lambda1):
    # we assume undirected graphs
    L = torch.diag (torch.sum(A, dim=1)) - A
    eigs, Q = torch.linalg.eigh (L)
    return Q @ torch.diag(proximal_vec (eigs, phi_name, lambda1)) @ Q.T

def proximal_vec (x, phi_name, lambda1):
    if phi_name == 'l0':
        return hard (x, (2 * lambda1)**0.5)
    elif phi_name == 'linf':
        return x - lambda1 * l1_ball_vec (x, lambda1)
    elif phi_name == 'l1':
        return soft (x, lambda1)

def proximal_adj (A, phi_name, lambda1):
    if phi_name == 'rank':
        L_prox = prox_spectral(A, 'l0', lambda1)
    elif phi_name == 'nedges':
        L_prox = prox_spectral(A, 'l1', lambda1)
    elif phi_name == 'eigmax':
        L_prox = prox_spectral(A, 'linf', lambda1)
    elif phi_name == 'eig2max':
        L = torch.diag (torch.sum(A, dim=1)) - A
        eval1, evec1 = torch.lobpcg(L, k=1)
        L_prox = eval1 * evec1 @ evec1.T + prox_spectral(A - eval1 * evec1 @ evec1.T, 
                                                         'linf', lambda1)
    elif phi_name == 'eig2min':
        L = torch.diag (torch.sum(A, dim=1)) - A