#!/usr/bin/env python3
import fcntl
import os
import queue
import subprocess
import tempfile
import threading
import skimage.io
import minerl
import getpass
from pathlib import Path

import gym
from gym.wrappers import TimeLimit

from stable_baselines.common.policies import CnnPolicy
from stable_baselines.common.vec_env import SubprocVecEnv, VecFrameStack, VecEnv
from stable_baselines.common.vec_env.subproc_vec_env import _worker
from stable_baselines import PPO2
from stable_baselines.bench import Monitor

# import logging
# import coloredlogs
# coloredlogs.install(logging.INFO)


ACTION_MAPPING = [
    ["forward", "attack"],
    ["forward", "jump"],
    ["attack", "sneak"],
    ["attack"],
    ["attack", "camera_left"],
    ["attack", "camera_right"],
    ["forward"],
    ["forward", "camera_left"],
    ["forward", "camera_right"],
    ["camera_left"],
    ["camera_right"]
]


class QueuePipe:

    def __init__(self, send_q, req_q):
        self._send_q = send_q
        self._req_q = req_q

    def recv(self):
        return self._req_q.get()

    def send(self, value):
        return self._send_q.put(value)

    def close(self):
        pass


class ThreadedVecEnv(SubprocVecEnv):
    def __init__(self, env_fns, start_method=None):
        self.waiting = False
        self.closed = False
        n_envs = len(env_fns)

        self.remotes, self.work_remotes = zip(
            *[(lambda x, y: (QueuePipe(x, y), QueuePipe(y, x)))(queue.Queue(1), queue.Queue(1)) for _ in range(n_envs)])
        self.processes = []
        for work_remote, remote, env_fn in zip(self.work_remotes, self.remotes, env_fns):
            env_fn.var = env_fn
            args = (work_remote, remote, env_fn)
            # daemon=True: if the main process crashes, we should not cause things to hang
            process = threading.Thread(target=_worker, args=args, daemon=True)
            process.start()
            self.processes.append(process)

        self.remotes[0].send(('get_spaces', None))
        observation_space, action_space = self.remotes[0].recv()
        VecEnv.__init__(self, len(env_fns), observation_space, action_space)


class FrameSkip(gym.Wrapper):
    """Return every `skip`-th frame and repeat given action during skip.
    Note that this wrapper does not "maximize" over the skipped frames.
    """
    def __init__(self, env, skip=4):
        super().__init__(env)

        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info


class ConvertActions(gym.ActionWrapper):

    def __init__(self, env):
        super(ConvertActions, self).__init__(env)
        self._old_action_scape = env.action_space
        self.action_space = gym.spaces.Discrete(len(ACTION_MAPPING))

    def action(self, action):
        action = ACTION_MAPPING[action]
        noop = {x: 0 for x, v in self._old_action_scape.spaces.items() if isinstance(v, gym.spaces.Discrete)}

        for a in action:
            if a in noop:
                noop[a] = 1
            if a == "camera_left":
                noop["camera"] = [0, -1]
            elif a == "camera_right":
                noop["camera"] = [0, 1]
            elif a == "camera_up":
                noop["camera"] = [1, 0]
            elif a == "camera_down":
                noop["camera"] = [-1, 0]
        return noop


class RemoveNoop(gym.Wrapper):

    def __init__(self, env):
        super(RemoveNoop, self).__init__(env)
        del env.action_space.noop
        del env.action_space.noop_func


class SelectPOVWrapper(gym.ObservationWrapper):

    def __init__(self, env):
        super(SelectPOVWrapper, self).__init__(env)
        self.observation_space = env.observation_space["pov"]

    def observation(self, observation):
        return observation["pov"]


def plot(run):
    obs = env.reset()
    for i in range(250):
        action, _states = model.predict(obs)
        obs, rewards, dones, info = env.step(action)
        os.makedirs("imgs/{}".format(run), exist_ok=True)
        for x in range(obs.shape[0]):
            skimage.io.imsave("imgs/{}/{}_{:05d}_{}.png".format(run, x, i, rewards[x]), obs[x, :, :, 9:])
    env.reset()


def gym_sync_create(env_string):
    with open(os.path.join(tempfile.tempdir, "minecraft.%s.lock" % getpass.getuser()), "wb") as lock_file:
        fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
        try:
            env = gym.make(env_string)
            return env
        finally:
            fcntl.flock(lock_file.fileno(),fcntl.LOCK_UN)


def reprioritze_env(env):
    import psutil
    current_process = psutil.Process(env.instance.minecraft_process.pid)
    current_process.nice(19)
    children = current_process.children(recursive=True)
    for child in children:
        for t in child.threads():
            subprocess.call(["renice", "19", str(t.id)])
    return env


# multi-threaded environment
if __name__ == "__main__":
    """ main """

    # number of parallel workers
    n_cpu = 4

    # number of experimental repeats
    for i in range(5):

        # init envs
        env = ThreadedVecEnv([lambda: TimeLimit(Monitor(
            FrameSkip(ConvertActions(SelectPOVWrapper(RemoveNoop(reprioritze_env(gym_sync_create('MineRLTreechop-v0')))))),
            None,
            allow_early_resets=True), max_episode_steps=1000*(i+1)) for i in range(n_cpu)])
        env = VecFrameStack(env, n_stack=4)

        # init PPO model
        model = PPO2(CnnPolicy, env, verbose=1, tensorboard_log=os.path.join(str(Path.home()), "test_run"))

        # load pre-trained model
        if os.path.exists("trained.pkl"):
            model.load("trained")

        # optimize for 50000 steps
        model.learn(total_timesteps=200000, reset_num_timesteps=False)

        # save trained model parameters
        model.save("trained")

        # plot some frames
        # plot(i)
