import socket
import gym, minerl
import struct
from tqdm import tqdm

from train.common.config import Config
from train.common.utils import import_module
from train.envs.minerl_env import *
from minerl.env import MineRLEnv

try:
    import cPickle as pickle
except ImportError:
    import pickle


# =====================================================================

def _send_message(socket, type, data):
    message = pickle.dumps({"type": type, "data": data}, protocol=4)
    message = struct.pack('>I', len(message)) + message
    socket.sendall(message)
    return _read_response(socket)


def _read_response(socket):
    raw_msglen = recvall(socket, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # Read the message data
    message = recvall(socket, msglen)
    return pickle.loads(message, encoding="bytes")


def recvall(sock, n):
    # Helper function to recv n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data


# =====================================================================

class RemoteEnv(object):
    def __init__(self, envname, socket):
        specs = gym.envs.registration.spec(envname)
        self.action_space = specs._kwargs["action_space"]
        self.observation_space = specs._kwargs["observation_space"]
        self.action_space.noop = self.action_space.no_op
        self.socket = socket
        self.env_seed = None
        self.reward_range = MineRLEnv.reward_range
        self.metadata = MineRLEnv.metadata

    def version(self):
        return _send_message(self.socket, "version", {})

    def seed(self, seed):
        self.env_seed = seed

    def reset(self):
        if self.env_seed:
            return _send_message(self.socket, "reset", {"seed": self.env_seed})
        else:
            return _send_message(self.socket, "reset", {})

    def step(self, action):
        return _send_message(self.socket, "step", action)

    def render(self, mode="human"):
        return _send_message(self.socket, "render", {"mode": mode})

    def close(self):
        _send_message(self.socket, "close", None)
        self.socket.close()


# =====================================================================

class RemoteGym(object):
    def __init__(self, host="localhost", port=9999):
        self.host = host
        self.port = port
        self.socket = self._connect()

    def _connect(self):
        try:
            soc = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            soc.connect((self.host, self.port))
        except:
            raise Exception("Connection error")

        return soc

    def version(self):
        response = _send_message(self.socket, "version", None)
        return response

    def make(self, envname):
        response = _send_message(self.socket, "make", envname)
        if not response:
            raise Exception("unable to make environment")
        else:
            return RemoteEnv(envname, self.socket)


# =====================================================================

def main():
    env_name = "MineRLObtainDiamond-v0"
    n_envs = 1

    env = RemoteEnv(env_name, None)
    noop = None

    config = Config('configs/experiment/config.meta.json')
    dataset = import_module(config.bc.dataset)

    input_space = dataset.INPUT_SPACE

    print("Enter 'close' to exit")
    message = input(" -> ")
    while message != 'close':
        if message == "make":
            env = make_minerl(env_name, n_envs, dataset.SEQ_LENGTH, dataset.DATA_TRANSFORM, input_space,
                              use_reprio=False, env_server=True)
            # env.seed(1234) # TODO implement seeding of multiple envs
            env.reset()
            noop = env.action_space.noop()
        if message == "step":
            if noop:
                print(env.step((noop,) * n_envs))
            else:
                print("make first")
        elif message == "render":
            env.render()
        elif message == "reset":
            print(env.reset())
        elif message == "timeit":
            if noop:
                for _ in tqdm(range(10000)):
                    env.step((noop,) * n_envs)
                    env.render()
            else:
                print("make first")

        message = input(" -> ")
    env.close()


if __name__ == "__main__":
    main()
