import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import WassDistance

class BNN(AbstractModel):
    def __init__(self, config,dataset):
        super(BNN, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.alpha = self.config['alpha']
        self.beta = self.config['beta']
        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']


        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))

        self.pred_layers_treated = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))

        self.pred_layers_treated.add_module('out1',nn.Linear(self.pred_layer_sizes[-1],1))
        self.pred_layers_control = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1],
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.pred_layers_control.add_module('out0', nn.Linear(self.pred_layer_sizes[-1],1))

        if self.loss_type == 'MSE':
            self.loss_fct = nn.MSELoss(reduction='none')
        elif self.loss_type == 'CE':
            self.loss_fct = nn.BCEWithLogitsLoss(reduction='none')
        else:
            raise NotImplementedError("Make sure 'loss_type' in ['MSE', 'CE']!")
        self.wass_distance = WassDistance(eps=0.01, max_iter=10, device=self.device)


    def forward(self, x, t):
        self.repre = self.repre_layers(x)
        treat_outputs = self.pred_layers_treated(self.repre)
        control_outputs = self.pred_layers_control(self.repre)
        y = torch.where(t == 1, treat_outputs,control_outputs)
        return y

    def get_repre(self, x, device):
        self.eval()
        with torch.no_grad():
            return self.repre_layers.to(device)(x.to(device))

    def calculate_loss(self, x,t,y,w):

        y_n = self.construct_neighbor_comes(x,t)
        y_pred = self.forward(x,t)
        n_pred = self.forward(x,1-t)
        mse_loss = torch.mean(self.loss_fct(y_pred,y)) + torch.mean(self.loss_fct(n_pred,y_n)) * self.beta
        imb_loss = self.wass_distance(self.repre,t,w.squeeze())
        loss = mse_loss + self.alpha * imb_loss
        return loss

    def construct_neighbor_comes(self,x,t):
        y_n = torch.zeros(size=[len(x),1]).to(self.device)
        x_treated_length = len(self.dataset.x_treated)
        x_control_length = len(self.dataset.x_control)
        index = 0

        for feature,treatment in zip(x,t):

            if treatment[0] == 1:
                distance = torch.norm(feature.unsqueeze(0).repeat(x_treated_length,1)-self.dataset.x_treated.to(self.device),dim=-1)
                arg_index = torch.argmin(distance)
                y_n[index][0] = self.dataset.y_treated[arg_index][0]
            else:
                distance = torch.norm(feature.unsqueeze(0).repeat(x_control_length, 1) - self.dataset.x_control.to(self.device), dim=-1)
                arg_index = torch.argmin(distance)
                y_n[index][0] = self.dataset.y_control[arg_index][0]
            index += 1
        return y_n

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        y = self.forward(x, t)
        if self.loss_type == 'MSE':
            return y
        else:
            return torch.sigmoid(y)
