import torch
import torch.nn as nn
import torch.nn.functional as F


# Neural_FactorizationMachine
class Neural_FactorizationMachine(nn.Module):

    # 드랍아웃을 전체적으로 적용했음
    
    def __init__(self, p, k=16):  # p = cnn_out_dim, k = 잠재변수(보통 16)
        
        super().__init__()
        
        self.v = nn.Parameter(torch.rand(p, k) / 10)   # -0.1 ~ 0.1 
        self.linear  = nn.Linear(p, 1, bias=True) 
        self.dropout = nn.Dropout(0.5)

        self.mlp = nn.Linear(k, k)
        self.h = nn.Linear(k, 1, bias=False)        

        
    def forward(self, x):
        
        linear_part = self.linear(x)  # input shape(batch_size, cnn_out_dim), out shape(batch_size, 1)
        # print(self.v[0])
        inter_part1 = torch.mm(x, self.v) ** 2
        inter_part2 = torch.mm(x ** 2, self.v ** 2)
        bilinear = 0.5 * (inter_part1 - inter_part2)

        output = F.relu(self.mlp(bilinear))
        output = self.dropout(output)
        output = self.h(output) + linear_part

        return output  # out shape(batch_size, 1)



class NoAttn(nn.Module):

    def __init__(self, config, word_emb):
        
        # word_dim=50 , kernel_count=100, kernel_size=3, feature_dim=50, dropout_prob=0.5, vocab_count=20

        super().__init__()

        # 파라미터 설정
        self.word_dim     = config.word_dim
        self.kernel_count = config.kernel_count
        self.kernel_size  = config.kernel_size

        self.feature_dim  = config.feature_dim

        self.dropout_prob = config.dropout_prob

        #self.vocab_count  = config.top_k
        

        self.user_num = config.user_num  # 데이터셋마다 달라져야 함
        self.item_num = config.item_num  # 데이터셋마다 달라져야 함

        self.fusion_method = config.fusion_method




        print(f'DeepCoNN with ID embedding')
        print(f'(idemb-setting) user_num : {self.user_num}, item_num : {self.item_num}') 
        print(f'(fusion_method) {self.fusion_method}, (feature_dim) {self.feature_dim}') 

        self.uid_embedding = nn.Embedding(self.user_num, self.kernel_count)  # 룩업 임베딩
        self.iid_embedding = nn.Embedding(self.item_num, self.kernel_count)  # 룩업 임베딩         



        # 모듈
        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        
        self.user_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count ,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)
        
        self.item_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)

        self.user_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)   
        self.item_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)

        self.u_id_fc_linear = nn.Linear(self.kernel_count, self.feature_dim) 
        self.i_id_fc_linear = nn.Linear(self.kernel_count, self.feature_dim) 


        self.dropout = nn.Dropout(p=self.dropout_prob)

        #self.maxpool = nn.MaxPool2d(kernel_size=(1, self.vocab_count))

         
        if self.fusion_method == 'cat' :      
            self.fm = Neural_FactorizationMachine(self.feature_dim * 4)
    
        else : self.fm = Neural_FactorizationMachine(self.feature_dim * 2)





    def forward(self, uid, iid, user_vocablist, item_vocablist):
        
        # [batch,20] > embedding : [batch,20,D] > permute(0,2,1) : [batch,D,20] > CNN : [batch,100,20] > pooling [batch,100,1] 

        # Only CNN

        # 리뷰
        user_doc = self.embedding(user_vocablist)  
        item_doc = self.embedding(item_vocablist)  

        u_fea = F.relu(self.user_cnn(user_doc.permute(0, 2, 1)))   
        i_fea = F.relu(self.item_cnn(item_doc.permute(0, 2, 1)))  

        u_fea = F.max_pool1d(u_fea, kernel_size = u_fea.size(2))
        i_fea = F.max_pool1d(i_fea, kernel_size = i_fea.size(2))

        u_fea = self.dropout(F.relu(self.user_fc_linear(u_fea.reshape(-1,self.kernel_count)))) 
        i_fea = self.dropout(F.relu(self.item_fc_linear(i_fea.reshape(-1,self.kernel_count))))

  
        # 아이디
        uid_emb = self.uid_embedding(uid)
        iid_emb = self.iid_embedding(iid)

        uid_emb = self.dropout(F.relu(self.u_id_fc_linear(uid_emb.squeeze(1)))) 
        iid_emb = self.dropout(F.relu(self.i_id_fc_linear(iid_emb.squeeze(1))))


        if self.fusion_method == 'cat' :

            concat_latent = torch.cat((uid_emb, u_fea, iid_emb, i_fea), dim=1)

        elif self.fusion_method == 'sum' :

            u_fea += uid_emb  
            i_fea += iid_emb
            concat_latent = torch.cat((u_fea, i_fea), dim=1)

        else : print("fusion_method error") 
        
        # FM
        prediction = self.fm(concat_latent)
        
        return prediction