"""This file contains classes related to the PointMazeEnv."""
import gym
import numpy as np
import torch
from typing import Optional, Tuple

from imitation.envs.examples.airl_envs.point_maze_env import PointMazeEnv

from offline_rl.rewards.reward_model import RewardModel


class JsonWritablePointMazeEnv(PointMazeEnv):
    """Fixes a non-json-writable element in the info of the base env.

    Also allows for customizing the environment frame skip.

    Args:
        frame_skip: Number of frames to skip between timesteps.
    """
    def __init__(self, *args, frame_skip: int = 5, **kwargs):
        super().__init__(*args, **kwargs)
        
        # Overwrite frame skip after calling super init.
        self.frame_skip = frame_skip
        self.metadata["video.frames_per_second"] = int(np.round(1.0 / self.dt))

    def step(self, *args, **kwargs) -> Tuple:
        """Converts info element to float to make it json-serializable."""
        obs, reward, done, info = super().step(*args, **kwargs)
        info["reward_ctrl"] = float(info["reward_ctrl"])
        return obs, reward, done, info


# The position of the target / goal location in the PointMazeEnv.
TARGET_LOCATION = [0.3, 0.5, 0.0]


class PointMazeEnvRewardModel(RewardModel):
    """The ground-truth reward model for the PointMazeEnv.

    This class exists in order to adhere to certain reward-learning interfaces that assume a base reward model,
    as well as to allow for experimenting with different rewards if necessary.

    Args:
        obs_space: The observation space used in the environment.
        act_space: The action space used in the environment.
    """
    def __init__(self, obs_space: gym.spaces.Space, act_space: gym.spaces.Space):
        self.obs_space = obs_space
        self.act_space = act_space

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Computes the ground-truth reward of the environment."""
        target = torch.tensor(TARGET_LOCATION).to(states.dtype).to(states.device)
        reward_dist = -torch.norm(states[:, :3] - target, dim=1, keepdim=True)
        reward_ctrl = -(actions**2).sum(dim=1, keepdim=True).to(reward_dist.dtype)
        return reward_dist + 0.001 * reward_ctrl
