from common.imports import *
from common.utils import Linear, th_act_fns

class RQNetwork(nn.Module):
    def __init__(self, venv, args, agent_id):
        super().__init__()

        act_str, self.act_fn = args.act_fn, th_act_fns[args.act_fn]
        self.n_envs = args.n_envs

        self.obs_space = venv.single_observation_space[agent_id]
        self.act_space = venv.single_action_space[agent_id]
       
        self.hs = args.h_size
        self.hidden = Linear(self.obs_space.n, self.hs, act_str)
        self.gru = nn.GRU(self.hs, self.hs, batch_first=True)
        self.output = Linear(self.hs, self.act_space.n, 'linear')

    def forward(self, x, h):    # x = (n_envs, 1, size), h = (1, n_envs, size)
        x = self.act_fn(self.hidden(x))
        x, h = self.gru(x, h)
        return self.output(x), h

    def get_action(self, x, h, eps=0.0):
        q_values, h = self.forward(x, h[None])
        h = h.squeeze_(0)
        if np.random.rand() < eps: return th.randint(high=self.act_space.n, size=(x.shape[0],)), h
        else: return th.argmax(q_values, dim=-1).squeeze(-1), h
        
    def init_hidden(self):
        return th.zeros([self.n_envs, self.hs])
    
class DuelRQNetwork(RQNetwork):
    def __init__(self, venv, args, agent_id):
        super().__init__(venv, args, agent_id)
      
        del self.output
        self.V = Linear(self.hs, 1, 'linear')
        self.A = Linear(self.hs, self.act_space.n, 'linear')

    def forward(self, x, h):
        x = self.act_fn(self.hidden(x))
        x, h = self.gru(x, h)
        v = self.V(x)
        a = self.A(x)
        return v + (a - th.mean(a)), h