import os
import cv2
import numpy as np
from pyquaternion import Quaternion
import math
from scipy.spatial.transform import Rotation as R

CLASSES_REMAP = {
            0: 0,  # Unlabeled (so that we don't forget this class)
            10: 1,  # Vehicles
            8: 0,  # Sidewalk
            14: 0,  # Ground (non-drivable)
            22: 0,  # Terrain (non-drivable)
            7: 2,  # Road
            6: 3,  # Road line
            18: 0,  # Traffic light
            5: 0,  # Pole
            1: 0,  # Building
            4: 0,  # Pedestrian
            9: 0,  # Vegetation
        }

CLASS_TO_RGB = {
            0: [0, 0, 0],  # Background
            1: [255, 255, 255],  # Vehicle
            2: [238, 123, 94],  # Road
            3: [41, 132, 199],  # Road line
        }

def chmk_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def get_transformation_matrix_bkp(calibrate_param):
    """
    quaternion is a vector contain 4 elements [w,x,y,z]
    in camera params, the x,y need to be set as 0,0
    """
    quaternion, translation = calibrate_param['rotation'], calibrate_param['translation']
    quaternion[1], quaternion[2] = 0, 0
    cam_extrinsic = Quaternion(quaternion).transformation_matrix
    cam_extrinsic[:3, -1] = translation
    return cam_extrinsic

def get_transformation_matrix(ego_pose, calibrate_param):
    """
    quaternion is a vector contain 4 elements [w,x,y,z]
    in camera params, the x,y need to be set as 0,0
    """
    # Get ego transformation matrix
    ego2world=np.eye(4)
    roll, yaw, pitch = covert_degree(ego_pose['rotation'])
    r = R.from_euler('xyz', [roll, pitch, yaw], degrees=True)
    rotation_matrix = r.as_matrix()
    ego2world[:3, :3] = rotation_matrix
    ego2world[:3,3] = ego_pose['translation'][:]

    # Get camera transformation matrix and correct angle
    quaternion, translation = calibrate_param['rotation'], calibrate_param['translation']
    quaternion[1], quaternion[2] = 0, 0
    roll, yaw, pitch = covert_degree(quaternion)
    yaw = yaw+90
    r = R.from_euler('xyz', [roll, pitch, yaw], degrees=True)
    cam2ego = np.eye(4)
    rotation_matrix = r.as_matrix()
    cam2ego[:3, :3] = rotation_matrix
    cam2ego[:3, -1] = translation

    # Get camera pose list with [x, y, z, roll, yaw, pitch]
    cam2world=np.dot(ego2world,cam2ego)

    cam_pose_list = [0 for _ in range(6)] # x, y, z, roll, yaw, pitch
    r = R.from_matrix(cam2world[:3, :3])
    roll, pitch, yaw = r.as_euler('xyz',degrees=True)
    cam_pose_list[:3] = cam2world[:3,3].tolist()
    cam_pose_list[3], cam_pose_list[4], cam_pose_list[5] = roll.tolist(), yaw.tolist(), pitch.tolist()
    return cam_pose_list, cam2ego

def get_lidar_world_pose(ego_pose, calibrated_pose):
    # Get ego transformation matrix
    ego2world=np.eye(4)
    roll, yaw, pitch = covert_degree(ego_pose['rotation'])
    r = R.from_euler('xyz', [roll, pitch, yaw], degrees=True)
    rotation_matrix = r.as_matrix()
    ego2world[:3, :3] = rotation_matrix
    ego2world[:3,3] = ego_pose['translation'][:]

    # Get lidar transformation matrix
    lidar2ego=np.eye(4)
    roll, yaw, pitch = covert_degree(calibrated_pose['rotation'])
    r = R.from_euler('xyz', [roll, pitch, yaw], degrees=True)
    rotation_matrix = r.as_matrix()
    lidar2ego[:3, :3] = rotation_matrix
    lidar2ego[:3,3] = calibrated_pose['translation'][:]

    # Get lidar pose list with [x, y, z, roll, yaw, pitch]
    lidar2world=np.dot(ego2world,lidar2ego)

    lidar_pose_list = [0 for _ in range(6)]
    r = R.from_matrix(lidar2world[:3, :3])
    roll, pitch, yaw = r.as_euler('xyz',degrees=True)
    lidar_pose_list[:3] = lidar2world[:3,3].tolist()
    lidar_pose_list[3], lidar_pose_list[4], lidar_pose_list[5] = roll.tolist(), yaw.tolist(), pitch.tolist()
    return lidar_pose_list

def covert_degree(quaternion):
    """
    roll, yaw, pitch
    """
    if type(quaternion) != Quaternion:
        ypr = Quaternion(quaternion).yaw_pitch_roll
    else:
        ypr = quaternion.yaw_pitch_roll
    y, p, r = ypr[0], ypr[1], ypr[2]
    ryp = [r, y, p]
    return [float(x*180/math.pi) for x in ryp]

def covert_extent_order(extent):
    res = extent/2
    res[0], res[1] = res[1], res[0]
    return res.tolist()

def get_ego_annotations(ego_pose, ann_boxes, range=75):
    """
    Get annotations for ego pose
    """
    res_dict = dict()
    ego_loc = np.array(ego_pose[:2])
    for (i, ann_box) in enumerate(ann_boxes):
        # # Move box to sensor coord system
        # ann_box.translate(-np.array(calibrated_pose['translation']))
        # ann_box.rotate(Quaternion(calibrated_pose['rotation']).inverse)
        if ann_box.name == "vehicle.car" or ann_box.name == "vehicle.emergency.police": #TODO: only consider these two type
            ann_loc = ann_box.center[:2]
            distance = np.linalg.norm(ego_loc-ann_loc)
            if distance < range:
                angle = covert_degree(ann_box.orientation)
                extent = covert_extent_order(ann_box.wlh)
                location = ann_box.center.tolist()
                obj_type = ann_box.name
                res_dict[i] = {'angle': angle,
                               'center': [0, 0, 0],
                               'extent': extent,
                               'location': location,
                               'obj_type': obj_type,
                               }

    return res_dict

def get_seg_label(bev_data_path):

    bev_pix = np.load(bev_data_path, allow_pickle=True)

    # some data comes as .npz format
    if type(bev_pix) == np.lib.npyio.NpzFile:
        bev_pix = bev_pix["arr_0"]

    bev_pix = cv2.resize(
        bev_pix, dsize=(256, 256), interpolation=cv2.INTER_NEAREST
    )

    new_bev_pix = np.zeros(bev_pix.shape, dtype=np.uint8)
    for key, value in CLASSES_REMAP.items():
        new_bev_pix[np.where(bev_pix == key)] = value

    return new_bev_pix