"""
Pytorch models for continuous control.

All models assume that a feature representing the
current timestep is used in addition to the features
received from the environment.
"""

import torch

from all import nn


def fc_q(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Float(),
        nn.Linear(env.state_space.shape[0] + env.action_space.shape[0], hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, 1),
    )


def fc_v(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Float(),
        nn.Linear(env.state_space.shape[0], hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, 1),
    )


def fc_deterministic_policy(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Float(),
        nn.Linear(env.state_space.shape[0], hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, env.action_space.shape[0]),
    )


def fc_soft_policy(env, hidden1=400, hidden2=300):
    return nn.Sequential(
        nn.Float(),
        nn.Linear(env.state_space.shape[0], hidden1),
        nn.ReLU(),
        nn.Linear(hidden1, hidden2),
        nn.ReLU(),
        nn.Linear0(hidden2, env.action_space.shape[0] * 2),
    )


class fc_policy(nn.Module):
    def __init__(self, env, hidden1=400, hidden2=300):
        super().__init__()
        self.model = nn.Sequential(
            nn.Float(),
            nn.Linear(env.state_space.shape[0], hidden1),
            nn.Tanh(),
            nn.Linear(hidden1, hidden2),
            nn.Tanh(),
            nn.Linear(hidden2, env.action_space.shape[0]),
        )
        self.log_stds = nn.Parameter(torch.zeros(env.action_space.shape[0]))

    def forward(self, x):
        means = self.model(x)
        stds = self.log_stds.expand(*means.shape)
        return torch.cat((means, stds), 1)
