import torch.nn as nn
import torch


class MINE(nn.Module):
    def __init__(self, x_dim, y_dim, hidden_size=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(x_dim + y_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, x, y):
        input_pair = torch.cat([x, y], dim=1)
        return self.net(input_pair)
