import numpy as np
from d4rl.pointmaze import q_iteration
from d4rl.pointmaze.gridcraft import grid_env
from d4rl.pointmaze.gridcraft import grid_spec


ZEROS = np.zeros((2,), dtype=np.float32)
ONES = np.zeros((2,), dtype=np.float32)


class WaypointController:
    def __init__(self, maze_str, solve_thresh=0.1, p_gain=10.0, d_gain=-1.0, action_ub=0.1):
        """
        Waypoint controller for Maze2d environments.
        Args:
            maze_str (str): String representation of the maze.
            solve_thresh (float): Distance threshold for reaching waypoints.
            p_gain (float): Proportional gain for controller.
            d_gain (float): Derivative gain for controller.
            action_ub (float): Upper bound for action projection.
        """
        self.maze_str = maze_str
        self._target = -1000 * ONES

        self.p_gain = p_gain
        self.d_gain = d_gain
        self.solve_thresh = solve_thresh
        self.vel_thresh = 0.1
        self.action_ub = action_ub  # New upper bound for action constraint

        self._waypoint_idx = 0
        self._waypoints = []
        self._waypoint_prev_loc = ZEROS

        self.env = grid_env.GridEnv(grid_spec.spec_from_string(maze_str))

    def current_waypoint(self):
        return self._waypoints[self._waypoint_idx]

    def get_action(self, location, velocity, target):
        """
        Generate an action based on current location, velocity, and target.
        Args:
            location (np.ndarray): Current agent position.
            velocity (np.ndarray): Current agent velocity.
            target (np.ndarray): Target location.
        Returns:
            action (np.ndarray): Computed action.
            done (bool): Whether the task is completed.
        """
        if np.linalg.norm(self._target - np.array(self.gridify_state(target))) > 1e-3:
            self._new_target(location, target)

        dist = np.linalg.norm(location - self._target)
        vel = self._waypoint_prev_loc - location
        vel_norm = np.linalg.norm(vel)
        task_not_solved = (dist >= self.solve_thresh) or (vel_norm >= self.vel_thresh)

        if task_not_solved:
            next_wpnt = self._waypoints[self._waypoint_idx]
        else:
            next_wpnt = self._target

        # Compute control
        prop = next_wpnt - location
        action = self.p_gain * prop + self.d_gain * velocity

        # Apply action constraints
        action = np.clip(action, -self.action_ub, self.action_ub)

        dist_next_wpnt = np.linalg.norm(location - next_wpnt)
        if task_not_solved and (dist_next_wpnt < self.solve_thresh) and (vel_norm < self.vel_thresh):
            self._waypoint_idx += 1
            if self._waypoint_idx == len(self._waypoints) - 1:
                assert np.linalg.norm(self._waypoints[self._waypoint_idx] - self._target) <= self.solve_thresh

        self._waypoint_prev_loc = location
        return action, (not task_not_solved)

    def gridify_state(self, state):
        """
        Convert continuous state into discrete grid indices.
        Args:
            state (np.ndarray): Continuous state.
        Returns:
            (tuple): Grid indices.
        """
        return (int(round(state[0])), int(round(state[1])))

    def _new_target(self, start, target):
        """
        Recompute waypoints from start to target.
        Args:
            start (np.ndarray): Start position.
            target (np.ndarray): Target position.
        """
        start = self.gridify_state(start)
        start_idx = self.env.gs.xy_to_idx(start)
        target = self.gridify_state(target)
        target_idx = self.env.gs.xy_to_idx(target)
        self._waypoint_idx = 0

        self.env.gs[target] = grid_spec.REWARD
        q_values = q_iteration.q_iteration(env=self.env, num_itrs=50, discount=0.99)

        max_ts = 100
        s = start_idx
        waypoints = []
        for i in range(max_ts):
            a = np.argmax(q_values[s])
            new_s, reward = self.env.step_stateless(s, a)

            waypoint = self.env.gs.idx_to_xy(new_s)
            if new_s != target_idx:
                waypoint = waypoint - np.random.uniform(size=(2,)) * 0.2
            waypoints.append(waypoint)
            s = new_s
            if new_s == target_idx:
                break

        self.env.gs[target] = grid_spec.EMPTY
        self._waypoints = waypoints
        self._waypoint_prev_loc = start
        self._target = target