import gymnasium
import torch

from compression_autoencoder.envs.multireward_mountaincar import (
    MultiRewardMountainCarWrapper,
)
from compression_autoencoder.envs.multireward_reacher import MultiRewardReacherWrapper
from compression_autoencoder.utils.custom_layers import MinMaxScaler, StandardScalerFromUniform

INPUT_SCALERS = {
    "MountainCarContinuous-v0": StandardScalerFromUniform(
        lb=torch.tensor([-1.2, -0.07], dtype=torch.float32),
        ub=torch.tensor([0.6, 0.07], dtype=torch.float32),
    ),
    "AutoencoderMountainCarContinuous-v0": StandardScalerFromUniform(
        lb=torch.tensor([-2.5], dtype=torch.float32),
        ub=torch.tensor([2.5], dtype=torch.float32),
    ),
    "AutoencoderReacher-v5": StandardScalerFromUniform(
        lb=torch.tensor([-2.5], dtype=torch.float32),
        ub=torch.tensor([2.5], dtype=torch.float32),
    ),
    "Reacher-v5": MinMaxScaler(
        lb=torch.tensor([-1, -1, -1, -1, -5, -5], dtype=torch.float32),
        ub=torch.tensor([1, 1, 1, 1, 5, 5], dtype=torch.float32),
    ),
}

WRAPPER_CLASSES: dict[str, gymnasium.Wrapper] = {
    "MountainCarContinuous-v0": MultiRewardMountainCarWrapper,  # type: ignore
    "Reacher-v5": MultiRewardReacherWrapper,  # type: ignore
}
