import torch
import torch.nn as nn
import torch.nn.functional as F
from causally.model.utils import BayesianLinear
from causally.model.abstract_model import AbstractModel
from causally.model.utils import MMDDistance

class UITE_MMD(AbstractModel):
    def __init__(self, config,dataset):
        super(UITE_MMD, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.alpha = self.config['alpha']
        self.beta = self.config['beta']
        self.bn = self.config['bn']
        self.repre_layer_sizes =  [self.in_feature]+self.config['repre_layer_sizes']
        self.pred_layer_sizes = [self.config['repre_layer_sizes'][-1]]+self.config['pred_layer_sizes'] + [1]
        self.batch_size = self.config['train_batch_size']

        self.repre_layers = nn.Sequential(*[
            BayesianLinear(in_s,out_s) for in_s,out_s in zip(self.repre_layer_sizes[:-1],
                                                             self.repre_layer_sizes[1:])
        ])

        self.pred_layers_treated = nn.Sequential(*[
            BayesianLinear(in_s, out_s) for in_s, out_s in zip(self.pred_layer_sizes[:-1],
                                                               self.pred_layer_sizes[1:])
        ])

        # self.pred_layers_treated.add_module('out1',BayesianLinear(self.pred_layer_sizes[-1],1))


        self.pred_layers_control = nn.Sequential(*[
            BayesianLinear(in_s, out_s) for in_s, out_s in zip(self.pred_layer_sizes[:-1],
                                                               self.pred_layer_sizes[1:])
        ])
        # self.pred_layers_control.add_module('out0', BayesianLinear(self.pred_layer_sizes[-1],1))

        if self.loss_type == 'MSE':
            self.loss_fct = nn.MSELoss(reduction='none')
        elif self.loss_type == 'CE':
            self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['MSE', 'CE']!")
        self.mmd_distance = MMDDistance()

    def forward(self, x, t):

        self.repre = self.repre_layers(x)
        outputs = torch.where(t == 1, self.pred_layers_treated(self.repre),
                                 self.pred_layers_control(self.repre))
        log_prior = self.log_prior()
        log_variational_posterior = self.log_variational_posterior()

        return outputs,log_variational_posterior,log_prior


    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def calculate_loss(self, x,t,y,w):

        outputs, log_variational_posterior, log_prior = self.forward(x,t)
        negative_log_likelihood = torch.sum(self.loss_fct(outputs, y) * w)
        # print(negative_log_likelihood, log_variational_posterior, log_prior)
        vlb_loss = (log_variational_posterior-log_prior) / self.batch_size + negative_log_likelihood
        imb_loss = self.mmd_distance(self.repre,t,w)

        loss = vlb_loss + self.alpha * imb_loss
        return loss

    def log_prior(self):
        expect_log_prior = 0
        for module in self.repre_layers:
            expect_log_prior += module.log_prior
        for module in self.pred_layers_treated:
            expect_log_prior += module.log_prior
        for module in self.pred_layers_control:
            expect_log_prior += module.log_prior

        return expect_log_prior


    def log_variational_posterior(self):

        expect_log_posterior = 0
        for module in self.repre_layers:
            expect_log_posterior += module.log_variational_posterior
        for module in self.pred_layers_treated:
            expect_log_posterior += module.log_variational_posterior
        for module in self.pred_layers_control:
            expect_log_posterior += module.log_variational_posterior

        return expect_log_posterior

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        ret = None
        for _ in range(self.config['k']):
            y, _, _ = self.forward(x, t)
            if ret == None: ret = y.unsqueeze(-1)
            else:
                ret = torch.cat([ret, y.unsqueeze(-1)], dim=-1)
        if self.loss_type == 'MSE':
            return ret
        else:
            return torch.sigmoid(ret)

