import torch
import numpy as np
from sklearn import datasets as skdatasets
def euclidean_proj_simplex(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # check if we are already on the simplex
    if v.sum() == s and np.alltrue(v >= 0):
        # best projection: itself!
        return v
    # get the array of cumulative sums of a sorted (decreasing) copy of v
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    # get the number of > 0 components of the optimal solution
    rho = np.nonzero(u * np.arange(1, n+1) > (cssv - s))[0][-1]
    # compute the Lagrange multiplier associated to the simplex constraint
    theta = float(cssv[rho] - s) / rho
    # compute the projection by thresholding v using theta
    w = (v - theta).clip(min=0)
    return w


def euclidean_proj_l1ball(v, s=1):
    assert s > 0, "Radius s must be strictly positive (%d <= 0)" % s
    n, = v.shape  # will raise ValueError if v is not 1-D
    # compute the vector of absolute values
    # u = np.abs(v)
    u = v.abs()
    # check if v is already a solution
    if u.sum() <= s:
        # L1-norm is <= s
        return v
    # v is not already a solution: optimum lies on the boundary (norm == s)
    # project *u* on the simplex
    w = euclidean_proj_simplex(u.cpu().numpy(), s=s)
    # compute the solution to the original problem on v
    return torch.tensor(w, device=v.device, dtype=torch.float) * torch.sign(v)
    # w *= np.sign(v)
    # return w
data, label = skdatasets.load_svmlight_file('your_data_path')

data, label = (
    torch.tensor(data.toarray(), device='cuda'),
    torch.tensor(label, device='cuda'),
    dtype=torch.float32
)

param = torch.zeros(data.shape[1], device='cuda', requires_grad=True)

def loss_fn(param, data, label, is_convex=True):
    if is_convex==True:
        return torch.log(1 + torch.exp(-label* (data @ param))).mean()
    else:
        return (1/(1 + torch.exp(label*(data@param)))).mean()
        
def compute_avg_grad(
    param: torch.Tensor, data: torch.Tensor, label: torch.Tensor, is_convex: bool = True,
) -> torch.Tensor:
    if is_convex:
        numerator = -label* torch.exp(-label*(data @ param))
        denominator = 1 + torch.exp(-label* (data @ param))
        return ((numerator / denominator).unsqueeze(1) * data).mean(dim=0)
    else:
        numerator = -label* torch.exp(-label*(data @ param))
        denominator = (1 + torch.exp(-label* (data @ param))).pow(2)
        return ((numerator / denominator).unsqueeze(1) * data).mean(dim=0)

min_loss = 1000

for i in range(2_000):
    # loss = torch.nn.functional.binary_cross_entropy_with_logits(data @ param, label)
    loss = loss_fn(param, data, label, is_convex=False)

    with torch.no_grad():
        if param.abs().sum() <= 20:
            print(f'At iteration {i}, the loss is {loss}, the norm of param is {param.abs().sum()}')
            min_loss = min(min_loss, loss.item())

    loss.backward()

    with torch.no_grad():
        grad = compute_avg_grad(param, data, label, is_convex=False)
        # manual_grad = compute_avg_grad(param, data, label)
        # print(f'the automatic grad is {torch.allclose(grad, manual_grad)} with the manual grad!')
        
        param.copy_(euclidean_proj_l1ball(param - 1e-4* grad, 20))
        param.grad = None
print(min_loss)