import torch
import numpy as np
from numpy import *

def S(g, tau):
    max_vec = g - tau
    max_vec[max_vec < 0] = 0
    min_vec = g + tau
    min_vec[min_vec > 0] = 0
    return max_vec + min_vec

def admm_o(x, X, Y, W, C, P, Q, E, tmp, rho, iter, gama3, v_num, flag=""):
    

    n, _ = Y.shape

    
    I = torch.eye(n).cuda()
    o = torch.zeros((n, n)).cuda()

    # z_old = zeros(n,c);
    # u_old = zeros(n,c);
    # x_old=x;
    z_old = torch.zeros((n, n)).cuda()
    u_old = torch.zeros((n, n)).cuda()
    o_old = o.clone().detach()

    # e_his= gama2*norm(A*(W+x)-b,2)^2+lambda*norm(x,1);
    e_his_loss = torch.pow(torch.norm(x - (C + o) @ x, 2),2) + torch.pow(torch.norm(Y - C @ P - E @ Q, 2), 2) + torch.pow(torch.norm(Y - E @ Q - X @ W, 2), 2) + gama3 * torch.norm(o, 1)

    for ii in range(0, iter):
        o_s = o.clone().detach()
        o = (2*(x-C@x)@x.T+ 2/v_num *(2*Y - 2*tmp @ Q - X @ W - C @ P)@Q.T + rho * z_old - u_old) @ torch.inverse(2* x@ x.T + 4/(v_num ** 2) * (Q @ Q.T) + rho * I)

        z_new = S(o + u_old/rho, gama3 / rho)

        u_new = u_old + rho * (o - z_new)


        e_loss = torch.pow(torch.norm(x - (C + o) @ x, 2),2) + torch.pow(torch.norm(Y - C @ P - (o/v_num + tmp) @ Q, 2), 2) + torch.pow(torch.norm(Y - (o/v_num + tmp) @ Q - X @ W, 2), 2) + gama3 * torch.norm(o, 1)
        if torch.abs(e_his_loss - e_loss) < 0.001:
            break

        e_his_loss = e_loss.clone().detach()
        z_old = z_new.clone().detach()
        u_old = u_new.clone().detach()
    
    
    print("As for {}, Admm_loss: ".format(flag), e_his_loss.item())
    return o_s.cuda()

