from typing import Optional, Type, Union
from stable_baselines3.common.policies import ActorCriticPolicy
import torch as th
from stable_baselines3.ppo.ppo import PPO
from stable_baselines3.common.save_util import load_from_zip_file, recursive_setattr
from stable_baselines3.common.utils import check_for_correct_spaces

from ...utils.features_extractor import ResizeFeatureExtractors
from .policy import ActorCriticPolicyWarmStartWrapper

class PPOWarmStart(PPO):

    def __init__(
        self,
        policy: Union[str, Type[ActorCriticPolicy]],
        *args,
        **kwargs,
    ):
        policy = ActorCriticPolicyWarmStartWrapper(
            policy
        )

        super().__init__(policy, *args, **kwargs)

    def load_policy(self, path: str):
        policy = PPO.load(path).policy
        self.policy.load_policy_reuse(policy)
