import numpy as np
from gym import core, spaces


class DMCGymWrapper(core.Env):
    def __init__(
        self,
        env,
    ):
        self._env = env

    def __getattr__(self, name):
        return getattr(self._env, name)

    def _get_obs(self, time_step):
        if self._from_pixels:
            obs = self.render(
                height=self._height, width=self._width, camera_id=self._camera_id
            )
            if self._channels_first:
                obs = obs.transpose(2, 0, 1).copy()
        else:
            obs = time_step.observation
        return obs

    def _convert_action(self, action):
        action = action.astype(np.float64)
        true_delta = self._true_action_space.high - self._true_action_space.low
        norm_delta = self._norm_action_space.high - self._norm_action_space.low
        action = (action - self._norm_action_space.low) / norm_delta
        action = action * true_delta + self._true_action_space.low
        action = action.astype(np.float32)
        return action

    @property
    def observation_space(self):
        return self._observation_space

    @property
    def state_space(self):
        return self._state_space

    @property
    def action_space(self):
        return self._norm_action_space

    @property
    def reward_range(self):
        return 0, self._frame_skip

    def seed(self, seed):
        self._true_action_space.seed(seed)
        self._norm_action_space.seed(seed)
        self._observation_space.seed(seed)

    def step(self, action, render=False):
        assert self._norm_action_space.contains(action)
        action = self._convert_action(action)
        assert self._true_action_space.contains(action)
        reward = 0
        extra = {"internal_state": self._env.physics.get_state().copy()}
        if "dog" not in self._domain:
            xyz_before = self.physics.named.data.geom_xpos[
                ["torso"], ["x", "y", "z"]
            ].copy()
            obsbefore = self.physics.get_state()
        else:
            xyz_before = self.physics.named.data.geom_xpos[
                ["collision_torso"], ["x", "y", "z"]
            ].copy()
            obsbefore = self.physics.get_state()

        for _ in range(self._frame_skip):
            time_step = self._env.step(action)
            reward += time_step.reward or 0
            done = time_step.last()
            if done:
                break

        if "dog" not in self._domain:
            xyz_after = self.physics.named.data.geom_xpos[
                ["torso"], ["x", "y", "z"]
            ].copy()
        else:
            xyz_after = self.physics.named.data.geom_xpos[
                ["collision_torso"], ["x", "y", "z"]
            ].copy()

        obs = self._get_obs(time_step)
        self.current_state = time_step.observation
        obsafter = self.physics.get_state()
        extra["discount"] = time_step.discount

        if render:
            extra["render"] = self.render(
                mode="rgb_array", width=64, height=64
            ).transpose(2, 0, 1)

        if self._domain in ["cheetah"]:
            extra["coordinates"] = np.array([xyz_before[0], 0.0])
            extra["next_coordinates"] = np.array([xyz_after[0], 0.0])
        elif self._domain in ["quadruped", "humanoid", "dog"]:
            extra["coordinates"] = np.array([xyz_before[0], xyz_before[1]])
            extra["next_coordinates"] = np.array([xyz_after[0], xyz_after[1]])
        extra["ori_obs"] = obsbefore
        extra["next_ori_obs"] = obsafter

        return obs, reward, done, extra

    def calc_eval_metrics(self, trajectories, is_option_trajectories=False):
        return dict()

    def compute_reward(self, ob, next_ob, action=None):
        xposbefore = ob[:, 0]
        xposafter = next_ob[:, 0]

        reward = (xposafter - xposbefore) / self.dt
        done = np.zeros_like(reward)

        return reward, done

    def reset(self):
        time_step = self._env.reset()
        self.current_state = time_step.observation
        obs = self._get_obs(time_step)
        return obs

    def render(self, mode="rgb_array", height=None, width=None, camera_id=0):
        assert mode == "rgb_array", "only support rgb_array mode, given %s" % mode
        height = height or self._height
        width = width or self._width
        camera_id = camera_id or self._camera_id
        return self._env.physics.render(height=height, width=width, camera_id=camera_id)

    def plot_trajectory(self, trajectory, color, ax):
        if self._domain in ["cheetah"]:
            trajectory = trajectory.copy()
            # https://stackoverflow.com/a/20474765/2182622
            from matplotlib.collections import LineCollection

            linewidths = np.linspace(0.2, 1.2, len(trajectory))
            points = np.reshape(trajectory, (-1, 1, 2))
            segments = np.concatenate([points[:-1], points[1:]], axis=1)
            lc = LineCollection(segments, linewidths=linewidths, color=color)
            ax.add_collection(lc)
        else:
            ax.plot(trajectory[:, 0], trajectory[:, 1], color=color, linewidth=0.7)

    def plot_trajectories(self, trajectories, colors, plot_axis, ax):
        """Plot trajectories onto given ax."""
        square_axis_limit = 0.0
        for trajectory, color in zip(trajectories, colors):
            trajectory = np.array(trajectory)
            self.plot_trajectory(trajectory, color, ax)

            square_axis_limit = max(
                square_axis_limit, np.max(np.abs(trajectory[:, :2]))
            )
        square_axis_limit = square_axis_limit * 1.2

        if plot_axis == "free":
            return

        if plot_axis is None:
            plot_axis = [
                min(plot_axis[0], -square_axis_limit),
                max(plot_axis[1], square_axis_limit),
                min(plot_axis[2], -square_axis_limit),
                max(plot_axis[3], square_axis_limit),
            ]
            # plot_axis = [
            #     -square_axis_limit,
            #     square_axis_limit,
            #     -square_axis_limit,
            #     square_axis_limit,
            # ]

        if plot_axis is not None:
            ax.axis(plot_axis)
            ax.set_aspect("equal")
        else:
            ax.axis("scaled")

    def render_trajectories(self, trajectories, colors, plot_axis, ax):
        coordinates_trajectories = self._get_coordinates_trajectories(trajectories)
        self.plot_trajectories(coordinates_trajectories, colors, plot_axis, ax)

    def _get_coordinates_trajectories(self, trajectories):
        coordinates_trajectories = []
        for trajectory in trajectories:
            coordinates_trajectories.append(
                np.concatenate(
                    [
                        trajectory["env_infos"]["coordinates"],
                        [trajectory["env_infos"]["next_coordinates"][-1]],
                    ]
                )
            )
        if self._domain in ["cheetah"]:
            for i, traj in enumerate(coordinates_trajectories):
                traj[:, 1] = (i - len(coordinates_trajectories) / 2) / 1.25
        return coordinates_trajectories

    def calc_eval_metrics(self, trajectories, is_option_trajectories):
        eval_metrics = {}

        coord_dim = 2 if self._domain in ["quadruped", "humanoid"] else 1

        coords = []
        for traj in trajectories:
            traj1 = traj["env_infos"]["coordinates"][:, :coord_dim]
            traj2 = traj["env_infos"]["next_coordinates"][-1:, :coord_dim]
            coords.append(traj1)
            coords.append(traj2)
        coords = np.concatenate(coords, axis=0)
        uniq_coords = np.unique(np.floor(coords), axis=0)
        eval_metrics.update(
            {
                "MjNumTrajs": len(trajectories),
                "MjAvgTrajLen": len(coords) / len(trajectories) - 1,
                "MjNumCoords": len(coords),
                "MjNumUniqueCoords": len(uniq_coords),
            }
        )

        return eval_metrics


class GeneralEnvWrapper:
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.reward_range = (-np.inf, np.inf)
        self.metadata = {}

    def reset(self):
        state = super().reset()
        ob = self.get_state(state)
        self.last_state = state
        self.last_ob = ob
        return ob

    def step(self, action, render=False):
        next_state, reward, done, info = super().step(action)
        ob = self.get_state(next_state)

        coords = self.last_state["state"][:2].copy()
        next_coords = next_state["state"][:2].copy()
        info["coordinates"] = coords
        info["next_coordinates"] = next_coords
        info["ori_obs"] = self.last_state["state"]
        info["next_ori_obs"] = next_state["state"]
        if render:
            info["render"] = next_state["image"].transpose(2, 0, 1)

        self.last_state = next_state
        self.last_ob = ob

        return ob, reward, done, info

    def plot_trajectory(self, trajectory, color, ax):
        ax.plot(trajectory[:, 0], trajectory[:, 1], color=color, linewidth=0.7)

    def plot_trajectories(self, trajectories, colors, plot_axis, ax):
        square_axis_limit = 0.0
        for trajectory, color in zip(trajectories, colors):
            trajectory = np.array(trajectory)
            self.plot_trajectory(trajectory, color, ax)

            square_axis_limit = max(
                square_axis_limit, np.max(np.abs(trajectory[:, :2]))
            )
        square_axis_limit = square_axis_limit * 1.2

        if plot_axis == "free":
            return

        if plot_axis is None:
            plot_axis = [
                -square_axis_limit,
                square_axis_limit,
                -square_axis_limit,
                square_axis_limit,
            ]

        if plot_axis is not None:
            ax.axis(plot_axis)
            ax.set_aspect("equal")
        else:
            ax.axis("scaled")

    def render_trajectories(self, trajectories, colors, plot_axis, ax):
        coordinates_trajectories = self._get_coordinates_trajectories(trajectories)
        self.plot_trajectories(coordinates_trajectories, colors, plot_axis, ax)

    def _get_coordinates_trajectories(self, trajectories):
        coordinates_trajectories = []
        for trajectory in trajectories:
            if trajectory["env_infos"]["coordinates"].dtype == object:
                coordinates_trajectories.append(
                    np.concatenate(
                        [
                            np.concatenate(
                                trajectory["env_infos"]["coordinates"], axis=0
                            ),
                            [trajectory["env_infos"]["next_coordinates"][-1][-1]],
                        ]
                    )
                )
            elif trajectory["env_infos"]["coordinates"].ndim == 2:
                coordinates_trajectories.append(
                    np.concatenate(
                        [
                            trajectory["env_infos"]["coordinates"],
                            [trajectory["env_infos"]["next_coordinates"][-1]],
                        ]
                    )
                )
            elif trajectory["env_infos"]["coordinates"].ndim > 2:
                coordinates_trajectories.append(
                    np.concatenate(
                        [
                            trajectory["env_infos"]["coordinates"].reshape(-1, 2),
                            trajectory["env_infos"]["next_coordinates"].reshape(-1, 2)[
                                -1:
                            ],
                        ]
                    )
                )
            else:
                assert False
        return coordinates_trajectories

    def calc_eval_metrics(self, trajectories, is_option_trajectories, coord_dims=None):
        eval_metrics = {}

        goal_names = [
            "BottomBurner",
            "LightSwitch",
            "SlideCabinet",
            "HingeCabinet",
            "Microwave",
            "Kettle",
        ]

        sum_successes = 0
        for i, goal_name in enumerate(goal_names):
            success = 0
            for traj in trajectories:
                success = max(
                    success,
                    traj["env_infos"][f"metric_success_task_relevant/goal_{i}"].max(),
                )
            eval_metrics[f"KitchenTask{goal_name}"] = success
            sum_successes += success
        eval_metrics[f"KitchenOverall"] = sum_successes

        return eval_metrics


from gym.wrappers import FilterObservation, FlattenObservation
import gym


class FetchPushEnv(GeneralEnvWrapper):
    def __init__(self, *args, **kwargs):
        # https://robotics.farama.org/envs/fetch/push/
        env = gym.make("FetchPush-v1")
        self.block_pos = [3, 4]
        # self.env = FlattenObservation(
        #     FilterObservation(
        #         env, filter_keys=["observation", "desired_goal", "achieved_goal"]
        #     )
        # )
        self.env = env

        from gym.wrappers.time_limit import TimeLimit

        max_episode_steps = 200
        self.env = TimeLimit(self.env, max_episode_steps=max_episode_steps)
        self.desired_goal = None

        self.env.observation_space = gym.spaces.Box(
            low=np.full((25,), -np.inf), high=np.full((25,), np.inf), dtype=np.float32
        )

    @property
    def metadata(self):
        return self.env.metadata

    @property
    def observation_space(self):
        return self.env.observation_space

    # @property
    # def state_space(self):
    #     return self._state_space

    @property
    def action_space(self):
        return self.env.action_space

    @property
    def reward_range(self):
        # return 0, self._frame_skip
        return self.env.reward_range

    def reconstruct_desired_goal(self, desired_goal, base_state):
        g = np.zeros_like(base_state)
        g[:3] = desired_goal
        g[3:6] = desired_goal

        return g.copy()

    def reset(self):
        ob = self.env.reset()
        self.desired_goal = self.reconstruct_desired_goal(
            ob["desired_goal"], ob["observation"]
        )
        return ob["observation"]

    def step(self, action, render=False):
        # we feed the "desired_goal" together with the observation.
        next_state, reward, done, info = self.env.step(action)
        observation = next_state["observation"]
        desired_goal = next_state["desired_goal"]
        achieved_goal = next_state["achieved_goal"]

        success = info["is_success"]
        info = {}

        coords = observation[self.block_pos]
        # next_coords = next_state["state"][:2].copy()
        info["coordinates"] = coords
        # info["next_coordinates"] = next_coords
        # info["ori_obs"] = self.last_state["state"]
        # info["next_ori_obs"] = next_state["state"]
        if render:
            rendered = self.env.render(mode="rgb_array")
            rendered = rendered.transpose(2, 0, 1)
            info["render"] = rendered
        info["success"] = success

        self.desired_goal = self.reconstruct_desired_goal(desired_goal, observation)
        return observation, reward, done, info

    def render(self, *args, **kwargs):
        return self.env.render(*args, **kwargs)

    def plot_trajectory(self, trajectory, color, ax):
        ax.plot(
            trajectory[:, 0],
            trajectory[:, 1],
            color=color,
            linewidth=0.7,
        )

    def plot_trajectories(self, trajectories, colors, plot_axis, ax):
        square_axis_limit = 0.0
        for trajectory, color in zip(trajectories, colors):
            trajectory = np.array(trajectory)
            self.plot_trajectory(trajectory, color, ax)

            square_axis_limit = max(
                square_axis_limit, np.max(np.abs(trajectory[:, :2]))
            )
        square_axis_limit = square_axis_limit * 1.2

        if plot_axis == "free":
            return

        if plot_axis is None:
            plot_axis = [
                -square_axis_limit,
                square_axis_limit,
                -square_axis_limit,
                square_axis_limit,
            ]

        if plot_axis is not None:
            ax.axis(plot_axis)
            ax.set_aspect("equal")
        else:
            ax.axis("scaled")

    def render_trajectories(self, trajectories, colors, plot_axis, ax):
        coordinates_trajectories = self._get_coordinates_trajectories(trajectories)
        self.plot_trajectories(coordinates_trajectories, colors, plot_axis, ax)

    def _get_coordinates_trajectories(self, trajectories):
        coordinates_trajectories = []
        for trajectory in trajectories:
            assert trajectory["env_infos"]["coordinates"].ndim == 2
            coordinates_trajectories.append(
                np.concatenate(
                    [
                        trajectory["env_infos"]["coordinates"],
                        # [trajectory["env_infos"]["next_coordinates"][-1]],
                    ]
                )
            )
        return coordinates_trajectories

    def calc_eval_metrics(self, trajectories, is_option_trajectories, coord_dims=None):
        eval_metrics = {}

        # sum_successes = 0
        # for i, goal_name in enumerate(goal_names):
        #     success = 0
        #     for traj in trajectories:
        #         success = max(
        #             success,
        #             traj["env_infos"][f"metric_success_task_relevant/goal_{i}"].max(),
        #         )
        #     eval_metrics[f"KitchenTask{goal_name}"] = success
        #     sum_successes += success
        # eval_metrics[f"KitchenOverall"] = sum_successes

        return eval_metrics


import matplotlib.pyplot as plt

if __name__ == "__main__":

    env = LifelongHopperEnv()
    env.reset()
    img = env.render(mode="rgb_array")

    plt.imshow(img)
    plt.show()
    plt.savefig("render.png")
    # env = gym.make('maze2d-umaze-v1')
