import os
import time
import pathlib

from sys import platform
import subprocess
import socket
import json
import numpy as np
import gym
from gym import spaces
from gym.spaces.tuple import Tuple as TupleSpace
from gym.spaces import Dict, Box
import atexit
from IPython import embed

class GodotGymWrapper(gym.Env):
    def __init__(self,
                env_path=None,
                port=11008,
                show_window=False,
                seed=0,
                framerate=None,
                action_repeat=None,
                map_name=None,
                nb_agents=1, 
                agent_type="countbased",
                episodic=False,
                *args, **kw
            ):
        
        super().__init__(*args, **kw)

        self._env = GodotEnv(
            env_path=env_path,
            port=port, 
            show_window=show_window,
            seed=seed, 
            framerate=framerate,
            action_repeat=action_repeat, 
            map_name=map_name,
            nb_agents=nb_agents,
            agent_type=agent_type,
            episodic=episodic
        )


        self.episodic = episodic
        self.agent_type = agent_type

        self.action_names = self._env.action_space.keys()
        self.discrete_actions = []
        self.continuous_actions = []
        for action_name in self.action_names:
            if type(self._env.action_space[action_name]) is spaces.Discrete:
                self.discrete_actions.append(action_name)
            else:
                self.continuous_actions.append(action_name)
                assert type(self._env.action_space[action_name]) is gym.spaces.Box
                assert self._env.action_space[action_name].low == -1 and self._env.action_space[action_name].high == 1
        if len(self.discrete_actions) > 1:
            raise NotImplementedError("mlacademy doesn't support multiple discrete action branches")

        discrete_space = None
        continuous_space = None
        if len(self.discrete_actions) == 1:
            discrete_space = self._env.action_space[self.discrete_actions[0]]
        if len(self.continuous_actions) > 0:
            continuous_space = gym.spaces.Box(-1, 1, (len(self.continuous_actions),))

        if discrete_space is None:
            self.action_space = continuous_space
        elif continuous_space is None:
            self.action_space = discrete_space
        else:
            self.action_space = TupleSpace([discrete_space, continuous_space])

        self.obs_names = self._env.observation_space.keys()

        self.observation_space = Dict({
            "image" : Box(low=-np.inf, high=np.inf, shape=(192,)),
            "vector_obs" : Box(low=-np.inf, high=np.inf, shape=(21,)),
        })

        if agent_type == "countbased-augmented":
            self.observation_space["heatmap"] = Box(low=-np.inf, high=np.inf, shape=(2,32,32))

        self.nb_agents = nb_agents
        print(self.observation_space)

    def step(self, action):
        all_agent_actions = [{} for _ in range(self.nb_agents)]

        discrete_actions = action
        continuous_actions = action
        if len(self.discrete_actions) > 0 and len(self.continuous_actions) > 0:
            discrete_actions = action[0]
            continuous_actions = action[1]

        if len(self.discrete_actions) > 0:
            discrete_actions = discrete_actions if self.nb_agents > 1 else [discrete_actions]
            all_agent_actions = [{self.discrete_actions[0]: act} for act in discrete_actions]
        #
        if len(self.continuous_actions) > 0:
            continuous_actions = continuous_actions #if self.nb_agents > 1 else [continuous_actions]
            for agent_number, agent_action in enumerate(list(continuous_actions)[0]):
                action_dict = {}
                for action_index, act_name in enumerate(self.continuous_actions):
                    action_dict[act_name] = float(agent_action[action_index])
                    all_agent_actions[agent_number].update(action_dict)

        next_obs, reward, done, info = self._env.step(all_agent_actions)
        self.obs = [np.array(o['obs']) for o in next_obs]
        obs = self.get_obs()

        if self.nb_agents == 1:
            reward = reward[0]
            done = done[0]
            info = info[0]

        return obs, reward, done, info

    def reset(self):
        obs = self._env.reset()
        self.obs = [np.array(o['obs']) for o in obs]
        obs = self.get_obs()
        return obs

    def get_obs(self):
        all_obs = []
        for o in self.obs:
            obs = {
                "image" : o[ : 3*8*8],
                "vector_obs" : o[3*8*8 : 3*8*8 + 21]
            }

            agent_pos = obs["vector_obs"][[6,7]]
            pos_map = np.zeros((32,32))[None, :]
            pos_map[0, int(agent_pos[0]), int(agent_pos[1])] = 1

            if "augmented" in self.agent_type:
                heatmap = o[(3*8*8)+21:].reshape(32,32)[None, :]
                obs["heatmap"] = np.stack([heatmap, pos_map]).squeeze()

            all_obs.append(obs)

        return np.array(all_obs)
    
    def render(self, mode='human'):
        raise NotImplementedError


class GodotEnv:
    MAJOR_VERSION = "0"
    MINOR_VERSION = "1"
    DEFAULT_PORT = 11008
    DEFAULT_TIMEOUT = 60

    def __init__(
        self,
        env_path=None,
        port=11008,
        show_window=False,
        seed=0,
        framerate=None,
        action_repeat=None,
        map_name=None,
        nb_agents=1,
        episodic=False,
        agent_type="countbased"
    ):

        if env_path is None:
            port = GodotEnv.DEFAULT_PORT
        self.proc = None
        if env_path is not None:
            self.check_platform(env_path)
            self._launch_env(
                env_path, port, show_window, framerate, seed, action_repeat, map_name, nb_agents, agent_type, episodic
            )
        else:
            print(
                "No game binary has been provided, please press PLAY in the Godot editor"
            )

        self.port = port
        self.connection = self._start_server()
        self.num_envs = None
        self._handshake()
        self._get_env_info()

        atexit.register(self._close)

    def check_platform(self, filename: str):

        if platform == "linux" or platform == "linux2":
            assert (
                pathlib.Path(filename).suffix == ".x86_64"
            ), f"incorrect file suffix for fileman {filename} suffix {pathlib.Path(filename).suffix }"
        elif platform == "darwin":
            assert 0, "mac is not supported, yet"
            # OS X
        elif platform == "win32":
            # Windows...
            assert (
                pathlib.Path(filename).suffix == ".exe"
            ), f"incorrect file suffix for fileman {filename} suffix {pathlib.Path(filename).suffix }"
        else:
            assert 0, f"unknown filetype {pathlib.Path(filename).suffix}"

        assert os.path.exists(filename)

    def from_numpy(self, action):
        result = []

        for a in action:
            d = {}
            for k, v in a.items():
                if isinstance(v, np.ndarray):
                    d[k] = v.tolist()
                else:
                    d[k] = int(v)
            result.append(d)

        return result

    def step(self, action):
        message = {
            "type": "action",
            "action": action,
        }
        self._send_as_json(message)
        response = self._get_json_dict()

        response["obs"] = self._process_obs(response["obs"])

        info_to_use = []
        for agent_info in response["info"]:
            agent_info_to_use = {}
            for key in agent_info.keys():
                if key[0] != "_":
                    value = agent_info[key] if type(agent_info[key]) is not list else agent_info[key][0]
                    agent_info_to_use[key] = value
            info_to_use.append(agent_info_to_use)

        return (
            response["obs"],
            response["reward"],
            np.array(response["done"]).tolist(),
            info_to_use,
        )

    def _process_obs(self, response_obs: dict):

        for k in response_obs[0].keys():
            if "2d" in k:
                for sub in response_obs:
                    sub[k] = self.decode_2d_obs_from_string(
                        sub[k], self.observation_space[k].shape
                    )

        return response_obs

    def reset(self):
        # may need to clear message buffer
        # there will be a the next obs to collect
        # _ = self._get_json_dict()
        # self._clear_socket()
        message = {
            "type": "reset",
        }
        self._send_as_json(message)
        response = self._get_json_dict()
        response["obs"] = self._process_obs(response["obs"])
        assert response["type"] == "reset"
        obs = np.array(response["obs"])
        return obs

    def call(self, method):
        message = {
            "type": "call",
            "method": method,
        }
        self._send_as_json(message)
        response = self._get_json_dict()

        return response["returns"]

    def close(self):
        message = {
            "type": "close",
        }
        self._send_as_json(message)
        print("close message sent")
        time.sleep(1.0)
        self.connection.close()
        try:
            atexit.unregister(self._close)
        except Exception as e:
            print("exception unregistering close method", e)

    def _close(self):
        print("exit was not clean, using atexit to close env")
        self.close()

    def _launch_env(self, env_path, port, show_window, framerate, seed, action_repeat, map_name, nb_agents, agent_type, episodic):
        # --fixed-fps {framerate}
        launch_cmd = f"{env_path} --port={port} --env_seed={seed}"

        if framerate is not None:
            launch_cmd += f" --fixed-fps={framerate}"
            launch_cmd += f" --steps_per_second={framerate}"
        if action_repeat is not None:
            launch_cmd += f" --action_repeat={action_repeat}"
        if map_name is not None:
            launch_cmd += " --map_name=" + map_name
        if nb_agents is not None:
            launch_cmd += f" --nb_agents={nb_agents}"
        if agent_type is not None:
            launch_cmd += f" --agent_type={agent_type}"
        if episodic is not None:
            launch_cmd += f" --episodic={episodic}"

        if show_window == False:
            launch_cmd += " --disable-render-loop --no-window"

        launch_cmd = launch_cmd.split(" ")
        self.proc = subprocess.Popen(
            launch_cmd,
            start_new_session=True,
            # shell=True,
        )

    def _start_server(self):
        # Either launch a an exported Godot project or connect to a playing godot game
        # connect to playing godot game

        print(f"waiting for remote GODOT connection on port {self.port}")
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # Bind the socket to the port, "localhost" was not working on windows VM, had to use the IP
        server_address = ("127.0.0.1", self.port)
        sock.bind(server_address)

        # Listen for incoming connections
        sock.listen(1)
        sock.settimeout(GodotEnv.DEFAULT_TIMEOUT)
        connection, client_address = sock.accept()
        # connection.settimeout(GodotEnv.DEFAULT_TIMEOUT)
        #        connection.setblocking(False) TODO
        print("connection established")
        return connection

    def _handshake(self):
        message = {
            "type": "handshake",
            "major_version": GodotEnv.MAJOR_VERSION,
            "minor_version": GodotEnv.MINOR_VERSION,
        }

        self._send_as_json(message)

    def _get_env_info(self):
        message = {"type": "env_info"}
        self._send_as_json(message)

        json_dict = self._get_json_dict()
        assert json_dict["type"] == "env_info"

        # actions can be "single" for a single action head
        # or "multi" for several outputeads
        action_spaces = {}
        print("action space", json_dict["action_space"])
        for k, v in json_dict["action_space"].items():
            if v["action_type"] == "discrete":
                action_spaces[k] = spaces.Discrete(v["size"])
            elif v["action_type"] == "continuous":
                action_spaces[k] = spaces.Box(low=-1.0, high=1.0, shape=(v["size"],))
            else:
                print(f"action space {v['action_type']} is not supported")
                assert 0, f"action space {v['action_type']} is not supported"
        self.action_space = spaces.Dict(action_spaces)

        observation_spaces = {}
        print("observation space", json_dict["observation_space"])
        for k, v in json_dict["observation_space"].items():
            if v["space"] == "box":
                observation_spaces[k] = spaces.Box(
                    low=-1.0,
                    high=1.0,
                    shape=v["size"],
                    dtype=np.float32,
                )
            elif v["space"] == "discrete":
                observation_spaces[k] = spaces.Discrete(v["size"])
            elif v["space"] == "repeated":
                raise NotImplementedError("We don't support repeated state spaces for now")
                assert "max_length" in v
                if v["subspace"] == "box":
                    subspace = observation_spaces[k] = spaces.Box(
                        low=-1.0,
                        high=1.0,
                        shape=v["size"],
                        dtype=np.float32,
                    )
                elif v["subspace"] == "discrete":
                    subspace = spaces.Discrete(v["size"])
                observation_spaces[k] = Repeated(subspace, v["max_length"])
            else:
                print(f"observation space {v['space']} is not supported")
                assert 0, f"observation space {v['space']} is not supported"
        self.observation_space = spaces.Dict(observation_spaces)

    @staticmethod
    def decode_2d_obs_from_string(
        hex_string,
        shape,
    ):
        return (
            np.frombuffer(bytes.fromhex(hex_string), dtype=np.float16)
            .reshape(shape)
            .astype(np.float32)[:, :, :]  # TODO remove the alpha channel
        )

    def _send_as_json(self, dictionary):
        message_json = json.dumps(dictionary)
        self._send_string(message_json)

    def _get_json_dict(self):
        data = self._get_data()
        try:
            return json.loads(data)
        except:
            embed()

    def _get_obs(self):
        return self._get_data()

    def _clear_socket(self):

        self.connection.setblocking(False)
        try:
            while True:
                data = self.connection.recv(4)
                if not data:
                    break
        except BlockingIOError as e:
            # print("BlockingIOError expection on clear")
            pass
        self.connection.setblocking(True)

    def _get_data(self):
        try:
            data = self.connection.recv(4)
            if not data:
                time.sleep(0.000001)
                return self._get_data()
            length = int.from_bytes(data, "little")
            string = ""
            while (
                len(string) != length
            ):  # TODO: refactor as string concatenation could be slow
                string += self.connection.recv(length).decode()

            return string
        except socket.timeout as e:
            print("env timed out", e)

        return None

    def _send_string(self, string):
        message = len(string).to_bytes(4, "little") + bytes(string.encode())
        self.connection.sendall(message)

    def _send_action(self, action):
        self._send_string(action)


if __name__ == "__main__":
    import matplotlib.pyplot as plt

    env = GodotEnv()
    print("observation space", env.observation_space)
    print("action space", env.action_space)
    obs = env.reset()

    for i in range(1000):

        # env.reset()
        obs, reward, done, info = env.step(
            [env.action_space.sample() for _ in range(env.num_envs)]
        )
    env.close()
