import json
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Any, List, Tuple, Dict
from sklearn.decomposition import PCA
from plot.utils.geometry import rigid_points_registration_numpy
from plot.utils.processing import normalize, check_angle_range, voxel_downsample, range_filter, remove_outliers_by_percentile
from roma import rigid_points_registration


filter_thresholds = {
    "Car"           : [5, 95],
    "Pedestrian"    : [10, 70],
    "Others"        : [10, 90]
}

def read_kitti_calib(calib_file):
    with open(calib_file) as f:
        line = f.read().splitlines()[2].strip().split(' ')[1:]
    K = np.array(line, dtype=np.float32).reshape(3, 4)
    return K[:3, :3]


def scale_intrinsics(K, osize, nsize):
    sh, sw = nsize[0]/osize[0], nsize[1]/osize[1]
    new_K = K.copy()
    new_K[0, :] *= sw
    new_K[1, :] *= sh
    return new_K


def read_kitt_gt_label(label_file):
    with open(label_file) as f:
        lines = f.read().splitlines()

    labels = []
    for line in lines:
        annots = line.strip().split(' ')
        cls_type, truncation, occlusion = annots[:3]

        if cls_type in ['Misc', 'DontCare']:
            continue

        if float(truncation) > 0.8:
            continue

        x1, y1, x2, y2 = list(map(float, annots[4:8]))
        h, w, l = list(map(float, annots[8:11]))
        x, y, z = list(map(float, annots[11:14]))
        labels.append([cls_type, float(annots[3]), x1, y1, x2, y2, h, w, l, x, y, z, float(annots[14])])
    return labels


def registration(source_pts, target_pts, sample_indices=None, ransac_inlier_threshold=0.1, compute_scale=False):
    source_pts_sel, target_pts_sel = source_pts, target_pts
    if sample_indices is not None:
        source_pts_sel, target_pts_sel = source_pts[sample_indices], target_pts[sample_indices]

    R, t, s = rigid_points_registration_numpy(source_pts_sel, target_pts_sel, weights=None, compute_scaling=compute_scale)
    trf = np.eye(4)
    trf[:3, :3] = s * R
    trf[:3, -1] = t
    source_pts_trf = (s * (R @ source_pts.T) + t[:, None]).T
    residuals = np.linalg.norm(source_pts_trf - target_pts, axis=1)
    inlier_mask = residuals < ransac_inlier_threshold
    inlier_count = inlier_mask.sum()
    cost = residuals.mean()
    return trf, inlier_count, cost

def registration_gpu(source_pts, target_pts, sample_indices=None, ransac_inlier_threshold=0.1, compute_scale=False):
    source_pts_sel, target_pts_sel = source_pts, target_pts
    if sample_indices is not None:
        source_pts_sel, target_pts_sel = source_pts[sample_indices], target_pts[sample_indices]
    if compute_scale:
        R, t, s = rigid_points_registration(source_pts_sel, target_pts_sel, compute_scaling=True)
    else:
        R, t = rigid_points_registration(source_pts_sel, target_pts_sel)
        s = 1
    trf = torch.eye(4)
    trf[:3, :3] = s * R
    trf[:3, -1] = t
    source_pts_trf = (s * (R @ source_pts.T) + t[:, None]).T
    residuals = torch.linalg.norm(source_pts_trf - target_pts, dim=1)
    inlier_mask = residuals < ransac_inlier_threshold
    inlier_count = inlier_mask.sum()
    cost = residuals.mean()
    return trf, inlier_count, cost


def ransac_registration_gpu(source_pts, target_pts, 
                        ransac_iters=300, 
                        ransac_inlier_threshold=0.1, ransac_stop_threshold=0.05,
                        compute_scale=False, n_samples=5):
    assert len(source_pts) > n_samples + 1
    n_samples = n_samples if len(source_pts) > n_samples else 3
    best_cost = torch.inf
    best_inlier_count = -torch.inf
    best_trf = torch.eye(4)
    source_pts = torch.from_numpy(source_pts).cuda()
    target_pts = torch.from_numpy(target_pts).cuda()

    for _ in range(ransac_iters):
        # sub-sample points for registration
        sample_indices = np.random.choice(source_pts.shape[0], size=n_samples, replace=True)
        # rigid registration with select points
        trf, inlier_count, cost = registration_gpu(source_pts, target_pts, sample_indices, ransac_inlier_threshold, compute_scale)
        if (inlier_count > best_inlier_count) and (cost < best_cost):
            best_inlier_count, best_trf, best_cost = inlier_count, trf, cost
        # early stopping
        if (best_cost < ransac_stop_threshold) and (best_inlier_count > 4):
            break
    return best_trf.cpu().numpy(), best_inlier_count.cpu().item(), best_cost.cpu().item()


def ransac_registration(source_pts, target_pts, 
                        ransac_iters=300, 
                        ransac_inlier_threshold=0.1, ransac_stop_threshold=0.05,
                        compute_scale=False, n_samples=5):
    assert len(source_pts) > n_samples + 1
    n_samples = n_samples if len(source_pts) > n_samples else 3
    best_cost = np.inf
    best_inlier_count = -np.inf
    best_trf = np.eye(4)

    for _ in range(ransac_iters):
        # sub-sample points for registration
        sample_indices = np.random.choice(source_pts.shape[0], size=n_samples, replace=True)
        # rigid registration with select points
        trf, inlier_count, cost = registration(source_pts, target_pts, sample_indices, ransac_inlier_threshold, compute_scale)
        if (inlier_count > best_inlier_count) and (cost < best_cost):
            best_inlier_count, best_trf, best_cost = inlier_count, trf, cost
        # early stopping
        if (best_cost < ransac_stop_threshold) and (best_inlier_count > 4):
            break
    return best_trf, best_inlier_count, best_cost




def calculate_yaw_from_position_changes(p1, p2):
    return np.arctan2(p2[1] - p1[1], p2[0] - p1[0])


def merge_object_masks(masks):
    N, H, W = masks.shape
    merged = np.zeros((H, W)).astype(bool)
    for idx in range(len(masks)):
        merged |= masks[idx].astype(bool)
    return merged



def compute_principal_directions(points):
    """PCA can determine the main axes of the point cloud.
    The largest principal component represents the longitudinal direction of the object (e.g., car length)
    The first eigenvector (largest eigenvalue) = dominant direction of the car (typically its length).
    The second eigenvector = the sideways direction (car width)
    Returns:
        eigenvectors: (3, 3) matrix where colums represent principal axes
        eigenvalues: (3, 1) array representing the variance along each axis.
        centroid: the mean (x, y, z) position of the point cloud.
    """
    # compute centroid
    # centroid = np.mean(points, axis=0)
    centroid = np.min(points, axis=0)
    # center the point cloud
    centered_points = points - centroid
    pca = PCA(2)
    pca.fit(centered_points)
    # principal directions (columns)
    eigenvectors = pca.components_.T
    # variance along each axis
    eigenvalues = pca.explained_variance_
    return eigenvectors, eigenvalues, centroid


def compute_dir_with_PCA(points):
    # points = voxel_downsample(points, voxel_size=0.5)
    # points = normalize(points)
    bev_points = points[:, [0, 2]]
    if len(bev_points) < 2:
        return -1.56, 1.0
    eigenvectors, eigenvalues, centroid = compute_principal_directions(bev_points)
    dominant_axis = 0
    # first principal component
    dominant_direction = eigenvectors[:, dominant_axis] # the first axis corresponds to a car's length
    # normalize to avoid issues with scaling
    dominant_direction = normalize(dominant_direction)
    # compute yaw angle wrt. the camera's X-axis (1, 0, 0)
    yaw_rad = np.arctan2(dominant_direction[1], dominant_direction[0])
    return -check_angle_range(yaw_rad), eigenvalues[0] / eigenvalues[1]

def compute_dir_from_object_motion(cur_pts: np.ndarray, adj_pts: np.ndarray, trf_mat: np.ndarray, cur_index: int, adj_index: int):
    if trf_mat is not None:
        adj_pts = ((trf_mat[:3, :3] @ adj_pts.T) + trf_mat[:3, -1:]).T
    cur_center = np.min(cur_pts[:, [0, 2]], axis=0)
    adj_center = np.min(adj_pts[:, [0, 2]], axis=0)
    
    if adj_index > cur_index:
        yaw = -calculate_yaw_from_position_changes(cur_center, adj_center)
    else:
        yaw = -calculate_yaw_from_position_changes(adj_center, cur_center)

    # print(cur_index, adj_index, int(np.rad2deg(yaw)), cur_center[1], adj_center[1])

    # plt.scatter(cur_center[0], cur_center[1], c='g')
    # plt.scatter(adj_center[0], adj_center[1], c='r')
    # plt.xlim(-10, 10)
    # plt.ylim(0, 20)
    # plt.show()
    return yaw


def compute_size_from_axis_aligned_points(points):
    x_min, x_max = points[:, 0].min(), points[:, 0].max()
    y_min, y_max = points[:, 1].min(), points[:, 1].max()
    z_min, z_max = points[:, 2].min(), points[:, 2].max()
    L, H, W = x_max - x_min, y_max - y_min, z_max - z_min
    x3d, y3d, z3d = (x_min+x_max)/2, (y_min+y_max)/2, (z_min+z_max)/2
    return [L, H, W], [x3d, y3d, z3d]


def compute_size_from_axis_aligned_points_frontal(points):
    x_min, x_max = points[:, 0].min(), points[:, 0].max()
    y_min, y_max = points[:, 1].min(), points[:, 1].max()
    z_min, z_max = points[:, 2].min(), points[:, 2].max()
    L, H, W = x_max - x_min, y_max - y_min, z_max - z_min
    x3d, y3d, z3d = (x_min+x_max)/2, (y_min+y_max)/2, (z_min+z_max)/2
    return [L, H, W], (x3d, y3d, z3d)



def merge_and_filter_pseudo_lidars(list_of_lidars, cls_name: str, min_depth=1e-4, max_depth=80):
    pseudo_lidars = np.concatenate(list_of_lidars, axis=0)
    filtered_pseudo_lidars, _ = range_filter(pseudo_lidars, min_depth, max_depth)
    if len(filtered_pseudo_lidars) < 10: return pseudo_lidars, pseudo_lidars
    q_min, q_max = filter_thresholds[cls_name.title()] if cls_name in filter_thresholds.keys() else filter_thresholds["Others"]
    cleaned_pseudo_lidars, _ = remove_outliers_by_percentile(filtered_pseudo_lidars, q_min, q_max)
    if len(cleaned_pseudo_lidars) < 10: return pseudo_lidars, filtered_pseudo_lidars
    return pseudo_lidars, cleaned_pseudo_lidars
