import torch
import torch.nn as nn
from gymnasium import spaces
from torchvision.models import resnet18, ResNet18_Weights
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor


class MetaWorldResNet(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, mode='train'):
        super().__init__(observation_space, features_dim=512)

        self.resnet = resnet18(weights=ResNet18_Weights.DEFAULT)
        self.resnet.fc = nn.Identity()
        self.preprocess = ResNet18_Weights.DEFAULT.transforms()
        self.mode = mode

    def forward(self, obs):
        if self.mode == 'train':
            return self.resnet(self.preprocess(obs))
        elif self.mode == 'eval':
            with torch.no_grad():
                return self.resnet(self.preprocess(obs))
