import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel, reward_format

class DRNet(AbstractModel):
    def __init__(self, config,dataset):
        super(DRNet, self).__init__(config,dataset)
        if self.config['v_input_type'] == 'x_and_z':
            # self.in_feature = self.dataset.size[1] + self.dataset.size[1] // 2 + 1 
            self.in_feature_tau = self.dataset.size[1] + self.dataset.size[1] // 2
        else:
            # self.in_feature = self.dataset.size[1] + 1# 加上treatment
            self.in_feature_tau = self.dataset.size[1] 

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']

        # self.outcome_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
        #                                      + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
        #                                      + [nn.Linear(self.repre_layer_sizes[-1], 1)]
        #                                      ))

        self.outcome_model_treated = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                             + get_linear_layers(self.in_feature_tau,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))
        self.outcome_model_control = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                             + get_linear_layers(self.in_feature_tau,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.propensity_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                              + get_linear_layers(self.in_feature_tau, self.repre_layer_sizes, self.bn, nn.ReLU)
                                              + [nn.Linear(self.repre_layer_sizes[-1], 1), nn.Sigmoid()]))
        
        self.tau_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                       + get_linear_layers(self.in_feature_tau, self.repre_layer_sizes, self.bn, nn.ReLU)
                                       + [nn.Linear(self.repre_layer_sizes[-1], 1)]))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.bce_loss = nn.BCELoss(reduction='none')

    def forward(self, x, t):
        tau = self.tau_model(x)
        return tau.squeeze()


    def calculate_loss(self, x,t,y,w):
        t = t.float()
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        loss_treated = torch.sum(self.mse_loss(self.outcome_model_treated(x), y) * mask_treated )
        loss_control = torch.sum(self.mse_loss(self.outcome_model_control(x), y) * mask_control )
        loss_outcome =  loss_treated + loss_control

        prop_pred = self.propensity_model(x).squeeze()  # e(x)

        prop_pred_clamped = torch.clamp(prop_pred, 0.01, 0.99)
        loss_propensity = torch.sum(self.bce_loss(prop_pred_clamped, t.float().reshape(-1)) ) 

        mu1 = self.outcome_model_treated(x)
        mu0 = self.outcome_model_control(x)

        dr_contrib1 = (t / prop_pred_clamped.view(-1, 1)) * (y - mu1) + mu1

        dr_contrib0 = ((1 - t) / (1 - prop_pred_clamped.view(-1, 1))) * (y - mu0) + mu0
        dr_score = dr_contrib1 - dr_contrib0

        tau_pred = self.tau_model(x).view(-1, 1) 
        loss_tau = torch.sum(self.mse_loss(tau_pred, dr_score) )  

        total_loss = loss_outcome + loss_propensity + loss_tau
        return total_loss
    
    @torch.no_grad()
    def get_reward(self, x,t,y,w):
        self.eval()
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        loss_treated = self.mse_loss(self.outcome_model_treated(x), y) * mask_treated 
        loss_control = self.mse_loss(self.outcome_model_control(x), y) * mask_control 
        loss_outcome =  loss_treated + loss_control

        prop_pred = self.propensity_model(x).squeeze()  

        prop_pred_clamped = torch.clamp(prop_pred, 0.01, 0.99)
        loss_propensity = self.bce_loss(prop_pred_clamped, t.float().reshape(-1)) 

        mu1 = self.outcome_model_treated(x)
        mu0 = self.outcome_model_control(x)


        dr_contrib1 = (t / prop_pred_clamped.view(-1, 1)) * (y - mu1) + mu1
        dr_contrib0 = ((1 - t) / (1 - prop_pred_clamped.view(-1, 1))) * (y - mu0) + mu0
        dr_score = dr_contrib1 - dr_contrib0

        tau_pred = self.tau_model(x).view(-1, 1)  
        loss_tau = self.mse_loss(tau_pred, dr_score)  

        total_loss = loss_outcome.reshape(-1) + loss_propensity.reshape(-1) + loss_tau.reshape(-1)
        loss = total_loss
        return reward_format(loss)
