from src.utils import *
from torch_scatter import scatter_add, scatter_mean, scatter_max


class BaseModel(nn.Module):
    def __init__(self, args, kg):
        super(BaseModel, self).__init__()
        self.args = args
        self.kg = kg  
        self.ent_embeddings = nn.Embedding(self.kg.snapshots[0].num_ent, self.args.emb_dim).to(self.args.device).double()
        self.rel_embeddings = nn.Embedding(self.kg.snapshots[0].num_rel, self.args.emb_dim).to(self.args.device).double()
        xavier_normal_(self.ent_embeddings.weight)
        xavier_normal_(self.rel_embeddings.weight)
        self.lambda1=nn.Parameter(torch.ones(1)*1.0)
        self.lambda2=nn.Parameter(torch.ones(1)*1.0)
        self.lambda3=nn.Parameter(torch.ones(1)*1.0)
        self.lambda4=nn.Parameter(torch.ones(1)*1.0)
        '''loss function'''
        self.margin_loss_func = nn.MarginRankingLoss(margin=float(self.args.margin), reduction="sum")#.to(self.args.device)  #


    def reinit_param(self):
        for n, p in self.named_parameters():
            if p.requires_grad:
                xavier_normal_(p)

    def expand_embedding_size(self):
        ent_embeddings = nn.Embedding(self.kg.snapshots[self.args.snapshot + 1].num_ent, self.args.emb_dim).to(
            self.args.device).double()
        rel_embeddings = nn.Embedding(self.kg.snapshots[self.args.snapshot + 1].num_rel, self.args.emb_dim).to(
            self.args.device).double()
        xavier_normal_(ent_embeddings.weight)
        xavier_normal_(rel_embeddings.weight)
        return deepcopy(ent_embeddings), deepcopy(rel_embeddings)

    def switch_snapshot(self):
        pass

    def pre_snapshot(self):
        pass

    def epoch_post_processing(self, size=None):
        pass

    def snapshot_post_processing(self):
        pass

    def store_old_parameters(self):
        for name, param in self.named_parameters():
            name = name.replace('.', '_')
            if param.requires_grad:
                value = param.data
                self.register_buffer('old_data_{}'.format(name), value.clone().detach())

    def initialize_old_data(self):
        for n, p in self.named_parameters():
            if p.requires_grad:
                n = n.replace('.', '_')
                self.register_buffer('old_data_{}'.format(n), p.data.clone())

    def embedding(self, stage=None):
        return self.ent_embeddings.weight, self.rel_embeddings.weight

    def new_loss(self, head, rel, tail=None, label=None):
        return self.margin_loss(head, rel, tail, label)/head.size(0)



    def margin_loss(self, head, rel, tail, label=None):
        ent_embeddings, rel_embeddings = self.embedding('Train')

        s = torch.index_select(ent_embeddings, 0, head)
        r = torch.index_select(rel_embeddings, 0, rel)
        o = torch.index_select(ent_embeddings, 0, tail)
        score = self.score_fun(s, r, o)
        p_score, n_score = self.split_pn_score(score, label)
        y = torch.Tensor([-1]).to(self.args.device)
        loss = self.margin_loss_func(p_score, n_score, y)
        return loss

    def split_pn_score(self, score, label):
        p_score = score[torch.where(label>0)]
        n_score = (score[torch.where(label<0)]).reshape(-1, self.args.neg_ratio).mean(dim=1)
        return p_score, n_score

    def score_fun(self, s, r, o):
        s = self.norm_ent(s)
        r = self.norm_rel(r)
        o = self.norm_ent(o)
        return torch.norm(s + r - o, 1, -1)

    def score_fun1(self, s, r, o):
        s = self.norm_ent(s)
        r = self.norm_rel(r)
        o = self.norm_ent(o)
        return torch.norm(s *r - o, 1, -1)

    def predict(self, sub, rel, stage='Valid'):
        if stage != 'Test':
            num_ent = self.kg.snapshots[self.args.snapshot].num_ent
        else:
            num_ent = self.kg.snapshots[self.args.snapshot_test].num_ent
        ent_embeddings, rel_embeddings = self.embedding(stage)
        s = torch.index_select(ent_embeddings, 0, sub)
        r = torch.index_select(rel_embeddings, 0, rel)
        o_all = ent_embeddings[:num_ent]
        s = self.norm_ent(s)
        r = self.norm_rel(r)
        o_all = self.norm_ent(o_all)
        pred_o = s + r
        score = 9.0 - torch.norm(pred_o.unsqueeze(1) - o_all, p=1, dim=2)
        score = torch.sigmoid(score)
        return score
    def norm_rel(self, r):
        return nn.functional.normalize(r, 2, -1)
    def norm_ent(self, e):
        return nn.functional.normalize(e, 2, -1)
