import torch
import torch.nn as nn
from causally.model.utils import get_linear_layers
from causally.model.abstract_model import AbstractModel, reward_format

class SNet(AbstractModel):
    def __init__(self, config,dataset):
        super(SNet, self).__init__(config,dataset)
        if self.config['v_input_type'] == 'x_and_z':
            self.in_feature = self.dataset.size[1] + self.dataset.size[1] // 2 + 1
        else:
            self.in_feature = self.dataset.size[1] + 1 

        self.bn = self.config['bn']
        self.repre_layer_sizes = self.config['repre_layer_sizes']
        self.pred_layer_sizes = self.config['pred_layer_sizes']


        self.repre_layers = nn.Sequential(*(([nn.BatchNorm1d(self.in_feature)] if self.bn else [])
                                             + get_linear_layers(self.in_feature,self.repre_layer_sizes,self.bn,nn.ReLU)
                                             + [nn.Linear(self.repre_layer_sizes[-1], 1)]
                                             ))

        self.mse_loss = nn.MSELoss(reduction='none')

    def forward(self, x, t):
        t = t.view(-1, 1) 
        x_t = torch.cat([x, t], dim=1)
        y = self.repre_layers(x_t)
        return y

    def calculate_loss(self, x,t,y,w):
        pred = self.forward(x,t)
        loss = torch.sum(self.mse_loss(pred,y))
        return loss
    
    @torch.no_grad()
    def get_reward(self, x,t,y,w):
        self.eval()
        pred = self.forward(x,t)
        loss = self.mse_loss(pred,y) 
        return reward_format(loss)

