import torch
import torch.optim as optim
from torch import nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
import time

from algorithm.JOBCD.Parallel_updateX import Parallel_updateV
from UHstruct.fobj_val import UltraE_fval_obj
from UHstruct.get_hit_ranks import UltraE_test_hits_rank
from dataset.gettriple import Corrupt

class VRJJOBCD(nn.Module):
    def __init__(self, triple, ttest, vec_entity, vec_relation, vec_bias, config_yaml):
        super().__init__()
        self.device = config_yaml["device"]
        self.ttest = ttest
        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.Rmaxiter = int(float(config_yaml["run"]["Rmaxiter"]))
        self.Cmaxiter = int(float(config_yaml["run"]["Cmaxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.Lconstant = float(config_yaml["JOBCD"]["Lconstant"])
        self.theta = torch.tensor(int(float(config_yaml["JOBCD"]["theta"])))
        self.margin = config_yaml["datafeature"]["margin"]
        self.knum = config_yaml["datafeature"]["knum"]
        self.b = triple.shape[0]
        self.bb = int(self.b**0.5)
        self.gradp = self.bb / (self.b+self.bb)
        self.stopindex = 0

        self.triple = triple
        self.Corrupted_triple = None
        self.vec_entity = nn.Parameter(data=vec_entity)
        self.vec_relation = nn.Parameter(data=vec_relation)
        self.vec_relation_VRold = nn.Parameter(data=vec_relation)
        self.vec_bias = nn.Parameter(data=vec_bias)
        self.Rnum = vec_relation.shape[0]
        self.entity_num = self.vec_entity.shape[0]

        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(torch.float32).to(self.device)
        self.lambdaa = 0
        self.myeps = 1e-8
        self.Xgrad = None
        self.optimizer = optim.Adam([self.vec_entity, self.vec_bias], lr=float(config_yaml["run"]["lr_eb"]), weight_decay=float(config_yaml["run"]["weight_decay"]))

    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_CObj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        hist_hits = torch.zeros(self.maxiter, 3).to(self.device)
        hist_MRR = torch.zeros(self.maxiter, 1).to(self.device)
        start_time = time.time()
        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            self.Corrupted_triple = Corrupt(self.triple, self.entity_num, self.knum)

            for Citer in tqdm(range(self.Cmaxiter), position=0):
                with torch.no_grad():
                    if Citer % self.config_yaml['run']['inner_dispgap'] == 0:
                        temp = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation,
                                               self.vec_bias, self.config_yaml)
                        print('VR-J-JOBCD  Before R iter:{},citer:{}/{}, obj:{:.4f}'.format(iter, Citer, self.Cmaxiter, temp))

                self.Xgrad = self.fgrad()
                self.innertrain(Citer)

            hist_Obj[iter] = UltraE_fval_obj(self.triple, self.Corrupted_triple,
                                             self.vec_entity, self.vec_relation, self.vec_bias,
                                             self.config_yaml)
            if iter == 0:
                hist_CObj[iter] = hist_Obj[iter]
            else:
                hist_CObj[iter] = hist_CObj[iter - 1] + hist_Obj[iter]

            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                hist_hits[iter, :], hist_MRR[iter] = UltraE_test_hits_rank(self.ttest, self.vec_entity, self.vec_relation,
                                                                           self.vec_bias,
                                                                           self.config_yaml)
                print('VR-J-JOBCD iter:{}, csumfval:{:.2f}, hist hits:{:.4f}-{:.4f}-{:.4f}, hist MRR:{:.4f}'.format(iter,
                        hist_CObj[iter][0].data,hist_hits[iter,:].data[0],hist_hits[iter,:].data[1],hist_hits[iter,:].data[2],hist_MRR[iter][0].data))
            self.grad_old = None

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_CObj = hist_CObj[hist_CObj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_hits = hist_hits[:len(hist_Obj),:]
        hist_MRR = hist_MRR[:len(hist_Obj)]
        return hist_CObj, self.vec_relation, hist_t, hist_hits, hist_MRR

    def innertrain(self, Citer):
        allB = torch.zeros([self.Rnum, int(self.d / 2), 2]).to(self.device).to(torch.int64)
        for i in range(self.Rnum):
            original_vector = torch.randperm(self.d)
            original_vector = original_vector[torch.randperm(self.d)]
            allB[i, :, :] = original_vector.view(-1, 2)

        with torch.no_grad():
            self.vec_relation_VRold = self.vec_relation

        for i in range(self.Rnum):
            with torch.no_grad():
                self.vec_relation[i, :, :] = Parallel_updateV(self.vec_relation[i, :, :].to(torch.float32).to(self.device)
                                          , self.Xgrad[i, :, :].to(torch.float32), allB[i, :, :], self.Lconstant, self.theta, self.p)

    def fgrad(self):
        fval = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation.data,
                               self.vec_bias, self.config_yaml)
        fval.requires_grad_(True)
        fval.backward()

        if torch.isnan(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isnan(self.vec_entity.grad)] = 0
        if torch.isnan(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isnan(self.vec_bias.grad)] = 0

        if torch.isinf(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isinf(self.vec_entity.grad)] = 0
        if torch.isinf(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isinf(self.vec_bias.grad)] = 0

        self.optimizer.step()

        self.vec_entity.grad = None
        self.vec_bias.grad = None

        if torch.rand(1) < self.gradp or not hasattr(self, 'grad_old') or self.grad_old == None:  # compute all data's grad with probaility p
            fval = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity.data, self.vec_relation,
                                   self.vec_bias.data, self.config_yaml)
            fval.requires_grad_(True)
            fval.backward()

            if torch.isnan(self.vec_relation.grad).any():
                self.vec_relation.grad[torch.isnan(self.vec_relation.grad)] = 0
            if torch.isinf(self.vec_relation.grad).any():
                self.vec_relation.grad[torch.isinf(self.vec_relation.grad)] = 0

            clip_grad_norm_(self.vec_relation, max_norm=self.Lconstant)

            grad_R = self.vec_relation.grad

            self.vec_relation.grad = None
            self.vec_relation_VRold.grad = None
        else:
            sampled_numbers = torch.randperm(self.b)[:self.bb]
            triple = self.triple[sampled_numbers, :]
            Corrupted_triple = self.Corrupted_triple[sampled_numbers, :, :]

            fval = UltraE_fval_obj(triple, Corrupted_triple, self.vec_entity.data, self.vec_relation,
                                   self.vec_bias.data, self.config_yaml)
            fval.requires_grad_(True)
            fval.backward()

            with torch.no_grad():
                if torch.isnan(self.vec_relation.grad).any():
                    self.vec_relation.grad[torch.isnan(self.vec_relation.grad)] = 0

                if torch.isinf(self.vec_relation.grad).any():
                    self.vec_relation.grad[torch.isinf(self.vec_relation.grad)] = 0

            clip_grad_norm_(self.vec_relation, max_norm=self.Lconstant)

            self.vec_entity.grad = None
            self.vec_bias.grad = None

            fval_old = UltraE_fval_obj(triple, Corrupted_triple, self.vec_entity.data, self.vec_relation_VRold,
                                       self.vec_bias.data, self.config_yaml)
            fval_old.requires_grad_(True)
            fval_old.backward()

            with torch.no_grad():
                if torch.isnan(self.vec_relation_VRold.grad).any():
                    self.vec_relation_VRold.grad[torch.isnan(self.vec_relation_VRold.grad)] = 0

                if torch.isinf(self.vec_relation_VRold.grad).any():
                    self.vec_relation_VRold.grad[torch.isinf(self.vec_relation_VRold.grad)] = 0

            clip_grad_norm_(self.vec_relation_VRold, max_norm=self.Lconstant)

            grad_R = self.grad_old + (self.vec_relation.grad - self.vec_relation_VRold.grad)
            self.vec_relation.grad = None
            self.vec_relation_VRold.grad = None
        self.grad_old = grad_R
        return grad_R


