# %%
import gym
from matplotlib import pyplot as plt
from gym import error, spaces, utils
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from matplotlib.ticker import FuncFormatter
import matplotlib.colors as colors

import numpy as np
from PIL import Image
# Env from Maximilian Igl
# First dimension: x, second dimension: y
# Layouts are defined by 'WALKABLE' (np array) and a map with constraints

MOVE = {
    0: np.array([0, 1]),  # No-op
    1: np.array([0, -1]),  # Turn
    2: np.array([1, 0]),
    3: np.array([-1, 0]),
}

ACTIONS = {
    0: np.array([0,0]), # No-op
    1: np.array([0,-1]),
    2: np.array([0,1]),
    3: np.array([1,0]),
    4: np.array([-1,0]),
    5: np.array([0,0]), # Pickup
}

shape = (8, 8)

TAXI_ROOMS_LAYOUT = np.array([
    [0, 0, 0, 0, 0, 0, 0, 0],
    [0, 1, 1, 0, 1, 1, 1, 0],
    [0, 1, 1, 0, 1, 1, 1, 0],
    [0, 1, 1, 1, 1, 1, 1, 0],
    [0, 1, 1, 1, 1, 1, 1, 0],
    [0, 1, 0, 1, 0, 1, 1, 0],
    [0, 1, 0, 1, 0, 1, 1, 0],
    [0, 0, 0, 0, 0, 0, 0, 0],
])

LOCATIONS = {
    'R': np.array([1, 1]),
    'Y': np.array([6, 1]),
    'B': np.array([6, 5]),
    'G': np.array([1, 6])
}

MAP = [
    "+---------+",
    "|R: | : :G|",
    "| : : : : |",
    "| : : : : |",
    "| | : | : |",
    "|Y| : |B: |",
    "+---------+",
]


class Taxi(gym.Env):
    metadata = {'render.modes': ['human'],
                'video.frames_per_second': 20}

    def __init__(self,
                 reward_per_timestep=-0.1,
                 discount=0.95,
                 random_move_prob=0.0):
        super().__init__()

        self.viewer = None
        self.walkable = TAXI_ROOMS_LAYOUT
        self.grid = self.map = self.walkable
        self.timestep_limit = 1500
        self.timesteps = 0

        self.random_move_prob = random_move_prob

        self.actions = ACTIONS

        self.one_hot = True
        self.observation_space = spaces.Box(low=0., high=1., shape=(72,))
        self.action_space = spaces.Discrete(6)
        self.state_map, self.state_inverse_map = self._create_coord_mapping()

        self.state = {
            'loc': (0, 0),
            'pas': False
        }

        self.task_data = {
            'pic': np.array([0, 0]),
            'gol': np.array([0, 0]),
            'start': np.array([0, 0]),
            'pas': False
        }

        self.done = True
        self.inner_done = True

        self.realgoal = None
        self.randomizeCorrect()

        self.reward_goal_found = 2
        self.reward_per_timestep = reward_per_timestep
        self.discount = discount

    def randomizeCorrect(self):
        self.realgoal = np.random.choice([i for i in range(60) if i not in [0, 4, 8, 13, 17, 21, 26, 30, 34, 39, 43, 47]])
        self.set_task(self.realgoal)
        self.reset()

    def set_task(self, task):
        choices = ['R', 'Y', 'B', 'G']
        self.task = task
        self.timesteps = 0
        if task < 48:
            goal = task % 4
            pic = (task // 4) % 3
            start = ((task // 4) // 3) % 4

            self.task_data['gol'] = LOCATIONS[choices[goal]]
            self.task_data['start'] = LOCATIONS[choices[start]]
            self.task_data['pas'] = False
            self.task_data['pic'] = LOCATIONS[choices[(goal + pic + 1) % 4]]
        else:
            task -= 48
            goal = task % 4
            start = (task // 4) % 3
            self.task_data['gol'] = LOCATIONS[choices[goal]]
            self.task_data['start'] = LOCATIONS[choices[(goal + start + 1) % 4]]
            self.task_data['pas'] = True
            self.task_data['pic'] = np.array([0, 0])

    def _create_coord_mapping(self):
        """
        Create mapping from coordinates to state id and inverse mapping
        :return: Mapping from coords to id, mapping from id to coords
        """
        state_map = {}
        state_inverse_map = {}
        for i in range(1,7):
            for j in range(1,7):
                x = i - 1
                y = j - 1
                idx = y * 6 + x
                state_map[(i, j)] = idx
                state_inverse_map[idx] = (i, j)
                state_inverse_map[idx + 36] = (i, j)
        return state_map, state_inverse_map

    def reset(self):
        self.done = False
        self.inner_done = False
        self.timesteps = 0
        self.set_task(self.realgoal)
        self.state['pas'] = self.task_data['pas']
        self.state['loc'] = self.task_data['start']
        return self.create_observation()

    def step(self, action):
        if self.done:
            raise RuntimeError(f"Environment must be reset {self.timesteps}")
        if self.inner_done:
            obs = np.zeros(72)
            reward = 0
            self.timesteps += 1
            if self.timesteps >= self.timestep_limit:
                if not self.done:
                    self.done = True
            return obs, reward, self.done, {'state': self.state, 'task': self.task_data}
        else:
            action_to_perform = float(action)
            if np.random.rand() < self.random_move_prob:
                action_to_perform = np.random.randint(0, 6)

            # No-op (0) or movement (1-4)
            if action_to_perform < 5:
                mov = self.actions[action_to_perform]
                loc = self.state['loc']
                new_loc = loc + mov
                reward = self.reward_per_timestep
                if not (self.check_inside_area(new_loc) and self.check_walkable(new_loc)):
                    new_loc = loc
                self.state['loc'] = new_loc

            # Pickup
            elif action_to_perform == 5:
                if self.check_pickup_possible():
                    self.state['pas'] = True
                    reward = self.reward_per_timestep
                elif self.check_dropoff_possible():
                    # self.done = True
                    self.inner_done = True
                    reward = self.reward_goal_found
                else:
                    reward = self.reward_per_timestep

            self.timesteps += 1
            if self.timesteps >= self.timestep_limit:
                if not self.inner_done:
                    self.done = True
                    reward = self.reward_per_timestep * 1 / (1 - self.discount)
            if self.inner_done:
                obs = np.zeros(72)
            else:
                obs = self.create_observation()
            return obs, reward, self.done, {'state': self.state, 'task': self.task_data}

    def create_observation(self, loc=None, pas=None):
        if loc is None:
            loc = self.state['loc']
        if pas is None:
            pas = self.state['pas']
        obs = np.zeros(72)
        # Remove wall from state
        x = loc[0] - 1
        y = loc[1] - 1
        idx = pas * 36 + y * 6 + x
        obs[idx] = 1
        return obs

    def check_pickup_possible(self):
        return np.all(self.state['loc'] == self.task_data['pic']) and not self.state['pas']

    def check_dropoff_possible(self):
        return np.all(self.state['loc'] == self.task_data['gol']) and self.state['pas']

    def check_walkable(self, loc):
        return self.walkable[tuple(loc)] == 1

    def check_inside_area(self, loc):
        return (0 <= loc[0] < self.walkable.shape[0] and
                0 <= loc[1] < self.walkable.shape[1])

    def render(self, mode='human', close=False):
        if self.viewer is None:
            from gym.envs.classic_control import rendering
            max_x, max_y = self.walkable.shape
            square_size = 75

            screen_height = square_size * max_x
            screen_width = square_size * max_y
            self.viewer = rendering.Viewer(screen_width, screen_height)
            self.viewer.square_map = {}
            for i in range(max_x):
                for j in range(max_y):
                    l = j * square_size
                    r = l + square_size
                    t = max_x * square_size - i * square_size
                    b = t - square_size
                    square = rendering.FilledPolygon([(l, b), (l, t), (r, t), (r, b)])
                    self.viewer.add_geom(square)
                    self.viewer.square_map[(i, j)] = square

        for square_coords in self.viewer.square_map:
            color = None
            i, j = square_coords
            square = self.viewer.square_map[square_coords]

            # Agent
            if square_coords == tuple(self.state["loc"]):
                if self.state["pas"]:
                    color = [1, 1, 0]
                else:
                    color = [1, 0, 1]
            # Goal
            elif square_coords == tuple(self.task_data["gol"]):
                color = [0, 0, 1]

            # Passenger
            elif (not self.state["pas"]) and square_coords == tuple(self.task_data["pic"]):
                color = [0, 1, 1]

            # elif square_coords == tuple(self.default_start):
            #     color = [0, 1, 0]
            elif self.walkable[i, j] == 0:
                color = [0, 0, 0]
            elif self.walkable[i, j] == 1:
                color = [1, 1, 1]
            square.set_color(*color)

        return self.viewer.render(return_rgb_array=mode == 'rgb_array')

    def get_mdp_description(self):
        states = self.observation_space.shape[0]
        actions = self.action_space.n

        # +1 for terminal state
        transition_probs = np.zeros((states, actions, states))
        reward_distribution = np.zeros((states, actions))

        for state in range(states):
            y_coord, x_coord = self.state_inverse_map[state]
            pas = state >= 36

            # If it's a wall continue
            if not self.check_walkable((y_coord, x_coord)):
                continue

            curr_coords = np.array((y_coord, x_coord))
            for action in range(actions):
                # If it is a movement action reward is per_step reward and next_state might change
                if action < 5:
                    mov = self.actions[action]
                    new_coords = curr_coords + mov
                    new_pas = pas
                    reward = self.reward_per_timestep
                    if not (self.check_inside_area(new_coords) and self.check_walkable(new_coords)):
                        new_coords = curr_coords

                    next_state = self.state_map[tuple(new_coords)] + new_pas * 36

                # If it is a pickup action nothing happens unless it is passenger cell and it is picked up
                #  or it is a goal state and passenger is on board
                elif action == 5:
                    new_coords = curr_coords
                    reward = self.reward_per_timestep
                    if np.all(curr_coords == self.task_data['pic']) and not pas:
                        new_pas = True
                    else:
                        new_pas = pas

                    next_state = self.state_map[tuple(new_coords)] + new_pas * 36

                transition_probs[state, action, next_state] = 1
                reward_distribution[state, action] = reward

            if np.all(curr_coords == self.task_data['gol']) and pas:
                reward_distribution[state, 5] = self.reward_goal_found
                transition_probs[state, 5, state] = 0

        return transition_probs, reward_distribution

    def get_plot_arrow_params(self, max_val, max_index):
        entries = len(max_val)
        x_pos = np.zeros(entries)
        y_pos = np.zeros(entries)
        x_dir = np.zeros(entries)
        y_dir = np.zeros(entries)
        color = np.zeros(entries)
        # down, up , right, left

        # none left right down up none
        for i in range(entries):
            y = (i // 6) + 1
            x = (i % 6) + 1
            coords = (x, y)
            if self.map[coords] == 0:
                continue
            if max_index[i] == 3:
                y_dir[i] = 1
            elif max_index[i] == 4:
                y_dir[i] = -1
            elif max_index[i] == 2:
                x_dir[i] = 1
            elif max_index[i] == 1:
                x_dir[i] = -1

            x_pos[i] = coords[1] + 0.5
            y_pos[i] = coords[0] + 0.5
            color[i] = max_val[i]
        return x_pos, y_pos, x_dir, y_dir, color

    def create_coords(self):
        coords = []
        for i in range(36):
            y = (i // 6) + 1
            x = (i % 6) + 1
            coords.append((x, y))
        return coords

    def plot_params(self, option_probs, policy_probs, name_prefix="", no_blocks=False):
        grid = self.map if no_blocks else self.grid
        options = option_probs.shape[1]
        coords = self.create_coords()
        # TODO Refactor taxi and four envs a bit
        #  Mental note: Plot probabilities is probably good to have in env (since plots are different or call it plot
        #  values and use it for baseline as well)
        #  Add ploting for baseline too
        for i in range(options):
            fig, axes = plt.subplots(nrows=3, ncols=2, figsize=(16, 9))

            for j in range(2):
                if j == 0:
                    passanger_status = "NoPassenger"
                    selected_policy_probs = policy_probs[:36, i, :]
                    seleceted_option_probs = option_probs[:36, i]
                else:
                    passanger_status = "Passenger"
                    selected_policy_probs = policy_probs[36:, i, :]
                    seleceted_option_probs = option_probs[36:, i]
                max_val = np.max(selected_policy_probs, axis=1)
                max_index = np.argmax(selected_policy_probs, axis=1)

                plot_policy(ax=axes[0][j], arrow_data=self.get_plot_arrow_params(max_val, max_index), grid=grid,
                            title_suffix=passanger_status)
                plot_options(ax=axes[2][j], probs=seleceted_option_probs, coords=coords, grid=grid,
                             title_suffix=passanger_status)

            plt.tight_layout()
            plt.savefig(name_prefix + "_option{}".format(i) + ".png", bbox_inches='tight', dpi=150)


def create_grid_plot(ax, grid, color_map="binary"):
    if color_map == "binary":
        grid = 1 - grid
    size_y = grid.shape[0]
    size_x = grid.shape[1]
    vmax = max(float(np.max(grid)), 1)
    vmin = 0

    mat = ax.matshow(np.flip(grid, 0), cmap=plt.get_cmap(color_map), extent=[0, size_x, 0, size_y], vmin=vmin, vmax=vmax)
    ax.set_xticks(np.arange(0, size_x))
    ax.set_xticks(np.arange(0.5, size_x + 0.5), minor=True)
    # ax.set_xticklabels(np.arange(0, size_x), minor=True)
    plt.setp(ax.get_xmajorticklabels(), visible=False)
    ax.set_yticks(np.arange(0, size_y))
    ax.set_yticks(np.arange(0.5, size_y + 0.5), minor=True)
    # ax.set_yticklabels(np.arange(0, size_y), minor=True)
    ax.invert_yaxis()
    plt.setp(ax.get_ymajorticklabels(), visible=False)

    ax.tick_params(axis='both', which='both', length=0)
    ax.grid(color="black")
    return mat


def create_grid_plot_values(ax, grid, color_map, coords, probs):
    grid = np.ma.masked_array(grid, grid == 0)
    for i in range(len(coords)):
        if grid[coords[i][0], coords[i][1]] != 0:
            x = coords[i][1]
            y = coords[i][0]
            grid[y][x] = probs[i]
            # if probs[i] < 0.9 * max(probs.detach().numpy()):
            #     color = "blue"
            # else:
            #     color = "deepskyblue"
            ax.text(x + 0.5, y + 0.5, "{}".format(int(probs[i] * 100)), horizontalalignment='center',
                    verticalalignment='center', color="black", fontsize=7)
    return create_grid_plot(ax, grid, color_map=color_map)


def plot_options(ax, probs, coords, grid, title_suffix="", colorbar_size='10%'):
    create_grid_plot(ax, grid)
    grid = np.array(grid, float)
    mat = create_grid_plot_values(ax, grid, "YlGn", coords, probs)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
    plt.colorbar(mat, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0.0, 1.1, 0.25))
    ax.set_title(("Probability of choosing an option in states" + title_suffix))

def plot_policy(ax, arrow_data, grid, title_suffix="", values=True, headwidth=9, headlength=20, colorbar_size='10%'):
    create_grid_plot(ax, grid)
    x_pos, y_pos, x_dir, y_dir, color = arrow_data
    quiv = ax.quiver(x_pos, y_pos, x_dir, y_dir, color, cmap=plt.get_cmap("viridis"),
                     norm=colors.Normalize(vmin=color.min(), vmax=color.max()), angles='xy', scale_units='xy',
                     scale=1, pivot='middle', clim=(0.3, 1), headwidth=headwidth, headaxislength=headlength, headlength=headlength)# width=0.1)
    divider = make_axes_locatable(ax)

    if values:
        for i in range(len(x_pos)):
            x = x_pos[i]
            y = y_pos[i]
            if x_dir[i] == 0:
                x -= 0.25
            else:
                y -= 0.25
            ax.text(x, y, "%2d" % (color[i] * 100), horizontalalignment='center',
                    verticalalignment='center', color="black", fontsize=7)
    cax = divider.append_axes("right", size=colorbar_size, pad=0.025)
    plt.colorbar(quiv, cax=cax, ax=ax, format=FuncFormatter(lambda y, _: '{:.0%}'.format(y)), ticks=np.arange(0.3, 1.1, 0.1))
    ax.set_title(("Maximum likelihood actions in states" + title_suffix))
# register(
#     id='Taxi-v1',
#     entry_point='environments.taxi:Taxi',
#     max_episode_steps=50,
#     kwargs={'add_action_in_obs': False,
#             'image_obs': False}
# )


def env_set_task():
    env = Taxi()
    test = []
    for e in range(env.n_envs):
        env.set_task(e)
        print(f"Env {e} Goal {env.task_data['gol']} Pic {env.task_data['pic']} "
              f"Start {env.task_data['start']} Pass {env.task_data['pas']}")
        t = (list(env.task_data['gol']), list(env.task_data['pic']), list(env.task_data['start']), env.task_data['pas'])
        assert t not in test
        test.append(t)

        if e < 48:
            assert not env.task_data['pas']
            assert not np.all(env.task_data['pas'] == env.task_data['gol'])
        else:
            assert env.task_data['pas']
            assert not np.all(env.task_data['start'] == env.task_data['gol'])
    print("Test env set_task successful")


def envs_to_sample():
    from collections import Counter
    np.random.seed(42)
    envs_to_exclude = (1,3,4,20,57,47,30)
    env = Taxi(exclude_envs=envs_to_exclude)
    samples = 50000
    counter = Counter()

    for sample in range(samples):
        env.set_task(env.sample_task())
        assert env.get_task() not in envs_to_exclude, f"{env.get_task()} in envs_to_exclude"
        counter[env.get_task()] += 1
    percentages = (np.array(list(counter.values())) - 1) / samples
    assert np.allclose(percentages, 1.0/env.n_envs, atol=5e-2)
    assert len(percentages) == env.n_envs
    print("Test envs_to_sample successful")


if __name__ == '__main__':
    env_set_task()
    envs_to_sample()

    env = Taxi()
    b = np.copy(TAXI_ROOMS_LAYOUT)
    b[LOCATIONS['R'][0], LOCATIONS['R'][1]] = 2
    b[LOCATIONS['Y'][0], LOCATIONS['Y'][1]] = 3
    b[LOCATIONS['B'][0], LOCATIONS['B'][1]] = 4
    b[LOCATIONS['G'][0], LOCATIONS['G'][1]] = 5

    env_runs = 100
    env.seed(42)
    np.random.seed(42)
    for e in env.envs_to_sample:

        env.set_task(e)
        if (not env.task_data['pas']) and np.all(env.task_data['gol'] == env.task_data['start']):
            print(f"Task {e}\nStart/Goal: {env.task_data['gol']}\nPic: {env.task_data['pic']}")
        env.reset()
        # init_pass = env.state["pas"]
        # print("Goal {}, Pass {} Start {}".format(env.task_data["gol"], env.task_data["pic"], env.state))
        # print("Picked Up {}".format(init_pass))
        # env.render()
        # while True:
        #     state, reward, done, info = env.step(env.action_space.sample())
        #     env.render()
        #     if done:
        #         break
