"""Defining the PPO loss for actor critic type models."""
import abc
import math
import typing
from typing import Dict, Union
from typing import Optional

import torch
import torch.nn.functional as F

from onpolicy_sync.losses import A2C, PPO
from onpolicy_sync.losses.abstract_loss import AbstractActorCriticLoss
from rl_base.common import ActorCriticOutput
from rl_base.distributions import CategoricalDistr
from utils.experiment_utils import Builder


class AlphaScheduler(abc.ABC):
    def next(self, step_count: int, *args, **kwargs):
        raise NotImplementedError


class AdvisorImitationStage(AbstractActorCriticLoss):
    """Implementation of the Advisor loss' stage 1 when main and auxiliary
    actors are equally weighted."""

    def loss(  # type: ignore
        self,
        step_count: int,
        batch: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]],
        actor_critic_output: ActorCriticOutput[CategoricalDistr],
        *args,
        **kwargs
    ):
        # Imitation calculation
        observations = typing.cast(Dict[str, torch.Tensor], batch["observations"])
        if "expert_action" in observations:
            expert_actions_and_mask = observations["expert_action"]
            assert expert_actions_and_mask.shape[-1] == 2
            expert_actions_and_mask_reshaped = expert_actions_and_mask.view(-1, 2)

            expert_actions = expert_actions_and_mask_reshaped[:, 0].view(
                *expert_actions_and_mask.shape[:-1], 1
            )
            expert_actions_masks = (
                expert_actions_and_mask_reshaped[:, 1]
                .float()
                .view(*expert_actions_and_mask.shape[:-1], 1)
            )

            expert_successes = expert_actions_masks.sum()
            if expert_successes.item() == 0:
                return 0, {}

            main_expert_log_probs = actor_critic_output.distributions.log_probs(
                typing.cast(torch.LongTensor, expert_actions)
            )
            aux_expert_log_probs = actor_critic_output.extras[
                "auxiliary_distributions"
            ].log_probs(typing.cast(torch.LongTensor, expert_actions))

            aux_expert_ce_loss = -(
                expert_actions_masks * aux_expert_log_probs
            ).sum() / torch.clamp(expert_successes, min=1)
        elif "expert_policy" in observations:
            raise NotImplementedError()
        else:
            raise NotImplementedError(
                "Imitation loss requires either `expert_action` or `expert_policy`"
                " sensor to be active."
            )

        main_expert_ce_loss = -(
            expert_actions_masks * main_expert_log_probs
        ).sum() / torch.clamp(expert_successes, min=1)
        total_loss = main_expert_ce_loss + aux_expert_ce_loss

        # weights := torch.ones_like(main_expert_log_probs)
        # Hence weights for ppo := torch.zeros_like(main_expert_log_probs)

        return (
            total_loss,
            {
                "main_expert_ce_loss": main_expert_ce_loss.item(),
                "aux_expert_ce_loss": aux_expert_ce_loss.item(),
                "total_loss": total_loss.item(),
            },
        )


class AdvisorWeightedStage(AbstractActorCriticLoss):
    """Implementation of the Advisor loss' second stage (simplest variant).

    # Attributes

    rl_loss: The RL loss to use, should be a loss object of type `PPO` or `A2C`
        (or a `Builder` that when called returns such a loss object).
    alpha : Exponent to use when reweighting the expert cross entropy loss.
        Larger alpha means an (exponentially) smaller weight assigned to the cross entropy
        loss. E.g. if a the weight with alpha=1 is 0.6 then with alpha=2 it is 0.6^2=0.36.
    bound : If the distance from the auxilary policy to expert policy is greater than
        this bound then the distance is set to 0.
    alpha_scheduler : An object of type `AlphaScheduler` which is before computing the loss
        in order to get a new value for `alpha`.
    smooth_expert_weight_decay : If not None, will redistribute (smooth) the weight assigned to the cross
        entropy loss at a particular step over the following `smooth_expert_steps` steps. Values
        of `smooth_expert_weight_decay` near 1 will increase how evenly weight is assigned
        to future steps. Values near 0 will decrease how evenly this weight is distributed
        with larger weight being given steps less far into the `future`.
        Here `smooth_expert_steps` is automatically defined from `smooth_expert_weight_decay` as detailed below.
    smooth_expert_steps : The number of "future" steps over which to distribute the current steps weight.
        This value is computed as `math.ceil(-math.log(1 + ((1 - r) / r) / 0.05) / math.log(r)) - 1` where
        `r=smooth_expert_weight_decay`. This ensures that the weight is always distributed over at least
        one additional step and that it is never distributed more than 20 steps into the future.
    """

    def __init__(
        self,
        rl_loss: Optional[Union[Union[PPO, A2C], Builder[Union[PPO, A2C]]]],
        fixed_alpha: Optional[float] = 1,
        fixed_bound: Optional[float] = 0.1,
        alpha_scheduler: AlphaScheduler = None,
        smooth_expert_weight_decay: Optional[float] = None,
        *args,
        **kwargs
    ):
        """Initializer.

        See the class documentation for parameter definitions not included below.

        fixed_alpha: This fixed value of `alpha` to use. This value is *IGNORED* if
            alpha_scheduler is not None.
        fixed_bound: This fixed value of the `bound` to use.
        """
        assert len(kwargs) == len(args) == 0

        super().__init__(*args, **kwargs)
        self.rl_loss: Union[PPO, A2C]
        if isinstance(rl_loss, Builder):
            self.rl_loss = rl_loss()
        else:
            self.rl_loss = rl_loss

        self.alpha = fixed_alpha
        self.bound = fixed_bound
        self.alpha_scheduler = alpha_scheduler
        self.smooth_expert_weight_decay = smooth_expert_weight_decay
        assert smooth_expert_weight_decay is None or (
            0 < smooth_expert_weight_decay < 1
        ), "`smooth_expert_weight_decay` must be between 0 and 1."
        if smooth_expert_weight_decay is not None:
            r = smooth_expert_weight_decay

            self.smooth_expert_steps = (
                math.ceil(-math.log(1 + ((1 - r) / r) / 0.05) / math.log(r)) - 1
            )

    def loss(  # type: ignore
        self,
        step_count: int,
        batch: Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]],
        actor_critic_output: ActorCriticOutput[CategoricalDistr],
        *args,
        **kwargs
    ):
        if self.alpha_scheduler is not None:
            self.alpha = self.alpha_scheduler.next(step_count=step_count)

        # Imitation calculation
        observations = typing.cast(Dict[str, torch.Tensor], batch["observations"])
        if "expert_action" in observations:
            expert_actions_and_mask = observations["expert_action"]
            assert expert_actions_and_mask.shape[-1] == 2
            expert_actions_and_mask_reshaped = expert_actions_and_mask.view(-1, 2)

            expert_actions = expert_actions_and_mask_reshaped[:, 0].view(
                *expert_actions_and_mask.shape[:-1], 1
            )
            expert_actions_masks = (
                expert_actions_and_mask_reshaped[:, 1]
                .float()
                .view(*expert_actions_and_mask.shape[:-1], 1)
            )

            expert_successes = expert_actions_masks.sum()
            if expert_successes.item() == 0:
                return 0, {}

            main_expert_neg_cross_entropy = actor_critic_output.distributions.log_probs(
                typing.cast(torch.LongTensor, expert_actions)
            )
            aux_expert_neg_cross_entropy = actor_critic_output.extras[
                "auxiliary_distributions"
            ].log_probs(typing.cast(torch.LongTensor, expert_actions))

            aux_expert_ce_loss = -(
                expert_actions_masks * aux_expert_neg_cross_entropy
            ).sum() / torch.clamp(expert_successes, min=1)
        elif "expert_policy" in observations:
            expert_policies = typing.cast(
                Dict[str, torch.Tensor], batch["observations"]
            )["expert_policy"][:, :-1]
            expert_actions_masks = typing.cast(
                Dict[str, torch.Tensor], batch["observations"]
            )["expert_policy"][:, -1:]

            expert_successes = expert_actions_masks.sum()
            if expert_successes.item() == 0:
                return 0, {}

            main_expert_log_probs_tensor = (
                actor_critic_output.distributions.log_probs_tensor
            )
            main_expert_neg_cross_entropy = (
                (main_expert_log_probs_tensor * expert_policies).sum(-1).unsqueeze(-1)
            )

            aux_expert_log_probs_tensor = actor_critic_output.extras[
                "auxiliary_distributions"
            ].log_probs_tensor

            aux_expert_neg_cross_entropy = (
                (aux_expert_log_probs_tensor * expert_policies).sum(-1).unsqueeze(-1)
            )
            aux_expert_ce_loss = (
                -aux_expert_neg_cross_entropy * expert_actions_masks
            ).sum() / expert_successes
        else:
            raise NotImplementedError(
                "Imitation loss requires either `expert_action` or `expert_policy`"
                " sensor to be active."
            )

        # TODO: Get rid of hardcoded constants
        top_bound = math.log(self.bound)

        use_expert_weights = (
            torch.exp(self.alpha * aux_expert_neg_cross_entropy)
            * expert_actions_masks
            * (aux_expert_neg_cross_entropy >= top_bound).float()
        ).detach()

        if self.smooth_expert_weight_decay:
            # Here we smooth `use_expert_weights` so that a weight p assigned
            # to a step at time t is redisributed to steps
            # t, t+1, ..., t + self.smooth_expert_steps. This redistribution of
            # weight p is not allowed to pass from one episode to another and so
            # batch["masks"] must be used to prevent this.
            _, nrollouts, _ = typing.cast(
                torch.Tensor, batch["recurrent_hidden_states"]
            ).shape

            start_shape = use_expert_weights.shape
            use_expert_weights = use_expert_weights.view(-1, nrollouts)

            padded_weights = F.pad(
                use_expert_weights, [0, 0, self.smooth_expert_steps, 0]
            )
            masks = typing.cast(torch.Tensor, batch["masks"]).view(-1, nrollouts)
            padded_masks = F.pad(masks, [0, 0, self.smooth_expert_steps, 0])
            divisors = torch.ones_like(masks)  # Keep track of normalizing constants
            for i in range(1, self.smooth_expert_steps + 1):
                # Modify `use_expert_weights` so that weights are now computed as a
                # weighted sum of previous weights.
                masks = masks * padded_masks[self.smooth_expert_steps - i : -i, :]
                use_expert_weights += (
                    self.smooth_expert_weight_decay ** i
                ) * padded_weights[self.smooth_expert_steps - i : -i, :]
                divisors += masks * (self.smooth_expert_weight_decay ** i)
            use_expert_weights /= divisors
            use_expert_weights = use_expert_weights.view(*start_shape)

        # noinspection PyTypeChecker
        use_rl_weights = 1 - use_expert_weights

        weighted_main_expert_ce_loss = -(
            use_expert_weights * main_expert_neg_cross_entropy
        ).mean()

        total_loss = aux_expert_ce_loss + weighted_main_expert_ce_loss
        output_dict = {
            "aux_expert_ce_loss": aux_expert_ce_loss.item(),
            "weighted_main_expert_ce_loss": weighted_main_expert_ce_loss.item(),
            "non_zero_weight": (use_expert_weights > 0).float().mean().item(),
            "weight": use_expert_weights.mean().item(),
        }

        # RL Loss Computation
        if self.rl_loss is not None:
            rl_losses = self.rl_loss.loss_per_step(
                step_count=step_count,
                batch=batch,
                actor_critic_output=actor_critic_output,
            )

            action_loss, rl_action_loss_weight = rl_losses["action"]
            assert rl_action_loss_weight is None
            entropy_loss, rl_entropy_loss_weight = rl_losses["entropy"]

            def reweight(loss, w):
                return loss if w is None else loss * w

            weighted_action_loss = (
                use_rl_weights * (reweight(action_loss, rl_action_loss_weight))
            ).mean()

            weighted_entropy_loss = (
                use_rl_weights * reweight(entropy_loss, rl_entropy_loss_weight)
            ).mean()

            value_loss = rl_losses["value"][0].mean()
            total_loss += (
                (value_loss * rl_losses["value"][1])
                + weighted_action_loss
                + weighted_entropy_loss
            )
            output_dict.update(
                {
                    "value_loss": value_loss.item(),
                    "weighted_action_loss": weighted_action_loss.item(),
                    "entropy_loss": entropy_loss.mean().item(),
                }
            )

        output_dict["total_loss"] = total_loss.item()

        return (total_loss, output_dict)


class LinearAlphaScheduler(AlphaScheduler):
    def __init__(self, start: float, end: float, total_steps: int):
        self.start = start
        self.end = end
        self.total_steps = total_steps

    def next(self, step_count: int, *args, **kwargs):
        p = min(step_count / self.total_steps, 1)
        return self.start * (1.0 - p) + self.end * p
