import torch
from torch import nn
import torch.nn.functional as F


class CNN(nn.Module):

    def __init__(self, config, kernel_size):
        super().__init__()

        self.kernel_count = config.kernel_count
        self.kerenl_size  = kernel_size

        self.conv = nn.Sequential(
            nn.Conv1d(
                in_channels=config.word_dim,
                out_channels=config.kernel_count,
                kernel_size=self.kerenl_size,
                padding=(self.kerenl_size - 1) // 2),        # out shape(new_batch_size, kernel_count, review_length)
            nn.ReLU()
            # nn.MaxPool2d(kernel_size=(1, config.vocab_count)),  # out shape(new_batch_size,kernel_count,1)
        )

        self.linear = nn.Sequential(
            nn.Linear(config.kernel_count, config.feature_dim),
            nn.Tanh(),
        )

    def forward(self, vec):  # input shape(new_batch_size, review_length, word2vec_dim)
        latent = self.conv(vec.permute(0, 2, 1))  # output shape(new_batch_size, kernel_count, review_length)
        latent = F.max_pool1d(latent, kernel_size = latent.size(2)) # out shape(new_batch_size,kernel_count,1)
    
        latent = latent.view(-1, self.kernel_count)
        latent = self.linear(latent)

        return latent  # output shape(batch_size, cnn_out_dim)


class FactorizationMachine(nn.Module):

    def __init__(self, p, k):  # p=cnn_out_dim
        super().__init__()

        self.v = nn.Parameter(torch.rand(p, k) / 10)   
        self.linear  = nn.Linear(p, 1, bias=True) 
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):

        # input shape(batch_size, cnn_out_dim) > out shape(batch_size, 1)
        
        linear_part = self.linear(x)  
        
        inter_part1 = torch.mm(x, self.v) ** 2
        inter_part2 = torch.mm(x ** 2, self.v ** 2)
        
        pair_interactions = torch.sum(inter_part1 - inter_part2, dim=1, keepdim=True)
        pair_interactions = self.dropout(pair_interactions)
        
        output = linear_part + 0.5 * pair_interactions
        return output  
    

class SourceNet(nn.Module):

    def __init__(self, config, word_emb, with_idemb=False):

        super(SourceNet, self).__init__()

        self.extend_model = with_idemb
        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))
        self.cnn_u = CNN(config, config.kernel_size)
        self.cnn_i = CNN(config, config.kernel_size)
        self.transform = nn.Sequential(
            nn.Linear(config.feature_dim * 2, config.feature_dim),
            nn.Tanh(),
            nn.Linear(config.feature_dim, config.feature_dim),
            nn.Tanh(),
            nn.Dropout(p=config.dropout_prob)
        )

        self.fm = FactorizationMachine(p=config.feature_dim, k=8)


    def forward(self, user_reviews, item_reviews):  # shape(batch_size, review_count, review_length)

        u_vec = self.embedding(user_reviews)  # bs, vocab, 50
        i_vec = self.embedding(item_reviews)

        # print(u_vec.shape)
        # print(u_vec.shape)
        user_latent = self.cnn_u(u_vec)       # batch_size, f_dim
        # print(user_latent.shape)
        item_latent = self.cnn_i(i_vec)

        concat_latent = torch.cat((user_latent, item_latent), dim=1)
        trans_latent = self.transform(concat_latent)  # batch_size, f_dim*2 > batch_size, f_dim
 
        prediction = self.fm(trans_latent.detach())  # Detach forward 
        
        return trans_latent, prediction
    
    # trans_param 
    def trans_param(self):

        return [x for x in self.cnn_u.parameters()] + \
                [x for x in self.cnn_i.parameters()] + \
                [x for x in self.transform.parameters()]



class TargetNet(nn.Module):

    def __init__(self, config, word_emb):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(torch.Tensor(word_emb))

        self.cnn = CNN(config, kernel_size=3) # 3
        self.fm = nn.Sequential(
            nn.Dropout(config.dropout_prob),  # Since cnn did not dropout, dropout before FM.
            FactorizationMachine(p=config.feature_dim, k=8)
        )

    def forward(self, reviews):  # input shape(batch_size, review_length)
        vec = self.embedding(reviews)
        cnn_latent = self.cnn(vec)
        prediction = self.fm(cnn_latent)
        return cnn_latent, prediction