import numpy as np
from matplotlib.path import Path
from shapely.geometry import Polygon
import matplotlib.patches as patches
from scipy.spatial import ConvexHull


def projection(points, lidar2image):
    V = lidar2image.shape[0]
    N = points.shape[0]
    P = points.shape[1]

    points_hom = np.concatenate(
        [points, np.ones((N, P, 1), dtype=np.float32)], axis=2
    )  # (N, P, 4)
    points_hom = np.expand_dims(points_hom, axis=0).repeat(V, axis=0)  # (V, N, P, 4)
    points_hom = points_hom.reshape(V, -1, 4)  # (V, N*P, 4)

    points_2d = np.einsum(
        "vij,vjk->vik", lidar2image[:, :3, :4], points_hom.transpose(0, 2, 1)
    )  # (V, 3, N*P)
    points_2d = points_2d.reshape(V, 3, N, P)  # (V, 3, N, P)

    eps = 1e-6
    z = points_2d[:, 2, :, :] + eps  # (V, N, P)
    u = points_2d[:, 0, :, :] / z  # (V, N, P)
    v = points_2d[:, 1, :, :] / z  # (V, N, P)
    points_2d = np.stack([u, v], axis=-1)  # (V, N, P, 2)

    mask_z = z > 0
    mask_exp = np.expand_dims(mask_z, axis=-1).astype(np.float32)  # (V, N, P, 1)
    points_2d = points_2d * mask_exp + (-1) * np.expand_dims(~mask_z, axis=-1).astype(
        np.float32
    )  # (V, N, P, 2)
    return points_2d, mask_z, z


def projection_2d(lidar2image, points, H, W, return_z=False):
    points_2d, mask_z, z = projection(points, lidar2image)

    BIG_POS, BIG_NEG = 1.0e10, -1.0e10
    valid_points_min = np.where(
        np.expand_dims(mask_z, axis=-1),  # (V, N, P, 1)
        points_2d,
        np.full_like(points_2d, BIG_POS),
    )
    valid_points_max = np.where(
        np.expand_dims(mask_z, axis=-1), points_2d, np.full_like(points_2d, BIG_NEG)
    )

    min_u = np.min(valid_points_min[..., 0], axis=-1)  # (V, N)
    max_u = np.max(valid_points_max[..., 0], axis=-1)  # (V, N)
    min_v = np.min(valid_points_min[..., 1], axis=-1)  # (V, N)
    max_v = np.max(valid_points_max[..., 1], axis=-1)  # (V, N)

    min_u = np.clip(min_u, 0, W - 1).astype(np.int64)
    max_u = np.clip(max_u, 0, W - 1).astype(np.int64)
    min_v = np.clip(min_v, 0, H - 1).astype(np.int64)
    max_v = np.clip(max_v, 0, H - 1).astype(np.int64)
    coordss = np.stack([min_v, min_u, max_v, max_u], axis=-1)  # [v, n, 4]

    valid_mask = (max_u > min_u) & (max_v > min_v)
    coordss[~valid_mask] = 0
    if return_z:
        return coordss, z
    return coordss


def sort_vertices_counterclockwise(vertices):
    center = np.mean(vertices, axis=0)
    angles = np.arctan2(vertices[:, 1] - center[1], vertices[:, 0] - center[0])
    sorted_indices = np.argsort(angles)
    return vertices[sorted_indices]


def expand_polygon(vertices, expansion_factor=2.0):
    polygon = Polygon(vertices)
    expanded_polygon = polygon.buffer(expansion_factor, join_style=2)
    return np.array(expanded_polygon.exterior.coords[:-1])


def sample_from_hist(data, num_samples=1000):
    counts, bin_edges = np.histogram(data, bins=20, density=True)
    cdf = np.cumsum(counts) / np.sum(counts)

    random_probs = np.random.rand(num_samples)
    sampled_indices = np.searchsorted(cdf, random_probs)
    sampled_values = bin_edges[sampled_indices]
    return sampled_values


def get_frustum_corners_world(
    x_min, x_max, y_min, y_max, K, R, t, z_near=5, z_far=50.0
):
    corners_2d = np.array(
        [
            # [x_min, y_min],  # top-left
            # [x_max, y_min],  # top-right
            [x_max, y_max],  # bottom-right
            [x_min, y_max],  # bottom-left
        ],
        dtype=np.float32,
    )

    K_inv = np.linalg.inv(K)

    frustum_corners_world = []

    for z in [z_near, z_far]:
        for u, v in corners_2d:
            uv1 = np.array([u, v, 1], dtype=np.float32)
            xyz_c_norm = K_inv @ uv1  # (3,) shape
            xyz_cam = xyz_c_norm * z  # (3,)
            xyz_world = R @ (xyz_cam - t)

            frustum_corners_world.append(xyz_world)

    frustum_corners_world = np.vstack(frustum_corners_world)  # (8, 3)
    return frustum_corners_world


def lidar_xy_to_top_np(
    x, y, res=0.1, side_range=(-52.0, 52 - 0.05), fwd_range=(-52.0, 52 - 0.05)
):
    x_img = (-y / res).astype(np.int32)
    y_img = (-x / res).astype(np.int32)
    x_img -= int(np.floor(side_range[0] / res))
    y_img += int(np.floor(fwd_range[1] / res))

    return x_img, y_img


def point_cloud_2_top_np(
    points,
    res=0.1,
    zres=1.0,
    side_range=(-52.0, 52 - 0.05),
    fwd_range=(-52.0, 52 - 0.05),
    height_range=(-2.0, 1.0),
):
    x_points = points[:, 0]
    y_points = points[:, 1]
    z_points = points[:, 2]
    reflectance = points[:, 3]

    x_max = int((side_range[1] - side_range[0]) / res)
    y_max = int((fwd_range[1] - fwd_range[0]) / res)
    z_max = int((height_range[1] - height_range[0]) / zres)
    top = np.zeros([y_max + 1, x_max + 1, z_max + 1], dtype=np.float32)

    f_filt = np.logical_and((x_points > fwd_range[0]), (x_points < fwd_range[1]))
    s_filt = np.logical_and((y_points > -side_range[1]), (y_points < -side_range[0]))
    filt = np.logical_and(f_filt, s_filt)

    for i, height in enumerate(np.arange(height_range[0], height_range[1], zres)):

        z_filt = np.logical_and((z_points >= height), (z_points < height + zres))
        zfilter = np.logical_and(filt, z_filt)
        indices = np.argwhere(zfilter).flatten()

        xi_points = x_points[indices]
        yi_points = y_points[indices]
        zi_points = z_points[indices]
        ref_i = reflectance[indices]

        x_img = (-yi_points / res).astype(np.int32)  # x axis is -y in LIDAR
        y_img = (-xi_points / res).astype(np.int32)  # y axis is -x in LIDAR

        x_img -= int(np.floor(side_range[0] / res))
        y_img += int(np.floor(fwd_range[1] / res))

        pixel_values = zi_points - height_range[0]
        top[y_img, x_img, i] = pixel_values
        top[y_img, x_img, z_max] = ref_i

    top = (top / np.max(top) * 255).astype(np.uint8)
    return top


def show_polygon(x, y, color):
    points = np.array(list(set(list(zip(x, y)))))
    hull = ConvexHull(points)
    clockwise_coords = [tuple(points[vertex]) for vertex in hull.vertices]
    poly_main = patches.Polygon(
        clockwise_coords, closed=True, fill=False, edgecolor=color
    )
    return poly_main


def get_bounding_box_vertices(min_point, max_point):
    x_min, y_min, z_min = min_point
    x_max, y_max, z_max = max_point
    vertices = [
        [x_min, y_min, z_min],
        [x_min, y_min, z_max],
        [x_min, y_max, z_min],
        [x_min, y_max, z_max],
        [x_max, y_min, z_min],
        [x_max, y_min, z_max],
        [x_max, y_max, z_min],
        [x_max, y_max, z_max],
    ]
    return np.array(vertices)


def get_bounding_box_faces(vertices):
    faces = [
        [vertices[0], vertices[1], vertices[3], vertices[2]],  # -X face
        [vertices[4], vertices[5], vertices[7], vertices[6]],  # +X face
        [vertices[0], vertices[1], vertices[5], vertices[4]],  # -Y face
        [vertices[2], vertices[3], vertices[7], vertices[6]],  # +Y face
        [vertices[0], vertices[2], vertices[6], vertices[4]],  # -Z face
        [vertices[1], vertices[3], vertices[7], vertices[5]],  # +Z face
    ]
    return faces
