import torch
from catatonic import custom_torso
from tonic.torch import agents, models, normalizers, updaters
from tonic import logger, replays  # noqa
from catatonic.utils import reduce_actor_observations

FLOAT_EPSILON = 1e-8


# class DistributionalDeterministicPolicyGradient:
#     def __init__(self, optimizer=None, gradient_clip=0):
#         self.optimizer = optizmier or \
#                 tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
#         self.gradient_clip = gradient_clip
#
# class TwinCriticDistributionalDeterministicQLearning:
#     def __init__(self, optimizer=None, target_action_noise, gradient_clip=0):
#         self.optimizer = optimizer or \
#             tf.keras.optimizers.Adam(lr=1e-4, epsilon=1e-8)
#         self.target_action_noise = target_action_noise or \
#             TargetActionNoise(scale=0.2, clip=0.5)
#         self.gradient_clip = gradient_clip


def default_model():
    return models.ActorTwinCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((256, 256), "relu"),
            head=models.GaussianPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((256, 256), "relu"),
            head=models.DistributionalValueHead(-150.0, 150.0, 51),
        ),
        observation_normalizer=normalizers.MeanStd(),
    )


def retnorm_mpo(hidden_size=256):

    return models.ActorCriticWithTargets(
        actor=models.Actor(
            encoder=models.ObservationEncoder(),
            torso=models.MLP((256, 256), torch.nn.ReLU),
            head=models.GaussianPolicyHead(),
        ),
        critic=models.Critic(
            encoder=models.ObservationActionEncoder(),
            torso=models.MLP((256, 256), torch.nn.ReLU),
            head=models.ValueHead(),
        ),
        observation_normalizer=normalizers.MeanStd(),
        return_normalizer=normalizers.returns.Return(0.99),
    )


class TunedMPO(agents.MPO):
    """Maximum a Posteriori Policy Optimisation.
    MPO: https://arxiv.org/pdf/1806.06920.pdf
    MO-MPO: https://arxiv.org/pdf/2005.07513.pdf
    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def set_params(
        self,
        lr_critic=1e-3,
        grad_clip_critic=0,
        lr_actor=3e-4,
        lr_dual=1e-2,
        grad_clip_actor=0,
        hidden_size=None,
        batch_size=None,
        retnorm=None,
        return_steps=None,
    ):
        optim_critic = lambda params: torch.optim.Adam(params, lr_critic, capturable=True)
        self.critic_updater = TunedExpectedSARSA(
            optimizer=optim_critic, gradient_clip=grad_clip_critic
        )
        optim_actor = lambda params: torch.optim.Adam(params, lr=lr_actor, capturable=True)
        optim_dual = lambda params: torch.optim.Adam(params, lr=lr_dual, capturable=True)
        self.actor_updater = TunedMaximumAPosteriori(
            actor_optimizer=optim_actor,
            dual_optimizer=optim_dual,
            gradient_clip=grad_clip_actor,
        )
        if hidden_size is None:
            hidden_size = 256
        if retnorm is not None:
            print("retnorm")
            self.model = retnorm_mpo(hidden_size=hidden_size)
        else:
            print("no retnorm")
            self.model = custom_torso.custom_model_mpo(hidden_size=hidden_size)
        if batch_size is not None:
            self.replay.batch_size = batch_size
        if return_steps is not None:
            self.replay.return_steps = return_steps

    def initialize(self, *args, **kwargs):
        super().initialize(*args, **kwargs)
        print(f"{self.actor_updater}")


class TunedExpectedSARSA(updaters.critics.ExpectedSARSA):
    def __init__(self, num_samples=20, loss=None, optimizer=None, gradient_clip=0):
        self.num_samples = num_samples
        self.loss = loss or torch.nn.MSELoss()
        self.optimizer = optimizer or (
            lambda params: torch.optim.Adam(params, lr=3e-4, weight_decay=1e-5, capturable=True)
        )
        self.gradient_clip = gradient_clip


class TunedMaximumAPosteriori(updaters.actors.MaximumAPosterioriPolicyOptimization):
    def __init__(
        self,
        num_samples=20,
        epsilon=1e-1,
        epsilon_penalty=1e-3,
        epsilon_mean=1e-3,
        epsilon_std=1e-6,
        initial_log_temperature=1.0,
        initial_log_alpha_mean=1.0,
        initial_log_alpha_std=10.0,
        min_log_dual=-18.0,
        per_dim_constraining=True,
        action_penalization=True,
        actor_optimizer=None,
        dual_optimizer=None,
        gradient_clip=0,
    ):
        self.num_samples = num_samples
        self.epsilon = epsilon
        self.epsilon_mean = epsilon_mean
        self.epsilon_std = epsilon_std
        self.initial_log_temperature = initial_log_temperature
        self.initial_log_alpha_mean = initial_log_alpha_mean
        self.initial_log_alpha_std = initial_log_alpha_std
        self.min_log_dual = torch.as_tensor(min_log_dual, dtype=torch.float32)
        self.action_penalization = action_penalization
        self.epsilon_penalty = epsilon_penalty
        self.per_dim_constraining = per_dim_constraining
        self.actor_optimizer = actor_optimizer or (
            lambda params: torch.optim.Adam(params, lr=3e-4, weight_decay=1e-5, capturable=True)
        )
        self.dual_optimizer = dual_optimizer or (
            lambda params: torch.optim.Adam(params, lr=1e-2, weight_decay=1e-5, capturable=True)
        )
        self.gradient_clip = gradient_clip


class WiseCriticMPO(agents.MPO):
    def __init__(
        self, model=None, replay=None, actor_updater=None, critic_updater=None
    ):
        self.model = retnorm_mpo(hidden_size=256)
        self.replay = replay or replays.Buffer(return_steps=5)
        self.actor_updater = actor_updater or \
            WiseCriticMaximumAPosterioriPolicyOptimization()
        self.critic_updater = critic_updater or WiseCriticExpectedSARSA()

    def step(self, observations, steps):
        observations = reduce_actor_observations(observations)
        return super().step(observations, steps)

    def test_step(self, observations, steps):
        observations = reduce_actor_observations(observations)
        return super().test_step(observations, steps)


class WiseCriticMaximumAPosterioriPolicyOptimization(updaters.MaximumAPosterioriPolicyOptimization):

    def __call__(self, observations):
        actor_observations = reduce_actor_observations(observations)
        def parametric_kl_and_dual_losses(kl, alpha, epsilon):
            kl_mean = kl.mean(dim=0)
            kl_loss = (alpha.detach() * kl_mean).sum()
            alpha_loss = (alpha * (epsilon - kl_mean.detach())).sum()
            return kl_loss, alpha_loss

        def weights_and_temperature_loss(q_values, epsilon, temperature):
            tempered_q_values = q_values.detach() / temperature
            weights = torch.nn.functional.softmax(tempered_q_values, dim=0)
            weights = weights.detach()

            # Temperature loss (dual of the E-step).
            q_log_sum_exp = torch.logsumexp(tempered_q_values, dim=0)
            num_actions = torch.as_tensor(
                q_values.shape[0], dtype=torch.float32)
            log_num_actions = torch.log(num_actions)
            loss = epsilon + (q_log_sum_exp).mean() - log_num_actions
            loss = temperature * loss

            return weights, loss

        # Use independent normals to satisfy KL constraints per-dimension.
        def independent_normals(distribution_1, distribution_2=None):
            distribution_2 = distribution_2 or distribution_1
            return torch.distributions.independent.Independent(
                torch.distributions.normal.Normal(
                    distribution_1.mean, distribution_2.stddev), -1)

        with torch.no_grad():
            self.log_temperature.data.copy_(
                torch.maximum(self.min_log_dual, self.log_temperature))
            self.log_alpha_mean.data.copy_(
                torch.maximum(self.min_log_dual, self.log_alpha_mean))
            self.log_alpha_std.data.copy_(
                torch.maximum(self.min_log_dual, self.log_alpha_std))
            if self.action_penalization:
                self.log_penalty_temperature.data.copy_(torch.maximum(
                    self.min_log_dual, self.log_penalty_temperature))

            target_distributions = self.model.target_actor(actor_observations)
            actions = target_distributions.sample((self.num_samples,))

            tiled_observations = updaters.tile(observations, self.num_samples)
            flat_observations = updaters.merge_first_two_dims(
                tiled_observations)
            flat_actions = updaters.merge_first_two_dims(actions)
            values = self.model.target_critic(flat_observations, flat_actions)
            values = values.view(self.num_samples, -1)

            assert isinstance(
                target_distributions, torch.distributions.normal.Normal)
            target_distributions = independent_normals(target_distributions)

        self.actor_optimizer.zero_grad()
        self.dual_optimizer.zero_grad()

        distributions = self.model.actor(actor_observations)
        distributions = independent_normals(distributions)

        temperature = torch.nn.functional.softplus(
            self.log_temperature) + FLOAT_EPSILON
        alpha_mean = torch.nn.functional.softplus(
            self.log_alpha_mean) + FLOAT_EPSILON
        alpha_std = torch.nn.functional.softplus(
            self.log_alpha_std) + FLOAT_EPSILON
        weights, temperature_loss = weights_and_temperature_loss(
            values, self.epsilon, temperature)

        # Action penalization is quadratic beyond [-1, 1].
        if self.action_penalization:
            penalty_temperature = torch.nn.functional.softplus(
                self.log_penalty_temperature) + FLOAT_EPSILON
            diff_bounds = actions - torch.clamp(actions, -1, 1)
            action_bound_costs = -torch.norm(diff_bounds, dim=-1)
            penalty_weights, penalty_temperature_loss = \
                weights_and_temperature_loss(
                    action_bound_costs,
                    self.epsilon_penalty, penalty_temperature)
            weights += penalty_weights
            temperature_loss += penalty_temperature_loss

        # Decompose the policy into fixed-mean and fixed-std distributions.
        fixed_std_distribution = independent_normals(
            distributions.base_dist, target_distributions.base_dist)
        fixed_mean_distribution = independent_normals(
            target_distributions.base_dist, distributions.base_dist)

        # Compute the decomposed policy losses.
        policy_mean_losses = (fixed_std_distribution.base_dist.log_prob(
            actions).sum(dim=-1) * weights).sum(dim=0)
        policy_mean_loss = -(policy_mean_losses).mean()
        policy_std_losses = (fixed_mean_distribution.base_dist.log_prob(
            actions).sum(dim=-1) * weights).sum(dim=0)
        policy_std_loss = -policy_std_losses.mean()

        # Compute the decomposed KL between the target and online policies.
        if self.per_dim_constraining:
            kl_mean = torch.distributions.kl.kl_divergence(
                target_distributions.base_dist,
                fixed_std_distribution.base_dist)
            kl_std = torch.distributions.kl.kl_divergence(
                target_distributions.base_dist,
                fixed_mean_distribution.base_dist)
        else:
            kl_mean = torch.distributions.kl.kl_divergence(
                target_distributions, fixed_std_distribution)
            kl_std = torch.distributions.kl.kl_divergence(
                target_distributions, fixed_mean_distribution)

        # Compute the alpha-weighted KL-penalty and dual losses.
        kl_mean_loss, alpha_mean_loss = parametric_kl_and_dual_losses(
            kl_mean, alpha_mean, self.epsilon_mean)
        kl_std_loss, alpha_std_loss = parametric_kl_and_dual_losses(
            kl_std, alpha_std, self.epsilon_std)

        # Combine losses.
        policy_loss = policy_mean_loss + policy_std_loss
        kl_loss = kl_mean_loss + kl_std_loss
        dual_loss = alpha_mean_loss + alpha_std_loss + temperature_loss
        loss = policy_loss + kl_loss + dual_loss

        loss.backward()
        if self.gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(
                self.actor_variables, self.gradient_clip)
            torch.nn.utils.clip_grad_norm_(
                self.dual_variables, self.gradient_clip)
        self.actor_optimizer.step()
        self.dual_optimizer.step()

        dual_variables = dict(
            temperature=temperature.detach(), alpha_mean=alpha_mean.detach(),
            alpha_std=alpha_std.detach())
        if self.action_penalization:
            dual_variables['penalty_temperature'] = \
                penalty_temperature.detach()

        return dict(
            policy_mean_loss=policy_mean_loss.detach(),
            policy_std_loss=policy_std_loss.detach(),
            kl_mean_loss=kl_mean_loss.detach(),
            kl_std_loss=kl_std_loss.detach(),
            alpha_mean_loss=alpha_mean_loss.detach(),
            alpha_std_loss=alpha_std_loss.detach(),
            temperature_loss=temperature_loss.detach(),
            **dual_variables)


class WiseCriticExpectedSARSA(updaters.ExpectedSARSA):
    def __call__(
            self, observations, actions, next_observations, rewards, discounts
    ):
        actor_next_observations = reduce_actor_observations(next_observations)

        # Approximate the expected next values.
        with torch.no_grad():
            next_target_distributions = self.model.target_actor(
                actor_next_observations)
            next_actions = next_target_distributions.rsample(
                (self.num_samples,))
            next_actions = updaters.merge_first_two_dims(next_actions)
            next_observations = updaters.tile(
                next_observations, self.num_samples)
            next_observations = updaters.merge_first_two_dims(
                next_observations)
            next_values = self.model.target_critic(
                next_observations, next_actions)
            next_values = next_values.view(self.num_samples, -1)
            next_values = next_values.mean(dim=0)
            returns = rewards + discounts * next_values

        self.optimizer.zero_grad()
        values = self.model.critic(observations, actions)
        loss = self.loss(returns, values)

        loss.backward()
        if self.gradient_clip > 0:
            torch.nn.utils.clip_grad_norm_(self.variables, self.gradient_clip)
        self.optimizer.step()

        return dict(loss=loss.detach(), q=values.detach())
