from utils.fobj_val import fobj_val
import torch
from utils.Jexc import exc_operator
from utils.Jorthogonal_test import Jtest
import time

torch.manual_seed(0)
class admm:
    def __init__(self, X, config_yaml, lambda_val, mode):
        self.X = X.clone()
        self.Y = self.X
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = int(float(config_yaml["run"]["stop_lr"]))
        self.Jerrindex = config_yaml["run"]["Jerrindex"]
        self.mode = mode

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32)
        self.myeps = 1e-8
        self.lambda_val = lambda_val
        self.l = 1e4
        self.Pi = torch.randn(X.size()).to(torch.float32) * 0.1
        self.ksi = torch.randn(X.size()).to(torch.float32) * 0.1

        self.Jerr0 = Jtest(self.X, self.p)


    def train(self, C):
        hist_Obj = torch.zeros(self.maxiter, 1)
        hist_t = torch.zeros(self.maxiter, 1)
        hist_err = torch.zeros(self.maxiter, 1)
        start_time = time.time()

        for iter in range(self.maxiter):
            hist_Obj[iter] = fobj_val(self.X, C)
            hist_err[iter] = Jtest(self.X, self.p)

            self.step_forward(C)

            hist_t[iter] = time.time() - start_time
            if hist_t[iter] > self.stopt: # self.stopt
                break

            if self.mode == 'd':
                if iter % 2 == 0:
                    if self.lambda_val < 1e12:
                        self.lambda_val *= 2

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_err = hist_err[:len(hist_Obj)]
        return hist_Obj, self.X, hist_t, hist_err

    def step_forward(self, C):
        gradX = C @ self.X

        # sub-problem 1
        try:
            self.X = torch.pinverse((self.l+self.lambda_val)*torch.eye(self.d)+self.lambda_val*self.Y@self.Y.t()) @ (-(gradX - self.l * self.X + self.Y @ self.ksi.t()+self.J @self.Pi - self.lambda_val * self.Y@self.J-self.lambda_val*self.J@self.Y))
        except:
            self.X = self.X
        self.X = torch.clamp(self.X, -100, 100)

        diff = self.X.T @ self.Y - self.J
        self.ksi += self.lambda_val * diff
        self.ksi = torch.clamp(self.ksi, -10, 10)

        # sub-problem 2
        try:
            self.Y = torch.pinverse(self.lambda_val * (self.X @ self.X.t()+torch.eye(self.d)))@(-(self.X @ self.ksi - self.J @ self.Pi - self.lambda_val * self.X @ self.J - self.lambda_val * self.J @ self.X))
        except:
            self.Y = self.Y

        diff = self.J @ self.X - self.Y
        self.Pi += self.lambda_val * diff
        self.Pi = torch.clamp(self.Pi, -10, 10)


