# based on https://github.com/jannerm/trajectory-transformer/blob/master/trajectory/utils/rendering.py

from pathlib import Path

import mujoco_py as mjc
import numpy as np
import skvideo.io
import torch


def view(viewer, dim=256, render_kwargs=None):
    if render_kwargs is None:
        render_kwargs = {
            "trackbodyid": 2,
            "distance": 1,
            "lookat": [1, -0.5, 1],
            "elevation": -15,
            "azimuth": 135,
        }
    for key, val in render_kwargs.items():
        if key == "lookat":
            viewer.cam.lookat[:] = val[:]
        else:
            setattr(viewer.cam, key, val)

    if type(dim) == int:
        dim = (dim, dim)

    viewer.render(*dim)
    data = viewer.read_pixels(*dim, depth=False)
    data = data[::-1, :, :]
    return data


def save_video(filename, video_frames, fps=60, video_format="mp4"):
    assert fps == int(fps), fps
    filename = Path(filename)
    filename.parent.mkdir(parents=True, exist_ok=True)

    skvideo.io.vwrite(
        filename,
        video_frames,
        inputdict={
            "-r": str(int(fps)),
        },
        outputdict={
            "-f": video_format,
            "-pix_fmt": "yuv420p",  # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
        },
    )


def save_videos(filename, *video_frames, **kwargs):
    # video_frame : [ N x H x W x C ]
    video_frames = np.concatenate(video_frames, axis=2)
    save_video(filename, video_frames, **kwargs)


def to_np(x, dtype=None):
    if torch.is_tensor(x):
        x = x.detach().cpu().numpy()

    if dtype is not None:
        x = np.array(x, dtype=dtype)
    return x


def set_state(env, state):
    qpos_dim = env.sim.data.qpos.size
    qvel_dim = env.sim.data.qvel.size
    qstate_dim = qpos_dim + qvel_dim

    if "ant" in env.name:
        ypos = np.zeros(1)
        state = np.concatenate([ypos, state])

    if state.size == qpos_dim - 1 or state.size == qstate_dim - 1:
        xpos = np.zeros(1)
        state = np.concatenate([xpos, state])

    if state.size == qpos_dim:
        qvel = np.zeros(qvel_dim)
        state = np.concatenate([state, qvel])

    if "ant" in env.name and state.size > qpos_dim + qvel_dim:
        xpos = np.zeros(1)
        state = np.concatenate([xpos, state])[:qstate_dim]

    assert state.size == qpos_dim + qvel_dim

    env.set_state(state[:qpos_dim], state[qpos_dim:])


def rollout_from_state(env, state, actions):
    qpos_dim = env.sim.data.qpos.size
    env.set_state(state[:qpos_dim], state[qpos_dim:])
    observations = [env._get_obs()]
    for act in actions:
        obs, rew, term, _ = env.step(act)
        observations.append(obs)
        if term:
            break
    for i in range(len(observations), len(actions) + 1):
        # if terminated early, pad with zeros
        observations.append(np.zeros(obs.size))
    return np.stack(observations)


class DebugRenderer:
    def __init__(self, *args, **kwargs):
        pass

    def render(self, *args, **kwargs):
        return np.zeros((10, 10, 3))

    def render_plan(self, *args, **kwargs):
        pass

    def render_rollout(self, *args, **kwargs):
        pass


class Renderer:
    def __init__(self, env, observation_dim=None, action_dim=None):
        if type(env) is str:
            self.env = load_environment(env)
        else:
            self.env = env

        self.unwrap_env()
        self.observation_dim = observation_dim or np.prod(self.env.observation_space.shape)
        self.action_dim = action_dim or np.prod(self.env.action_space.shape)
        self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)

    def __call__(self, *args, **kwargs):
        return self.renders(*args, **kwargs)

    def unwrap_env(self):
        env = self.env
        name = env.name
        while hasattr(env, "env"):
            env = env.env
        self.env = env
        self.env.name = name

    def render(self, observation, dim=256, render_kwargs=None):
        observation = to_np(observation)

        if render_kwargs is None:
            render_kwargs = {
                "trackbodyid": 2,
                "distance": 3,
                "lookat": [0, -0.5, 1],
                "elevation": -20,
            }

        for key, val in render_kwargs.items():
            if key == "lookat":
                self.viewer.cam.lookat[:] = val[:]
            else:
                setattr(self.viewer.cam, key, val)

        set_state(self.env, observation)

        if type(dim) == int:
            dim = (dim, dim)

        self.viewer.render(*dim)
        data = self.viewer.read_pixels(*dim, depth=False)
        data = data[::-1, :, :]
        return data

    def renders(self, observations, **kwargs):
        images = []
        for observation in observations:
            img = self.render(observation, **kwargs)
            images.append(img)
        return np.stack(images, axis=0)

    def render_plan(self, savepath, sequence, state, fps=30):
        """
        state : np.array[ observation_dim ]
        sequence : np.array[ horizon x transition_dim ]
            as usual, sequence is ordered as [ s_t, a_t, r_t, V_t, ... ]
        """

        if len(sequence) == 1:
            return

        sequence = to_np(sequence)

        # compare to ground truth rollout using actions from sequence
        actions = sequence[:-1, self.observation_dim : self.observation_dim + self.action_dim]
        rollout_states = rollout_from_state(self.env, state, actions)

        videos = [
            self.renders(sequence[:, : self.observation_dim]),
            self.renders(rollout_states),
        ]

        save_videos(savepath, *videos, fps=fps)

    def render_rollout(self, savepath, states, **video_kwargs):
        images = self(states)
        save_video(savepath, images, **video_kwargs)

    def renders_rollout_actions(self, state, actions, **kwargs):
        self.env.reset()
        self.env.sim.set_state(state)
        imgs = [view(**kwargs)]
        for act in actions:
            state, reward, term, info = self.env.step(act)
            if term:
                break
            data = view(**kwargs)
            imgs.append(data)

        return np.stack(imgs, axis=0)

    def render_rollout_actions(self, savepath, state, actions, **video_kwargs):
        images = self.renders_rollout_actions(state, actions)
        save_video(savepath, images, **video_kwargs)
