import torch
import torch.nn as nn
import torch.nn.functional as F

class prediction(nn.Module) :

    def forward(self, u_fea, i_fea):

        output = (u_fea * i_fea).sum(dim=1, keepdim=True)
        return output


class Dattn(nn.Module):

    def __init__(self, config, word_emb):
        super().__init__()

        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))

        self.user_net = Net(config)
        self.item_net = Net(config)

        self.prediction = prediction()

    def forward(self, user_review, item_review):

        user_word_embs = self.embedding(user_review)
        item_word_embs = self.embedding(item_review) 

        u_fea = self.user_net(user_word_embs)
        i_fea = self.item_net(item_word_embs)           # 128, feature_dim
        prediction = self.prediction(u_fea, i_fea)

        return prediction


class Net(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.local_att = LocalAttention(config)
        self.global_att = GlobalAttention(config)

        kernel_dim = config.kernel_count * 4  

        self.fc = nn.Sequential(
            nn.Linear(kernel_dim, kernel_dim),
            nn.Dropout(0.5),
            nn.ReLU(),
            nn.Linear(kernel_dim, config.feature_dim),
        )
        self.dropout = nn.Dropout(config.dropout_prob)


    def forward(self, x):
  
        local_fea = self.local_att(x)
        global_fea = self.global_att(x)
        cat_fea = torch.cat([local_fea]+global_fea, 1)
        cat_fea = self.dropout(cat_fea)
        cat_fea = self.fc(cat_fea)

        # return torch.stack([cat_fea], 1) # 128, 1, feature_dim
        return cat_fea


class LocalAttention(nn.Module):

    def __init__(self, config):
        super(LocalAttention, self).__init__()

        # bs , 1, vocab, 1
        self.att_conv = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=(config.kernel_size, config.word_dim), padding=((config.kernel_size-1)//2, 0)),
            nn.Sigmoid()
        )

        # bs ,100, vocab, 1
        self.cnn = nn.Conv2d(1, config.kernel_count, kernel_size=(1, config.word_dim))


    def forward(self, x):
        
        score = self.att_conv(x.unsqueeze(1)).squeeze(1)  # bs,vocab,50 > (us) bs,1,vocab,50 > bs,vocab,1
        out = x.mul(score)                                # (bs,vocab,1) * (bs,vocab,50) > bs,vocab,50
        out = out.unsqueeze(1)                            # bs,1,vocab,50
        out = torch.tanh(self.cnn(out)).squeeze(3)        # bs,1,vocab,50 > bs,100,vocab,1 > bs,100,vocab
        out = F.max_pool1d(out, out.size(2)).squeeze(2)   # bs,100,1 > (s) bs,100
        return out


class GlobalAttention(nn.Module):

    def __init__(self, config, filters_size=[2,3,4]):

        super().__init__()

        # vocab_count 
        self.vocab_count = 400

        # bs , 1, 1, 1
        self.att_conv = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=(self.vocab_count, config.word_dim)),  
            nn.Sigmoid()
        )

        # bs ,100, vocab, 1 (list)
        self.convs = nn.ModuleList([nn.Conv2d(1, config.kernel_count, (k, config.word_dim)) for k in filters_size]) 


    def forward(self, x):

        x = x.unsqueeze(1)         #  bs,vocab,50 > (us) bs,1,vocab,50 
        score = self.att_conv(x)   #  bs,1,vocab,50 >  bs, 1, 1, 1
        x = x.mul(score)           # (bs,1,1,1) * (bs,1,vocab,50) = bs,1,vocab,50 

        conv_outs = [torch.tanh(cnn(x).squeeze(3)) for cnn in self.convs]  # bs,1,vocab,50 >  bs,100,vocab (list)
        conv_outs = [F.max_pool1d(out, out.size(2)).squeeze(2) for out in conv_outs] # bs,100,vocab > (p) bs,100,1 > (s) bs,100 (list)
        return conv_outs