"""Custom policies for SB3 utilizing polynomial approximators. Adapted from SB3 custom_policy.rst.
"""
import copy
from typing import Any, Optional
import numpy as np
import torch as th
from gymnasium import spaces
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import PyTorchObs
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import PyTorchObs
from polyagents.polynomial_approximators import PolynomialARSApproximator, PolynomialPPOApproximator
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import PyTorchObs, Schedule
from stable_baselines3.common.preprocessing import get_action_dim, get_flattened_obs_dim
from stable_baselines3.common.torch_layers import (
    BaseFeaturesExtractor,
    FlattenExtractor
)

__author__ = "anonymizedforblindreview"
__version__ = "0.1"
__email__ = "anonymizedforblindreview"


# Fixed sigma schedule
STD_INIT = 0.99
STD_END = 0.20


class PolynomialPPOPolicy(BasePolicy):
    """
    Adapted from stable_baselines3.common.policies
    Policy and value functions are represented as polynomial approximators.
    
    Policy class for PPO algorithm (has both policy and value prediction).

    :param observation_space: Observation space
    :param action_space: Action space
    :param lr_schedule: Learning rate schedule (could be constant)
    :param use_sde: Whether to use State Dependent Exploration or not
    :param use_expln: Only used when use_fixed_std_schedule=False. linear when x > 0, exponential when x < 0
    :param use_fixed_std_schedule: A fixed std schedule is used that declines exponentially from STD_INIT to STD_END
    :param degree: Polynomial degree of policy and critic approximators
    :param basis: Polynomial basis to use for all approximators
    :param initialization: Coefficient initialization method to use for policy and critic approximators
    :param coeffs: Polynomial coefficients to initialize approximators with
    :param log_std_batch_size: Queue size to calculate running std for logging when use_fixed_std_schedule = False 
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Schedule,
        use_sde: bool = False,
        use_expln: bool = True,
        squash_output: bool = False,
        optimizer_class: type[th.optim.Optimizer] = th.optim.Adam,
        optimizer_kwargs: Optional[dict[str, Any]] = None,
        features_extractor_class: type[BaseFeaturesExtractor] = FlattenExtractor,
        features_extractor_kwargs: Optional[dict[str, Any]] = None,
        use_fixed_std_schedule: bool = False,
        degree: int = 3,
        basis: str = "chebyshev",
        initialization: str = "random",
        coeffs: Optional[Any] = None,
        log_std_batch_size: int = 64,
        *args,
        **kwargs,
    ):
        self.degree = degree
        self.basis = basis
        self.initialization = initialization
        self.initialization_coeffs = coeffs
        self.log_std_batch_size = log_std_batch_size
        self.use_expln = use_expln
        self.use_fixed_std_schedule = use_fixed_std_schedule
        self.current_progress_remaining = 1.0
        self.training_mode = False
        self.std = 1.0

        if optimizer_kwargs is None:
            optimizer_kwargs = {}
            # Small values to avoid NaN in Adam optimizer
            if optimizer_class == th.optim.Adam:
                optimizer_kwargs["eps"] = 1e-5

        kwargs["ortho_init"] = False # Disable orthogonal initialization
        super().__init__(
            observation_space,
            action_space,
            features_extractor_class,
            features_extractor_kwargs,
            optimizer_class=optimizer_class,
            optimizer_kwargs=optimizer_kwargs,
            *args, # Pass remaining arguments to base class
        )

        self.features_dim = get_flattened_obs_dim(observation_space)
        self.features_extractor = self.make_features_extractor()

        if squash_output or use_sde: # squash_output=True is only available when using gSDE (use_sde=True)
            raise Exception('SDE not implemented.')

        self._build(lr_schedule)

        action_dim = get_action_dim(self.action_space)
        if action_dim > 1:
            raise Exception('Only one-dimensional action spaces supported at this time.')

    def get_std(self, evaluate=False):
        """
        Retrieve std parameter according to the current progress remaining (1 = 100% remaining) in case of fixed schedule.
        """
        if evaluate:
            return copy.deepcopy(self.std) # Ensure that std remains the same between rollout collection and evaluation, as current_progress_remaining is updated between both in on_policy_algorithm.py
        else:
            #self.std = th.tensor(STD_INIT * (STD_END / STD_INIT) ** (1 - self.current_progress_remaining))
            p = (1 - self.current_progress_remaining)
            k = 3
            alpha = 2.0
            self.std = th.tensor(STD_END + (STD_INIT - STD_END) * (1 + k * p) ** (-alpha))
            return copy.deepcopy(self.std)


    def update_log_std_list(self, log_std_list):
        """
        Update the log_std_list for usage in parent class.
        """
        self.log_std = log_std_list

    def train(self, mode):
        """Set approximator training mode, i.e. activate or deactivate computation graph."""
        self.training_mode = mode
        self.policy.set_training_mode(mode)

    def parameters(self):
        """
        Returns parameters of approximators.
        """
        return self.policy.parameters()

    def reset_noise(self, n_envs: int = 1) -> None:
        """
        Sample new weights for the exploration matrix.
        reset_noise() is only available when using gSDE
        """
        raise Exception("SDE not implemented.")

    def _build_mlp_extractor(self) -> None:
        raise Exception('No MLP Extractor is utilized, policy and value nets are realized by polynomial function approximators.')

    def _build(self, lr_schedule: Schedule) -> None:
        """
        Create the approximators and the optimizer.

        :param lr_schedule: Learning rate schedule
            lr_schedule(1) is the initial learning rate
        """
        self.policy = PolynomialPPOApproximator(feature_dim=self.features_dim, degree=self.degree, basis=self.basis, initialization=self.initialization, use_expln=self.use_expln, 
                                                coeffs=self.initialization_coeffs, update_log_std_in_parent_fn=self.update_log_std_list, 
                                                log_std_batch_size=self.log_std_batch_size, use_fixed_std_schedule=self.use_fixed_std_schedule, std_schedule_fn=self.get_std)
        self.degree = self.policy.degree # adapt degree in case coefficients were passed

        # Setup optimizer with initial learning rate
        self.optimizer = self.optimizer_class(self.policy.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs)

    def _predict(self, observation: PyTorchObs, deterministic: bool = False) -> th.Tensor:
        """
        Used outside training to get the action according to the policy for a given observation.

        :param state: The observation
        :param deterministic: Whether to use stochastic or deterministic actions
        :return: Taken action according to the policy
        """
        if deterministic:
            return self.policy.forward_actor(observation)
        else:
            return self.policy.sample_actions(observation)

    def forward(self, obs: th.Tensor, deterministic: bool = False) -> tuple[th.Tensor, th.Tensor, th.Tensor]:
        """
        Forward pass in all the networks (actor and critic)
        Used during learn() to collect action rollouts.
        Here, obs is a single observation.

        on_policy_algorithm usually calls this method in the context of "with torch.no_grad()".

        :param obs: Observation
        :param deterministic: Whether to sample or use deterministic actions
        :return: action, value and log probability of the action
        """
        # # Evaluate the values for the given observations
        # values = self.value_net(latent_vf)
        # distribution = self._get_action_dist_from_latent(latent_pi)
        # actions = distribution.get_actions(deterministic=deterministic)
        # log_prob = distribution.log_prob(actions)
        # actions = actions.reshape((-1, *self.action_space.shape))  # type: ignore[misc]

        # Preprocess/flatten the observation. This is needed to flatten nested lists of features
        features = self.extract_features(obs, self.features_extractor)
        if deterministic:
            actions = self.policy.forward_actor(features)
        else:
            actions = self.policy.sample_actions(features)

        #log_prob = self.policy.get_log_probs(features)

        dist = self.policy.get_action_dist(features)
        log_prob = self.policy.get_log_probs_from_dist(dist, actions)

        #actions = actions.reshape((-1, *self.action_space.shape))  # type: ignore[misc]
        values = self.policy.forward_critic(features)
        return actions, values, log_prob

    def evaluate_actions(self, obs: PyTorchObs, rollout_actions: th.Tensor) -> tuple[th.Tensor, th.Tensor, Optional[th.Tensor]]:
        """
        Evaluate actions according to the current policy, given the observations. 
        Used during train() to compute policy update.
        Here, obs contains all observations from this rollout and actions all the corresponding actions.
        
        on_policy_algorithm usually calls this method with the gradient computation graph active.

        :param obs: Observation
        :param actions: Actions
        :return: estimated value, log likelihood of taking those actions
            and entropy of the action distribution.
        """
        # Original common.policies.py code:
        # ...
        # assuming batch_size = default = 64, dim obs --> 1x64

        # # Preprocess/flatten the observation. This is needed to extract features
        # features = self.extract_features(obs)

        # # <-- Default option is to share the features_extractor. In any case, dim latent_pi and latent_vf --> both 64x64
        # if self.share_features_extractor:
        #     latent_pi, latent_vf = self.mlp_extractor(features) 
        # else:
        #     ...

        # # Next, the distributions. self._get_action_dist_from_latent does: mean_actions = self.action_net(latent_pi) --> Get 64 mean actions via the 64x64 "latent" actions. 
        # # Then it returns 64x1 std distributions using self.log_std 
        # # (dist. loc = 64x1 and dist.scale = 64x1, with scale being the same for all 64 entries = th.exp(self.log_std))
        # distribution = self._get_action_dist_from_latent(latent_pi)  

        # # Forward actions collected during rollout (with no grads) with dim 64x1 to distribution 64x1. Each std distribution (64) now corresponds to one rollout action
        # # Calculate log_prob of each rollout action (64) using the corresponding distribution of the 64 distributions
        # log_prob = distribution.log_prob(actions) 

        # entropy = distribution.entropy() # get entropies of newly computed action distributions
        
        features = self.extract_features(obs, self.features_extractor)
        values = self.policy.forward_critic(features) # newly compute values from observations
        dist = self.policy.get_action_dist(features, evaluate=True) # number of batch_size (default 64) distributions, were, at each observation location, the current policy approximators evaluate the mean and sigma at each observation location (64). Then calculate log_prob but with rollout actions.
        log_probs = self.policy.get_log_probs_from_dist(dist, rollout_actions.squeeze(-1))
        entropies = dist.entropy()
        return values, log_probs, entropies

    def predict_values(self, obs: PyTorchObs) -> th.Tensor:
        """
        Get the estimated values according to the current policy given the observations.

        :param obs: Observation
        :return: the estimated values.
        """
        features = self.extract_features(obs, self.features_extractor)
        return self.policy.forward_critic(features)
    

class PolynomialARSPolicy(BasePolicy):
    """
    Adapted from sb3_contrib.ars.policies

    :param observation_space: The observation space of the environment
    :param action_space: The action space of the environment
    :param degree: Polynomial degree of policy and critic approximators
    :param basis: Polynomial basis to use for all approximators
    :param initialization: Coefficient initialization method to use for policy and critic approximators
    :param coeffs: Polynomial coefficients to initialize approximators with
    """

    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        with_bias: bool = True,
        squash_output: bool = False,
        degree: int = 3,
        basis: str = "chebyshev",
        initialization: str = "random",
        coeffs: Optional[Any] = None,
        **kwargs,
    ):
        super().__init__(
            observation_space,
            action_space,
            squash_output=isinstance(action_space, spaces.Box) and squash_output,
        )
        self.degree = degree
        self.basis = basis
        self.initialization = initialization
        self.initialization_coeffs = coeffs

        self.features_extractor = self.make_features_extractor()
        self.features_dim = self.features_extractor.features_dim
        self.training_mode = False

        self.actor = PolynomialARSApproximator(feature_dim=self.features_dim, degree=self.degree, basis=self.basis, initialization=self.initialization, coeffs=self.initialization_coeffs)
        self.degree = self.actor.degree # adapt degree in case coefficients were passed

    def _get_constructor_parameters(self) -> dict[str, Any]:
        # data = super()._get_constructor_parameters() this adds normalize_images, which we don't support...
        data = dict(
            observation_space=self.observation_space,
            action_space=self.action_space,
            net_arch=self.net_arch,
            activation_fn=self.activation_fn,
        )
        return data

    def train(self, mode):
        """Set approximator training mode, i.e. activate or deactivate computation graph."""
        self.training_mode = mode
        self.actor.set_training_mode(mode)

    def forward(self, obs: PyTorchObs) -> th.Tensor:
        features = self.extract_features(obs, self.features_extractor)
        return self.actor.forward(features)

    def _predict(self, observation: PyTorchObs, deterministic: bool = True) -> th.Tensor:
        # Non deterministic action does not really make sense for ARS, we ignore this parameter for now.
        return self.actor.forward(observation)

    def parameters(self):
        """
        Returns polynomial coefficients.
        """
        return self.actor.policy.coeffs

    def load_from_vector(self, vector: np.ndarray) -> None:
        """
        Load parameters from a 1D vector.

        :param vector:
        """
        with th.no_grad():
            self.actor.policy.coeffs.copy_(vector)

    def parameters_to_vector(self) -> np.ndarray:
        """
        Convert the parameters to a 1D vector.

        :return:
        """
        return self.parameters().detach().cpu().numpy()