"""Custom approximators for SB3 policies utilizing polynomial bases.  
"""
import math
import torch as th
from torch.distributions import Normal
from torchrl.modules.utils.mappings import expln
from typing import Tuple, Callable
from collections import deque
from polyagents import multivariate_polynomial_basis as multivarpoly


__author__ = "anonymizedforblindreview"
__version__ = "0.1"
__email__ = "anonymizedforblindreview"


INITIAL_LOG_SIGMA = -0.1 # sigma = ~0.9, expln starts exponential


class PolynomialPPOApproximator():
    """
    Custom approximator for PPO policy and value function.
    Policy and value functions are represented as gaussian polynomial approximators.

    :param feature_dim: Observations space dimension
    :param degree: Polynomial (max-)degree.
    :param basis: Polynomial basis. chebyshev, bernstein or power.
    :param initialization: Initialization of polynomial coefficients. See class MultiVarPoly for details.
    :param coeffs: Polynomial coefficients to initialize approximators with
    :param use_expln: Only used when use_fixed_std_schedule=False. linear when x > 0, exponential when x < 0
    :param log_std_batch_size: Queue size to calculate running std for logging when use_fixed_std_schedule = False 
    :param update_log_std_in_parent_fn: Function to call for updating std logging value in parent
    :param use_fixed_std_schedule: A fixed std schedule is used that declines exponentially from STD_INIT to STD_END
    :param std_schedule_fn: Function representing std schedule
    """

    def __init__(
        self,
        feature_dim: int,
        degree: int = 3,
        basis: str = 'chebyshev',
        initialization: str = 'random',
        coeffs: list = None,
        use_expln: bool = True,
        log_std_batch_size: int = 64,
        update_log_std_in_parent_fn: Callable[..., None] = None,
        use_fixed_std_schedule: bool = False,
        std_schedule_fn: Callable[..., None] = None,
    ):
        
        self.degree = degree
        self.basis = basis
        self.use_expln = use_expln
        self.training_mode = False
        self.use_fixed_std_schedule = use_fixed_std_schedule

        policy_coeffs = None
        value_coeffs = None
        sigma_coeffs = None 

        if coeffs is not None:
            try:
                policy_coeffs = coeffs[0]
                value_coeffs = coeffs[1]
                sigma_coeffs = coeffs[2]
            except:
                raise Exception('Error parsing coefficients')        

        self.policy_approximator = multivarpoly.MultiVarPoly(dim=feature_dim, degree=degree, basis=basis, initialization=initialization, coeffs=policy_coeffs)
        self.value_approximator = multivarpoly.MultiVarPoly(dim=feature_dim, degree=degree, basis=basis, initialization=initialization, coeffs=value_coeffs)
        self.sigma_approximator = multivarpoly.MultiVarPoly(dim=feature_dim, degree=2, basis=basis, initialization='flat', flat_init_offset=INITIAL_LOG_SIGMA, coeffs=sigma_coeffs) # use fixed smaller degree to improve convergence behaviour

        # Adapt degree in case coefficients were passed
        self.degree = self.policy_approximator.degree

        # log_std history, mean of this list will be used by ppo class for logging "train/std"
        self.log_std_list = deque(maxlen=log_std_batch_size)

        # Function to update log_std in owner class
        self.update_log_std_in_parent_fn = update_log_std_in_parent_fn

        # Fixed log_std_schedule
        self.std_schedule_fn = std_schedule_fn

    def set_training_mode(self, mode):
        """Set approximator training mode, i.e. activate or deactivate computation graph."""
        self.training_mode = mode
        self.policy_approximator.set_training_mode(mode)
        self.sigma_approximator.set_training_mode(mode)
        self.value_approximator.set_training_mode(mode)

    def parameters(self):
        """
        Returns parameters of function approximators.
        Used to initialize optimizer and by ppo.py calling th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm).
        """
        return [self.policy_approximator.coeffs, self.value_approximator.coeffs, self.sigma_approximator.coeffs]
        #return [self.policy_approximator.coeffs, self.value_approximator.coeffs] # in case of fixed log_std schedule

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        Forward one or multiple observations and get corresponding action(s) and value(s)
        """
        return self.forward_actor(features), self.forward_critic(features)

    def forward_actor(self, features: th.Tensor) -> th.Tensor:
        """
        Forward one or multiple observations and get corresponding action(s)
        """
        return self.policy_approximator.forward(features)

    def forward_critic(self, features: th.Tensor) -> th.Tensor:
        """
        Forward one or multiple observations and get corresponding value(s)
        """
        return self.value_approximator.forward(features)

    def sample_actions(self, observations):
        """
        Sample multiple actions according to given mus and sigmas of a Normal distribution.
        """
        if observations.dim() == 1:
            mu = self.policy_approximator.forward(observations)
            sigma = self.evaluate_std_at(observations)
            return th.distributions.Normal(loc=mu, scale=sigma).rsample()
        else:
            return th.stack([th.distributions.Normal(loc=self.policy_approximator.forward(o), scale=self.evaluate_std_at(o)).rsample() for o in observations]) # Use rsample() / reparametrization trick to pass gradients

    def sum_independent_dims(self, tensor: th.Tensor) -> th.Tensor:
        """
        Code Taken from stable baselines3 common.distributions.Distribution:
        Continuous actions are usually considered to be independent,
        so we can sum components of the ``log_prob`` or the entropy.

        :param tensor: shape: (n_batch, n_actions) or (n_batch,)
        :return: shape: (n_batch,) for (n_batch, n_actions) input, scalar for (n_batch,) input
        """
        if len(tensor.shape) > 1:
            tensor = tensor.sum(dim=1)
        else:
            tensor = tensor.sum()
        return tensor

    def get_log_probs(self, obs):
        """
        Get log probabilities at locations
        """
        if obs.flatten().dim() == 1:
            return self.evaluate_log_std_at(obs)
        else:
            return th.stack([self.evaluate_log_std_at(o.flatten()) for o in obs])

    def get_action_dist(self, obs, evaluate=False):
        mus = []
        sigmas = []

        for o in obs:
            mus.append(self.policy_approximator.forward(o)) # deterministic mu at o
            sigmas.append(self.evaluate_std_at(o, evaluate))

        return Normal(loc=th.stack(mus).squeeze(-1), scale=th.stack(sigmas).squeeze(-1)) 

    def get_log_probs_from_dist(self, dist, actions):
        log_probs = dist.log_prob(actions)
        #return self.sum_independent_dims(log_probs) # needed for multi-dimensional action spaces, which are not supported by this class at the moment
        return log_probs

    def evaluate_std_at(self, observation, evaluate=False):
        """ 
        expln is a smooth, continuous positive mapping presented in "State-Dependent Exploration for Policy Gradient Methods".
        https://people.idsia.ch/~juergen/ecml2008rueckstiess.pdf

        It behaves linearly for positive inputs (x >= 0  -->  x + 1) and exponentially for negative inputs (x < 0   -->  exp(x)).
        """
        if not self.use_fixed_std_schedule:
            log_std = self.sigma_approximator.forward(observation)
            self.log_std_list.append(log_std)

            if self.update_log_std_in_parent_fn is not None:
                self.update_log_std_in_parent_fn(th.stack(list(self.log_std_list)))

            if self.use_expln:
                return expln(log_std) # TODO: Investigate: If no MINIMAL_SIGMA is added here, training will result in NaN coefficients soon
            else:
                return th.exp(log_std) # TODO: dito
        else:
            return self.std_schedule_fn(evaluate)

    def evaluate_log_std_at(self, observation):
        """ Returns log_std for an observation.
        """
        if not self.use_fixed_std_schedule:
            return self.sigma_approximator.forward(observation)
        else:
            return th.tensor(math.log(self.std_schedule_fn()))


class PolynomialARSApproximator():
    """
    Custom approximator for ARS linear policy represented as polynomial approximator.

    :param feature_dim: Observations space dimension
    :param degree: Polynomial (max-)degree.
    :param basis: Polynomial basis. chebyshev, bernstein or power.
    :param initialization: Initialization of polynomial coefficients. See class MultiVarPoly for details.
    """

    def __init__(
        self,
        feature_dim: int,
        degree: int = 3,
        basis: str = 'chebyshev',
        initialization: str = 'random',
        coeffs: list = None
    ):
        super().__init__()

        self.policy = multivarpoly.MultiVarPoly(dim=feature_dim, degree=degree, basis=basis, initialization=initialization, coeffs=coeffs)

        # Adapt degree in case coefficients were passed
        self.degree = self.policy.degree

        self.training_mode = False

    def set_training_mode(self, mode):
        """Set approximator training mode, i.e. activate or deactivate computation graph."""
        self.training_mode = mode
        self.policy.set_training_mode(mode)

    def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
        """
        Forward one or multiple observations and get corresponding action(s)
        """
        return self.policy.forward(features)