from typing import Dict
import gym
import numpy as np
from gym import spaces
import pygame
import math
import cvxpy as cp

from action_masking.provably_safe_env.utils_geometry import (
    circle_contains_point,
    get_boundary_distance_from_line,
    order_vertices_clockwise,
)

from action_masking.util.sets import Zonotope


class SeekerCircleEnv(gym.Env):
    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}

    def __init__(
        self,
        render_mode: str = None,
        size: float = 10,
        seed: int = None,
        randomize: bool = True,
        render_safe_input_set: bool = False,
        template_input_set: Zonotope = None,
    ) -> None:
        super().__init__()
        self.window = None
        self.clock = None
        self.window_size = 512
        self.render_mode = render_mode
        self.render_safe_input_set = render_safe_input_set
        self.template_input_set = template_input_set
        self.randomize = randomize

        self.rnd_seed = seed
        if seed is not None:
            np.random.seed(self.rnd_seed)

        self.size = size
        self.action_space = spaces.Box(-1, 1, (2,), dtype=np.float32)

        # observations (ag_x, ag_y, ob_x, ob_y, ob_r, goal_x, goal_y)
        self.observation_shape = (7,)
        self.observation_space = spaces.Box(
            low=-self.size,
            high=self.size,
            shape=self.observation_shape,
            dtype=np.float32,
        )

        self._agent_position = np.zeros(2)
        self._obstacle_position = np.zeros(2)
        self._obstacle_radius = np.zeros(1)
        self._goal_position = np.zeros(2)
        self._goal_radius = 0.05 * self.size

        assert render_mode is None or render_mode in self.metadata["render_modes"]

    def reset(self, seed: int = None, options: Dict = None):
        if not self.randomize:
            self._agent_position = np.zeros(2) - 6 + np.random.uniform(-0.1, 0.1, 2)
            self._goal_position = np.zeros(2) + 6 
            self._obstacle_position = np.zeros(2)
            self._obstacle_radius = 4
            return self._get_obs()

        if seed is not None:
            np.random.seed(seed)

        self._agent_position = np.random.uniform(-self.size, self.size, (2,))

        # Sample the goal position from the other half of the box, where the agent is not
        self._goal_position = np.random.uniform(-self.size, self.size, (2,))
        while np.linalg.norm(self._agent_position - self._goal_position) < 5:
            self._goal_position = np.random.uniform(-10, 10, (2,))

        # Get a random point on the line between agent and goal
        point_temp = self._agent_position + np.random.uniform(0.3, 0.7) * (
            self._goal_position - self._agent_position
        )

        normal = self._goal_position - self._agent_position
        normal = np.array([normal[1], -normal[0]])
        normal = normal / np.linalg.norm(normal)

        # Calculate the distance from point_temp to the boarder (self.size) along the normal
        a_positive = get_boundary_distance_from_line(point_temp, normal, self.size)
        a_negative = -get_boundary_distance_from_line(point_temp, -normal, self.size)

        # Sample a random point along this line
        obstacle_pos = (
            point_temp + np.random.uniform(a_negative * 0.2, a_positive * 0.2) * normal
        )

        min_radius = np.linalg.norm(obstacle_pos - point_temp)
        max_radius = np.min(
            [
                np.linalg.norm(obstacle_pos - self._agent_position),
                np.linalg.norm(obstacle_pos - self._goal_position),
            ]
        )

        # min_radius * 1.1 could theoretically be higher than max_radius!
        eps = self.size / 100
        radius = np.random.uniform(min_radius + eps, max_radius - eps)

        # Construct the obstacle
        self._obstacle_position = obstacle_pos
        self._obstacle_radius = radius

        if circle_contains_point(
            self._obstacle_position, self._obstacle_radius, self._agent_position
        ) or circle_contains_point(
            self._obstacle_position, self._obstacle_radius, self._goal_position
        ):
            print("Repeating env.reset()")
            return self.reset(seed, options)

        if self.render_mode == "human":
            self._render_frame()

        return self._get_obs()
        # return self._get_obs(), self._get_info()

    def step(self, action: np.ndarray):
        self._agent_position = self._agent_position + action

        dist_to_goal = np.linalg.norm(self._goal_position - self._agent_position)
        reward = -dist_to_goal / self.size
        done = False

        if self._check_collision():
            reward = -10 * self.size
            done = True

        if self._goal_reached():
            reward = 10 * self.size
            done = True

        if self.render_mode == "human":
            self._render_frame()

        # Testing -> roughly 25x computation time
        # try:
        #     _ = self.calc_safe_input_set_zono(Zonotope.from_unit_box(2))
        # except:
        #     print("Safe input set calculation failed!")

        return self._get_obs(), reward, done, self._get_info()

    def render(self, mode: str = "human"):
        self._render_frame()

    # Deprecated!
    # def transform_action(self, action: np.ndarray) -> np.ndarray:
    #     """Transform action from [-1, 1] to [-self.size / 50, self.size / 50]"""
    #     return action * self.size / 50

    # def inv_transform_action(self, action: np.ndarray) -> np.ndarray:
    #     """Inverse transform action from [-self.size / 50, self.size / 50] to [-1, 1]"""
    #     return action * 50 / self.size

    def _goal_reached(self) -> bool:
        return (
            np.linalg.norm(self._goal_position - self._agent_position)
            < self._goal_radius
        )

    def _check_collision(self) -> bool:
        if (np.abs(self._agent_position) >= self.size).any():
            return True

        if circle_contains_point(
            self._obstacle_position, self._obstacle_radius, self._agent_position
        ):
            return True

    def _get_obs(self) -> np.ndarray:
        return np.concatenate(
            [
                self._agent_position,
                self._obstacle_position,
                np.array([self._obstacle_radius]),
                self._goal_position,
            ]
        )
        # return {
        #     "agent": self._agent_position,
        #     "goal": self._goal_position,
        #     "obstacle_pos": self._obstacle_position,
        #     "obstacle_radius": self._obstacle_radius,
        # }

    def _get_info(self):
        return {
            "distance": np.linalg.norm(self._goal_position - self._agent_position),
        }

    def _render_frame(
        self,
    ) -> np.ndarray:
        # Create the pygame window
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode((self.window_size, self.window_size))
            pygame.display.set_caption("Seeker")

        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))

        # Draw the agent
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (
                int((self._agent_position[0] + 10) / 20 * self.window_size),
                int((self._agent_position[1] + 10) / 20 * self.window_size),
            ),
            5,
        )

        # Draw the obstacle
        pygame.draw.circle(
            canvas,
            (255, 0, 0),
            (
                int((self._obstacle_position[0] + 10) / 20 * self.window_size),
                int((self._obstacle_position[1] + 10) / 20 * self.window_size),
            ),
            int(self._obstacle_radius / 20 * self.window_size),
        )

        # Draw the goal
        pygame.draw.circle(
            canvas,
            (0, 255, 0),
            (
                int((self._goal_position[0] + 10) / 20 * self.window_size),
                int((self._goal_position[1] + 10) / 20 * self.window_size),
            ),
            int(self._goal_radius / 20 * self.window_size),
        )

        if self.render_safe_input_set:
            template = (
                self.template_input_set
                if self.template_input_set is not None
                else Zonotope.from_unit_box(2)
            )
            safe_input_set = self.calc_safe_input_set_zono(template, mode="vol_max")

            safe_input_set.c += self._agent_position[:, np.newaxis]
            safe_input_set_polygon = np.array(
                [
                    (
                        int((point[0] + 10) / 20 * self.window_size),
                        int((point[1] + 10) / 20 * self.window_size),
                    )
                    for point in safe_input_set.vertices.T
                ]
            )
            safe_input_set_polygon = order_vertices_clockwise(safe_input_set_polygon)
            pygame.draw.polygon(canvas, (0, 0, 0), safe_input_set_polygon, 1)

        if self.render_mode == "human":
            # Limit the frame rate
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            self.clock.tick(self.metadata["render_fps"])

        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )

    def close(self):
        """Close the render window."""
        if self.window is not None:
            pygame.display.quit()
            pygame.quit()

    def calc_safe_input_set_zono(
        self, tU: Zonotope, mode: str = "vol_max", debug: bool = False
    ) -> Zonotope:
        n_U = 2
        nG_U = tU.G.shape[1]

        c_U = cp.Variable(n_U)
        G_U = cp.Variable((n_U, nG_U))

        # scaling factors for template input set
        p = cp.Variable((nG_U, 1), nonneg=True)
        c0 = G_U - tU.G @ cp.diag(p) == 0

        def support_function(d, G, c):
            return d @ c + cp.sum([cp.abs(d @ G[:, i]) for i in range(G.shape[1])])

        # U_\phi \in U_F
        c11 = c_U + cp.sum(cp.abs(G_U), axis=1) <= 1
        c12 = -c_U + cp.sum(cp.abs(G_U), axis=1) <= 1
        # c11 = support_function(np.array([1, 0]), G_U, c_U) <= 1
        # c12 = support_function(np.array([0, 1]), G_U, c_U) <= 1
        # c13 = support_function(np.array([-1, 0]), G_U, c_U) <= 1
        # c14 = support_function(np.array([0, -1]), G_U, c_U) <= 1

        # X_1 \in X_s
        c_1 = self._agent_position + c_U
        c21 = c_1 + cp.sum(cp.abs(G_U), axis=1) <= self.size
        c22 = -c_1 + cp.sum(cp.abs(G_U), axis=1) <= self.size
        # c21 = support_function(np.array([1, 0]), G_U, c_U) <= self.size
        # c22 = support_function(np.array([0, 1]), G_U, c_U) <= self.size
        # c23 = support_function(np.array([-1, 0]), G_U, c_U) <= self.size
        # c24 = support_function(np.array([0, -1]), G_U, c_U) <= self.size

        # X_1 \cap Obstacle = \{ \}
        # i.e., X_1 \in \{ x | a_T x \leq b \}
        a = self._obstacle_position - self._agent_position
        a /= np.linalg.norm(a)
        b = np.dot(a, self._obstacle_position - a * self._obstacle_radius)
        c3 = a @ c_1 + cp.sum(cp.abs(a @ G_U)) <= b
        # c3 = a @ c_1 + cp.sum([cp.abs(a @ G_U[:, i]) for i in range(nG_U)]) <= b

        constraints = [c0, c11, c12, c21, c22, c3]
        # constraints = [c0, c11, c12, c13, c14, c21, c22, c23, c24, c3]

        if mode == "vol_max":
            objective = cp.Maximize(cp.geo_mean(p))
        else:
            raise NotImplementedError(f"Mode {mode} not implemented")

        problem = cp.Problem(objective, constraints)
        # problem.solve(cp.ECOS)
        problem.solve(cp.CLARABEL)

        if debug:
            # Plot zonotope
            import matplotlib.pyplot as plt
            from copy import deepcopy

            input_set = Zonotope(G_U.value, c_U.value)
            state_set = deepcopy(input_set)
            state_set.c += self._agent_position[:, np.newaxis]

            fig, ax = plt.subplots()
            vertices = state_set.vertices

            ax.plot(vertices[0, :], vertices[1, :], "k", label="state set")

            point = self._obstacle_position - self._obstacle_radius * a
            ax.plot(point[0], point[1], "ro", label="Point")

            # Plot the halfspace \{ x | a_T x \leq b \}
            x = np.linspace(-self.size, self.size, 100)
            y = (b - a[0] * x) / a[1]
            mask = (y >= -self.size) & (y <= self.size)
            x = x[mask]
            y = y[mask]
            ax.plot(x, y, "r", label="Halfspace")

            # Plot a
            ax.quiver(point[0], point[1], a[0], a[1], color="g", label="a")

            # Plot the obstacle
            circle = plt.Circle(
                self._obstacle_position,
                self._obstacle_radius,
                color="r",
                fill=False,
                label="Obstacle",
            )
            ax.add_artist(circle)

            # Plot the agent
            ax.plot(
                self._agent_position[0], self._agent_position[1], "bo", label="Agent"
            )

            ax.set_aspect("equal")
            plt.legend()

            plt.show()

        if problem.status not in [cp.OPTIMAL]:
            raise ValueError("Safe input set cannot be calculated")

        return Zonotope(G_U.value, c_U.value)


if __name__ == "__main__":
    env = SeekerCircleEnv(render_mode="human", render_safe_input_set=True, seed=1)
    env.reset()
    safe_set = env.calc_safe_input_set_zono(Zonotope.from_unit_box(2), debug=True)
    for i in range(1000):
        env.render()
