import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel, reward_format

class XNet(AbstractModel):
    def __init__(self, config,dataset):
        super(XNet, self).__init__(config,dataset)
        if self.config['v_input_type'] == 'x_and_z':
            self.in_feature = self.dataset.size[1] + self.dataset.size[1] // 2 
        else:
            self.in_feature = self.dataset.size[1]

        self.alpha = self.config['alpha']
        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']


        self.model_treated = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.model_control = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.propensity_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                              + get_linear_layers(self.in_feature, self.repre_layer_sizes, self.bn, nn.ReLU)
                                              + [nn.Linear(self.repre_layer_sizes[-1], 1), nn.Sigmoid()]))
        
        self.tau_treated = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                       + get_linear_layers(self.in_feature, self.repre_layer_sizes, self.bn, nn.ReLU)
                                       + [nn.Linear(self.repre_layer_sizes[-1], 1)]))

        self.tau_control = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                       + get_linear_layers(self.in_feature, self.repre_layer_sizes, self.bn, nn.ReLU)
                                       + [nn.Linear(self.repre_layer_sizes[-1], 1)]))


        self.mse_loss = nn.MSELoss(reduction='none')
        self.bce_loss = nn.BCELoss(reduction='none')
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')

    def forward(self, x, t):
        y1 = self.tau_treated(x)
        y0 = self.tau_control(x)
        ps = self.propensity_model(x)
        return y1.squeeze(), y0.squeeze(), ps.squeeze()

    def _first_step(self, x, t, y, w, for_reward=False):
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        if for_reward is False:
            loss_treated = torch.sum(self.mse_loss(self.model_treated(x), y) * mask_treated )
            loss_control = torch.sum(self.mse_loss(self.model_control(x), y) * mask_control )
            return loss_treated + loss_control
        else:
            loss_treated = self.mse_loss(self.model_treated(x), y) * mask_treated 
            loss_control = self.mse_loss(self.model_control(x), y) * mask_control 
            return loss_treated + loss_control
    
    def _second_step(self, x, t, y, w, for_reward=False):
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        pseudo_0 = (self.model_treated(x) - y ) 
        pseudo_1 = (y - self.model_control(x)) 
        if for_reward is False:
            loss_tau_treated = torch.sum(self.mse_loss(self.tau_treated(x), pseudo_1) * mask_treated )
            loss_tau_control = torch.sum(self.mse_loss(self.tau_control(x), pseudo_0) * mask_control )
            loss_ps = torch.sum(self.ce_loss(self.propensity_model(x), t.float()) )
            return loss_tau_control + loss_tau_treated + loss_ps
        else:
            loss_tau_treated = self.mse_loss(self.tau_treated(x), pseudo_1) * mask_treated 
            loss_tau_control = self.mse_loss(self.tau_control(x), pseudo_0) * mask_control 
            loss_ps = self.bce_loss(self.propensity_model(x), t) 
            return loss_tau_control + loss_tau_treated + loss_ps

    def calculate_loss(self, x,t,y,w):
        loss_first = self._first_step(x, t, y, w)
        loss_second = self._second_step(x, t, y, w)
        return loss_first + loss_second
    
    @torch.no_grad()
    def get_reward(self, x,t,y,w):
        self.eval()
        t = t.float()
        loss_first = self._first_step(x, t, y, w, for_reward=True)
        loss_second = self._second_step(x, t, y, w, for_reward=True)
        loss = loss_first + loss_second
        return reward_format(loss)

