import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm

from ..utils.nn import one_hot

class MultCombination(nn.Module):
    """Multiplicative combination of two inputs

    Description
    -----------
    Implements the multiplicative interactions described in Oh et al.'s Action-
    Conditional Video Prediction paper (https://arxiv.org/pdf/1507.08750.pdf).
    Here, X is their encoded feature, Y is their control input, and this block
    computes the factored transformation: x_ = T_out (T_x X * T_y Y) + b.

    Parameters
    ----------
    x_size : int
        The size of the encoded feature vector.
    y_size : int
        The size of the one-hot control vector, i.e. the number of actions.
    n_factors : int
        The number of factors to use in place of the full |X|*|X|*|Y| tensor.
        The 3-way tensor is replaced by three 2-way tensors:
            T_out: |X|*|F|
            T_x: |F|*|X|
            T_y: |F|*|Y|
    """
    def __init__(self, x_size, y_size, n_factors):
        super().__init__()
        self.x_size = x_size
        self.y_size = y_size
        self.n_factors = n_factors
        self.T_x = nn.Linear(x_size, n_factors, bias=False)
        self.T_y = nn.Linear(y_size, n_factors, bias=False)
        self.T_out = nn.Linear(n_factors, x_size, bias=True)
    def __repr__(self):
        s = self.__class__.__name__
        s += '(x_size={}, y_size={}, n_factors={})'
        s = s.format(self.x_size, self.y_size, self.n_factors)
        return s
    def forward(self, x, y):
        f_x = self.T_x(x)
        f_y = self.T_y(y)
        x_ = self.T_out(f_x * f_y)
        return x_

def main():
    x_size = 128
    y_size = 5
    n_factors = 16
    batch_size = 32
    mc = MultCombination(x_size, y_size, n_factors)
    x = torch.ones(batch_size, x_size)
    iy = torch.multinomial(torch.ones(y_size), batch_size, replacement=True)
    y = one_hot(iy, depth=y_size)
    x_ = mc(x, y)
    assert x_.shape == torch.Size([batch_size, x_size])
    print('Testing complete.')

if __name__ == '__main__':
    main()
