import gym
import torch as th
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

from ocr.tools import Tensor
from ocr.transformer.transformer_module import Transformer_Module


class TransformerFeaturesExtractor(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, slate, transformer_kwargs):
        transformer = Transformer_Module(**transformer_kwargs)
        super().__init__(observation_space, features_dim=transformer.config.d_model)
        self.transformer = transformer

        # Have to do this as SLATE has bool parameters and thus brakes polyak updates in sb3
        self.slate = [slate]

    def forward(self, observations: Tensor) -> Tensor:
        return self.transformer(self.slate[0](observations))


class SLATEMean(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, slate):
        super().__init__(observation_space, features_dim=slate._config.slotattr.slot_size)
        self.slate = [slate]

    def forward(self, observations: th.Tensor, *args, **kwargs) -> th.Tensor:
        return self.slate[0](observations).mean(dim=1)
