import numpy as np
import cv2
import open3d as o3d
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from skimage.measure import label, regionprops, find_contours

size_to_corner = np.array([
    [1, 1, 1],
    [1, 1, -1],
    [-1, 1, -1],
    [-1, 1, 1],
    [1, -1, 1],
    [1, -1, -1],
    [-1, -1, -1],
    [-1, -1, 1]
])

def alpha2ry(alpha, u, f, cu):
    ry = alpha + np.arctan2(u - cu, f)
    return check_angle_range(ry)

def ry2alpha(ry, u, f, cu):
    alpha = ry - np.arctan2(u - cu, f)
    return check_angle_range(alpha)


def check_angle_range(angle):
    if angle > np.pi:
        angle -= 2 * np.pi
    if angle < -np.pi:
        angle += 2 * np.pi
    return angle


def find_min_max_indices(mask):
    indices = np.argwhere(mask == 1)
    min_row, min_col = np.min(indices, axis=0)
    max_row, max_col = np.max(indices, axis=0)
    return min_row, min_col, max_row, max_col


def rotate_y(angle):
    rotmat = np.zeros((3, 3))
    rotmat[1, 1] = 1
    cosval = np.cos(angle)
    sinval = np.sin(angle)
    rotmat[0, 0] = cosval
    rotmat[0, 2] = sinval
    rotmat[2, 0] = -sinval
    rotmat[2, 2] = cosval
    return rotmat


def normalize(v):
    norm = np.linalg.norm(v)
    if norm == 0:
        return v
    return v / norm



def point_to_plane_distance(ground_equ, x, y, z):
    A, B, C, D = ground_equ
    plane_normal_length = np.sqrt(A**2 + B**2 + C**2)
    distance = abs(A*x + B*y + C*z + D) / plane_normal_length
    return distance

def fit_plane_x(points):
    X = np.c_[points[:, [0,1]], np.ones(points.shape[0])]
    Y = points[:, 2]
    try:
        w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y)
    except(np.linalg.LinAlgError):
        return np.inf, np.inf, np.inf, np.inf, np.inf
    a, b, d = w
    c = -1
    Y_pred = X.dot(w)
    mse = np.mean((Y - Y_pred) ** 2)
    return a, b, c, d, mse

def fit_plane_y(points):
    X = np.c_[points[:, [0,2]], np.ones(points.shape[0])]
    Y = points[:, 1]
    try:
        w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y)
    except(np.linalg.LinAlgError):
        return np.inf, np.inf, np.inf, np.inf, np.inf
    a, c, d = w
    b = -1
    Y_pred = X.dot(w)
    mse = np.mean((Y - Y_pred) ** 2)
    return a, b, c, d, mse

def fit_plane_z(points):
    X = np.c_[points[:, [1,2]], np.ones(points.shape[0])]
    Y = points[:, 0]
    try:
        w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(Y)
    except(np.linalg.LinAlgError):
        return np.inf, np.inf, np.inf, np.inf, np.inf
    b, c, d = w
    a = -1
    Y_pred = X.dot(w)
    mse = np.mean((Y - Y_pred) ** 2)
    return a, b, c, d, mse


def extract_ground(points):
    a1,b1,c1,d1,distance_1 = fit_plane_x(points)
    a2,b2,c2,d2,distance_2 = fit_plane_y(points)
    a3,b3,c3,d3,distance_3 = fit_plane_z(points)

    if distance_1 < distance_2 and distance_1 < distance_3:
        return np.array([a1, b1, c1, d1])
    elif distance_2 < distance_1 and distance_2 < distance_3:
        return np.array([a2, b2, c2, d2])
    else:
        return np.array([a3, b3, c3, d3])


def point_to_plane_distance(ground_equ, x, y, z):
    A, B, C, D = ground_equ
    plane_normal_length = np.sqrt(A**2 + B**2 + C**2)
    distance = abs(A*x + B*y + C*z + D) / plane_normal_length
    return distance


def erode_mask(mask, k_vertical, k_horizontal):
    """
    Function to erode the mask using vertical and horizontal kernels.
    """
    dtype = mask.dtype
    mask = mask.astype(np.float32)
    kernel_vertical = np.ones((3,1), np.uint8)  
    kernel_horizontal = np.ones((1,3), np.uint8)  
    eroded_mask_vertical = cv2.erode(mask, kernel_vertical, iterations=k_vertical)
    eroded_mask_horizontal = cv2.erode(mask, kernel_horizontal, iterations=k_horizontal)
    new_mask = np.logical_and(eroded_mask_vertical, eroded_mask_horizontal).astype(dtype)
    return new_mask


def adaptive_erode_mask(mask, k_vertical, k_vertical_min, k_horizontal, k_horizontal_min):
    """
    Function to erode the mask based on the size of the mask.
    If the mask is too small, use the minimum kernel size.
    """ 
    dtype = mask.dtype
    mask = mask.astype(np.float32)
    new_mask = np.zeros_like(mask)
    kernel_vertical = np.ones((3,1), np.uint8)  
    kernel_horizontal = np.ones((1,3), np.uint8)  

    min_row, min_col, max_row, max_col = find_min_max_indices(mask)

    k_vertical = k_vertical if max_row - min_row >= 10 else k_vertical_min
    k_horizontal = k_horizontal if max_col - min_col >= 10 else k_horizontal_min
    
    eroded_mask_vertical = cv2.erode(mask, kernel_vertical, iterations=k_vertical)
    eroded_mask_horizontal = cv2.erode(mask, kernel_horizontal, iterations=k_horizontal)

    new_mask = np.logical_and(eroded_mask_vertical, eroded_mask_horizontal).astype(dtype)
    return new_mask


def remove_statistical_outliers_o3d(pts, neighbors=20, std_ratio=2.0):
    mask = np.zeros(pts.shape[0]).astype(bool)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    pcd_clean, ind = pcd.remove_statistical_outlier(nb_neighbors=neighbors, std_ratio=std_ratio)
    mask[ind] = 1
    return mask


def remove_radius_outliers_o3d(pts, nb_points: int, radius: float):
    """Remove points that have neighbors less than nb_points in a sphere of a given radius
    """
    mask = np.zeros(pts.shape[0]).astype(bool)
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    pcd_clean, ind = pcd.remove_radius_outlier(nb_points=nb_points, radius=radius)
    mask[ind] = 1
    return mask


def uniform_downsample(pts, every_k_points=100):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    pcd = pcd.uniform_down_sample(every_k_points)
    return np.asarray(pcd.points)



def voxel_downsample(pts, voxel_size=0.5):
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(pts)
    pcd = pcd.voxel_down_sample(voxel_size)
    return np.asarray(pcd.points)

def remove_outliers_by_percentile_custom(pts, q_min=10, q_max=95):
    p_min = np.percentile(pts[:, 2], q_min)
    p_max = np.percentile(pts[:, 2], q_max)
    mask = (pts[:, 2] > p_min) & (pts[:, 2] <= p_max)
    inlier_mask = mask

    return pts[inlier_mask], inlier_mask


def remove_outliers_by_percentile(pts, q_min=10, q_max=95):
    p_min = np.percentile(pts[:, 2], q_min)
    p_max = np.percentile(pts[:, 2], q_max)
    mask = (pts[:, 2] > p_min) & (pts[:, 2] <= p_max)
    inlier_mask = mask

    p_min = np.percentile(pts[:, 0], 5)
    p_max = np.percentile(pts[:, 0], 95)
    mask = (pts[:, 0] > p_min) & (pts[:, 0] <= p_max)
    inlier_mask &= mask

    p_min = np.percentile(pts[:, 1], 5)
    p_max = np.percentile(pts[:, 1], 95)
    mask = (pts[:, 1] > p_min) & (pts[:, 1] <= p_max)
    inlier_mask &= mask

    return pts[inlier_mask], inlier_mask


def range_filter(points, min_depth=1e-4, max_depth=80):
    mask = (points[:, -1] > min_depth) & (points[:, -1] < max_depth)
    return points[mask], mask



def create_local_box3d(l, h, w):
    # Define the corners of the box in its local coordinate system
    corners3d = np.array([l/2, h/2, w/2])[None, :] * size_to_corner
    return corners3d

def create_ego_box3d_rotated(cx, cy, cz, l, h, w, yaw, rot_mat):
    local_corners = create_local_box3d(l, h, w)
    # rotated_corners = (rotate_y(yaw) @ local_corners.T).T
    rotated_corners = (rot_mat @ local_corners.T).T
    ego_corners = rotated_corners + np.array([cx, cy, cz])
    return ego_corners

def create_ego_box3d(cx, cy, cz, l, h, w, yaw):
    local_corners = create_local_box3d(l, h, w)
    rotated_corners = (rotate_y(yaw) @ local_corners.T).T
    ego_corners = rotated_corners + np.array([cx, cy, cz])
    return ego_corners

def project_box3d_to_image(corners3d, intrinsic):
    corners2d = intrinsic @ corners3d.T
    corners2d = corners2d[:2] / corners2d[2]
    return corners2d

def project_box3d_to_bev(corners3d):
    return corners3d[:4, [0, 2]]


def create_face_vertices(corners2d):
    face_idx = [0, 1, 5, 4, # front face    0   2^0 = 2^(n+0)
                1, 2, 6, 5, # left face     4   2^2 = 2^(n+1)
                2, 3, 7, 6, # back face     8   2^3 = 2^(n+2)
                3, 0, 4, 7, # right face    12  2^4 = 2^(n+3)
                0, 2,
                1, 3] 
    verts2d = (corners2d[:, face_idx]).T
    return verts2d

def mask_to_boundary(mask, dilation_ratio=0.02):
    """
    Convert binary mask to boundary mask.
    :param mask (numpy array, uint8): binary mask
    :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
    :return: boundary mask (numpy array)
    """
    h, w = mask.shape
    img_diag = np.sqrt(h ** 2 + w ** 2)
    dilation = int(round(dilation_ratio * img_diag))
    if dilation < 1:
        dilation = 1
    # Pad image so mask truncated by the image border is also considered as boundary.
    new_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0)
    kernel = np.ones((3, 3), dtype=np.uint8)
    new_mask_erode = cv2.erode(new_mask, kernel, iterations=dilation)
    mask_erode = new_mask_erode[1 : h + 1, 1 : w + 1]
    # G_d intersects G in the paper.
    return mask - mask_erode


def calculate_iou(mask1, mask2):
    mask1_area = np.count_nonzero(mask1 == 1)
    mask2_area = np.count_nonzero(mask2 == 1)
    intersection = np.count_nonzero(np.logical_and( mask1==1,  mask2==1 ))
    iou = intersection/(mask1_area+mask2_area-intersection)
    return iou

def boundary_iou(gt, dt, dilation_ratio=0.005):
    """
    Compute boundary iou between two binary masks.
    :param gt (numpy array, uint8): binary mask
    :param dt (numpy array, uint8): binary mask
    :param dilation_ratio (float): ratio to calculate dilation = dilation_ratio * image_diagonal
    :return: boundary iou (float)
    """
    gt_boundary = mask_to_boundary(gt, dilation_ratio)
    dt_boundary = mask_to_boundary(dt, dilation_ratio)
    intersection = ((gt_boundary * dt_boundary) > 0).sum()
    union = ((gt_boundary + dt_boundary) > 0).sum()
    boundary_iou = intersection / union
    return boundary_iou

def box_iou(box1, box2):
    # calculate the intersection box
    x1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x2 = min(box1[2], box2[2])
    y2 = min(box1[3], box2[3])

    inter = max(x2-x1+1, 0) * max(y2-y1+1, 0)
    if inter == 0:
        return 0
    
    box1_area = abs((box1[2] - box1[0] + 1) * (box1[3] - box1[1] + 1))
    box2_area = abs((box2[2] - box2[0] + 1) * (box2[3] - box2[1] + 1))
    union = box1_area + box2_area - inter
    if union < 0:
        return 0
    iou = inter / union
    return iou


def tracked_to_mask(tracked, visibility, width, height):
    """
    Convert tracked points to masks.

    Args:
        tracked (np.array): Tracked points of shape (N, 2).
        visibility (np.array): Visibility of tracked points of shape (N, 1).
        width (int): Width of the mask.
        height (int): Height of the mask.

    Returns:
        np.array: Mask of the tracked points.
    """
    tracked = np.round(tracked).astype(np.int32)
    
    mask = np.zeros((height, width), dtype=np.uint8)
    for i, (x, y) in enumerate(tracked):
        if visibility[i] and 0 <= x < width and 0 <= y < height:
            mask[int(y), int(x)] = True
    return mask


def tracked_to_box(tracked, visibility, width, height):
    vis_mask = visibility & (tracked[:, 0] >= 0) & (tracked[:, 0] < width) & (tracked[:, 1] >= 0) & (tracked[:, 1] < height)
    tracked = tracked[vis_mask]
    if len(tracked) < 2:
        return 0, 0, 0, 0
    x_min, y_min = tracked.min(axis=0)
    x_max, y_max = tracked.max(axis=0)
    return x_min, y_min, x_max, y_max


def mask_to_broder(mask):
    H, W = mask.shape
    border = np.zeros((H, W))
    contours = find_contours(mask, 128)
    for contour in contours:
        for c in contour:
            x = int(c[0])
            y = int(c[1])
            border[x][y] = 255
    return border


def mask_to_box(mask: np.ndarray):
    # get coordinates of non-zero pixels
    coords = np.column_stack(np.where(mask))
    if coords.size == 0:
        return None
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    return x_min, y_min, x_max, y_max


def area(boxes, add1=False):
    if add1:
        return (boxes[:, 2] - boxes[:, 0] + 1.0) * (boxes[:, 3] - boxes[:, 1] + 1.0)
    else:
        return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])



def intersection(boxes1, boxes2, add1=False):
    """Compute pairwise intersection areas between boxes.

    Args:
        boxes1: a numpy array with shape [N, 4] holding N boxes
        boxes2: a numpy array with shape [M, 4] holding M boxes

    Returns:
        a numpy array with shape [N*M] representing pairwise intersection area
    """
    [x_min1, y_min1, x_max1, y_max1] = np.split(boxes1, 4, axis=1)
    [x_min2, y_min2, x_max2, y_max2] = np.split(boxes2, 4, axis=1)

    all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2))
    all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2))
    if add1: all_pairs_min_ymax += 1.0
    intersect_heights = np.maximum(np.zeros(all_pairs_max_ymin.shape), all_pairs_min_ymax - all_pairs_max_ymin)

    all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2))
    all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2))
    if add1: all_pairs_min_xmax += 1.0
    intersect_widths = np.maximum(np.zeros(all_pairs_max_xmin.shape), all_pairs_min_xmax - all_pairs_max_xmin)
    return intersect_heights * intersect_widths


def get_iou_matrix(boxes1, boxes2, add1=False):
    """Computes pairwise intersection-over-union between box collections.

    Args:
        boxes1: a numpy array with shape [N, 4] holding N boxes.
        boxes2: a numpy array with shape [M, 4] holding N boxes.

    Returns:
        a numpy array with shape [N, M] representing pairwise iou scores.
    """
    boxes1, boxes2 = boxes1.astype(np.float32), boxes2.astype(np.float32)
    intersect = intersection(boxes1, boxes2, add1)
    area1, area2 = area(boxes1, add1), area(boxes2, add1)
    union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
    return intersect / union


def calculate_depth_aware_iou(mask1, mask2, depth_map, th):
    mask1 = mask1.astype(bool)
    mask2 = mask2.astype(bool)

    union_mask = np.logical_or(mask1, mask2)
    intersection_mask = np.logical_and(mask1, mask2)

    depth1 = depth_map[mask1]
    depth2 = depth_map[mask2]

    mask1_coords = np.argwhere(mask1)
    mask2_coords = np.argwhere(mask2)

    depth_consistent_intersection = 0
    depth_consistent_union = 0

    for y, x in zip(*np.where(union_mask)):
        if not mask1[y, x] and not mask2[y, x]:
            continue
        d1 = depth_map[y, x] if mask1[y, x] else None
        d2 = depth_map[y, x] if mask2[y, x] else None
        if d1 is not None and d2 is not None and abs(d1 - d2) < th:
            if mask1[y, x] and mask2[y, x]:
                depth_consistent_intersection += 1
            depth_consistent_union += 1
        elif d1 is not None and d2 is not None:
            continue
        else:
            depth_consistent_union += 1

    return depth_consistent_intersection / (depth_consistent_union + 1e-6)

def get_depth_aware_iou_matrix(masks, query_masks, depth_map=None, use_depth=False, th=0.5):
    iou_mat = np.zeros((len(query_masks), len(masks)), dtype=np.float32)
    for i, q_mask in enumerate(query_masks):
        for j, mask in enumerate(masks):
            if use_depth and depth_map is not None:
                iou_mat[i, j] = calculate_depth_aware_iou(q_mask, mask, depth_map, th)
            else:
                box1, box2 = tuple(map(mask_to_box, [q_mask, mask]))
                if box1 is None or box2 is None:
                    iou_mat[i, j] = 0
                else:
                    iou_mat[i, j] = box_iou(box1, box2)
    return iou_mat