r"""
NeuMF
################################################
Reference:
    Xiangnan He et al. "Neural Collaborative Filtering." in WWW 2017.
"""

import torch
import torch.nn as nn
from torch.nn.init import normal_

from causally.model.recommender.abstract_recommender import AbstractRecommender
# from recbole.model.layers import MLPLayers
from causally.model.utils import MLPLayers
from causally.model.utils import MMDDistance

class NeuMF_IPM(AbstractRecommender):

    def __init__(self, config, dataset):
        super(NeuMF_IPM, self).__init__(config, dataset)

        self.alpha = config['alpha']
        # load parameters info
        self.mf_embedding_size = config["mf_embedding_size"]
        self.mlp_embedding_size = config["mlp_embedding_size"]
        self.mlp_hidden_size = config["mlp_hidden_size"]
        self.dropout_prob = config["dropout_prob"]


        # define layers and loss
        self.user_mf_embedding = nn.Linear(self.x_n_covariate,self.mf_embedding_size)
        self.item_mf_embedding = nn.Linear(self.v_n_covariate,self.mf_embedding_size)
        self.user_mlp_embedding = nn.Linear(self.x_n_covariate,self.mlp_embedding_size)
        self.item_mlp_embedding = nn.Linear(self.v_n_covariate,self.mlp_embedding_size)

        self.mlp_layers = MLPLayers(
            [2 * self.mlp_embedding_size] + self.mlp_hidden_size, self.dropout_prob
        )
        self.mlp_layers.logger = None  # remove logger to use torch.save()

        self.treat_predict_layer = MLPLayers(
            # self.mf_embedding_size + self.mlp_hidden_size[-1], 1
            [self.mf_embedding_size + self.mlp_hidden_size[-1]] + [64, 1]
        )

        self.control_predict_layer = MLPLayers(
            # self.mf_embedding_size + self.mlp_hidden_size[-1], 1
            [self.mf_embedding_size + self.mlp_hidden_size[-1]] + [64, 1]
        )
        self.mse_loss = nn.MSELoss(reduction='none')
        self.mmd_distance = MMDDistance()


        # self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            normal_(module.weight.data, mean=0.0, std=0.01)

    def forward(self, user, item,t):
        user_mf_e = self.user_mf_embedding(user)
        item_mf_e = self.item_mf_embedding(item)
        user_mlp_e = self.user_mlp_embedding(user)
        item_mlp_e = self.item_mlp_embedding(item)
        mf_output = torch.mul(user_mf_e, item_mf_e)  # [batch_size, embedding_size]
        mlp_output = self.mlp_layers(
                torch.cat((user_mlp_e, item_mlp_e), -1)
            )  # [batch_size, layers[-1]]
        self.repre = torch.cat((mf_output, mlp_output), -1)

        output = torch.where(t == 1, self.treat_predict_layer(self.repre), self.control_predict_layer(self.repre))

        return output

    def calculate_loss(self, x,t,y,v):
        output = self.forward(x,v,t)
        mse_loss = torch.sum(self.mse_loss(output, y))
        imb_loss = self.mmd_distance(self.repre, t, None)
        loss = mse_loss + self.alpha * imb_loss
        return loss

    def predict(self,data):

        x = data["x"]
        v = data["v"]
        t = data["t"]

        output = self.forward(x,v,t)

        return output

