import numpy
import torch
import torch.nn as nn


#####################
# Actors blueprints #
#####################

class ContinuousStatesDiscreteActionsNet(nn.Module):
    def __init__(self,
                 num_state: int,
                 num_action: int,
                 hidden_layers_size: list,
                 activation_functions: list):
        super(ContinuousStatesDiscreteActionsNet, self).__init__()

        self.model = torch.nn.Sequential(torch.nn.Linear(num_state, hidden_layers_size[0]),
                                         activation_functions[0],
                                         torch.nn.Linear(hidden_layers_size[0], hidden_layers_size[1]),
                                         activation_functions[1],
                                         torch.nn.Linear(hidden_layers_size[1], num_action),
                                         torch.nn.Softmax(dim=1))

    # return a probability distribution over the action space
    def forward(self, state):
        return self.model(state)


class ContinuousStatesContinuousActionsNet(nn.Module):
    def __init__(self,
                 num_state: int,
                 num_action: int,
                 low: float,
                 high: float,
                 hidden_layers_size: list,
                 activation_functions: list):
        super(ContinuousStatesContinuousActionsNet, self).__init__()

        self.low = low
        self.high = high
        self.num_action = num_action
        self.model = torch.nn.Sequential(torch.nn.Linear(num_state, hidden_layers_size[0]),
                                         activation_functions[0],
                                         torch.nn.Linear(hidden_layers_size[0], hidden_layers_size[1]),
                                         activation_functions[1],
                                         torch.nn.Linear(hidden_layers_size[1], num_action * 2),
                                         torch.nn.Tanh())

    # return a probability distribution over the action space
    def forward(self, state):
        y = self.model(state)
        prob = torch.softmax(y[:self.num_action], dim=0)
        samples = (y[self.num_action:]+1.0)/2.0 * (self.high-self.low) + self.low
        return prob, samples


######################
# Critics blueprints #
######################

class QFunctionNN(nn.Module):
    def __init__(self,
                 num_state: int,
                 num_action: int,
                 hidden_layers_size: list,
                 activation_functions: list):
        super(QFunctionNN, self).__init__()

        self.model = torch.nn.Sequential(torch.nn.Linear(num_state+num_action, hidden_layers_size[0]),
                                         activation_functions[0],
                                         torch.nn.Linear(hidden_layers_size[0], hidden_layers_size[1]),
                                         activation_functions[1],
                                         torch.nn.Linear(hidden_layers_size[1], 1),
                                         torch.nn.Identity())

    def forward(self, x):
        return self.model(x)


class VFunctionNN(nn.Module):
    def __init__(self,
                 num_state: int,
                 hidden_layers_size: list,
                 activation_functions: list):
        super(VFunctionNN, self).__init__()

        if len(hidden_layers_size) == 2 and len(activation_functions) == 2:
            self.model = torch.nn.Sequential(torch.nn.Linear(num_state, hidden_layers_size[0]),
                                             activation_functions[0],
                                             torch.nn.Linear(hidden_layers_size[0], hidden_layers_size[1]),
                                             activation_functions[1],
                                             torch.nn.Linear(hidden_layers_size[1], 1),
                                             torch.nn.Identity())
        elif len(hidden_layers_size) == 1 and len(activation_functions) == 1:
            self.model = torch.nn.Sequential(torch.nn.Linear(num_state, hidden_layers_size[0]),
                                             activation_functions[0],
                                             torch.nn.Linear(hidden_layers_size[0], 1),
                                             torch.nn.Identity())
        else:
            ValueError("Unknown shape.")

    def forward(self, x):
        return self.model(x)
        
