import torch
import torch.nn as nn
import torch.nn.functional as F

class SimSHAPTabular(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim=128):
        super(SimSHAPTabular, self).__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.linear = nn.Linear(in_dim, hidden_dim)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(hidden_dim, out_dim*in_dim)

    def forward(self, x):
        out = self.linear(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        out = out.view(out.size(0), self.out_dim, self.in_dim)
        return out