import math
from typing import Type, Union

import gym
import numpy as np
import torch as th
from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.distributions import CategoricalDistribution
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.save_util import (load_from_zip_file,
                                                recursive_setattr)
from stable_baselines3.common.type_aliases import GymEnv
from stable_baselines3.common.utils import (check_for_correct_spaces,
                                            get_schedule_fn, obs_as_tensor)
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.ppo.policies import ActorCriticPolicy
from stable_baselines3.ppo.ppo import PPO

from ...utils.features_extractor import ResizeFeatureExtractors


class PPOPolicyReuse(PPO):
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        env: Union[GymEnv, str],
        policy_reuse_epsilon: float = 0.5,
        **kwargs
    ):
        self.policy_reuse = None
        self.policy_reuse_epsilon = policy_reuse_epsilon

        self._policy_reuse_epsilon_schedule = get_schedule_fn(self.policy_reuse_epsilon)
        self._policy_reuse_epsilon = self._policy_reuse_epsilon_schedule(1.)

        # kwargs["max_grad_norm"] = 0.2
        # kwargs["learning_rate"] = 1e-4

        super().__init__(policy, env, **kwargs)

    def load_policy_reuse(self, path):

        policy = PPO.load(path).policy
        
        policy.features_extractor = ResizeFeatureExtractors(policy.features_extractor)
        
        if self.policy.action_dist.action_dim != policy.action_dist.action_dim:
            if self.policy.action_dist.action_dim < policy.action_dist.action_dim:
                mask = th.tensor([float(i < self.policy.action_dist.action_dim) for i in range(policy.action_dist.action_dim)]).bool()
                rebalance_function = lambda a: a - (mask.bitwise_not() * math.inf).nan_to_num(0.0)
            elif self.policy.action_dist.action_dim > policy.action_dist.action_dim:
                unknown_actions_num = self.policy.action_dist.action_dim - policy.action_dist.action_dim
                rebalance_function = lambda a: th.concat([
                    a, 
                    a.mean(axis=1).repeat(unknown_actions_num, 1).T
                ], axis=1)

            def _get_action_dist_from_latent_override(p, latent_pi):
                mean_actions = p.action_net(latent_pi)
                mean_actions = rebalance_function(mean_actions)
                if isinstance(p.action_dist, CategoricalDistribution):
                    return p.action_dist.proba_distribution(action_logits=mean_actions)
                else:
                    raise ValueError("Inapplicable action distribution")

            policy._get_action_dist_from_latent = lambda pi, p=policy: _get_action_dist_from_latent_override(p, pi)

        self.policy_reuse = policy
        self.policy_reuse.set_training_mode(False)

    def update_policy_reuse_epsilon(self, current_progress_remaining) -> None:
        self._policy_reuse_epsilon = self._policy_reuse_epsilon_schedule(current_progress_remaining)

    def collect_rollouts(
        self,
        env: VecEnv,
        callback: BaseCallback,
        rollout_buffer: RolloutBuffer,
        n_rollout_steps: int,
    ) -> bool:
        """
        Collect experiences using the current policy and fill a ``RolloutBuffer``.
        The term rollout here refers to the model-free notion and should not
        be used with the concept of rollout used in model-based RL or planning.

        :param env: The training environment
        :param callback: Callback that will be called at each step
            (and at the beginning and end of the rollout)
        :param rollout_buffer: Buffer to fill with rollouts
        :param n_steps: Number of experiences to collect per environment
        :return: True if function returned with at least `n_rollout_steps`
            collected, False if callback terminated rollout prematurely.
        """
        assert self._last_obs is not None, "No previous observation was provided"
        # Switch to eval mode (this affects batch norm / dropout)
        self.policy.set_training_mode(False)

        n_steps = 0
        rollout_buffer.reset()
        # Sample new weights for the state dependent exploration
        if self.use_sde:
            self.policy.reset_noise(env.num_envs)

        callback.on_rollout_start()

        while n_steps < n_rollout_steps:
            if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:
                # Sample a new noise matrix
                self.policy.reset_noise(env.num_envs)

            with th.no_grad():
                # Convert to pytorch tensor or to TensorDict
                obs_tensor = obs_as_tensor(self._last_obs, self.device)
                if self._policy_reuse_epsilon > np.random.random():
                    actions, _values, _log_probs = self.policy_reuse.forward(obs_tensor)
                    values, log_probs, _ = self.policy.evaluate_actions(obs_tensor, actions)
                else:
                    actions, values, log_probs = self.policy.forward(obs_tensor)
            actions = actions.cpu().numpy()

            # Rescale and perform action
            clipped_actions = actions
            # Clip the actions to avoid out of bound error
            if isinstance(self.action_space, gym.spaces.Box):
                clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)

            new_obs, rewards, dones, infos = env.step(clipped_actions)

            self.num_timesteps += env.num_envs

            # Give access to local variables
            callback.update_locals(locals())
            if callback.on_step() is False:
                return False

            self._update_info_buffer(infos)
            n_steps += 1

            if isinstance(self.action_space, gym.spaces.Discrete):
                # Reshape in case of discrete action
                actions = actions.reshape(-1, 1)

            # Handle timeout by bootstraping with value function
            # see GitHub issue #633
            for idx, done in enumerate(dones):
                if (
                    done
                    and infos[idx].get("terminal_observation") is not None
                    and infos[idx].get("TimeLimit.truncated", False)
                ):
                    terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]
                    with th.no_grad():
                        terminal_value = self.policy.predict_values(terminal_obs)[0]
                    rewards[idx] += self.gamma * terminal_value

            rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)
            self._last_obs = new_obs
            self._last_episode_starts = dones

        with th.no_grad():
            # Compute value for the last timestep
            values = self.policy.predict_values(obs_as_tensor(new_obs, self.device))

        rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)

        callback.on_rollout_end()

        return True

    def train(self) -> None:
        """
        Update policy using the currently gathered rollout buffer.
        """
        self.update_policy_reuse_epsilon(self._current_progress_remaining)
        return super().train()
