import gym
from gym import spaces
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt


# Puddle World is a continuous state 2-dimensional world with (x, y) ∈ [0, 1]2
# with 2 puddles: (1) [0.45, 0.4] to [0.45, 0.8], and (2) [0.1, 0.75] to
# [0.45, 0.75] - with radius 0.1 and the goal is the region
# (x, y) ∈ ([0.95, 1.0], [0.95, 1.0]). The agent receives a reward of −1−400∗d
# on each time step, where d denotes the distance between the agent’s position
# and the center of the puddle, and an undiscounted, episodic γ of 1.0. The
# agent can select an action to move 0.05 + ζ, ζ ∼ N(µ = 0, σ2 = 0.01)

# https://papers.nips.cc/paper/1109-generalization-in-reinforcement-learning-successful-examples-using-sparse-coarse-coding.pdf

ACTIONS = {
    0: np.array([-0.05, 0.0]),
    1: np.array([0.05, 0.0]),
    2: np.array([0.0, -0.05]),
    3: np.array([0.0, 0.05]),
    4: np.array([0.0, 0.0]),
}


def estimate_value(states,
                   policy,
                   num_rollouts=100,
                   max_rollout_length=1000,
                   environment=None):
    if environment is None:
        environment = PuddleWorld()

    values = np.full((len(states), num_rollouts), np.nan)

    for state_i, state in enumerate(states):
        print(f"state {state_i}/{len(states)}: {state}")
        for iteration in range(num_rollouts):
            print(f"iteration {iteration}/{num_rollouts}")
            path = []
            observation_t_0 = environment.reset(position=state)
            for i in range(max_rollout_length):
                action = policy.action(observation_t_0)
                observation_t_1, reward, terminal, info = (
                    environment.step(action))
                # environment.render()
                path.append({
                    'observation_t_0': observation_t_0,
                    'observation_t_1': observation_t_1,
                    'reward': reward,
                    'terminal': terminal,
                    'info': info,
                })
                observation_t_0 = observation_t_1
                if terminal:
                    break
            value = np.sum([step['reward'] for step in path])
            values[state_i, iteration] = value

    return values


def distance_to_segment(start, end, point):
    start, end, point = np.array(start), np.array(end), np.array(point)

    u = end - start
    v = point - start
    x = (np.einsum('...i,...i->...', u, v) / (
        np.einsum('...i,...i->...', u, u)))

    result = np.full(x.shape, np.nan)

    result[x < 0] = np.linalg.norm(point[x < 0] - start, ord=2, axis=-1)

    assert np.all(np.isnan(result[np.logical_and(0 <= x, x <= 1.0)]))
    projection_points = start + x[..., None] * u
    result[np.logical_and(0 <= x, x <= 1.0)] = np.linalg.norm(
        projection_points[np.logical_and(0 <= x, x <= 1.0)]
        - point[np.logical_and(0 <= x, x <= 1.0)], ord=2, axis=-1)

    assert np.all(np.isnan(result[1 < x]))
    result[1 < x] = np.linalg.norm(point[1 < x] - end, ord=2, axis=-1)

    assert not np.any(np.isnan(result))

    return result


class PuddleWorld(gym.Env):
    metadata = {'render.modes': ('human', 'rgb_array')}

    def __init__(self,
                 reset_fn=lambda: np.random.uniform(
                     low=[0.1, 0.45], high=[0.3, 0.65]),
                 goal=[1.0, 1.0],
                 goal_threshold=0.1,
                 noise_scale=0.01,
                 puddles=(
                     ((0.45, 0.4), (0.45, 0.8), 0.1),
                     ((0.1, 0.75), (0.45, 0.75), 0.1)),
                 puddle_cost_weight=1.0):
        self.reset_fn = reset_fn
        self.goal = np.array(goal)
        self.goal_threshold = goal_threshold
        self.noise_scale = noise_scale
        self.puddles = puddles
        self.puddle_cost_weight = puddle_cost_weight

        self.action_space = spaces.Discrete(5)
        self.observation_space = spaces.Box(0.0, 1.0, shape=(2,))

        self.viewer = None

    def step(self, action):
        action = action.item()
        if not self.action_space.contains(action):
            raise ValueError(f"Invalid action: {action}.")

        self.position += ACTIONS[action] + np.random.uniform(
            low=-self.noise_scale, high=self.noise_scale, size=(2, ))
        self.position = np.clip(self.position, 0.0, 1.0)

        reward = self._compute_reward(self.position)

        done = np.linalg.norm(
            (self.position - self.goal), ord=1) < self.goal_threshold

        info = {}

        return self.position.copy(), reward, done, info

    def _compute_reward(self, position):
        puddle_costs = [
            np.maximum(
                0,
                puddle_radius
                - distance_to_segment(puddle_start, puddle_end, position)
            ) * 400
            for puddle_start, puddle_end, puddle_radius in self.puddles
        ]
        puddle_cost = self.puddle_cost_weight * np.max(puddle_costs, axis=0)
        constant_cost = 1.0
        reward = -(constant_cost + puddle_cost)
        return reward

    def uniform_random_state(self):
        state = self.observation_space.sample()
        goal_state = np.linalg.norm(
            (state - self.goal), ord=1) < self.goal_threshold
        while goal_state:
            state = self.observation_space.sample()
            goal_state = np.linalg.norm(
                (state - self.goal), ord=1) < self.goal_threshold
        return state

    def reset(self, position=None):
        if position is None:
            if callable(self.reset_fn):
                self.position = self.reset_fn()
            else:
                self.position = np.array(self.reset_fn)
        else:
            self.position = np.array(position)

        return self.position.copy()

    def render(self, mode='human', close=False):
        if close:
            if self.viewer is not None:
                self.viewer.close()
                self.viewer = None
            return

        screen_width = 400
        screen_height = 400

        if self.viewer is None:
            from gym.envs.classic_control import rendering
            from gym_puddle.shapes.image import Image
            self.viewer = rendering.Viewer(screen_width, screen_height)

            import pyglet
            img_width = 100
            img_height = 100
            fformat = 'RGB'
            pixels = np.zeros((img_width, img_height, len(fformat)))
            for i in range(img_width):
                for j in range(img_height):
                    x = float(i)/img_width
                    y = float(j)/img_height
                    pixels[j, i, :] = self._compute_reward(np.array([x, y]))

            pixels -= pixels.min()
            pixels *= 255. / pixels.max()
            pixels = np.floor(pixels)

            img = pyglet.image.create(img_width, img_height)
            img.format = fformat

            data = [chr(int(pixel)) for pixel in pixels.flatten()]

            img.set_data(fformat, img_width * len(fformat), ''.join(data))
            bg_image = Image(img, screen_width, screen_height)
            bg_image.set_color(1.0, 1.0, 1.0)

            self.viewer.add_geom(bg_image)

            thickness = 5
            agent_polygon = rendering.FilledPolygon([
                (-thickness, -thickness),
                (-thickness, thickness),
                (thickness, thickness),
                (thickness, -thickness)])
            agent_polygon.set_color(0.0, 1.0, 0.0)
            self.agenttrans = rendering.Transform()
            agent_polygon.add_attr(self.agenttrans)
            self.viewer.add_geom(agent_polygon)

        self.agenttrans.set_translation(
            self.position[0]*screen_width, self.position[1]*screen_height)

        return self.viewer.render(return_rgb_array=(mode == 'rgb_array'))
