import numpy as np
import heapq
from scipy.interpolate import RegularGridInterpolator

def create_distance_map(maze_size=3.0, resolution=100, goal_pos=(1.0, 1.0)):
    """
    Create a distance map that gives the true shortest path distance to the goal
    for any point in the U-maze.
    
    Args:
        maze_size: Size of the maze (from -maze_size/2 to maze_size/2)
        resolution: Number of cells in each dimension
        goal_pos: The goal position (x, y)
        
    Returns:
        A function that maps (x, y) coordinates to path distances
    """
    # Create a grid
    cell_size = maze_size / resolution
    grid = np.ones((resolution, resolution), dtype=int) * np.inf
    
    # Convert coordinates to grid indices
    def coord_to_idx(x, y):
        i = int((y + maze_size/2) / cell_size)
        j = int((x + maze_size/2) / cell_size)
        i = max(0, min(i, resolution-1))
        j = max(0, min(j, resolution-1))
        return i, j
    
    # Define the U-maze walls
    for i in range(resolution):
        for j in range(resolution):
            x = j * cell_size - maze_size/2 + cell_size/2
            y = i * cell_size - maze_size/2 + cell_size/2
            
            # Set valid cells (inside the maze but outside the U-obstacle)
            if -maze_size/2 <= x <= maze_size/2 and -maze_size/2 <= y <= maze_size/2:
                # Check if in the U-obstacle
                if -0.5 <= x <= 0.5 and -0.5 <= y <= 0.5:
                    if not (x <= 0 and y >= 0):  # The open part of the U
                        grid[i, j] = np.inf  # Obstacle
                    else:
                        grid[i, j] = 1  # Valid cell
                else:
                    grid[i, j] = 1  # Valid cell
            else:
                grid[i, j] = np.inf  # Outside maze boundaries
    
    # Compute shortest path from every cell to the goal using Dijkstra's
    goal_i, goal_j = coord_to_idx(*goal_pos)
    distance_grid = np.ones_like(grid) * np.inf
    distance_grid[goal_i, goal_j] = 0
    
    # Dijkstra's algorithm
    queue = [(0, goal_i, goal_j)]
    heapq.heapify(queue)
    directions = [(0, 1), (1, 0), (0, -1), (-1, 0), 
                 (1, 1), (1, -1), (-1, 1), (-1, -1)]  # Include diagonals
    
    while queue:
        dist, i, j = heapq.heappop(queue)
        
        if dist > distance_grid[i, j]:
            continue
        
        for di, dj in directions:
            ni, nj = i + di, j + dj
            if 0 <= ni < resolution and 0 <= nj < resolution:
                # Diagonal movements have distance sqrt(2)
                move_dist = 1.414 if abs(di) + abs(dj) == 2 else 1
                if grid[ni, nj] != np.inf:  # If not an obstacle
                    new_dist = dist + move_dist * grid[ni, nj]
                    if new_dist < distance_grid[ni, nj]:
                        distance_grid[ni, nj] = new_dist
                        heapq.heappush(queue, (new_dist, ni, nj))
    
    # Convert distance from grid units to maze units
    distance_grid *= cell_size
    
    # Create grid coordinates for interpolation
    x_coords = np.linspace(-maze_size/2, maze_size/2, resolution)
    y_coords = np.linspace(-maze_size/2, maze_size/2, resolution)
    
    # Create interpolator (for continuous positions)
    interp = RegularGridInterpolator((y_coords, x_coords), distance_grid, 
                                   bounds_error=False, fill_value=np.inf)
    
    # Return a function that gives distance for any (x, y) position
    return lambda pos: float(interp([pos[1], pos[0]]))

def custom_reward_function(state, goal_position, distance_map, prev_distance=None, 
                          reward_scale=1.0, reach_reward=10.0, reach_threshold=0.2):
    """
    Custom dense reward function for U-maze.
    
    Args:
        state: The current state [x, y, vx, vy]
        goal_position: The goal position (x, y)
        distance_map: Function mapping positions to path distances
        prev_distance: Previous distance to goal (for measuring progress)
        reward_scale: Scaling factor for the dense reward
        reach_reward: Bonus for reaching the goal
        reach_threshold: Distance threshold for reaching the goal
        
    Returns:
        reward: The computed reward
        current_distance: Current distance to the goal (for next call)
    """
    x, y = state[0], state[1]
    position = np.array([x, y])
    
    # Get the true path distance to goal
    current_distance = distance_map(position)
    
    # Check if goal reached
    euclidean_dist = np.linalg.norm(position - goal_position)
    if euclidean_dist <= reach_threshold:
        return reach_reward, current_distance
    
    # If first call (no previous distance)
    if prev_distance is None:
        return 0.0, current_distance
    
    # Dense reward based on progress in true path distance
    progress = prev_distance - current_distance
    reward = progress * reward_scale
    
    # Optional: Penalize high velocities
    # velocities = np.array([state[2], state[3]])
    # velocity_penalty = -0.01 * np.sum(np.square(velocities))
    # reward += velocity_penalty
    
    return reward, current_distance