from utils.fobj_val import fobj_val
import torch
from utils.Jexc import exc_operator
from utils.Jorthogonal_test import Jtest
import time

torch.manual_seed(0)
class UMCM_al:
    def __init__(self, X, config_yaml):
        self.X = X.clone()
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = int(float(config_yaml["run"]["stop_lr"]))
        self.Jerrindex = config_yaml["run"]["Jerrindex"]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32)
        self.myeps = 1e-8
        self.stepsize = 1e-3
        self.rho = 1e-3
        self.X_old = X

        self.Jerr0 = Jtest(self.X, self.p)


    def step_forward(self, C):
        self.X_old = self.X.clone().detach()
        G = C @ self.X
        deltaL = self.J @ self.X.t() @ G - G.t() @ self.X @ self.J + 2 * torch.norm(self.X - self.X_old, 'fro')
        self.X = self.X_old - deltaL * self.stepsize
        self.X = torch.clamp(self.X, -100, 100)

        if Jtest(self.X, self.p) > self.Jerr0*100 or Jtest(self.X, self.p)>1e-4:
            self.X = self.X_old

    def train(self, C):
        hist_Obj = torch.zeros(self.maxiter, 1)
        hist_t = torch.zeros(self.maxiter, 1)
        start_time = time.time()

        for iter in range(self.maxiter):
            Jerr = self.Jerrindex * torch.norm(self.X.t() @ self.J @ self.X - self.J, 'fro')
            hist_Obj[iter] = fobj_val(self.X, C) + Jerr

            self.step_forward(C)
            hist_t[iter] = time.time() - start_time
            if hist_t[iter] > self.stopt:
                break

            if iter % 2 == 0:
                if self.stepsize > 1e-8:
                    self.stepsize = self.stepsize / 2

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        return hist_Obj, self.X, hist_t


