import tensorflow as tf
from tonic import logger
from tonic import replays
from tonic.tensorflow import agents
from tonic.tensorflow import models
from tonic.tensorflow import normalizers
from tonic.tensorflow import updaters


def default_model():
    return models.ActorTwinCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((256, 256), "relu"),
            head=models.DeterministicPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((256, 256), "relu"),
            head=models.DistributionalValueHead(-150.0, 150.0, 51),
        ),
        observation_normalizer=normalizers.MeanStd(),
    )


class TD4Actor(updaters.actors.DistributionalDeterministicPolicyGradient):
    def __init__(self, optimizer=None, gradient_clip=0):
        self.optimizer = optimizer or tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
        self.gradient_clip = gradient_clip


class TD4Critic(updaters.critics.TwinCriticDistributionalDeterministicQLearning):
    def __init__(self, optimizer=None, target_action_noise=None, gradient_clip=0):
        self.optimizer = optimizer or tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
        self.target_action_noise = (
            target_action_noise
            or updaters.critics.TargetActionNoise(scale=0.25, clip=0.5)
        )
        self.gradient_clip = gradient_clip


class TunedTD4(agents.TD3):
    def __init__(
        self,
        model=None,
        replay=None,
        exploration=None,
        actor_updater=None,
        critic_updater=None,
        delay_steps=2,
    ):
        model = model or default_model()
        replay = replay or replays.Buffer(return_steps=5)
        actor_updater = actor_updater or TD4Actor()
        critic_updater = critic_updater or TD4Critic()
        super().__init__(
            model=model,
            replay=replay,
            exploration=exploration,
            actor_updater=actor_updater,
            critic_updater=critic_updater,
            delay_steps=delay_steps,
        )


class DistributionalMPO(agents.MPO):
    def __init__(
        self, model=None, replay=None, actor_updater=None, critic_updater=None
    ):
        self.model = model or default_model()
        self.replay = replay or replays.Buffer(return_steps=3)
        self.actor_updater = (
            actor_updater or updaters.MaximumAPosterioriPolicyOptimization()
        )
        self.critic_updater = (
            critic_updater
            or updaters.critics.TwinCriticDistributionalDeterministicQLearning()
        )
        self.model.critic = self.model.critic_1
        self.model.target_critic = self.model.target_critic_1
