import torch as th
import torch.nn as nn
import torchvision.transforms as T
from stable_baselines3.common.torch_layers import NatureCNN


class ResizeFeatureExtractors(nn.Module):

    def __init__(self, features_extractor: NatureCNN):
        super().__init__()

        self.features_extractor = features_extractor
        self.resizer = T.Resize(features_extractor._observation_space.shape[1])

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.features_extractor(self.resizer(observations))
