import tensorflow as tf
from tonic import logger, replays
from tonic.tensorflow import agents, models, normalizers, updaters



def default_model():
    return models.ActorTwinCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((256, 256), 'relu'),
            head=models.DeterministicPolicyHead()),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((256, 256), 'relu'),
            head=models.DistributionalValueHead(-150., 150., 51)),
        observation_normalizer=normalizers.MeanStd())


class TD4Actor(updaters.actors.DistributionalDeterministicPolicyGradient):
    def __init__(self, optimizer=None, gradient_clip=0):
        self.optimizer = optimizer or \
                tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
        self.gradient_clip = gradient_clip


class TD4Critic(updaters.critics.TwinCriticDistributionalDeterministicQLearning):
    def __init__(self, optimizer=None, target_action_noise=None, gradient_clip=0):
        self.optimizer = optimizer or \
            tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
        self.target_action_noise = target_action_noise or \
            updaters.critics.TargetActionNoise(scale=0.25, clip=0.5)
        self.gradient_clip = gradient_clip


class TunedTD4(agents.TD3):
    def __init__(
        self, model=None, replay=None, exploration=None, actor_updater=None,
        critic_updater=None, delay_steps=2
    ):
        model = model or default_model()
        replay = replay or replays.Buffer(return_steps=5)
        actor_updater = actor_updater or \
            TD4Actor()
        critic_updater = critic_updater or \
            TD4Critic()
        super().__init__(
            model=model, replay=replay, exploration=exploration,
            actor_updater=actor_updater, critic_updater=critic_updater,
            delay_steps=delay_steps)

class DistributionalMPO(agents.MPO):
    def __init__(
        self, model=None, replay=None, actor_updater=None, critic_updater=None
    ):
        self.model = model or default_model()
        self.replay = replay or replays.Buffer(return_steps=3)
        self.actor_updater = actor_updater or \
        updaters.MaximumAPosterioriPolicyOptimization()
        self.critic_updater = critic_updater or updaters.critics.TwinCriticDistributionalDeterministicQLearning()
        self.model.critic = self.model.critic_1
        self.model.target_critic = self.model.target_critic_1


