import torch
import numpy as np
from plyfile import PlyData, PlyElement
from pathlib import Path
from typing import Any, Dict, List, Tuple
from scipy.spatial.transform import Rotation as R
from gsplat import rasterization


class Gaussian:
    def __init__(
            self,
            aabb: list,
            sh_degree: int = 0,
            mininum_kernel_size: float = 0.0,
            scaling_bias: float = 0.01,
            opacity_bias: float = 0.1,
            scaling_activation: str = "exp",
            device='cuda'
    ):
        self.init_params = {
            'aabb': aabb,
            'sh_degree': sh_degree,
            'mininum_kernel_size': mininum_kernel_size,
            'scaling_bias': scaling_bias,
            'opacity_bias': opacity_bias,
            'scaling_activation': scaling_activation,
        }

        self.sh_degree = sh_degree
        self.active_sh_degree = sh_degree
        self.mininum_kernel_size = mininum_kernel_size
        self.scaling_bias = scaling_bias
        self.opacity_bias = opacity_bias
        self.scaling_activation_type = scaling_activation
        self.device = device
        self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
        self.setup_functions()

        self._xyz = None
        self._features_dc = None
        self._features_rest = None
        self._scaling = None
        self._rotation = None
        self._opacity = None

        # constants used by rendering/utilities
        self._GS_LWH_TO_XYZ_SCALE = (0.90, 0.90, 0.88)
        self._OPACITY_THRESHOLD = 0.01

    def setup_functions(self):
        def build_covariance_from_scaling_rotation(scaling, scaling_modifier, rotation):
            L = build_scaling_rotation(scaling_modifier * scaling, rotation)
            actual_covariance = L @ L.transpose(1, 2)
            symm = strip_symmetric(actual_covariance)
            return symm

        if self.scaling_activation_type == "exp":
            self.scaling_activation = torch.exp
            self.inverse_scaling_activation = torch.log
        elif self.scaling_activation_type == "softplus":
            self.scaling_activation = torch.nn.functional.softplus
            self.inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))

        self.covariance_activation = build_covariance_from_scaling_rotation

        self.opacity_activation = torch.sigmoid
        self.inverse_opacity_activation = inverse_sigmoid

        self.rotation_activation = torch.nn.functional.normalize

        self.scale_bias = self.inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
        self.rots_bias = torch.zeros((4)).to(self.device)
        self.rots_bias[0] = 1
        self.opacity_bias = self.inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)

        # rotation fix (Z 轴 180°) to match reference implementation
        rotation_fix_z_180 = R.from_euler('z', 180, degrees=True).as_matrix()
        self._TRANSFORM_FIX_Z_180 = np.eye(4, dtype=np.float32)
        self._TRANSFORM_FIX_Z_180[:3, :3] = rotation_fix_z_180

    @property
    def get_scaling(self):
        scales = self.scaling_activation(self._scaling + self.scale_bias)
        scales = torch.square(scales) + self.mininum_kernel_size ** 2
        scales = torch.sqrt(scales)
        return scales

    @property
    def get_rotation(self):
        return self.rotation_activation(self._rotation + self.rots_bias[None, :])

    @property
    def get_xyz(self):
        return self._xyz * self.aabb[None, 3:] + self.aabb[None, :3]

    @property
    def get_features(self):
        return torch.cat((self._features_dc, self._features_rest),
                         dim=2) if self._features_rest is not None else self._features_dc

    @property
    def get_opacity(self):
        return self.opacity_activation(self._opacity + self.opacity_bias)

    def get_covariance(self, scaling_modifier=1):
        return self.covariance_activation(self.get_scaling, scaling_modifier, self._rotation + self.rots_bias[None, :])

    def from_scaling(self, scales):
        scales = torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)
        self._scaling = self.inverse_scaling_activation(scales) - self.scale_bias

    def from_rotation(self, rots):
        self._rotation = rots - self.rots_bias[None, :]

    def from_xyz(self, xyz):
        self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]

    def from_features(self, features):
        self._features_dc = features

    def from_opacity(self, opacities):
        self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias

    def construct_list_of_attributes(self):
        l = ['x', 'y', 'z', 'nx', 'ny', 'nz']
        # All channels except the 3 DC
        for i in range(self._features_dc.shape[1] * self._features_dc.shape[2]):
            l.append('f_dc_{}'.format(i))
        l.append('opacity')
        for i in range(self._scaling.shape[1]):
            l.append('scale_{}'.format(i))
        for i in range(self._rotation.shape[1]):
            l.append('rot_{}'.format(i))
        return l

    def save_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
        xyz = self.get_xyz.detach().cpu().numpy()
        normals = np.zeros_like(xyz)
        f_dc = self._features_dc.detach().transpose(1, 2).flatten(start_dim=1).contiguous().cpu().numpy()
        opacities = inverse_sigmoid(self.get_opacity).detach().cpu().numpy()
        scale = torch.log(self.get_scaling).detach().cpu().numpy()
        rotation = (self._rotation + self.rots_bias[None, :]).detach().cpu().numpy()

        if transform is not None:
            import utils3d
            transform = np.array(transform)
            xyz = np.matmul(xyz, transform.T)
            rotation = utils3d.numpy.quaternion_to_matrix(rotation)
            rotation = np.matmul(transform, rotation)
            rotation = utils3d.numpy.matrix_to_quaternion(rotation)

        dtype_full = [(attribute, 'f4') for attribute in self.construct_list_of_attributes()]

        elements = np.empty(xyz.shape[0], dtype=dtype_full)
        attributes = np.concatenate((xyz, normals, f_dc, opacities, scale, rotation), axis=1)
        elements[:] = list(map(tuple, attributes))
        el = PlyElement.describe(elements, 'vertex')
        PlyData([el]).write(path)

    def load_ply(self, path, transform=[[1, 0, 0], [0, 0, -1], [0, 1, 0]]):
        plydata = PlyData.read(path)

        xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
                        np.asarray(plydata.elements[0]["y"]),
                        np.asarray(plydata.elements[0]["z"])), axis=1)
        opacities = np.asarray(plydata.elements[0]["opacity"])[..., np.newaxis]

        features_dc = np.zeros((xyz.shape[0], 3, 1))
        features_dc[:, 0, 0] = np.asarray(plydata.elements[0]["f_dc_0"])
        features_dc[:, 1, 0] = np.asarray(plydata.elements[0]["f_dc_1"])
        features_dc[:, 2, 0] = np.asarray(plydata.elements[0]["f_dc_2"])

        if self.sh_degree > 0:
            extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
            extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1]))
            assert len(extra_f_names) == 3 * (self.sh_degree + 1) ** 2 - 3
            features_extra = np.zeros((xyz.shape[0], len(extra_f_names)))
            for idx, attr_name in enumerate(extra_f_names):
                features_extra[:, idx] = np.asarray(plydata.elements[0][attr_name])
            # Reshape (P,F*SH_coeffs) to (P, F, SH_coeffs except DC)
            features_extra = features_extra.reshape((features_extra.shape[0], 3, (self.max_sh_degree + 1) ** 2 - 1))

        scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
        scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
        scales = np.zeros((xyz.shape[0], len(scale_names)))
        for idx, attr_name in enumerate(scale_names):
            scales[:, idx] = np.asarray(plydata.elements[0][attr_name])

        rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
        rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
        rots = np.zeros((xyz.shape[0], len(rot_names)))
        for idx, attr_name in enumerate(rot_names):
            rots[:, idx] = np.asarray(plydata.elements[0][attr_name])

        if transform is not None:
            import utils3d
            transform = np.array(transform)
            xyz = np.matmul(xyz, transform)
            rotation = utils3d.numpy.quaternion_to_matrix(rots)
            rotation = np.matmul(rotation, transform)
            rots = utils3d.numpy.matrix_to_quaternion(rotation)

        # convert to actual gaussian attributes
        xyz = torch.tensor(xyz, dtype=torch.float, device=self.device)
        features_dc = torch.tensor(features_dc, dtype=torch.float, device=self.device).transpose(1, 2).contiguous()
        if self.sh_degree > 0:
            features_extra = torch.tensor(features_extra, dtype=torch.float, device=self.device).transpose(1,
                                                                                                           2).contiguous()
        opacities = torch.sigmoid(torch.tensor(opacities, dtype=torch.float, device=self.device))
        scales = torch.exp(torch.tensor(scales, dtype=torch.float, device=self.device))
        rots = torch.tensor(rots, dtype=torch.float, device=self.device)

        # convert to _hidden attributes
        self._xyz = (xyz - self.aabb[None, :3]) / self.aabb[None, 3:]
        self._features_dc = features_dc
        if self.sh_degree > 0:
            self._features_rest = features_extra
        else:
            self._features_rest = None
        self._opacity = self.inverse_opacity_activation(opacities) - self.opacity_bias
        self._scaling = self.inverse_scaling_activation(
            torch.sqrt(torch.square(scales) - self.mininum_kernel_size ** 2)) - self.scale_bias
        self._rotation = rots - self.rots_bias[None, :]

    # --- Extensions merged from gs_loader.py and gs_renderer.py ---
    def compute_scale_from_visible_points(self, object_lwh_real: np.ndarray, opacity_threshold: float = 0.01) -> torch.Tensor:
        opacity = self.get_opacity.squeeze(-1)
        visible = opacity > opacity_threshold
        if not torch.any(visible):
            return torch.ones(3, device=self.device, dtype=torch.float32)

        visible_xyz = self.get_xyz[visible].detach().cpu().numpy()
        min_xyz = np.min(visible_xyz, axis=0)
        max_xyz = np.max(visible_xyz, axis=0)
        object_size_np = np.maximum(max_xyz - min_xyz, 1e-6)

        real_size = torch.tensor(object_lwh_real, device=self.device, dtype=torch.float32)
        current_size = torch.tensor(object_size_np, device=self.device, dtype=torch.float32)
        scale_wlh = torch.tensor(self._GS_LWH_TO_XYZ_SCALE, device=self.device, dtype=torch.float32)
        scale = scale_wlh * real_size / current_size
        return scale

    @staticmethod
    def load_and_transform_ply_simple(
        ply_path: Path,
        target_lwh: np.ndarray,
        target_object_to_ego: np.ndarray,
        opacity_threshold: float = 0.01,
    ) -> Tuple[np.ndarray, np.ndarray]:
        from plyfile import PlyData as _PlyData

        ply = _PlyData.read(str(ply_path))
        el = ply.elements[0]

        x = np.asarray(el["x"], dtype=np.float32)
        y = np.asarray(el["y"], dtype=np.float32)
        z = np.asarray(el["z"], dtype=np.float32)
        points = np.stack([x, y, z], axis=1)

        f0 = np.asarray(el["f_dc_0"], dtype=np.float32)
        f1 = np.asarray(el["f_dc_1"], dtype=np.float32)
        f2 = np.asarray(el["f_dc_2"], dtype=np.float32)
        fdc = np.stack([f0, f1, f2], axis=1)
        SH_C0 = 0.28209479177387814
        colors = (fdc * SH_C0) + 0.5
        colors = np.clip(colors, 0.0, 1.0)

        op_logit = np.asarray(el["opacity"], dtype=np.float32)
        opacity = 1.0 / (1.0 + np.exp(-op_logit))
        visible = opacity > opacity_threshold

        if not np.any(visible):
            return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.float32)

        points = points[visible]
        colors = colors[visible]

        min_xyz = np.min(points, axis=0)
        max_xyz = np.max(points, axis=0)
        center = (min_xyz + max_xyz) / 2.0
        current_lwh = max_xyz - min_xyz
        current_lwh[current_lwh < 1e-6] = 1e-6

        scale_wlh = np.array([0.96, 0.96, 0.92], dtype=np.float32)
        scale_factors = (target_lwh / current_lwh) * scale_wlh
        points_centered = points - center
        points_scaled = points_centered * scale_factors

        ones = np.ones((points_scaled.shape[0], 1), dtype=np.float32)
        points_h = np.hstack([points_scaled, ones])
        points_ego = (target_object_to_ego @ points_h.T).T[:, :3]
        return points_ego, colors

    @staticmethod
    def render_objects_for_camera(
        device: torch.device,
        camera_pose_w2c: np.ndarray,
        K: np.ndarray,
        dist_coeffs: np.ndarray,
        W: int,
        H: int,
        objects_to_render: List[Dict[str, Any]],
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        if len(objects_to_render) == 0:
            return (
                np.zeros((H, W, 3), dtype=np.uint8),
                np.zeros((H, W), dtype=np.uint8),
                np.zeros((H, W), dtype=np.float32),
            )

        fx, fy, cx, cy = float(K[0, 0]), float(K[1, 1]), float(K[0, 2]), float(K[1, 2])
        Ks = torch.tensor([[fx, 0.0, cx], [0.0, fy, cy], [0.0, 0.0, 1.0]], device=device, dtype=torch.float32).unsqueeze(0)
        # 忽略外部传入的畸变参数，统一使用零畸变
        radial_coeffs = torch.zeros((1, 6), device=device, dtype=torch.float32)
        tangential_coeffs = torch.zeros((1, 2), device=device, dtype=torch.float32)

        color_list: List[torch.Tensor] = []
        alpha_list: List[torch.Tensor] = []
        depth_list: List[torch.Tensor] = []
        distance_list: List[float] = []

        rotation_fix_z_180 = R.from_euler('z', 180, degrees=True).as_matrix()
        transform_fix_z_180 = np.eye(4, dtype=np.float32)
        transform_fix_z_180[:3, :3] = rotation_fix_z_180

        for obj in objects_to_render:
            ply_path = Path(obj["ply_path"])  # type: ignore
            if not ply_path.exists():
                continue
            try:
                model = Gaussian(
                    aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
                    scaling_bias=0.004,
                    opacity_bias=0.1,
                    scaling_activation="softplus",
                    device=device,
                )
                model.load_ply(str(ply_path), transform=None)

                scale_factors = model.compute_scale_from_visible_points(np.asarray(obj["object_lwh"], dtype=np.float32), opacity_threshold=0.01)

                object_to_world = np.asarray(obj["object_to_world"], dtype=np.float32)
                object_to_world = object_to_world @ transform_fix_z_180
                object_to_camera = camera_pose_w2c @ object_to_world
                dist = float(np.linalg.norm(object_to_camera[:3, 3]))

                o2c = torch.from_numpy(object_to_camera).to(torch.float32).to(device).unsqueeze(0)

                colors_with_depth, alphas, meta = rasterization(
                    model.get_xyz * scale_factors,
                    model.get_rotation,
                    model.get_scaling * scale_factors,
                    model.get_opacity.squeeze(-1),
                    model.get_features,
                    o2c,
                    Ks,
                    W,
                    H,
                    sh_degree=0,
                    render_mode="RGB+D",
                    rasterize_mode='antialiased',
                    packed=False,
                    with_ut=True,
                    radial_coeffs=radial_coeffs,
                    tangential_coeffs=tangential_coeffs,
                )

                distance_list.append(dist)

                output_tensor = colors_with_depth[0]
                color_list.append(output_tensor[..., :3])
                alpha_list.append(alphas[0, ..., 0])
                depth_list.append(output_tensor[..., 3])
            except Exception:
                continue

        if len(color_list) == 0:
            return (
                np.zeros((H, W, 3), dtype=np.uint8),
                np.zeros((H, W), dtype=np.uint8),
                np.zeros((H, W), dtype=np.float32),
            )

        order = np.argsort(np.asarray(distance_list))[::-1]
        color_np = [color_list[i].detach().cpu().numpy() for i in order]
        alpha_np = [alpha_list[i].detach().cpu().numpy() for i in order]
        depth_np = [depth_list[i].detach().cpu().numpy() for i in order]

        canvas = np.zeros((H, W, 3), dtype=np.float32)
        alpha_canvas = np.zeros((H, W), dtype=np.float32)
        depth_canvas = np.zeros((H, W), dtype=np.float32)
        for c, a, d in zip(color_np, alpha_np, depth_np):
            a = np.clip(a, 0.0, 1.0)
            if not np.any(a > 0):
                continue
            a3 = a[..., None]
            canvas = c * a3 + canvas * (1.0 - a3)
            alpha_canvas = a + alpha_canvas * (1.0 - a)
            update_mask = (a > 0) & ((depth_canvas <= 0) | (d < depth_canvas))
            if np.any(update_mask):
                depth_canvas[update_mask] = d[update_mask]

        canvas_u8 = (np.clip(canvas, 0.0, 1.0) * 255).astype(np.uint8)
        alpha_u8 = (np.clip(alpha_canvas, 0.0, 1.0) * 255).astype(np.uint8)
        return canvas_u8, alpha_u8, depth_canvas


def inverse_sigmoid(x):
    return torch.log(x / (1 - x))


def strip_lowerdiag(L):
    uncertainty = torch.zeros((L.shape[0], 6), dtype=torch.float, device=L.device)

    uncertainty[:, 0] = L[:, 0, 0]
    uncertainty[:, 1] = L[:, 0, 1]
    uncertainty[:, 2] = L[:, 0, 2]
    uncertainty[:, 3] = L[:, 1, 1]
    uncertainty[:, 4] = L[:, 1, 2]
    uncertainty[:, 5] = L[:, 2, 2]
    return uncertainty


def strip_symmetric(sym):
    return strip_lowerdiag(sym)


def build_rotation(r):
    norm = torch.sqrt(r[:, 0] * r[:, 0] + r[:, 1] * r[:, 1] + r[:, 2] * r[:, 2] + r[:, 3] * r[:, 3])

    q = r / norm[:, None]

    R = torch.zeros((q.size(0), 3, 3), device=r.device)

    r = q[:, 0]
    x = q[:, 1]
    y = q[:, 2]
    z = q[:, 3]

    R[:, 0, 0] = 1 - 2 * (y * y + z * z)
    R[:, 0, 1] = 2 * (x * y - r * z)
    R[:, 0, 2] = 2 * (x * z + r * y)
    R[:, 1, 0] = 2 * (x * y + r * z)
    R[:, 1, 1] = 1 - 2 * (x * x + z * z)
    R[:, 1, 2] = 2 * (y * z - r * x)
    R[:, 2, 0] = 2 * (x * z - r * y)
    R[:, 2, 1] = 2 * (y * z + r * x)
    R[:, 2, 2] = 1 - 2 * (x * x + y * y)
    return R


def build_scaling_rotation(s, r):
    L = torch.zeros((s.shape[0], 3, 3), dtype=torch.float, device=s.device)
    R = build_rotation(r)

    L[:, 0, 0] = s[:, 0]
    L[:, 1, 1] = s[:, 1]
    L[:, 2, 2] = s[:, 2]

    L = R @ L
    return L
