from typing import Optional, Tuple

import gym
from gym.envs.mujoco.reacher import ReacherEnv
import numpy as np
import torch

from offline_rl.rewards.reward_model import RewardModel


class CustomReacherEnv(ReacherEnv):
    """A customized version of the reacher env.

    Customization includes frame skip, changing the obs to allow for simulation from it,
    making the info dict json serializable, setting a finite horizon independent of the
    gym wrapper for doing so, and other changes.

    Args:
        frame_skip: Number of frames to skip between timesteps.
        max_timesteps: The maximum number of timesteps to take in the env per episode.
        obs_mode: The mode for the obseravtion. Options:
            sim: Returns an observation that allows for simulating the mujoco simulator (default).
            original: Returns the original observation from the environment.
        terminate_when_unhealthy: If True, terminates the episode when healthy state bounds are exceeded.
        healthy_velocity_range: Tuple of min/max velocity values that define healthy bounds.
            These exist because without them rllib sometimes errors out with nan gradients when there are
            very large velocity values.
    """
    def __init__(
            self,
            frame_skip: int = 5,
            max_timesteps: int = 100,
            obs_mode: str = "sim",
            terminate_when_unhealthy: bool = True,
            healthy_velocity_range: Tuple[int, int] = (-50, 50),
    ):
        # These have to be stored before super init b/c it calls step.
        self.max_timesteps = max_timesteps
        self.t = 0
        self.obs_mode = obs_mode
        self.terminate_when_unhealthy = terminate_when_unhealthy
        self.healthy_velocity_range = healthy_velocity_range

        super().__init__()

        # Overwrite frame skip after calling super init.
        self.frame_skip = frame_skip
        self.metadata["video.frames_per_second"] = int(np.round(1.0 / self.dt))

    def is_healthy(self) -> bool:
        """Returns True if the simulator is in a healthy state."""
        min_velocity, max_velocity = self.healthy_velocity_range
        velocity = self.sim.data.qvel.flat[:]
        healthy_velocity = np.all(np.logical_and(min_velocity < velocity, velocity < max_velocity))

        healthy = healthy_velocity
        return healthy

    def reset(self) -> np.ndarray:
        """Resets the environment."""
        self.t = 0
        return super().reset()

    def step(self, *args, **kwargs) -> Tuple:
        """Fixes a non-json-writable element in the info of the base env."""
        obs, reward, done, info = super().step(*args, **kwargs)
        info["reward_ctrl"] = float(info["reward_ctrl"])

        if self.terminate_when_unhealthy and not self.is_healthy():
            done = True

        self.t += 1
        if self.t >= self.max_timesteps:
            done = True
        return obs, reward, done, info

    def _get_obs(self) -> np.ndarray:
        """Optionally overwrite the observation to for simulation purposes."""
        if self.obs_mode == "sim":
            return np.concatenate([
                self.sim.data.qpos.flat[:],
                self.sim.data.qvel.flat[:],
                self.get_body_com("fingertip") - self.get_body_com("target"),
            ])
        elif self.obs_mode == "original":
            return super()._get_obs()
        else:
            raise ValueError(f"Invalid observation mode: {self.obs_mode}")


class CustomReacherEnvRewardModel(RewardModel):
    """Reward model for custom Reacher environment.

    Args:
        obs_space: The observation space used in the environment.
        act_space: The action space used in the environment.
        reward_dist_factor: Weight on the distance from goal reward term.
        reward_ctrl_factor: Weight on the control reward term.
        reward_goal_factor: Weight on reaching the goal.
        shaping_factor: The value to scale the shaping.
        shaping_discount: The discount factor used in potential shaping.
    """
    # At this threshold around 2% of initial states are next to the goal.
    GOAL_REACHED_THRESHOLD = 0.05

    def __init__(
            self,
            obs_space: gym.spaces.Space,
            act_space: gym.spaces.Space,
            reward_dist_factor: float,
            reward_ctrl_factor: float,
            reward_goal_factor: float,
            shaping_factor: float,
            shaping_discount: float,
    ):
        self.obs_space = obs_space
        self.act_space = act_space
        self.reward_dist_factor = reward_dist_factor
        self.reward_ctrl_factor = reward_ctrl_factor
        self.reward_goal_factor = reward_goal_factor
        self.shaping_factor = shaping_factor
        self.shaping_discount = shaping_discount

    @property
    def observation_space(self) -> gym.spaces.Space:
        return self.obs_space

    @property
    def action_space(self) -> gym.spaces.Space:
        return self.act_space

    def reward(
            self,
            states: torch.Tensor,
            actions: torch.Tensor,
            next_states: Optional[torch.Tensor],
            terminals: Optional[torch.Tensor],
    ) -> torch.Tensor:
        """Computes the reward for the environment.

        See base class for documentation on args and return value.
        """
        del terminals
        states_dists = states[:, -3:].norm(dim=-1, keepdim=True)
        dist_rewards = -states_dists
        ctrl_rewards = -actions.square().sum(dim=1, keepdim=True).to(states.dtype)
        goal_rewards = states_dists < self.GOAL_REACHED_THRESHOLD

        next_states_dists = next_states[:, -3:].norm(dim=-1, keepdim=True)
        shaping_rewards = (self.shaping_discount * next_states_dists - states_dists)

        rewards = self.reward_dist_factor * dist_rewards \
            + self.reward_ctrl_factor * ctrl_rewards \
            + self.reward_goal_factor * goal_rewards \
            + self.shaping_factor * shaping_rewards

        return rewards
