import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel, reward_format

class RNet(AbstractModel):
    def __init__(self, config,dataset):
        super(RNet, self).__init__(config,dataset)
        if self.config['v_input_type'] == 'x_and_z':
            self.in_feature_tau = self.dataset.size[1] + self.dataset.size[1] // 2
        else:
            self.in_feature_tau = self.dataset.size[1] 

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']

        self.outcome_model_treated = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                             + get_linear_layers(self.in_feature_tau,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))
        self.outcome_model_control = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                             + get_linear_layers(self.in_feature_tau,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))
        self.outcome_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                             + get_linear_layers(self.in_feature_tau,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.propensity_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                              + get_linear_layers(self.in_feature_tau, self.repre_layer_sizes, self.bn, nn.ReLU)
                                              + [nn.Linear(self.repre_layer_sizes[-1], 1), nn.Sigmoid()]))
        
        self.tau_model = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature_tau)] if self.bn else [])
                                       + get_linear_layers(self.in_feature_tau, self.repre_layer_sizes, self.bn, nn.ReLU)
                                       + [nn.Linear(self.repre_layer_sizes[-1], 1)]))

        self.mse_loss = nn.MSELoss(reduction='none')
        self.bce_loss = nn.BCELoss(reduction='none')
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')

    def forward(self, x, t):
        tau = self.tau_model(x)
        return tau.squeeze()

    def calculate_loss(self, x, t, y, w):
        t = t.float()
        loss_outcome = torch.sum(self.mse_loss(self.outcome_model(x), y))

        prop_pred = self.propensity_model(x).squeeze() 
        loss_propensity = torch.sum(self.bce_loss(prop_pred, t.squeeze().float())) 
        
        tau_pred = self.tau_model(x).reshape(-1) * (t.reshape(-1) - prop_pred.reshape(-1))
        loss_tau = torch.sum(self.mse_loss(tau_pred, y.reshape(-1) - self.outcome_model(x).reshape(-1))) 
        
        total_loss = loss_outcome.reshape(-1) + loss_propensity.reshape(-1) + loss_tau.reshape(-1)
        return total_loss
    
    @torch.no_grad()
    def get_reward(self, x,t,y,w):
        self.eval()
        t = t.float()
        loss_outcome = self.mse_loss(self.outcome_model(x), y) 

        prop_pred = self.propensity_model(x).squeeze() 
        loss_propensity = self.bce_loss(prop_pred, t.squeeze().float())  
        
        tau_pred = self.tau_model(x).reshape(-1) * (t.reshape(-1) - prop_pred.reshape(-1)) 
        loss_tau = self.mse_loss(tau_pred, y.reshape(-1) - self.outcome_model(x).reshape(-1)) 
        
        total_loss = loss_outcome.reshape(-1) + loss_propensity.reshape(-1) + loss_tau.reshape(-1)
        loss = total_loss
        return reward_format(loss)

