from torch import nn
import torch as T
import torch.nn.functional as F

class GRC(nn.Module):
    def __init__(self, hidden_size, cell_hidden_size, dropout):
        super(GRC, self).__init__()
        self.hidden_dim = hidden_size
        self.wcell1 = nn.Linear(2 * hidden_size, cell_hidden_size)
        self.wcell2 = nn.Linear(cell_hidden_size, 4 * hidden_size)
        self.LN = nn.LayerNorm(hidden_size)
        self.dropout = dropout

    def forward(self, left=None, right=None):
        N, D = left.size()
        assert right.size() == left.size()

        concated = T.cat([left, right], dim=-1)

        intermediate = F.gelu(self.wcell1(concated))
        intermediate = F.dropout(intermediate, p=self.dropout, training=self.training)
        contents = self.wcell2(intermediate)

        contents = contents.view(N, 4, D)
        gates = T.sigmoid(contents[..., 0:3, :])
        parent = contents[..., 3, :]

        f1 = gates[..., 0, :]
        f2 = gates[..., 1, :]
        i = gates[..., 2, :]

        out = self.LN(f1 * left + f2 * right + i * parent)

        return out