import torch as th
import torch.nn as nn
import numpy as np


class MAOSDQNCritic(nn.Module):
    def __init__(self, scheme, args):
        super(MAOSDQNCritic, self).__init__()

        self.args = args
        self.n_actions = args.n_actions
        self.n_agents = args.n_agents

        input_shape = self._get_input_shape(scheme)
        self.output_type = "q"

        # Set up network layers
        self.net = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_shape, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, 1)
            ),
            nn.Sequential(
                nn.Linear(input_shape, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, args.hidden_dim),
                nn.ReLU(),
                nn.Linear(args.hidden_dim, 1)
            )
        ])

    def forward(self, states, actions_onehot, id=0):
        inputs = th.cat((states, actions_onehot.view(-1, self.n_agents*self.n_actions)), dim=1)
        q = self.net[id](inputs)
        if self.args.random_output and np.random.rand()<0.005:
            print("critic output: ",inputs[0],q[0])
        return q

    def _get_input_shape(self, scheme):
        # states
        input_shape = scheme["state"]["vshape"]
        # joint actions
        input_shape += self.n_agents * self.n_actions
        return input_shape
