import torch
from torch import nn
from tqdm import tqdm
import time

from utils.Jorthogonal_test import Jtest
from algorithm.JOBCD.Parallel_updateX import Parallel_updateV
from UHstruct.fobj_val import UltraE_fval_obj
from utils.get_dist import get_dist

class VRJOBCD(nn.Module):
    def __init__(self, T, H, Q, P, config_yaml):
        super().__init__()
        self.device = config_yaml["device"]
        self.T = T.to(self.device)
        self.H = H.to(self.device)
        self.Q = nn.Parameter(data=Q.to(self.device))
        self.P = P.to(self.device)
        self.Q_VRold = nn.Parameter(data=Q.to(self.device))

        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.Cmaxiter = int(float(config_yaml["run"]["Cmaxiter"]))

        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.stopt = config_yaml["run"]["stop_t"]
        self.stoplr = float(config_yaml["run"]["stop_lr"])
        self.Lconstant = float(config_yaml["JOBCD"]["Lconstant"])
        self.theta = torch.tensor(int(float(config_yaml["JOBCD"]["theta"])))
        self.b = config_yaml["datafeature"]["N"]
        self.bb = int(self.b**0.5)
        self.gradp = self.bb / (self.b+self.bb)
        self.stopindex = 0


        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32).to(self.device)
        self.Xgrad = None

    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        hist_err = torch.zeros(self.maxiter, 1)

        start_time = time.time()
        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            with torch.no_grad():
                hist_Obj[iter] = UltraE_fval_obj(self.T, self.H, self.P, self.Q, self.config_yaml)

            self.innertrain()

            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                hist_err[iter] = Jtest(self.Q, self.p)

                if hist_t[iter] > self.stopt:
                    self.stopindex = 1

                if self.stopindex == 1:
                    break
            self.grad_old = None


        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_err = hist_err[:len(hist_Obj)]
        return hist_Obj, self.Q, hist_t, hist_err

    def innertrain(self):
        self.Xgrad = self.fgrad()


        # generate Jocobbi matrix
        original_vector = torch.randperm(self.d)
        B = original_vector[torch.randperm(self.d)].view(-1, 2)


        # update all var
        with torch.no_grad():
            self.Q_VRold = self.Q

        gradX = self.Xgrad.to(torch.float32)
        Lconst = self.Lconstant
        theta = self.theta
        with torch.no_grad():
            self.Q = Parallel_updateV(self.Q, gradX, B, Lconst, theta, self.p)

    def fgrad(self):
        if torch.rand(1)<self.gradp or not hasattr(self, 'grad_old') or self.grad_old == None:
            fval = UltraE_fval_obj(self.T, self.H, self.P, self.Q, self.config_yaml)
            fval.requires_grad_(True)
            fval.backward()

            if torch.isnan(self.Q.grad).any():
                self.Q.grad[torch.isnan(self.Q.grad)] = 0
            if torch.isinf(self.Q.grad).any():
                self.Q.grad[torch.isinf(self.Q.grad)] = 0

            grad_R = self.Q.grad

            self.Q.grad = None

        else:
            sampled_numbers = torch.randperm(self.b)[:self.bb]
            H = self.H[sampled_numbers, :]
            T = get_dist(H).to(torch.float32)

            fval = UltraE_fval_obj(T, H, self.P, self.Q, self.config_yaml)
            fval.requires_grad_(True)
            fval.backward()

            with torch.no_grad():
                if torch.isnan(self.Q.grad).any():
                    self.Q.grad[torch.isnan(self.Q.grad)] = 0
                if torch.isinf(self.Q.grad).any():
                    self.Q.grad[torch.isinf(self.Q.grad)] = 0

            fval_old = UltraE_fval_obj(T, H, self.P, self.Q_VRold, self.config_yaml)
            fval_old.requires_grad_(True)
            fval_old.backward()

            with torch.no_grad():
                if torch.isnan(self.Q_VRold.grad).any():
                    self.Q_VRold.grad[torch.isnan(self.Q_VRold.grad)] = 0
                    # os.system("pause")
                if torch.isinf(self.Q_VRold.grad).any():
                    self.Q_VRold.grad[torch.isinf(self.Q_VRold.grad)] = 0

            grad_R = self.grad_old + (self.Q.grad - self.Q_VRold.grad)
            self.Q.grad = None
            self.Q_VRold.grad = None
        self.grad_old = grad_R
        return grad_R


