import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
import time
from tqdm import tqdm
import sys
from torch.autograd import Variable
from utils.initial_x import getU, getV, getH

import torch.optim as optim
from UHstruct.fobj_val import UltraE_fval_obj
from UHstruct.get_hit_ranks import UltraE_test_hits_rank
from dataset.gettriple import Corrupt


class CS(nn.Module):
    def __init__(self, triple, ttest, vec_entity, theta, mu, ksi, vec_bias, config_yaml):
        super().__init__()
        self.device = config_yaml["device"]
        self.triple = triple
        self.ttest = ttest
        self.config_yaml = config_yaml
        self.maxiter = int(float(config_yaml["run"]["maxiter"]))
        self.Rmaxiter = int(float(config_yaml["run"]["Rmaxiter"]))
        self.Cmaxiter = int(float(config_yaml["run"]["Cmaxiter"]))
        self.d = config_yaml["datafeature"]["d"]
        self.p = config_yaml["datafeature"]["p"]
        self.margin = config_yaml["datafeature"]["margin"]
        self.knum = config_yaml["datafeature"]["knum"]
        self.Rnum = ksi.shape[0]
        self.entity_num = vec_entity.shape[0]
        self.stopindex = 0

        self.vec_entity = nn.Parameter(data=vec_entity)
        self.vec_bias = nn.Parameter(data=vec_bias)
        self.theta = nn.Parameter(data=theta)
        self.mu = nn.Parameter(data=mu)
        self.ksi = nn.Parameter(data=ksi)
        self.getM()


        self.inneriter = int(1e2)
        J = torch.eye(self.d)
        J[self.p:, self.p:] = -1 * torch.eye(self.d - self.p)
        self.J = J.to(self.device)
        self.lambdaa = 0
        self.myeps = 1e-8

        self.optimizerEB = optim.Adam([self.vec_entity, self.vec_bias],
                                    lr=float(config_yaml["run"]["lr_eb"]),
                                    weight_decay=float(config_yaml["run"]["weight_decay"]))

        self.optimizerR = optim.Adagrad([self.theta, self.mu, self.ksi],
                                    lr=float(config_yaml["run"]["lr_r"]),
                                    weight_decay=float(config_yaml["run"]["weight_decay"]))


    def Train(self):
        hist_Obj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_CObj = torch.zeros(self.maxiter, 1).to(self.device)
        hist_t = torch.zeros(self.maxiter, 1).to(self.device)
        hist_hits = torch.zeros(self.maxiter, 3).to(self.device)
        hist_MRR = torch.zeros(self.maxiter, 1).to(self.device)
        start_time = time.time()
        for iter in tqdm(range(self.maxiter), desc='Outiter'):
            # get sample
            self.Corrupted_triple = Corrupt(self.triple, self.entity_num, self.knum)

            # train
            for Citer in tqdm(range(self.Cmaxiter), position=0):
                with torch.no_grad():
                    if Citer % self.config_yaml['run']['inner_dispgap'] == 0:
                        temp = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation,
                                               self.vec_bias, self.config_yaml)
                        print('CS  Before R iter:{},citer:{}/{}, obj:{:.4f}'.format(iter, Citer, self.Cmaxiter, temp))

                self.step_forward()
                self.getM()

            # log
            with torch.no_grad():
                hist_t[iter] = time.time() - start_time
                hist_Obj[iter] = UltraE_fval_obj(self.triple, self.Corrupted_triple,
                                                     self.vec_entity, self.vec_relation, self.vec_bias,
                                                     self.config_yaml)
                if iter == 0:
                    hist_CObj[iter] = hist_Obj[iter]
                else:
                    hist_CObj[iter] = hist_CObj[iter - 1] + hist_Obj[iter]

                hist_hits[iter, :], hist_MRR[iter] = UltraE_test_hits_rank(self.ttest, self.vec_entity, self.vec_relation,
                                                                           self.vec_bias,
                                                                           self.config_yaml)
                print('CS iter:{}, csumfval:{:.2f}, hist hits:{:.4f}-{:.4f}-{:.4f}, hist MRR:{:.4f}'.format(
                    iter, hist_CObj[iter][0].data, hist_hits[iter, :].data[0], hist_hits[iter, :].data[1],
                    hist_hits[iter, :].data[2], hist_MRR[iter][0].data))

        hist_Obj = hist_Obj[hist_Obj != 0]
        hist_CObj = hist_CObj[hist_CObj != 0]
        hist_t = hist_t[:len(hist_Obj)]
        hist_hits = hist_hits[:len(hist_Obj), :]
        hist_MRR = hist_MRR[:len(hist_Obj)]
        return hist_CObj, self.vec_relation, hist_t, hist_hits, hist_MRR

    def step_forward(self):
        fval = UltraE_fval_obj(self.triple, self.Corrupted_triple, self.vec_entity, self.vec_relation, self.vec_bias,
                               self.config_yaml)
        fval.backward()

        if torch.isnan(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isnan(self.vec_entity.grad)] = 0
        if torch.isnan(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isnan(self.vec_bias.grad)] = 0

        if torch.isinf(self.vec_entity.grad).any():
            self.vec_entity.grad[torch.isinf(self.vec_entity.grad)] = 0
        if torch.isinf(self.vec_bias.grad).any():
            self.vec_bias.grad[torch.isinf(self.vec_bias.grad)] = 0

        if torch.isnan(self.vec_relation.grad).any():
            self.vec_relation.grad[torch.isnan(self.vec_relation.grad)] = 0

        if torch.isnan(self.theta.grad).any():
            self.theta.grad[torch.isnan(self.theta.grad)] = 0
        if torch.isnan(self.ksi.grad).any():
            self.ksi.grad[torch.isnan(self.ksi.grad)] = 0
        if torch.isnan(self.mu.grad).any():
            self.mu.grad[torch.isnan(self.mu.grad)] = 0

        if torch.isinf(self.vec_relation.grad).any():
            self.vec_relation.grad[torch.isinf(self.vec_relation.grad)] = 0

        if torch.isinf(self.theta.grad).any():
            self.theta.grad[torch.isinf(self.theta.grad)] = 0
        if torch.isinf(self.ksi.grad).any():
            self.ksi.grad[torch.isinf(self.ksi.grad)] = 0
        if torch.isinf(self.mu.grad).any():
            self.mu.grad[torch.isinf(self.mu.grad)] = 0

        self.optimizerR.step()
        self.optimizerEB.step()
        self.optimizerR.zero_grad()
        self.optimizerEB.zero_grad()
        self.vec_relation.grad = None



    def getM(self):
        self.U = torch.stack([getU(theta) for theta in self.theta]).to(self.device)
        self.H = torch.stack([getH(self.d, self.p,theta) for theta in torch.clamp(self.mu, -1, 1)]).to(self.device)
        self.V = torch.stack([getV(theta) for theta in self.ksi]).to(self.device)
        self.X = torch.bmm(self.U, torch.bmm(self.H, self.V)).to(self.device)

        self.U.retain_grad()
        self.V.retain_grad()
        self.H.retain_grad()
        self.X.retain_grad()
        self.vec_relation = self.X

