import torch
from projop.utils import *

def l2_ball_feat (X, X_base, bound):
    # frobenius norm
    xdiff_norm = torch.norm(X - X_base, p='fro')
    if xdiff_norm <= bound:
        return X
    else:
        return (X - X_base)/xdiff_norm + X_base
        
def l2_ball_adj (A, A_base, bound):
    # frobenius norm
    xdiff_norm = torch.norm(A - A_base, p='fro')
    if xdiff_norm <= bound:
        return A
    else:
        return (A - A_base)/xdiff_norm + A_base
        
def linf_ball_vec (v, bound):
    # frobenius norm
    # return torch.clamp (v, -bound, bound)
    return torch.where(torch.abs(v) <= bound, v, bound * torch.sign(v))
        
def l1_ball_vec (v, bound):
    # l1 ball
    if torch.norm(v, p=1) <= bound:
        return v # soft(v, 0)
    else:
        a, b = 0, torch.max(torch.abs(v)) # can have better bounds but this should work too.
        lambda_star = bisection(v, func=lambda mu: plus_fn(torch.abs(v) - mu).sum() - bound, a=a, b=b)
        return soft (v, lambda_star)

def l0_ball_vec (v, bound):
    # l0 ball: \sum_{x_i > 0} 1 \le bound
    if torch.norm(v, p=0) <= bound:
        return v #hard(v, 0)
    else:
        a, b = 0, torch.max(torch.abs(v))**2 / 2
        lambda_star = bisection(v, func=lambda mu: torch.sum(torch.abs(v) >= (2*mu)**0.5) - bound, a=a, b=b)
        return hard (v, lambda_star)

def l2_ball_vecs (vs, bound):
    vsdiff_l2 = torch.norm(vs, dim=1, p=2)[:, None]
    return torch.where (vsdiff_l2 < bound, vs, vs*bound/vsdiff_l2)

def l2_ball_vec (v, bound):
    vdiff_l2 = torch.norm(v, p=2)
    return torch.where (vdiff_l2 < bound, v, v*bound/vdiff_l2)