import torch
from torch import nn

from .utils import build_mlp


class StateFunction(nn.Module):

    def __init__(self, state_shape, hidden_units=(64, 64),
                 hidden_activation=nn.Tanh()):
        super().__init__()

        self.net = build_mlp(
            input_dim=state_shape[0],
            output_dim=1,
            hidden_units=hidden_units,
            hidden_activation=hidden_activation
        )
        self.mean = torch.zeros(state_shape, dtype=torch.float)
        self.std = torch.ones(state_shape, dtype=torch.float)

    def forward(self, states):
        states = torch.clamp((states - self.mean) / self.std, min=-5.0, max=5.0)
        return self.net(states)


class StateActionFunction(nn.Module):

    def __init__(self, state_shape, action_shape, hidden_units=(100, 100),
                 hidden_activation=nn.Tanh()):
        super().__init__()

        self.net = build_mlp(
            input_dim=state_shape[0] + action_shape[0],
            output_dim=1,
            hidden_units=hidden_units,
            hidden_activation=hidden_activation
        )
        self.mean = torch.zeros(state_shape, dtype=torch.float)
        self.std = torch.ones(state_shape, dtype=torch.float)

    def forward(self, states, actions):
        states = torch.clamp((states - self.mean) / self.std, min=-5.0, max=5.0)
        return self.net(torch.cat([states, actions], dim=-1))


class TwinnedStateActionFunction(nn.Module):

    def __init__(self, state_shape, action_shape, hidden_units=(256, 256),
                 hidden_activation=nn.ReLU(inplace=True)):
        super().__init__()

        self.net1 = build_mlp(
            input_dim=state_shape[0] + action_shape[0],
            output_dim=1,
            hidden_units=hidden_units,
            hidden_activation=hidden_activation
        )
        self.net2 = build_mlp(
            input_dim=state_shape[0] + action_shape[0],
            output_dim=1,
            hidden_units=hidden_units,
            hidden_activation=hidden_activation
        )
        self.mean = torch.zeros(state_shape, dtype=torch.float)
        self.std = torch.ones(state_shape, dtype=torch.float)

    def forward(self, states, actions):
        states = torch.clamp((states - self.mean) / self.std, min=-5.0, max=5.0)
        xs = torch.cat([states, actions], dim=-1)
        return self.net1(xs), self.net2(xs)

    def q1(self, states, actions):
        states = torch.clamp((states - self.mean) / self.std, min=-5.0, max=5.0)
        return self.net1(torch.cat([states, actions], dim=-1))
