#!/usr/bin/python
#
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import pickle
import time as timer

import adept_envs
import click
import gym
import numpy as np
import skvideo.io
from mjrl.utils.gym_env import GymEnv
from parse_mjl import parse_mjl_logs, viz_parsed_mjl_logs

# headless renderer
render_buffer = []  # rendering buffer


def viewer(
    env,
    mode="initialize",
    filename="video",
    frame_size=(640, 480),
    camera_id=0,
    render=None,
):
    if render == "onscreen":
        env.mj_render()

    elif render == "offscreen":

        global render_buffer
        if mode == "initialize":
            render_buffer = []
            mode = "render"

        if mode == "render":
            curr_frame = env.render(mode="rgb_array")
            render_buffer.append(curr_frame)

        if mode == "save":
            skvideo.io.vwrite(filename, np.asarray(render_buffer))
            print("\noffscreen buffer saved", filename)

    elif render == "None":
        pass

    else:
        print("unknown render: ", render)


# view demos (physics ignored)
def render_demos(env, data, filename="demo_rendering.mp4", render=None):
    FPS = 30
    render_skip = max(
        1, round(1.0 / (FPS * env.sim.model.opt.timestep * env.frame_skip))
    )
    t0 = timer.time()

    viewer(env, mode="initialize", render=render)
    for i_frame in range(data["ctrl"].shape[0]):
        env.sim.data.qpos[:] = data["qpos"][i_frame].copy()
        env.sim.data.qvel[:] = data["qvel"][i_frame].copy()
        env.sim.forward()
        if i_frame % render_skip == 0:
            viewer(env, mode="render", render=render)
            print(i_frame, end=", ", flush=True)

    viewer(env, mode="save", filename=filename, render=render)
    print("time taken = %f" % (timer.time() - t0))


# playback demos and get data(physics respected)
def gather_training_data(env, data, filename="demo_playback.mp4", render=None):
    env = env.env
    FPS = 30
    render_skip = max(
        1, round(1.0 / (FPS * env.sim.model.opt.timestep * env.frame_skip))
    )
    t0 = timer.time()

    # initialize
    env.reset()
    init_qpos = data["qpos"][0].copy()
    init_qvel = data["qvel"][0].copy()
    act_mid = env.act_mid
    act_rng = env.act_amp

    # prepare env
    env.sim.data.qpos[:] = init_qpos
    env.sim.data.qvel[:] = init_qvel
    env.sim.forward()
    viewer(env, mode="initialize", render=render)

    # step the env and gather data
    path_obs = None
    for i_frame in range(data["ctrl"].shape[0] - 1):
        # Reset every time step
        # if i_frame % 1 == 0:
        #     qp = data['qpos'][i_frame].copy()
        #     qv = data['qvel'][i_frame].copy()
        #     env.sim.data.qpos[:] = qp
        #     env.sim.data.qvel[:] = qv
        #     env.sim.forward()

        obs = env._get_obs()

        # Construct the action
        # ctrl = (data['qpos'][i_frame + 1][:9] - obs[:9]) / (env.skip * env.model.opt.timestep)
        ctrl = (data["ctrl"][i_frame] - obs[:9]) / (env.skip * env.model.opt.timestep)
        act = (ctrl - act_mid) / act_rng
        act = np.clip(act, -0.999, 0.999)
        next_obs, reward, done, env_info = env.step(act)
        if path_obs is None:
            path_obs = obs
            path_act = act
        else:
            path_obs = np.vstack((path_obs, obs))
            path_act = np.vstack((path_act, act))

        # render when needed to maintain FPS
        if i_frame % render_skip == 0:
            viewer(env, mode="render", render=render)
            print(i_frame, end=", ", flush=True)

    # finalize
    if render:
        viewer(env, mode="save", filename=filename, render=render)

    t1 = timer.time()
    print("time taken = %f" % (t1 - t0))

    # note that <init_qpos, init_qvel> are one step away from <path_obs[0], path_act[0]>
    return path_obs, path_act, init_qpos, init_qvel


# MAIN =========================================================
@click.command(help="parse tele-op demos")
@click.option("--env", "-e", type=str, help="gym env name", required=True)
@click.option(
    "--demo_dir", "-d", type=str, help="directory with tele-op logs", required=True
)
@click.option(
    "--skip", "-s", type=int, help="number of frames to skip (1:no skip)", default=1
)
@click.option("--graph", "-g", type=bool, help="plot logs", default=False)
@click.option("--save_logs", "-l", type=bool, help="save logs", default=False)
@click.option("--view", "-v", type=str, help="render/playback", default="render")
@click.option("--render", "-r", type=str, help="onscreen/offscreen", default="onscreen")
def main(env, demo_dir, skip, graph, save_logs, view, render):

    gym_env = gym.make(env)
    paths = []
    print("Scanning demo_dir: " + demo_dir + "=========")
    for ind, file in enumerate(glob.glob(demo_dir + "*.mjl")):

        # process logs
        print("processing: " + file, end=": ")

        data = parse_mjl_logs(file, skip)

        print("log duration %0.2f" % (data["time"][-1] - data["time"][0]))

        # plot logs
        if graph:
            print("plotting: " + file)
            viz_parsed_mjl_logs(data)

        # save logs
        if save_logs:
            pickle.dump(data, open(file[:-4] + ".pkl", "wb"))

        # render logs to video
        if view == "render":
            render_demos(
                gym_env,
                data,
                filename=data["logName"][:-4] + "_demo_render.mp4",
                render=render,
            )

        # playback logs and gather data
        elif view == "playback":
            try:
                obs, act, init_qpos, init_qvel = gather_training_data(
                    gym_env,
                    data,
                    filename=data["logName"][:-4] + "_playback.mp4",
                    render=render,
                )
            except Exception as e:
                print(e)
                continue
            path = {
                "observations": obs,
                "actions": act,
                "goals": obs,
                "init_qpos": init_qpos,
                "init_qvel": init_qvel,
            }
            paths.append(path)
            # accept = input('accept demo?')
            # if accept == 'n':
            #     continue
            pickle.dump(path, open(demo_dir + env + str(ind) + "_path.pkl", "wb"))
            print(demo_dir + env + file + "_path.pkl")


if __name__ == "__main__":
    main()
