import torch
import numpy as np
from numpy import *

def S(g, tau):
    max_vec = g - tau
    max_vec[max_vec < 0] = 0
    min_vec = g + tau
    min_vec[min_vec > 0] = 0
    return max_vec + min_vec

def admm_lasso(X, Y, W, C, P, E, rho, iter, gama2,flag=""):
    

    n, c = Y.shape
    
    I = torch.eye(n).cuda()
    q = torch.zeros((n, c)).cuda()

    # z_old = zeros(n,c);
    # u_old = zeros(n,c);
    # x_old=x;
    z_old = torch.zeros((n, c)).cuda()
    u_old = torch.zeros((n, c)).cuda()
    q_old = q.clone().detach()

    # e_his= gama2*norm(A*(W+x)-b,2)^2+lambda*norm(x,1);
    e_his = torch.pow(torch.norm(Y - C @ P - E @ q), 2) + torch.pow(torch.norm(Y - E @ q - X @ W), 2) + gama2 * torch.norm(q, 1)

    for ii in range(0, iter):
        q_s = q.clone().detach()
        q = torch.inverse(4*(E.T @ E) + rho * I) @ (2*E.T @ (2*Y - C @ P - X @ W) + rho * z_old - u_old)

        z_new = S(q + u_old/rho, gama2 / rho)

        u_new = u_old + rho * (q - z_new)


        e = torch.pow(torch.norm(Y - C @ P - E @ q), 2) + torch.pow(torch.norm(Y - E @ q - X @ W), 2) + gama2 * torch.norm(q, 1)
        if torch.abs(e_his - e) < 0.001:
            print()
            break

        e_his = e.clone().detach()
        z_old = z_new.clone().detach()
        u_old = u_new.clone().detach()
    print("As for {}, Admm_loss: ".format(flag), e_his.item())
    return q_s.cuda()

