import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel, reward_format

class TNet(AbstractModel):
    def __init__(self, config,dataset):
        super(TNet, self).__init__(config,dataset)
        if self.config['v_input_type'] == 'x_and_z':
            self.in_feature = self.dataset.size[1] + self.dataset.size[1] // 2 
        else:
            self.in_feature = self.dataset.size[1] 

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']


        self.model_treated = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.model_control = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.mse_loss = nn.MSELoss(reduction='none')


    def forward(self, x, t):
        y1 = self.model_treated(x)
        y0 = self.model_control(x)
        return torch.where(t == 1, y1, y0)

    def calculate_loss(self, x,t,y,w):
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        loss_treated = torch.sum(self.mse_loss(self.model_treated(x), y)  * mask_treated)
        loss_control = torch.sum(self.mse_loss(self.model_control(x), y)  * mask_control)
        return loss_treated + loss_control
    
    @torch.no_grad()
    def get_reward(self, x,t,y,w):
        self.eval()
        mask_treated = (t == 1).float()
        mask_control = (t == 0).float()
        loss_treated = self.mse_loss(self.model_treated(x), y)  * mask_treated
        loss_control = self.mse_loss(self.model_control(x), y)  * mask_control
        loss = loss_treated + loss_control
        return reward_format(loss)

