import torch
from tonic.torch import agents
from tonic.torch import models
from tonic.torch import normalizers
from tonic.torch import updaters


def custom_model_mpo(hidden_size=256):
    return models.ActorCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            head=models.GaussianPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            head=models.ValueHead(),
        ),
        observation_normalizer=normalizers.MeanStd(),
    )


def custom_return_mpo(hidden_size=256):
    return models.ActorCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            head=models.GaussianPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            head=models.ValueHead(),
        ),
        observation_normalizer=normalizers.MeanStd(),
        return_normalizer=normalizers.returns.Return(0.99),
    )


def custom_model_d4pg(hidden_size=256):
    return models.ActorCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            head=models.DeterministicPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((hidden_size, hidden_size), torch.nn.ReLU),
            # These values are for the control suite with 0.99 discount.
            head=models.DistributionalValueHead(-150.0, 150.0, 51),
        ),
        observation_normalizer=normalizers.MeanStd(),
    )
