from typing import Dict, List, Tuple

import numpy as np

from pointcloud.toolkits.waymo_helpers import (
    image_heights,
    image_widths,
)


class CameraRigSpec:
    def __init__(self, camera_names: List[str], intrinsics: Dict[str, np.ndarray],
                 extrinsics_cam_to_ego: Dict[str, np.ndarray], image_sizes: Dict[str, Tuple[int, int]],
                 distortion_coeffs: Dict[str, np.ndarray] = None):
        self.camera_names = list(camera_names)
        self.intrinsics = dict(intrinsics)
        self.extrinsics_cam_to_ego = dict(extrinsics_cam_to_ego)
        self.image_sizes = dict(image_sizes)
        self.distortion_coeffs = dict(distortion_coeffs or {})

    @staticmethod
    def default_name_to_idx() -> Dict[str, int]:
        return {
            'front': 0,
            'front_left': 1,
            'front_right': 2,
            'side_left': 3,
            'side_right': 4,
        }

    @staticmethod
    def get_adjacent_cameras() -> Dict[str, List[str]]:
        return {
            'front': ['front_left', 'front_right'],
            'front_left': ['front', 'side_left'],
            'front_right': ['front', 'side_right'],
            'side_left': ['front_left'],
            'side_right': ['front_right'],
        }

    @classmethod
    def from_waymo(cls, extrinsics_list, intrinsics_list, camera_names: List[str]) -> 'CameraRigSpec':
        name_to_idx = cls.default_name_to_idx()

        intrinsics: Dict[str, np.ndarray] = {}
        extrinsics: Dict[str, np.ndarray] = {}
        sizes: Dict[str, Tuple[int, int]] = {}
        distortions: Dict[str, np.ndarray] = {}

        for cam in camera_names:
            cam_idx = name_to_idx.get(cam, None)
            if cam_idx is None:
                try:
                    cam_idx = camera_names.index(cam)
                except Exception:
                    cam_idx = 0
            K_entry = intrinsics_list[cam_idx]
            if isinstance(K_entry, dict):
                K = np.asarray(K_entry.get('K', np.eye(3, dtype=np.float32))).astype(np.float32)
                dist = np.asarray(K_entry.get('dist_coeffs', np.zeros(5, dtype=np.float32))).astype(np.float32)
            else:
                K = np.asarray(K_entry).astype(np.float32)
                dist = np.zeros(5, dtype=np.float32)
            cam_to_ego = np.asarray(extrinsics_list[cam_idx]).astype(np.float32)
            H = int(image_heights[cam_idx])
            W = int(image_widths[cam_idx])
            intrinsics[cam] = K
            extrinsics[cam] = cam_to_ego
            sizes[cam] = (H, W)
            distortions[cam] = dist

        return cls(camera_names=camera_names, intrinsics=intrinsics,
                   extrinsics_cam_to_ego=extrinsics, image_sizes=sizes,
                   distortion_coeffs=distortions)

    def get_K(self, cam: str) -> np.ndarray:
        return self.intrinsics[cam]

    def get_size(self, cam: str) -> Tuple[int, int]:
        return self.image_sizes[cam]

    def get_cam_to_ego(self, cam: str) -> np.ndarray:
        return self.extrinsics_cam_to_ego[cam]

    def get_dist_coeffs(self, cam: str) -> np.ndarray:
        return self.distortion_coeffs.get(cam, np.zeros(5, dtype=np.float32))


class CameraTransform:
    @staticmethod
    def apply_transform(
        original_camera_to_world: np.ndarray,
        translation: np.ndarray,
        rotation_matrix: np.ndarray,
    ) -> np.ndarray:
        original_R = original_camera_to_world[:3, :3]
        original_t = original_camera_to_world[:3, 3]
        new_R = rotation_matrix @ original_R
        new_t = rotation_matrix @ original_t + translation
        new_camera_to_world = np.eye(4, dtype=np.float32)
        new_camera_to_world[:3, :3] = new_R
        new_camera_to_world[:3, 3] = new_t
        return new_camera_to_world

    @staticmethod
    def points_to_world(points: np.ndarray, camera_to_world: np.ndarray) -> np.ndarray:
        points_homo = np.concatenate([points, np.ones((points.shape[0], 1), dtype=np.float32)], axis=1)
        points_world = (camera_to_world @ points_homo.T).T[:, :3]
        return points_world


