import torch
import torch.nn as nn
import torch.nn.functional as F


class Neural_FactorizationMachine(nn.Module):

  
    
    def __init__(self, p, k):  
        
        super().__init__()
        
        self.v = nn.Parameter(torch.rand(p, k) / 10)   # -0.1 ~ 0.1 
        self.linear  = nn.Linear(p, 1, bias=True) 
        self.dropout = nn.Dropout(0.5)

        self.mlp = nn.Linear(k, k)
        self.h = nn.Linear(k, 1, bias=False)        

        
    def forward(self, x):
        
        linear_part = self.linear(x)  # input shape(batch_size, cnn_out_dim), out shape(batch_size, 1)
        # print(self.v[0])
        inter_part1 = torch.mm(x, self.v) ** 2
        inter_part2 = torch.mm(x ** 2, self.v ** 2)
        bilinear = 0.5 * (inter_part1 - inter_part2)

        output = F.relu(self.mlp(bilinear))
        output = self.dropout(output)
        output = self.h(output) + linear_part

        return output  # out shape(batch_size, 1)


class DAML(nn.Module):

    def __init__(self, config, word_emb):
        super().__init__()
        
        self.word_dim = config.word_dim
        self.kernel_count = config.kernel_count
        self.kernel_size   = config.kernel_size

        self.dropout_prob  = config.dropout_prob
        self.feature_dim = config.feature_dim

        self.user_num = config.user_num  
        self.item_num = config.item_num  

        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))

        # share cnn
        self.word_cnn = nn.Conv2d(1, 1, (self.kernel_size, self.word_dim), padding=(((self.kernel_size - 1) // 2),0))
        
        # document-level cnn
        self.user_doc_cnn = nn.Conv2d(1, self.kernel_count, (self.kernel_size, self.word_dim), padding=(((self.kernel_size - 1) // 2),0))
        self.item_doc_cnn = nn.Conv2d(1, self.kernel_count, (self.kernel_size, self.word_dim), padding=(((self.kernel_size - 1) // 2),0))
        
        # abstract-level cnn
        self.user_abs_cnn = nn.Conv2d(1, self.kernel_count, (self.kernel_size, self.kernel_count))
        self.item_abs_cnn = nn.Conv2d(1, self.kernel_count, (self.kernel_size, self.kernel_count))

        # unfold
        self.unfold = nn.Unfold((self.kernel_size, self.kernel_count), padding=(((self.kernel_size - 1) // 2),0))

        # fc layer
        self.user_fc = nn.Linear(self.kernel_count, self.feature_dim)
        self.item_fc = nn.Linear(self.kernel_count, self.feature_dim)

        # id - emb
        self.uid_embedding = nn.Embedding(self.user_num, self.feature_dim)  
        self.iid_embedding = nn.Embedding(self.item_num, self.feature_dim)    

 
        print(f'DAML (idemb-setting) user_num : {self.user_num}, item_num : {self.item_num}')        


        self.fm = Neural_FactorizationMachine(self.feature_dim * 2, 16)
        

    def forward(self, u_ids, i_ids, user_review, item_review):

        # _, _, uids, iids, _, _, user_doc, item_doc = datas
        

        # ------------------ review encoder ---------------------------------
        user_word_embs = self.embedding(user_review)
        item_word_embs = self.embedding(item_review)                  # (BS, vocab_top_k, 50)
        
        # (BS, 100, vocab_top_k, 1)
        user_local_fea = self.local_attention_cnn(user_word_embs, self.user_doc_cnn)
        item_local_fea = self.local_attention_cnn(item_word_embs, self.item_doc_cnn)

        # vocab_top_k * vocab_top_k
        euclidean = (user_local_fea - item_local_fea.permute(0, 1, 3, 2)).pow(2).sum(1).sqrt()  
        attention_matrix = 1.0 / (1 + euclidean)

        # (?, DOC_LEN)
        user_attention = attention_matrix.sum(2)
        item_attention = attention_matrix.sum(1)

        # (?, 50)
        user_doc_fea = self.local_pooling_cnn(user_local_fea, user_attention, self.user_abs_cnn, self.user_fc)
        item_doc_fea = self.local_pooling_cnn(item_local_fea, item_attention, self.item_abs_cnn, self.item_fc)

        # --------------- (id-emb) ---------------
        
        uid_emb = self.uid_embedding(u_ids)
        iid_emb = self.iid_embedding(i_ids)
    

        
        # print(uid_emb.shape)
        
        # ================= (fusion) =================
        
                
        # print(user_doc_fea.shape)
        user_doc_fea += uid_emb.squeeze(1)
        item_doc_fea += iid_emb.squeeze(1)

        
        concat_latent = torch.cat((user_doc_fea, item_doc_fea), dim=1)
        # print(concat_latent[0])
        
        # NFM
        prediction = self.fm(F.relu(concat_latent))
        
        return prediction
        
        # return user_doc_fea, item_doc_fea

    def local_attention_cnn(self, word_embs, doc_cnn):

        local_att_words = self.word_cnn(word_embs.unsqueeze(1))         
        local_word_weight = torch.sigmoid(local_att_words.squeeze(1))   
        word_embs = word_embs * local_word_weight # local_word_weight
        d_fea = doc_cnn(word_embs.unsqueeze(1))
        return d_fea

    def local_pooling_cnn(self, feature, attention, cnn, fc):
  
        bs, n_filters, doc_len, _ = feature.shape
        feature = feature.permute(0, 3, 2, 1)  # bs * 1 * doc_len * embed
        attention = attention.reshape(bs, 1, doc_len, 1)  # bs * doc
        pools = feature * attention
        pools = self.unfold(pools)
        pools = pools.reshape(bs, self.kernel_size, n_filters, doc_len)
        pools = pools.sum(dim=1, keepdims=True)  # bs * 1 * n_filters * doc_len
        pools = pools.transpose(2, 3)  # bs * 1 * doc_len * n_filters

        abs_fea = cnn(pools).squeeze(3) 
        abs_fea = F.avg_pool1d(abs_fea, abs_fea.size(2))  
        # print(abs_fea.shape)
        # abs_fea = F.relu(fc(abs_fea.reshape(-1,self.kernel_count)))  # ? 100
        abs_fea = F.relu(fc(abs_fea.squeeze(2)))

        abs_fea = F.dropout(abs_fea, p=self.dropout_prob, training=self.training)
        
        return abs_fea