import numpy as np
from shapely.geometry import Point, Polygon, MultiPolygon


def add_score(translations_pred, rotations_pred, sizes_pred, translations_gt, rotations_gt, sizes_gt):
    """
    Instance-wise ADD score
    """
    mean_distances = []

    # Iterate over timesteps
    for i in range(len(translations_pred)):
        x_pred, y_pred, z_pred = translations_pred[i]
        roll_pred, pitch_pred, yaw_pred = rotations_pred[i]
        height_pred, width_pred, length_pred = sizes_pred[i]
        x_gt, y_gt, z_gt = translations_gt[i]
        roll_gt, pitch_gt, yaw_gt = rotations_gt[i]
        height_gt, width_gt, length_gt = sizes_gt[i]

        # Compute pairwise distance between g.t. and predicted bbox corners
        corners_gt = get_rotated_bounding_box(x_gt, y_gt, z_gt, width_gt, length_gt, height_gt, roll_gt, pitch_gt, yaw_gt)
        corners_pred = get_rotated_bounding_box(x_pred, y_pred, z_pred, width_pred, length_pred, height_pred, roll_pred, pitch_pred, yaw_pred)
        distance = np.linalg.norm(corners_gt - corners_pred, axis=0)
        # Average over corners
        mean_distances.append(np.mean(distance))
    
    # Average over the trajectory
    add = np.mean(mean_distances)

    return add


def euler_to_rotation_matrix(roll, pitch, yaw):
    """ Convert Euler angles (roll, pitch, yaw) to a rotation matrix. """
    R_x = np.array([[1, 0, 0],
                    [0, np.cos(roll), -np.sin(roll)],
                    [0, np.sin(roll), np.cos(roll)]])
    
    R_y = np.array([[np.cos(pitch), 0, np.sin(pitch)],
                    [0, 1, 0],
                    [-np.sin(pitch), 0, np.cos(pitch)]])
    
    R_z = np.array([[np.cos(yaw), -np.sin(yaw), 0],
                    [np.sin(yaw), np.cos(yaw), 0],
                    [0, 0, 1]])
    
    # Combine the rotations
    R = np.dot(R_z, np.dot(R_y, R_x))
    return R


def get_rotated_bounding_box(x, y, z, width, length, height, roll, pitch, yaw):
    """ Get the 8 corners of a 3D bounding box considering its size, position, and rotation. """
    # Half dimensions of the box
    half_dims = np.array([length / 2, width / 2, height / 2])
    
    # Define the 8 corners of the axis-aligned bounding box
    corners = np.array([
        [-half_dims[0],  half_dims[1], -half_dims[2]],  # Bottom Left Down
        [-half_dims[0], -half_dims[1], -half_dims[2]],  # Bottom Right Down
        [ half_dims[0], -half_dims[1], -half_dims[2]],  # Top Right Down
        [ half_dims[0],  half_dims[1], -half_dims[2]],  # Top Left Down
        [-half_dims[0],  half_dims[1],  half_dims[2]],  # Bottom Left Up
        [-half_dims[0], -half_dims[1],  half_dims[2]],  # Bottom Right Up
        [ half_dims[0], -half_dims[1],  half_dims[2]],  # Top Right Up
        [ half_dims[0],  half_dims[1],  half_dims[2]]   # Top Left Up
    ])
    
    # Rotation matrix from Euler angles (roll, pitch, yaw)
    rotation_matrix = euler_to_rotation_matrix(roll, pitch, yaw)
    
    # Rotate and translate the corners
    rotated_corners = np.dot(corners, rotation_matrix.T) + np.array([x, y, z])
    return rotated_corners


def get_axes(corners1, corners2):
    """ Get the axes to test for SAT (edges of the boxes). """
    axes = []
    
    # Get edges of the first box (use the differences between corners)
    for i in range(4):
        for j in range(i + 1, 4):
            axes.append(corners1[i] - corners1[j])
    
    # Get edges of the second box
    for i in range(4):
        for j in range(i + 1, 4):
            axes.append(corners2[i] - corners2[j])
    
    # Normalize the axes (vectors pointing along the edges)
    axes = [axis / np.linalg.norm(axis) for axis in axes]
    
    return axes


def project_corners(axes, corners):
    """ Project the corners of a box onto the given axes. """
    projections = []
    for axis in axes:
        projections.append(np.dot(corners, axis))
    return projections


def check_collision_3d(corners1, corners2):
    """ Perform SAT collision test on two rotated bounding boxes. """
    # Get the axes of separation (edges of both boxes)
    axes = get_axes(corners1, corners2)
    
    # Project the corners of both boxes onto each axis
    projections1 = project_corners(axes, corners1)
    projections2 = project_corners(axes, corners2)
    
    # Check for overlap on all axes
    for i in range(len(axes)):
        min_proj1, max_proj1 = min(projections1[i]), max(projections1[i])
        min_proj2, max_proj2 = min(projections2[i]), max(projections2[i])
        
        if max_proj1 < min_proj2 or max_proj2 < min_proj1:
            # If projections do not overlap, then there is no collision
            return False
    
    # If projections overlap on all axes, there is a collision
    return True


def is_collision(translations_i, rotations_i, sizes_i, translations_j, rotations_j, sizes_j):
    for timestep in range(len(translations_i)):
        # Get the position, rotation, and size of car i at this timestep
        x, y, z = translations_i[timestep]
        roll, pitch, yaw = rotations_i[timestep]
        height, width, length = sizes_i[timestep]
        
        # Get the corners of car i's rotated bounding box
        corners1 = get_rotated_bounding_box(x, y, z, width, length, height, roll, pitch, yaw)
        
        # Get the position, rotation, and size of car j at this timestep
        x, y, z = translations_j[timestep]
        roll, pitch, yaw = rotations_j[timestep]
        height, width, length = sizes_j[timestep]
        
        # Get the corners of car j's rotated bounding box
        corners2 = get_rotated_bounding_box(x, y, z, width, length, height, roll, pitch, yaw)
        
        # Check for collision
        if check_collision_3d(corners1, corners2):
            return True
    
    return False


def instance_motion_category(translations_history, translations_forecast, timesteps_history, timesteps_forecast, return_linear=False):
    timesteps = np.hstack((timesteps_history, timesteps_forecast))
    trajectory = np.vstack((translations_history, translations_forecast))
    avg_velocity = (trajectory[-1] - trajectory[0]) / (timesteps[-1] - timesteps[0])
    
    if np.linalg.norm(avg_velocity) < 0.5:
        if return_linear:
            return 'static', np.linalg.norm(avg_velocity), trajectory
        else:
            return 'static', np.linalg.norm(avg_velocity)

    linear_trajectory = np.array([trajectory[0] + avg_velocity*(timesteps[i] - timesteps[0]) for i in range(len(timesteps))])
    linear_error = np.mean(np.linalg.norm(linear_trajectory - trajectory, axis=1))
    # If the linear approximation matches the trajectory up to the average length of 1 step, consider the instance 'linear'
    mean_step = np.linalg.norm(trajectory[-1] - trajectory[0]) / len(trajectory)
    if linear_error <= mean_step:
        if return_linear:
            return 'linear', linear_error, linear_trajectory
        else:
            return 'linear', linear_error
    
    if return_linear:
        return 'nonlinear', linear_error, linear_trajectory
    else:
        return 'nonlinear', linear_error


def out_of_map(translations, geom):
    for x, y in translations[:, :2]:
        pt = Point(x, y)
        if isinstance(geom, Polygon):
            if not geom.contains(pt):
                return True
        elif isinstance(geom, MultiPolygon):
            if not any(g.contains(pt) for g in geom.geoms):
                return True
        else:
            raise TypeError(f'Unsupported geometry type: {type(geom)}')
    
    return False 