import torch.nn as nn
import torch.nn.functional as F


class DuelingRNNAgent(nn.Module):
    def __init__(self, input_shape, args):
        super(DuelingRNNAgent, self).__init__()
        self.args = args

        self.fc1 = nn.Linear(input_shape, args.hidden_dim)
        if self.args.use_rnn:
            self.rnn = nn.GRUCell(args.hidden_dim, args.hidden_dim)
        else:
            self.rnn = nn.Linear(args.hidden_dim, args.hidden_dim)
        self.hidden_dim_v = args.hidden_dim // 2
        self.hidden_dim_a = args.hidden_dim - self.hidden_dim_v
        self.fc2_v = nn.Linear(self.hidden_dim_v, 1)
        self.fc2_a = nn.Linear(self.hidden_dim_a, args.n_actions)
        self.init_net_a()

    def init_hidden(self):
        # make hidden states on same device as model
        return self.fc1.weight.new(1, self.args.hidden_dim).zero_()
    
    def init_net_a(self):
        #k = 1 / self.hidden_dim_a
        #nn.init.uniform_(self.fc2_a.weight, - k**0.5, k**0.5)
        #nn.init.uniform_(self.fc2_a.bias, -k**0.5, k**0.5)
        nn.init.zeros_(self.fc2_a.weight)
        nn.init.zeros_(self.fc2_a.bias)

    def forward(self, inputs, hidden_state):
        x = F.relu(self.fc1(inputs))
        h_in = hidden_state.reshape(-1, self.args.hidden_dim)
        if self.args.use_rnn:
            h = self.rnn(x, h_in)
        else:
            h = F.relu(self.rnn(x))
        h_v, h_a = h[:, :self.hidden_dim_v], h[:, self.hidden_dim_v:]
        v = self.fc2_v(h_v)
        a = self.fc2_a(h_a)
        return v, a, h

