import torch
import torch.nn as nn
import torch.nn.functional as F


# Neural_FactorizationMachine
class Neural_FactorizationMachine(nn.Module):

    # Dropout
    
    def __init__(self, p, k=16):  # p = cnn_out_dim, k = latent parameters
        
        super().__init__()
        
        self.v = nn.Parameter(torch.rand(p, k) / 10)   # -0.1 ~ 0.1 
        self.linear  = nn.Linear(p, 1, bias=True) 
        self.dropout = nn.Dropout(0.5)

        self.mlp = nn.Linear(k, k)
        self.h = nn.Linear(k, 1, bias=False)        

        
    def forward(self, x):
        
        linear_part = self.linear(x)  # input shape(batch_size, cnn_out_dim), out shape(batch_size, 1)
        # print(self.v[0])

        inter_part1 = torch.mm(x, self.v) ** 2
        inter_part2 = torch.mm(x ** 2, self.v ** 2)
        bilinear = 0.5 * (inter_part1 - inter_part2)

        output = F.relu(self.mlp(bilinear))
        output = self.dropout(output)
        output = self.h(output) + linear_part

        return output  # out shape(batch_size, 1)


# main_net
class Ournet(nn.Module):

    def __init__(self, config, word_emb):
        

        super().__init__()

        # parameters
        self.word_dim     = config.word_dim
        self.kernel_count = config.kernel_count
        self.kernel_size  = config.kernel_size

        self.feature_dim  = config.feature_dim

        self.dropout_prob = config.dropout_prob

        self.vocab_count  = config.top_k
        

        self.user_num = config.user_num  
        self.item_num = config.item_num  

        self.fusion_method = config.fusion_method

        print(f'with ID embedding')
        print(f'(idemb-setting) user_num : {self.user_num}, item_num : {self.item_num}') 
        print(f'(fusion_method) {self.fusion_method}, (feature_dim) {self.feature_dim}') 

        self.uid_embedding = nn.Embedding(self.user_num, self.kernel_count)  # lookup embedding
        self.iid_embedding = nn.Embedding(self.item_num, self.kernel_count)  # lookup embedding        


        # modules
        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        
        self.user_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count ,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)
        
        self.item_cnn =  nn.Conv1d(
                in_channels= self.word_dim,
                out_channels= self.kernel_count,
                kernel_size= self.kernel_size,
                padding=(self.kernel_size - 1) // 2)
        

        self.localattn_u = LocalAttention(config)
        self.localattn_i = LocalAttention(config)
        
        self.user_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)   
        self.item_fc_linear = nn.Linear(self.kernel_count, self.feature_dim)

        self.u_id_fc_linear = nn.Linear(self.kernel_count, self.feature_dim) 
        self.i_id_fc_linear = nn.Linear(self.kernel_count, self.feature_dim) 

        self.dropout = nn.Dropout(p=self.dropout_prob)

        self.globalattention = GlobalAttetnion()

        if self.fusion_method == 'cat' :      
            self.fm = Neural_FactorizationMachine(self.feature_dim * 4)
    
        else : self.fm = Neural_FactorizationMachine(self.feature_dim * 2)





    def forward(self, uid, iid, user_vocablist, item_vocablist):
        
        # [batch,20] > embedding : [batch,20,D] > permute(0,2,1) : [batch,D,20] > CNN : [batch,100,20] > pooling [batch,100,1] 

        ## review embedding
        user_vocab = self.embedding(user_vocablist)  
        item_vocab = self.embedding(item_vocablist)  


        # local-attn value : user(item)_vocab 
        score_u = self.localattn_u(user_vocab)
        score_i = self.localattn_i(item_vocab)

        # mutual-attn value : user(item)_vocab  
        first_u_gattn, first_i_gattn = self.globalattention(user_vocab, item_vocab)  # return 128,20

        # apply mutual-attn
        gattn_user_vocab =  user_vocab.mul(first_u_gattn.unsqueeze(-1)) # bs,vocab,D
        gattn_item_vocab =  item_vocab.mul(first_i_gattn.unsqueeze(-1))         

        # apply local-attn
        attn_user_vocab =  gattn_user_vocab.mul(score_u)  # bs,vocab,D
        attn_item_vocab =  gattn_item_vocab.mul(score_i) 

        # CNNs
        u_fea = F.relu(self.user_cnn(attn_user_vocab.permute(0, 2, 1)))   # batch,100,20
        i_fea = F.relu(self.item_cnn(attn_item_vocab.permute(0, 2, 1)))  

        u_fea = F.max_pool1d(u_fea, kernel_size = u_fea.size(2))  # batch,100,1
        i_fea = F.max_pool1d(i_fea, kernel_size = i_fea.size(2))

        u_fea = self.dropout(F.relu(self.user_fc_linear(u_fea.reshape(-1,self.kernel_count)))) 
        i_fea = self.dropout(F.relu(self.item_fc_linear(i_fea.reshape(-1,self.kernel_count))))

        # id embeddings
        uid_emb = self.uid_embedding(uid)
        iid_emb = self.iid_embedding(iid)

        uid_emb = self.dropout(F.relu(self.u_id_fc_linear(uid_emb.squeeze(1)))) 
        iid_emb = self.dropout(F.relu(self.i_id_fc_linear(iid_emb.squeeze(1))))


        if self.fusion_method == 'cat' :

            concat_latent = torch.cat((uid_emb, u_fea, iid_emb, i_fea), dim=1)

        elif self.fusion_method == 'sum' :
            
            u_fea += uid_emb  
            i_fea += iid_emb
            concat_latent = torch.cat((u_fea, i_fea), dim=1)

        else : print("fusion_method error") 
        
        # FM
        prediction = self.fm(concat_latent)
        
        return prediction, score_u, score_i, first_u_gattn, first_i_gattn
    

# LocalAttention
class LocalAttention(nn.Module):

    def __init__(self, config):
        super().__init__()

        # bs , 1, vocab, 1 > vocab-dim softmax
        self.att_conv = nn.Sequential(
            nn.Conv2d(1, 1, kernel_size=(config.kernel_size, config.word_dim), padding=((config.kernel_size-1)//2, 0)),
            nn.ReLU(),
            nn.Softmax(dim=-2)  
        )


    def forward(self, x):
        
        score = self.att_conv(x.unsqueeze(1)).squeeze(1)  # bs,vocab,50 > (us) bs,1,vocab,50 > bs,vocab,1

        return score
    
# GlobalAttetnion
class GlobalAttetnion(nn.Module):

    def forward(self, u, i):
        
        # 128,20,D   > 128,D,20 > 128,D,1,20 - 128,D,20,1 > 128,D,20,20 > 128,20,20 > return 128,20

        # 128,20,100 > 128,100,20 > 128,100,1,20 - 128,100,20,1 > 128,20,20 > return 128,20

        permute_u = u.permute(0, 2, 1).unsqueeze(-2)
        permute_i = i.permute(0, 2, 1).unsqueeze(-1)

        distance    = self.get_distance(permute_u, permute_i)
        A           = torch.reciprocal(distance+1)
        u_gattn     = F.softmax(torch.sum(A,dim=1), dim=1)
        i_gattn     = F.softmax(torch.sum(A,dim=2), dim=1)

        return u_gattn, i_gattn


    def get_distance(self, permute_u, permute_i):
        conv_sub = torch.sub(permute_u, permute_i)
        conv_pow = torch.pow(conv_sub, 2)
        del conv_sub
        conv_sum = torch.sum(conv_pow, dim=1)
        del conv_pow
        return torch.sqrt(conv_sum)


