from utils.fobj_val import fobj_val
import torch
import torch.optim as optim
from utils.initial_x import getU,getV,getH
from utils.Jorthogonal_test import Jtest
import time

class CSDM_al:
    def __init__(self, theta, mu, ksi, config_yaml):
        self.theta = theta.clone().detach().requires_grad_(True)
        self.mu = mu.clone().detach().requires_grad_(True)
        self.ksi = ksi.clone().detach().requires_grad_(True)

        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = int(float(config_yaml["run"]["stop_lr"]))
        self.Jerrindex = config_yaml["run"]["Jerrindex"]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32)
        self.U = getU(self.d, self.p, self.theta)
        self.H = getH(self.d, self.p, self.mu)
        self.V = getV(self.d, self.p, self.ksi)
        self.U.retain_grad()
        self.V.retain_grad()
        self.H.retain_grad()
        self.X = torch.mm(self.U, torch.mm(self.H, self.V)).to(torch.float32)
        self.optimizer = optim.Adam([self.theta, self.mu, self.ksi], lr=float(config_yaml["CS"]["lr"]))


    def step_forward(self, C):
        self.X = torch.mm(self.U, torch.mm(self.H, self.V)).to(torch.float32)
        self.X.retain_grad()
        fval = fobj_val(self.X, C)
        fval.backward()

        self.optimizer.step()
        self.optimizer.zero_grad()

    def train(self, C):
        hist_CS = torch.zeros(self.maxiter, 1)
        hist_t = torch.zeros(self.maxiter, 1)
        start_time = time.time()
        for iter in range(self.maxiter):
            Jerr = self.Jerrindex * torch.norm(self.X.t() @ self.J @ self.X - self.J, 'fro')
            hist_CS[iter] = fobj_val(self.X, C) + Jerr
            hist_t[iter] = time.time() - start_time
            if hist_t[iter] > self.stopt:
                break
            self.step_forward(C)
            self.U = getU(self.d, self.p, self.theta)
            self.H = getH(self.d, self.p, torch.clamp(self.mu, -1, 1))
            self.V = getV(self.d, self.p, self.ksi)

        hist_CS = hist_CS[hist_CS!=0]
        hist_t = hist_t[:len(hist_CS)]
        return hist_CS, self.X, hist_t


