import torch
import torch.nn as nn
import torch.nn.functional as F
from stmodels.layers import *

class GRUCell(torch.nn.Module):
    def __init__(
        self, 
        embed_size,
        conv_ru, 
        conv_c,
        nonlinearity='tanh'
        ):
        super().__init__()
        
        
        self._embed_size = embed_size
        self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu

        self.ru_gconv = GraphConv(
            input_dim=self._embed_size * 2,
            output_dim=self._embed_size * 2,
            conv=conv_ru
            )
        self.c_gconv = GraphConv(
            input_dim=self._embed_size * 2,
            output_dim=self._embed_size,                                     
            conv=conv_c
            )
        

    def forward(self, inputs, batch, t, hx):
        """Gated recurrent unit (GRU) with Graph Convolution.
        :param inputs: (B, node_num, input_dim) 
        :param hx: (B, node_num, rnn_units)
        :param t: (B, num_time_feature)
        :return
        - Output: A `3-D` tensor with shaconv=conv_kerconv=conv_kerconv=conv_kerconv=conv_kerconv=conv_kerconv=conv_kerconv=conv_kerconv=conv_kerpe `(B, node_num, rnn_units)`.
        """
        node_num = inputs.shape[1]
        
        conv_in_ru = self._concat(inputs, hx)
        value = torch.sigmoid(self.ru_gconv(conv_in_ru, batch, t))
        r, u = torch.split(tensor=value, split_size_or_sections=self._embed_size, dim=-1)
        r = torch.reshape(r, (-1, node_num, self._embed_size))
        u = torch.reshape(u, (-1, node_num, self._embed_size))
        conv_in_c = self._concat(inputs, r*hx)
        c = self.c_gconv(conv_in_c, batch, t)

        if self._activation is not None:
            c = self._activation(c)
        new_state = u * hx + (1.0 - u) * c
        
        return new_state

    @staticmethod
    def _concat(x, x_):
        return torch.cat([x, x_], dim=2)