import torch.nn as nn
import torch
import math
import torch.nn.functional as F
class Uni_model(nn.Module):
    def __init__(self, embedding_dim, num_users, num_items, user_feature, item_feature, user_embedding, item_embedding,
                 cold, pretrain, feature1_user=None, feature1_item=None, train_stat=None, test_stat=None, drop_ratio=0.3):
        super(Uni_model, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.embedding_dim = embedding_dim
        self.drop_ratio = drop_ratio
        self.dropout = torch.nn.Dropout(p=drop_ratio)
        self.cold = cold
        self.cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
        self.pretrain = pretrain
        self.feature_dim = item_feature.shape[1]
        self.pos_encoder = PositionalEncoding(embedding_dim, drop_ratio)
        # self.attention = CausalMultiHeadAttention(embedding_dim, 4, drop_ratio)


        position = torch.arange(10000).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2) * (-math.log(10000.0) / embedding_dim))
        pe = torch.zeros(10000, embedding_dim)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe[:100].cuda()


        self.user_feature = nn.Embedding.from_pretrained(user_feature, freeze=True)
        self.item_feature = nn.Embedding.from_pretrained(item_feature, freeze=True)

        self.transform = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim)
        )
        self.transform2 = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),
            nn.ReLU(),
            nn.Linear(embedding_dim, embedding_dim),
        nn.ReLU(),
        nn.Linear(embedding_dim, embedding_dim)
        )
        self.user_percentile = nn.Linear(user_feature.shape[1], embedding_dim, bias=False)
        self.item_percentile = nn.Linear(item_feature.shape[1], embedding_dim, bias=False)

        self.u_pos = nn.Embedding(user_feature.shape[1], embedding_dim)
        self.i_pos = nn.Embedding(item_feature.shape[1], embedding_dim)

        # self.percentile = nn.Embedding(101, embedding_dim)
        self.multi_head = torch.nn.MultiheadAttention(embedding_dim, num_heads=4, batch_first=True)
        self.act = nn.Tanh()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.atten_weight = nn.Linear(embedding_dim, 1, bias=False)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
        self.train_stat = False
        if train_stat is not None:
            self.train_stat = True
            print("Loading the Pre-computed stats")
            self.item_percentile.weight = torch.nn.Parameter(train_stat.permute(1, 0), requires_grad=False)

    def get_user_embedding(self, feature, percentile, transform, self_stat, self_trans, positional):
        mask = (feature > 0).float()
        # self_emb = self.dropout(self_trans(percentile[self_stat.long()])) + self.pe[self_stat.long()]
        if len(feature.shape) > 2:
            percentile = percentile.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1)
            pe = self.pe.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1)
        else:
            percentile = percentile.unsqueeze(0).repeat(feature.shape[0], 1, 1)
            pe = self.pe.unsqueeze(0).repeat(feature.shape[0], 1, 1)
        curr_emb = self.dropout(transform(percentile * feature.unsqueeze(-1)) + pe * mask.unsqueeze(-1))
        scores = self.dropout(self.atten_weight(curr_emb).squeeze())
        attn_mask = feature == 0
        scores.masked_fill_(attn_mask, -1e9)
        atten_weight = F.softmax(scores, -1).unsqueeze(-1) # B * F * 1
        atten_p = (atten_weight * self.dropout(curr_emb + pe)).sum(-2).squeeze()

        return atten_p

    def get_item_embedding(self, feature, percentile, transform, self_stat, self_trans, positional):
        mask = (feature > 0).float()
        if len(feature.shape) > 2:
            percentile = percentile.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1)
            pe = self.pe.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1)
        else:
            percentile = percentile.unsqueeze(0).repeat(feature.shape[0], 1, 1)
            pe = self.pe.unsqueeze(0).repeat(feature.shape[0], 1, 1)

        curr_emb = self.dropout(transform(percentile * feature.unsqueeze(-1)) + pe * mask.unsqueeze(-1))
        scores = self.dropout(self.atten_weight(curr_emb).squeeze())
        attn_mask = feature == 0
        scores.masked_fill_(attn_mask, -1e9)
        atten_weight = F.softmax(scores, -1).unsqueeze(-1)  # B * F * 1
        atten_p = (atten_weight * self.dropout(curr_emb + pe)).sum(-2).squeeze()  # F * D

        return atten_p

    #  Batch of user, item (item can positive or negative)
    def forward(self, user, item, user_stat=None, item_stat=None):
        user_feature = self.user_feature(user)
        item_feature = self.item_feature(item)
        if len(user_stat.shape) > 2:
            user_self_stat = user_stat[:, :, 0]
            item_self_stat = item_stat[: ,: ,0]
        else:
            user_self_stat = user_stat[:, 0]
            item_self_stat = item_stat[:, 0]
        user_embedding = self.get_user_embedding(user_feature, self.item_percentile.weight.permute(1, 0), self.transform,
                                        user_self_stat, self.transform2, self.u_pos)
        item_embedding = self.get_item_embedding(item_feature, self.item_percentile.weight.permute(1, 0), self.transform2,
                                        item_self_stat, self.transform, self.i_pos)
        mf_vector = torch.mul(user_embedding, item_embedding)
        curr_item_stat = item_stat[:, :, 1] if len(user_stat.shape) > 2 else item_stat[: ,1]
        logits = torch.sum(mf_vector, dim=-1).unsqueeze(-1) * torch.sigmoid(curr_item_stat).unsqueeze(-1)
        return logits

    def l2_loss_function(self, pos_prediction, neg_prediction):
        loss = torch.mean((pos_prediction - neg_prediction - 1) ** 2)
        return loss

    def reg_loss(self, n):

        loss = 0.0

        for i in range(0, self.feature_dim-1):
            curr_i = self.item_percentile.weight.permute(1, 0)[i].squeeze()
            curr_j = self.item_percentile.weight.permute(1, 0)[i+1].squeeze()
            loss += torch.norm((curr_i - curr_j), p=2)

        return loss/(self.feature_dim)

    def loss_function(self, positive_predictions, negative_predictions):
        # LightGCN loss.
        bpr_loss = (1.0 - torch.sigmoid(positive_predictions.unsqueeze(-1) -
                                        negative_predictions)).mean()
        return bpr_loss

    def multi_loss(self, logits, targets):
        n_classes = logits.shape[-1]
        return -torch.mean(
            torch.sum(F.log_softmax(logits.view(-1, n_classes), 1) * targets.view(-1, n_classes), -1))

