import torch
import torch.nn as nn
import torch.jit as jit
import torch.nn.functional as F
from network.utils import Linear


class Q_network_RNN(jit.ScriptModule):

    def __init__(self, ob_dim, ac_dim, h_dim):
        super().__init__()
        self.rnn_hidden = torch.tensor(0.0)
        self.fc1 = Linear(ob_dim, h_dim)
        self.rnn = nn.GRU(h_dim, h_dim, batch_first=True)
        try:
            self.rnn.cuda()
            self.rnn.flatten_parameters()
        except:
            pass
        self.fc2 = Linear(h_dim, ac_dim)

    @jit.script_method
    def reset(self):
        self.rnn_hidden = torch.tensor(0.0)

    @jit.script_method
    def forward(self, inputs) -> torch.Tensor:
        h = F.relu(self.fc1(inputs))
        if len(self.rnn_hidden.shape) == 0:
            h, self.rnn_hidden = self.rnn(h)
        else:
            h, self.rnn_hidden = self.rnn(h, self.rnn_hidden)
        Q = self.fc2(h)
        return Q