import torch
from copy import deepcopy
import tqdm

@ torch.no_grad()
def proximal_gradient_descent(X, Y, W, c, lambda_=0.1, learning_rate=0.0005, max_iter=200, tol=0.1):
    """Proximal Gradient Descent for L1-regularized least squares problem."""
    # Initialize
    # hat_W = torch.zeros(W.shape).to(W.device)
    hat_W = deepcopy(W)
    # breakpoint()
    # X_T_X = X @ X.T #[11008, 11008]
    # X_T_Y = Y @ X.T #[4096， 11008]
    # c_T_c = c.T @ c #[11008, 11008]
    # X [11008, 1]
    # Y [4096, 1]
    # W [4096, 11008]
    # c [11008, 11008]
    # [4096, 11008]

    for _ in tqdm.tqdm(range(max_iter)):
        # hat_W_old = deepcopy(hat_W)
        # Gradient step
        # grad = 2 * (hat_W @ X_T_X - X_T_Y) + 2 * (hat_W @ c_T_c - (W @ c) @ c.T)
        grad =(hat_W @ X - Y) @ X.T # * 50  (hat_W-W) @ c  + 
        print(grad)
        if grad.isnan().any().item():
            breakpoint()
        # grad = 2 * (X_T_X @ hat_W - X_T_Y) + 2 * (c_T_c @ hat_W - c.T @ (c @ W))
        hat_W = hat_W - learning_rate * grad
        # Proximal step
        # hat_W = W - torch.sign(W - hat_W) * torch.maximum(torch.abs(W - hat_W) - lambda_, torch.tensor(0.0))
        # print(np.linalg.norm(hat_W.cpu().numpy() - hat_W_old.cpu().numpy()))
        print("item 1:\n", hat_W @ c - W, "item2: \n", torch.norm(hat_W @ X - Y), "item3: \n", torch.norm(hat_W-W, p=1))
        # if np.linalg.norm(hat_W.cpu().numpy() - hat_W_old.cpu().numpy()) < tol:
            # break

def optimize(X, Y, W, c, lambda_=100000, learning_rate=0.0005, max_iter=200, k=1000):
    c = c * 50
    # hat_W = deepcopy(W)
    delta_W = torch.zeros_like(W, requires_grad=True)
    # torch.nn.init.normal_(delta_W, std=0.1)
    # hat_W.requires_grad = True
    opt = torch.optim.Adam([delta_W], lr=learning_rate)
    loss = torch.nn.MSELoss()
    l1_loss = torch.nn.L1Loss()
    relu = torch.nn.ReLU()
    C_transpose_C = torch.matmul(c.t(), c)
    for it in tqdm.tqdm(range(max_iter)):
        # sparse_delta_W = relu(delta_W)
        opt.zero_grad()
        # loss = torch.norm(hat_W @ X - Y) ** 2 \
        loss1 = loss((W+delta_W) @ X, Y)
        loss2 = loss(c @ delta_W.T ,torch.zeros(c.shape[0], delta_W.shape[0]).to(c.device))
        loss3 = l1_loss(delta_W + W, W) * lambda_
        output = loss1 + loss2# + loss3
        for group in opt.param_groups:
            for param in group['params']:
                if param.grad is not None:
                    # Assuming param represents the matrix W
                    gradient = 2 * torch.matmul(C_transpose_C, param)
                    param.grad = gradient
        # output.backward()
        opt.step()
        # breakpoint()
        with torch.no_grad():
            delta_W.data = torch.where((delta_W).abs() < torch.topk(torch.abs(delta_W.flatten()), k=k)[0].min(), 0, delta_W)
           
        print(loss1, loss2, loss3)
        print((torch.abs(delta_W) != 0).sum())
        #  (torch.abs(hat_W - W)>0.0005).sum()
        # if it > 5: breakpoint()

if __name__ == "__main__":
    X, Y = torch.randn(11008, 1).cuda(), torch.randn(4096, 1).cuda()
    W = torch.randn(4096, 11008).cuda()
    C = torch.randn(11008, 11008).cuda()
    optimize(X, Y, W, C, max_iter=50, lambda_=100000)