import torch
import time

from utils.fobj_val import fobj_val
from algorithm.JOBCD.Parallel_updateX import Parallel_updateV

class JJOBCD:
    def __init__(self, X, config_yaml):
        self.X = X.to(torch.float32)
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.theta = torch.tensor(int(float(config_yaml["JOBCD"]["theta"])))
        self.Jerrindex = config_yaml["run"]["Jerrindex"]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32)

    def step_forward(self, C, iter):
        # generate Jocobbi matrix
        gradX = torch.mm(C, self.X).to(torch.float32)
        theta = self.theta
        Lconst = torch.max(abs(C))

        original_vector = torch.randperm(self.d)
        B = original_vector[torch.randperm(self.d)].view(-1, 2)

        self.X = Parallel_updateV(self.X, gradX, B, Lconst, theta, self.p)


    def train(self, C):
        hist_JOBCD = torch.zeros(self.maxiter, 1)
        hist_t = torch.zeros(self.maxiter, 1)
        start_time = time.time()
        for iter in range(self.maxiter):
            Jerr = self.Jerrindex * torch.norm(self.X.t() @ self.J @ self.X - self.J, 'fro')
            hist_JOBCD[iter] = fobj_val(self.X, C) + Jerr
            hist_t[iter] = time.time() - start_time
            self.step_forward(C, iter)
            if hist_t[iter] > self.stopt:
                break
        hist_JOBCD = hist_JOBCD[hist_JOBCD!=0]
        hist_t = hist_t[:len(hist_JOBCD)]
        return hist_JOBCD, self.X, hist_t





