from typing import Any, Dict, Optional, Type, Union

import gym
import torch.nn as nn
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.ppo.policies import (ActorCriticCnnPolicy,
                                            ActorCriticPolicy)
from stable_baselines3.ppo.ppo import PPO

from ...utils.pre_conditions import (CNNGridWorldActionPreConditionWrapper,
                                     GridWorldActionPreConditionWrapper)
from ..classifier import MaskRolloutBuffer, PPOClassifier
from .policy import ActorCriticPolicyCombinedWrapper


class PPOCombined(PPOClassifier):
    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        *args,
        env: Union[GymEnv, str],
        buffer_class=MaskRolloutBuffer,
        classifier=None,
        train_classifier=True,
        balanced_sampling=True,
        policy_kwargs: Optional[Dict[str, Any]] = None,
        wrapper_class: Optional[Type[
            gym.core.Wrapper
        ]] = None,
        **kwargs,
    ):

        self.classifier = classifier
        self.train_classifier = train_classifier
        self.balanced_sampling = balanced_sampling
        self.buffer_class = buffer_class or (
            DictRolloutBuffer
            if isinstance(self.observation_space, gym.spaces.Dict)
            else RolloutBuffer
        )

        self.is_cnn_classifier = isinstance(self.classifier, nn.Module)
        policy = ActorCriticPolicyCombinedWrapper(
            policy, self.is_cnn_classifier
        )

        if wrapper_class is None:
            wrapper_class = (
                CNNGridWorldActionPreConditionWrapper
                if issubclass(policy, ActorCriticCnnPolicy)
                else GridWorldActionPreConditionWrapper
            )
        
        if isinstance(env, VecEnv):
            env.envs = [wrapper_class(e) for e in env.envs]
            single_env = env.envs[0]
        else:
            env = wrapper_class(env)
            single_env = env

        if policy_kwargs is None:
            policy_kwargs = {}
        policy_kwargs.setdefault("env", single_env)

        PPO.__init__(self, policy, env, *args, policy_kwargs=policy_kwargs, **kwargs)

    def train(self) -> None:
        self.policy.update_classifier_proba_weight(self._current_progress_remaining)
        return super().train()

    def learn(
        self,
        total_timesteps: int,
        callback: MaybeCallback = None,
        log_interval: int = 1,
        eval_env: Optional[GymEnv] = None,
        eval_freq: int = -1,
        n_eval_episodes: int = 5,
        tb_log_name: str = "PPOCombined",
        eval_log_path: Optional[str] = None,
        reset_num_timesteps: bool = True,
    ) -> "PPOCombined":

        return super().learn(
            total_timesteps=total_timesteps,
            callback=callback,
            log_interval=log_interval,
            eval_env=eval_env,
            eval_freq=eval_freq,
            n_eval_episodes=n_eval_episodes,
            tb_log_name=tb_log_name,
            eval_log_path=eval_log_path,
            reset_num_timesteps=reset_num_timesteps,
        )
