import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel
from causally.model.utils import SinkhornDistance

class DeRCFR(AbstractModel):
    def __init__(self, config,dataset):
        super(DeRCFR, self).__init__(config,dataset)
        self.in_feature = self.dataset.size[1]
        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']
        self.ipm_weight = self.config['ipm_weight']
        self.dercfr_weight = self.config['dercfr_weight']

        self.rep_layer_list_delta = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))
        self.rep_layer_list_tau = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))
        self.rep_layer_list_gamma = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)))
        self.mlp_t = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1]*2,
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.mlp_c = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1]*2,
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.mlp_x = nn.Sequential(*get_linear_layers(self.repre_layer_sizes[-1]*2,
                                                                    self.pred_layer_sizes, False, nn.ReLU))
        self.output_layer_t = nn.Linear(self.pred_layer_sizes[-1], 1)
        self.output_layer_x = nn.Linear(self.pred_layer_sizes[-1], 1)
        self.output_layer_c = nn.Linear(self.pred_layer_sizes[-1], 1)

        if self.config['dataset'] in ['Jobs','kang']:
            self.label_layer = nn.Sigmoid()

        self.wasserstein_func = SinkhornDistance(0.1, 100, reduction='mean', device=self.device)

        self.regression_loss_func = torch.nn.MSELoss(reduction='mean')
        self.treatment_loss_func = torch.nn.BCEWithLogitsLoss(reduction='mean')

    def forward(self, x,t):
        x_rep_delta = x.to(torch.float32)
        x_rep_tau = x.to(torch.float32)
        x_rep_gamma = x.to(torch.float32)
        x_rep_delta =  self.rep_layer_list_delta(x_rep_delta)
        x_rep_tau = self.rep_layer_list_tau(x_rep_tau)
        x_rep_gamma = self.rep_layer_list_gamma(x_rep_gamma)
        t = t.to(torch.float32)
        t_ind = torch.nonzero(t.reshape(-1) == 1).reshape(-1).to(self.device)
        c_ind = torch.nonzero(t.reshape(-1) == 0).reshape(-1).to(self.device)
        t_rep_gamma = x_rep_gamma[t_ind]
        c_rep_gamma = x_rep_gamma[c_ind]
        t_rep_delta = x_rep_delta[t_ind]
        c_rep_delta = x_rep_delta[c_ind]
        y_pre = torch.zeros((len(x), 1)).to(self.device)
        y_pre_t = torch.cat((t_rep_gamma, t_rep_delta), 1)

        y_pre_c = torch.cat((c_rep_gamma, c_rep_delta), 1)
        t_pre = torch.cat((x_rep_tau, x_rep_delta), 1)

        y_pre_t = self.mlp_t(y_pre_t)

        y_pre_c = self.mlp_c(y_pre_c)

        t_pre = self.mlp_x(t_pre)

        t_pre = self.output_layer_x(t_pre)

        y_pre_t = self.output_layer_t(y_pre_t)
        y_pre_c = self.output_layer_c(y_pre_c)
        y_pre[t_ind] = y_pre_t
        y_pre[c_ind] = y_pre_c
        if self.config['dataset'] in ['Jobs','kang']:
            y_pre = self.label_layer(y_pre)
        # binary outcomes
        # t_ind = torch.nonzero(y_pre_t.reshape(-1) >= 0.5).reshape(-1).to(self.device)
        # c_ind = torch.nonzero(y_pre_t.reshape(-1) < 0.5).reshape(-1).to(self.device)
        t_ind = torch.nonzero(t.reshape(-1) >= 0.5).reshape(-1).to(self.device)
        c_ind = torch.nonzero(t.reshape(-1) < 0.5).reshape(-1).to(self.device)
        t_rep_tau_y1 = x_rep_tau[t_ind]
        t_rep_tau_y0 = x_rep_tau[c_ind]
        # t_ind = torch.nonzero(y_pre_c.reshape(-1) >= 0.5).reshape(-1).to(self.device)
        # c_ind = torch.nonzero(y_pre_c.reshape(-1) < 0.5).reshape(-1).to(self.device)
        t_ind = torch.nonzero(t.reshape(-1) >= 0.5).reshape(-1).to(self.device)
        c_ind = torch.nonzero(t.reshape(-1) < 0.5).reshape(-1).to(self.device)
        c_rep_tau_y1 = x_rep_tau[t_ind]
        c_rep_tau_y0 = x_rep_tau[c_ind]
        self.res = y_pre, t_pre, t_rep_gamma, c_rep_gamma, t_rep_delta, c_rep_delta, t_rep_tau_y1, c_rep_tau_y1, t_rep_tau_y0, c_rep_tau_y0
        return y_pre

    def calculate_loss(self, x,t,y,w):
        _ = self.forward(x,t)
        y_pre, t_pre, t_rep, c_rep, t_rep_delta, c_rep_delta, t_rep_tau1, c_rep_tau1, t_rep_tau0, c_rep_tau0 = self.res
        loss1 = self.regression_loss_func(y_pre.to(torch.float32), y.to(torch.float32))
        loss21, _, _ = self.wasserstein_func(t_rep, c_rep)
        loss22, _, _ = self.wasserstein_func(t_rep_delta, c_rep_delta)
        loss23, _, _ = self.wasserstein_func(t_rep_tau1, t_rep_tau0)

        loss24, _, _ = self.wasserstein_func(c_rep_tau1, c_rep_tau0)
        loss2 = loss24 + loss23 + loss22 + loss21
        loss3 = self.treatment_loss_func(t_pre.to(torch.float32), t.to(torch.float32))
        loss = loss1 + self.ipm_weight * loss2 + self.dercfr_weight * loss3

        return loss

    def predict(self, x,t):
        r"""Predict the scores between users and items.

        Args:
            interaction (Interaction): Interaction class of the batch.

        Returns:
            torch.Tensor: Predicted scores for given users and items, shape: [batch_size]
        """
        y = self.forward(x, t)
        if self.loss_type == 'MSE':
            return y
        else:
            return torch.sigmoid(y)