import torch
import time
import random

from utils.fobj_val import fobj_val
from algorithm.JOBCD.nonconvex_orth2d_quad_notsame import nonconvex_orth2d_quad_notsame
from algorithm.JOBCD.nonconvex_orth2d_quad_L0 import nonconvex_orth2d_quad_L0
from utils.Jorthogonal_test import Jtest

class GS_JOBCD:
    def __init__(self, X, config_yaml):
        self.X = X.to(torch.float32)
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.inneriter = int(float(config_yaml["JOBCD"]["inneriter"]))
        self.theta = torch.tensor(int(float(config_yaml["JOBCD"]["theta"])))
        self.Jerrindex = config_yaml["run"]["Jerrindex"]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32)

    def step_forward(self, C):
        n = self.d
        B = torch.tensor([1, 2])

        for iter in range(self.inneriter):
            UJ = torch.diag(self.J[B, B])
            Z = self.X[B, :].to(torch.float32)
            ZZ = torch.mm(Z, Z.T).to(torch.float32)
            gradX = torch.mm(C, self.X).to(torch.float32)
            theta = self.theta
            hessian = C.data

            P = torch.mm(gradX[B,:], Z.T) - torch.mm(hessian[B][:, B], ZZ) -theta*torch.eye(2)

            if torch.sum(UJ)==0:
                V = nonconvex_orth2d_quad_notsame(torch.kron(ZZ, hessian[B][:, B])+torch.eye(4)*torch.sqrt(theta), P)
                if torch.mean(abs(V))>0.56:
                    V = torch.eye(2).to(torch.float32)
            else:
                V = nonconvex_orth2d_quad_L0(torch.kron(ZZ, hessian[B][:, B])+torch.eye(4)*torch.sqrt(theta), P)
            self.X[B, :] = torch.mm(V, self.X[B, :].to(torch.float32))

            B = torch.tensor(random.sample(range(n), 2))

    def train(self, C):
        hist_JOBCD = torch.zeros(self.maxiter, 1)
        hist_t = torch.zeros(self.maxiter, 1)
        hist_err = torch.zeros(self.maxiter, 1)

        start_time = time.time()
        for iter in range(self.maxiter):
            hist_JOBCD[iter] = fobj_val(self.X, C)
            hist_t[iter] = time.time() - start_time
            hist_err[iter] = Jtest(self.X, self.p)

            self.step_forward(C)
            if hist_t[iter] > self.stopt:
                break
        hist_JOBCD = hist_JOBCD[hist_JOBCD!=0]
        hist_t = hist_t[:len(hist_JOBCD)]
        hist_err = hist_err[:len(hist_JOBCD)]
        return hist_JOBCD, self.X, hist_t, hist_err


