"""Wraps the autonomous verification environment to output a `ground_truth_mode`
value in the `info` dict in each step.

The mode is computed using the learned
model that the environment authors provide in the repository. This means that
it's not a perfect model.
"""
import autonomous_car_verification.simulator.Car
from autonomous_car_verification.simulator.Car import World
from autonomous_car_verification.simulator.plot_trajectories_7a import ComposedModePredictor
from autonomous_car_verification.simulator.plot_trajectories_7a import Modes
from autonomous_car_verification.simulator.plot_trajectories_7a import normalize
from swmpo_experiments.ground_truth_wrapper import GroundTruthWrapper
from gymnasium.envs.registration import register
from gymnasium.envs.registration import WrapperSpec
from pathlib import Path
import inspect


class GroundTruthWorld(World):
    def __init__(
        self,
        pretrained_dir: Path,
        **kwargs,
    ):
        super().__init__(**kwargs)

        # Load pre-trained mode predictor
        mode_predictor = ComposedModePredictor(
            pretrained_dir/'big.yml',
            pretrained_dir/'straight_little.yml',
            pretrained_dir/'square_right_little.yml',
            pretrained_dir/'square_left_little.yml',
            pretrained_dir/'sharp_right_little.yml',
            pretrained_dir/'sharp_left_little.yml',
            True,
        )
        self.mode_predictor = mode_predictor

    def reset(self, *args, **kwargs):
        obs, info = super().reset(*args, **kwargs)
        observation = normalize(obs)
        mode = self.mode_predictor.predict(observation)
        info["ground_truth_mode"] = mode
        return obs, info

    def step(self, *args, **kwargs):
        obs, reward, terminated, truncated, info = super().step(
            *args,
            **kwargs,
        )
        observation = normalize(obs)
        mode = self.mode_predictor.predict(observation)
        int_mode = list(Modes).index(mode)
        info["ground_truth_mode"] = int_mode
        return obs, reward, terminated, truncated, info


pretrained_dir = Path(inspect.getfile(autonomous_car_verification.simulator.Car)).parent
register(
    id="GroundTruthAutonomousCar-v0",
    entry_point="swmpo_experiments.autonomous_driving_utils.ground_truth_wrapper:GroundTruthWorld",
    additional_wrappers=(
        WrapperSpec(
            name=GroundTruthWrapper.__name__,
            entry_point="swmpo_experiments.ground_truth_wrapper:GroundTruthWrapper",
            kwargs=dict(
                extrinsic_reward_scale=1.0,
                mode_n=5,
                exploration_window_size=60,
            ),
        ),
    ),
    kwargs=dict(
        pretrained_dir=pretrained_dir,
    ),
)
