#import torch
from torch import nn

class FactExtractor:
    def __init__(self, P, F, N_slot, hidden_channels=64):
        self.F = F #facts
        self.F_dim = len(F)
        self.P = P #predicates
        self.P_dim = len(P)
        self.N_slot = N_slot
        self.hidden_channels = hidden_channels
        self.mlps = [MLP(hidden_channels) for i in range(len(P))]

    def extract(self, pred_id, slot, slot_id):
        return self.mlps[pred_id](slot)



class MLP(nn.Module):
    """
    MLP for CLEVR as in Locatello et al. 2020 according to the set prediction architecture.
    """
    def __init__(self, hidden_channels):
        """
        Builds the MLP
        :param hidden_channels: Integer, hidden channel dimensions within encoder, is also equivalent to the input
        channel dims here.
        """
        super(MLP, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_channels, hidden_channels),
        )

    def forward(self, x):
        return self.network(x)


def build_grid(resolution):
    """
    Builds the grid for the Posisition Embedding.
    :param resolution: Tuple of Ints, in the dimensions of the latent space of the encoder.
    :return: 2D Float meshgrid representing th x y position.
    """
    ranges = [np.linspace(0., 1., num=res) for res in resolution]
    grid = np.meshgrid(*ranges, sparse=False, indexing="ij")
    grid = np.stack(grid, axis=-1)
    grid = np.reshape(grid, [resolution[0], resolution[1], -1])
    grid = np.expand_dims(grid, axis=0)
    grid = grid.astype(np.float32)
    return np.concatenate([grid, 1.0 - grid], axis=-1)

