import torch
import torch.nn as nn
import torch.nn.functional as F

# FM 
class FactorizationMachine(nn.Module):

    
    def __init__(self, p, k=10): 
        
        super().__init__()
        
        self.v = nn.Parameter(torch.rand(p, k) / 10)   
        self.linear  = nn.Linear(p, 1, bias=True) 
        self.dropout = nn.Dropout(0.5)

        
    def forward(self, x):
        
        # input shape(batch_size, cnn_out_dim) > out shape(batch_size, 1)
        
        linear_part = self.linear(x)  
        
        inter_part1 = torch.mm(x, self.v) ** 2
        inter_part2 = torch.mm(x ** 2, self.v ** 2)
        
        pair_interactions = torch.sum(inter_part1 - inter_part2, dim=1, keepdim=True)
        pair_interactions = self.dropout(pair_interactions)
        
        output = linear_part + 0.5 * pair_interactions
        return output  


# DeepCoNN 

class DeepCoNN(nn.Module):

    def __init__(self, config, word_emb):
        
        super().__init__()

        self.word_dim     = config.word_dim
        self.kernel_count = config.kernel_count
        self.kernel_size  = config.kernel_size

        self.feature_dim  = config.feature_dim

        self.dropout_prob = config.dropout_prob


        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        
        self.user_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count ,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)
        
        self.item_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)

        self.user_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)   
        self.item_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)

        self.dropout = nn.Dropout(p=self.dropout_prob)

        # self.maxpool = nn.MaxPool2d(kernel_size=(1, self.vocab_count))
        self.fm = FactorizationMachine(self.feature_dim * 2)


    def forward(self, user_vocablist, item_vocablist):
        
        # [batch,20] > embedding : [batch,20,D] > permute(0,2,1) : [batch,D,20] > CNN : [batch,100,20] > pooling [batch,100,1] 


        user_doc = self.embedding(user_vocablist)  
        item_doc = self.embedding(item_vocablist)  

        u_fea = F.relu(self.user_cnn(user_doc.permute(0, 2, 1)))   
        i_fea = F.relu(self.item_cnn(item_doc.permute(0, 2, 1)))  
        
        u_fea = F.max_pool1d(u_fea, kernel_size = u_fea.size(2))
        i_fea = F.max_pool1d(i_fea, kernel_size = i_fea.size(2))
        

        u_fea = self.dropout(F.relu(self.user_fc_linear(u_fea.reshape(-1,self.kernel_count)))) # batch,100,1 > (reshape) batch,100 > batch,50
        i_fea = self.dropout(F.relu(self.item_fc_linear(i_fea.reshape(-1,self.kernel_count))))


        concat_latent = torch.cat((u_fea, i_fea), dim=1)
        
        # FM
        prediction = self.fm(concat_latent)
        
        return prediction