import torch
import numpy as np
from numpy import *
from numpy.linalg import svd, norm, inv

def prox_l1(v, lambdat):
    """
    shrinkage operator
    """
    return maximum(0, v - lambdat) - maximum(0, -v - lambdat)


def prox_matrix(v, lambdat, prox_f):
    """
    SVT:
    v 奇异值软阈值化
    lambdat 阈值
    prox_f 软阈值算子
    """
    U, S, V = svd(v, full_matrices=False)
    S = S.reshape((len(S), 1))
    pf = diagflat(prox_f(S, lambdat))
    # It should be V.conj().T given MATLAB-Python conversion, but matrix matches with out the .T so kept it.
    return U.dot(pf).dot(V.conj())


def objective(Y, E, C, P, Q, gamma1, rho=1.0):
    """
    Objective function
    """
    # 核范数正则化(奇异值非负)
    tmp = svd(P, compute_uv=0)
    tmp = tmp.reshape((len(tmp), 1))

    return np.power(norm(Y - C.dot(P) - E.dot(Q), 'fro'), 2) + gamma1 * norm(tmp, 1)



def admm(Y, E, C, Q, gamma1, rho=1.0, mu=1.0, max_iter=100, tol=1e-3):
    """
    使用ADMM求解带有核范数正则化的问题
    """
    n, c = Y.shape
    P = np.zeros((n, c),dtype="float32")
    Z_old = np.zeros((n, c),dtype="float32")
    U_old = np.zeros((n, c),dtype="float32")


    e_his = objective(Y, E, C, P, Q, gamma1, rho)
    for i in range(max_iter):
        # 更新P 求梯度
        P_s = P
        P = np.linalg.inv(2 * C.T.dot(C) + rho * np.identity(n)).dot((2*C.T.dot((Y - E.dot(Q))) + rho * Z_old - U_old))

        # 更新Z
        # Z_old = Z.clone()
        Z_new = prox_matrix(P + U_old / rho, gamma1 / rho, prox_l1)

        # 更新U
        U_new = U_old + P - Z_new

        # 判断收敛
        e = objective(Y, E, C, P, Q, gamma1, rho)
        if norm(e_his - e) < tol:
            break
        e_his = e
        Z_old = Z_new
        U_old = U_new
    # print("numpy_P.dtype",P.dtype)
    return torch.Tensor(P_s).cuda()
