import torch
import numpy as np
import random
import math
from queue import PriorityQueue
from scipy.interpolate import CubicSpline
from utils.vis_utils import lpf_1d
from utils.joint2hml import map_rootpos_to_norm_rootvel


# A* algorithm function
def _plan_astar(costs, start, goal):
    # Create a grid of the same shape as the cost array to store costs
    grid_shape = costs.shape
    # Initialize start and goal points
    start_point = tuple(start)
    goal_point = tuple(goal)

    # Function to return neighbors of a given point
    def get_neighbors(point):
        neighbors = []
        for new_position in [(0, -1), (0, 1), (-1, 0), (1, 0)]:  # Adjacent squares (4-directions)
            node_position = (point[0] + new_position[0], point[1] + new_position[1])
            if node_position[0] < 0 or node_position[0] >= grid_shape[0] or node_position[1] < 0 or node_position[1] >= grid_shape[1]:
                continue  # Skip if the neighbor is out of bounds
            neighbors.append(node_position)
        return neighbors

    # Priority queue for nodes to explore
    open_list = PriorityQueue()
    open_list.put((0, start_point))
    # Track nodes that have been visited and their current known cost from start
    came_from = {}
    g_score = {start_point: 0}
    f_score = {start_point: np.linalg.norm(np.array(start_point) - np.array(goal_point))}

    while not open_list.empty():
        current = open_list.get()[1]

        if current == goal_point:
            path = []
            while current in came_from:
                path.append(current)
                current = came_from[current]
            return path[::-1], g_score[current]  # Return reversed path

        for neighbor in get_neighbors(current):
            tentative_g_score = g_score[current] + costs[neighbor]
            if neighbor not in g_score or tentative_g_score < g_score[neighbor]:
                came_from[neighbor] = current
                g_score[neighbor] = tentative_g_score
                f_score[neighbor] = g_score[neighbor] + np.linalg.norm(np.array(neighbor) - np.array(goal_point)) * 1
                open_list.put((f_score[neighbor], neighbor))

    return "not found"


def _block_passed_area(costs, path, costinc=10):
    all_points = set()
    for x, y in path[15:-15]:
        all_points.update((x + m, y + n) for m in [-2, -1, 0, 1, 2] for n in [-2, -1, 0, 1, 2] if
                          0 <= x + m < costs.shape[0] and 0 <= y + n < costs.shape[1])
    for x, y in all_points:
        costs[x, y] += costinc

    return costs


def _sample_per_frame_pos(future_traj, future_frames):
    idx = torch.LongTensor([int(round(t * future_traj.shape[0] / future_frames)) for t in range(future_frames)])
    return future_traj[idx]


def plan_trajectory(scene_height, observed_traj, destination, input_frames, output_frames):
    assert len(scene_height) == len(observed_traj) == len(destination)

    len_th = 1000
    costs = 1.03 ** (220 - scene_height + np.random.randn(*scene_height.shape) * 3)
    bs = len(scene_height)
    res = torch.zeros(bs, input_frames + output_frames, 2)

    for b in range(len(scene_height)):
        min_len = 1e6
        future_trajs = []
        while True:
            cur_traj = _plan_astar(costs, observed_traj, destination)
            cur_len = len(cur_traj)
            if min_len + len_th > cur_len:
                future_trajs.append(cur_traj)
            else:
                break
            min_len = min(min_len, cur_len)
            scene_height = _block_passed_area(scene_height, cur_traj)

        future_traj = future_trajs[random.randint(0, len(future_trajs))]
        res[b, :input_frames] = observed_traj
        res[b, input_frames:] = _sample_per_frame_pos(future_traj, output_frames)

    return res


def calculate_overwrite(traj, refering_joints, refering_r_vel, input_frames, output_frames):
    bs = len(traj)                          # traj: [bs, 160, 2]
    v_th = 0.8
    assert refering_joints.shape[0] == 1 and refering_r_vel.shape[0] == 1

    d_traj = traj[:, 58::10] - traj[:, 48:158:10]
    d_traj = torch.cat([d_traj, (traj[:, 158] - traj[:, 148]).unsqueeze(1)], dim=1)  # bs, 12, 2
    vel = torch.norm(d_traj, p=2, dim=-1)  # bs, 12
    rot = -torch.arctan(d_traj[:, :, 0] / (d_traj[:, :, 1] + 1e-6))  # bs, 12
    for t in range(1, 12):
        diff = rot[:, t - 1] - rot[:, t]
        diffs = torch.stack([diff.abs(), (diff - math.pi).abs(), (diff + math.pi).abs()])
        min_indices = torch.argmin(diffs, dim=0)
        rot[:, t] = torch.stack([rot[:, t], rot[:, t] + math.pi, rot[:, t] - math.pi], dim=1)[
            torch.arange(bs), min_indices]

    # In some cases, the rotation angle changes aggressively, but the speed is very low.
    # Thus the r_vel of root position can not represent that of the human facing.
    confidence = torch.clamp((vel[:, 1:] + vel[:, :-1]) / 2, max=v_th) / v_th
    rot_vel = confidence * (rot[:, 1:] - rot[:, :-1])  # bs, 11

    rot_vel = torch.stack([bezier_11to101(rot_vel[i, :, None].repeat(1, 2))[:, 0] / 10 for i in range(bs)], dim=0)      # bs, 101
    refering_r_vel = refering_r_vel.repeat(bs, 1)
    rot_vel = torch.cat([refering_r_vel, rot_vel[:, -output_frames:]], dim=1)                # bs, 159

    r_rot_ang = torch.zeros(bs, input_frames+output_frames+1).to(traj.device)               # bs, 160
    r_rot_ang[..., 1:] = rot_vel
    r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)

    r_rot_quat = torch.zeros(r_rot_ang.shape + (4,)).to(traj.device)                # [bs, 160, 4]
    r_rot_quat[..., 0] = torch.cos(r_rot_ang)
    r_rot_quat[..., 2] = torch.sin(r_rot_ang)

    traj_pad = torch.cat([traj, traj[:, -1:]], dim=1)                        # [bs, 160, 2]
    r_velocity = torch.Tensor(rot_vel)[:, :, None]
    l_velocity = torch.stack([
        map_rootpos_to_norm_rootvel(traj_pad[i], refering_joints[0], r_rot_quat[i], output_frames)
        for i in range(bs)], dim=0)

    return torch.cat([r_velocity, l_velocity], dim=-1)


def _interplot_11to101(points):
    num_points_between = 9
    interpolated_points = []
    for i in range(len(points) - 1):
        start_point = points[i]
        end_point = points[i + 1]
        # Compute the step size for both x and y
        step = (end_point - start_point) / (num_points_between + 1)

        # Add the start point
        interpolated_points.append(start_point)

        # Generate the intermediate points
        for j in range(1, num_points_between + 1):
            new = start_point + j * step
            interpolated_points.append(new)

    # Add the last point since it won't be included in the loop
    interpolated_points.append(points[-1])

    return np.array(interpolated_points)


def bezier_11to101(points):
    distances = np.insert(np.cumsum(np.sqrt(np.sum(np.diff(points, axis=0) ** 2, axis=1))), 0, 0)
    distances += np.linspace(0, 0.0001, len(distances))
    distances = lpf_1d(torch.Tensor(distances)).detach().numpy()

    # Perform spline interpolation for each dimension
    cs_x = CubicSpline(distances, points[:, 0])
    cs_y = CubicSpline(distances, points[:, 1])

    # Generate smoothed points using the spline functions
    # We use the same distances to interpolate, ensuring original points spacing is respected
    distances = _interplot_11to101(distances)
    smooth_points_x = cs_x(distances)
    smooth_points_y = cs_y(distances)
    smooth_points = np.vstack((smooth_points_x, smooth_points_y)).T

    return torch.Tensor(smooth_points)
