import math
import torch
import numpy as np
import torch.nn.functional as F

from plyfile import PlyData
from pointnet2_ops import pointnet2_utils

def xyz_norm(centroids):
    mean_centroid = np.mean(centroids, axis=0)
    centroids = centroids - mean_centroid
    max_centroid = np.max(np.sqrt(np.sum(centroids**2,axis=1)))
    centroids = centroids / max_centroid
    return centroids

def translate_pointcloud(pointcloud):
    xyz1 = np.random.uniform(low=0.8, high=1.25, size=[3])
    xyz2 = np.random.uniform(low=-0.1, high=0.1, size=[3])
    translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
    return translated_pointcloud

def normalize_xyz_scales(centroids, scales):

    mean_c = centroids.mean(axis=0, keepdims=True)              
    c_centered = centroids - mean_c
    max_radius = np.max(np.sqrt(np.sum(centroids**2,axis=1)))     
    centroids_n = c_centered / max_radius 
    
    scales_n = scales / max_radius 
    
    return centroids_n.astype('float32'), scales_n.astype('float32')

def augment_gs(centroids_n, scales_n,
               scale_low=0.8, scale_high=1.25,
               shift_low=-0.1, shift_high=0.1):

    m = np.random.uniform(scale_low, scale_high, size=[3])
    b = np.random.uniform(shift_low, shift_high, size=[3])
    centroids_aug = centroids_n * m + b

    scales_aug = scales_n * m
    
    return centroids_aug.astype('float32'), scales_aug.astype('float32')

def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    new_points = points.clone()
    B = points.shape[0]
    
    N = points.shape[1]

    # 防御性检查
    if B == 0 or N == 0:
        raise ValueError(f"[index_points] Empty tensor detected: points.shape={points.shape}, idx.shape={idx.shape}")

    if idx.max() >= N or idx.min() < 0:
        raise ValueError(f"[index_points] Invalid index in fps: "
                         f"max idx={idx.max().item()}, N={N}, idx.shape={idx.shape}")
    

    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = new_points[batch_indices, idx, :]
    return new_points

def loadply(dir, max_sh_degree, activate=True):


    assert max_sh_degree >= 0, "max_sh_degree must >= 0"
    plydata = PlyData.read(dir)
    xyz = torch.stack((torch.tensor(plydata.elements[0]["x"]),
                        torch.tensor(plydata.elements[0]["y"]),
                        torch.tensor(plydata.elements[0]["z"])),  axis=1)

    opacities = torch.tensor(plydata.elements[0]["opacity"])[..., None]
    features = torch.zeros((xyz.shape[0], 3, (max_sh_degree + 1) ** 2))
    features[:, 0, 0] = torch.tensor(plydata.elements[0]["f_dc_0"])
    features[:, 1, 0] = torch.tensor(plydata.elements[0]["f_dc_1"])
    features[:, 2, 0] = torch.tensor(plydata.elements[0]["f_dc_2"])
    if max_sh_degree > 0:
        extra_f_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("f_rest_")]
        extra_f_names = sorted(extra_f_names, key=lambda x: int(x.split('_')[-1]))
        assert len(extra_f_names) >= 3 * (max_sh_degree + 1) ** 2 - 3, f"max_sh_degree must < {len(extra_f_names)}"
        for idx, attr_name in enumerate(extra_f_names):
            rgb = idx // ((max_sh_degree + 1) ** 2 - 1)
            if rgb>=3: 
                break
            ceof = idx % ((max_sh_degree + 1) ** 2 - 1)
            if ceof >= (max_sh_degree + 1) ** - 1:
                continue
            features[:, rgb, ceof + 1] = torch.tensor(plydata.elements[0][attr_name])

    scale_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("scale_")]
    scale_names = sorted(scale_names, key=lambda x: int(x.split('_')[-1]))
    scales = torch.zeros((xyz.shape[0], len(scale_names)))
    for idx, attr_name in enumerate(scale_names):
        scales[:, idx] = torch.tensor(plydata.elements[0][attr_name])
    if activate:
        scales = torch.exp(scales)
        opacities = torch.sigmoid(opacities)
    rot_names = [p.name for p in plydata.elements[0].properties if p.name.startswith("rot")]
    rot_names = sorted(rot_names, key=lambda x: int(x.split('_')[-1]))
    rots = torch.zeros((xyz.shape[0], len(rot_names)))
    for idx, attr_name in enumerate(rot_names):
        rots[:, idx] = torch.tensor(plydata.elements[0][attr_name])
    rots = torch.nn.functional.normalize(rots)

    return {
            "xyz": xyz,													# [27688, 3]
            "opacity": opacities,  										# [27688, 1]
            "sh": features,											  	# [27688, 3, max_sh_degree]
            "scale": scales,  											# [27688, 3]
            "q": rots  													# [27688, 4]
            }


def load_view_gs(dir,a):
    gs_vertex = PlyData.read(dir)['vertex']
    ### load centroids[x,y,z] - Gaussian centroid
    x = gs_vertex['x'].astype(np.float32)
    y = gs_vertex['y'].astype(np.float32)
    z = gs_vertex['z'].astype(np.float32)
    centroids = np.stack((x, y, z), axis=-1) # [n, 3]


    ### load o - opacity
    opacity = gs_vertex['opacity'].astype(np.float32).reshape(-1, 1)

    ### load scales[sx, sy, sz] - Scale
    scale_names = [
        p.name
        for p in gs_vertex.properties
        if p.name.startswith("scale_")
    ]
    scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
    scales = np.zeros((centroids.shape[0], len(scale_names)))
    for idx, attr_name in enumerate(scale_names):
        scales[:, idx] = gs_vertex[attr_name].astype(np.float32)

    ### load rotation rots[q_0, q_1, q_2, q_3] - Rotation
    rot_names = [
        p.name for p in gs_vertex.properties if p.name.startswith("rot")
    ]
    rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
    rots = np.zeros((centroids.shape[0], len(rot_names)))
    for idx, attr_name in enumerate(rot_names):
        rots[:, idx] = gs_vertex[attr_name].astype(np.float32)

    rots = rots / (np.linalg.norm(rots, axis=1, keepdims=True) + 1e-9)

    ### load base sh_base[dc_0, dc_1, dc_2] - Spherical harmonic
    sh_base = np.zeros((centroids.shape[0], 3, 1))
    sh_base[:, 0, 0] = gs_vertex['f_dc_0'].astype(np.float32)
    sh_base[:, 1, 0] = gs_vertex['f_dc_1'].astype(np.float32)
    sh_base[:, 2, 0] = gs_vertex['f_dc_2'].astype(np.float32)
    sh_base = sh_base.reshape(-1, 3)

    scales = np.exp(scales)
    opacity = 1 / (1 + np.exp(-opacity)) 

    sample_list = [centroids,opacity,scales,rots,sh_base] #xyz,opacity,scale,Rotation matrix,sh
    sample = np.concatenate(sample_list, axis=1)
    return sample



def nomal_gs(gs_tensor):
    gs_nom = gs_tensor.clone()
    xyz = gs_tensor[:,:3]
    scales = gs_tensor[:,4:7]
    xyz_min = xyz.min(dim=0).values
    xyz_max = xyz.max(dim=0).values
    center = (xyz_min + xyz_max) / 2
    size = (xyz_max - xyz_min).max()
    scale_factor = 2.0 / size
    xyz_normalized = (xyz - center) * scale_factor
    scales = torch.log(torch.exp(scales) * (scale_factor))
    gs_nom[:,:3] = xyz_normalized
    gs_nom[:,4:7] = scales

    return gs_nom

def center_gs(gs_tensor):
    gs_centered = gs_tensor.clone()
    xyz = gs_tensor[:, :3]
    xyz_min = xyz.min(dim=0).values
    xyz_max = xyz.max(dim=0).values
    center = (xyz_min + xyz_max) / 2.0
    xyz_centered = xyz - center
    gs_centered[:, :3] = xyz_centered
    max_abs = xyz_centered.abs().max(dim=0).values  # [max_abs_x, max_abs_y, max_abs_z]

    return gs_centered, max_abs


def fps(gs_tensor, N):
    
    gs_tensor = gs_tensor.unsqueeze(0)
    
    _, gs_fps_tensor = gs_xyz_fps(gs_tensor, N)
    gs_fps_tensor = gs_fps_tensor.squeeze(0)
    return gs_fps_tensor


def gs_xyz_fps(gs_tensor, num):
    """
    gs_tensor: (B, P, C)  expected
    num: desired sample count
    Returns: fps_idx (B, S), gs_fps_tensor (B, S, C)
    """
    B, P, C = gs_tensor.shape

    # no point
    if P == 0:
        device = gs_tensor.device
        fps_idx = torch.zeros((B, num), dtype=torch.long, device=device)
        gs_fps_tensor = gs_tensor.new_zeros((B, num, C))
        return fps_idx, gs_fps_tensor

    if P < num:
        repeats = (num + P - 1) // P
        pts_rep = gs_tensor.repeat(1, repeats, 1)[:, :num, :].contiguous()  # (B, num, C)
        fps_idx = torch.arange(0, num, device=gs_tensor.device).unsqueeze(0).repeat(B, 1).long()
        return fps_idx, pts_rep

    try:
        fps_idx = pointnet2_utils.furthest_point_sample(gs_tensor[:, :, :3].contiguous(), num).long()
    except Exception as e:
        raise RuntimeError(f"[gs_xyz_fps] furthest_point_sample failed: {e}; gs_tensor.shape={gs_tensor.shape}")

    gs_fps_tensor = index_points(gs_tensor.contiguous(), fps_idx)
    return fps_idx, gs_fps_tensor

def select_near_half_frustum_points(gs_xyz, cam_info, ):

    device = gs_xyz.device

    #load camera
    cam_pos: torch.Tensor = cam_info["cam_pos"]
    R: torch.Tensor       = cam_info["cam_R"]
    fov_deg: float        = cam_info["fov_deg"]
    near: float           = cam_info["near"]
    far: float            = cam_info["far"]
    aspect_ratio: float   = cam_info["aspect_ratio"]

    points_cam = (R.matmul((gs_xyz - cam_pos).t())).t()


    x, y, z = points_cam[:, 0], points_cam[:, 1], points_cam[:, 2]

    #Visual cone
    fov_rad = math.radians(fov_deg)
    tan_half_fov = math.tan(fov_rad / 2.0)
    valid_z = (z > near) & (z < far)
    valid_x = torch.abs(x / z) < tan_half_fov * aspect_ratio
    valid_y = torch.abs(y / z) < tan_half_fov
    in_frustum = valid_z & valid_x & valid_y

    #nera 1/4 gs
    z_values = z[in_frustum]
    if z_values.numel() == 0:
        return torch.empty(0, dtype=torch.long, device=device)
    z_range = z_values.max() - z_values.min()
    near_threshold = z_values.min() + z_range * 0.5
    near_half_mask = in_frustum & (z < near_threshold)
    all_indices = torch.arange(gs_xyz.shape[0], device=device)

    return all_indices[near_half_mask]

def generate_external_lookat_camera(gs_xyz, k=1, fov_deg=60.0, aspect_ratio=1.0, near=0.1, far=10, distance_scale=2.5):
    device = gs_xyz.device 

    radius = 3 * k  

    rand_dir = F.normalize(torch.randn(3, device=device), dim=0)
    cam_pos = rand_dir * radius  

    # o - cam_pos
    forward = -F.normalize( -cam_pos, dim=0)

    up_guess = torch.tensor([0., 1., 0.], device=device)
    if torch.abs(torch.dot(forward, up_guess)) > 0.99:
        up_guess = torch.tensor([1., 0., 0.], device=device)
    right = F.normalize(torch.cross(up_guess, forward, dim=0), dim=0)
    up = torch.cross(forward, right, dim=0)
    R = torch.stack([right, up, -forward], dim=1)

    return {
        "cam_pos": cam_pos,
        "cam_R": R,
        "fov_deg": fov_deg,
        "near": near,
        "far": far,
        "aspect_ratio": aspect_ratio
    }

def select_choice_frustum_points(gs_xyz, cam_info):

    device = gs_xyz.device

    # load camera parameters
    cam_pos: torch.Tensor = cam_info["cam_pos"]
    R: torch.Tensor       = cam_info["cam_R"]
    fov_deg: float        = cam_info["fov_deg"]
    near: float           = cam_info["near"]
    far: float            = cam_info["far"]
    aspect_ratio: float   = cam_info["aspect_ratio"]

    # world → camera coordinates
    points_cam = (R.matmul((gs_xyz - cam_pos).t())).t()
    x, y, z = points_cam[:, 0], points_cam[:, 1], points_cam[:, 2]

    # view frustum filtering
    fov_rad = math.radians(fov_deg)
    tan_half_fov = math.tan(fov_rad / 2.0)
    valid_z = (z > near) & (z < far)
    valid_x = torch.abs(x / z) < tan_half_fov * aspect_ratio
    valid_y = torch.abs(y / z) < tan_half_fov
    in_frustum = valid_z & valid_x & valid_y

    # no points in frustum → return empty index
    if not in_frustum.any():
        return torch.empty(0, dtype=torch.long, device=device)

    # split into near half
    z_values = z[in_frustum]
    z_range = z_values.max() - z_values.min()
    near_threshold = z_values.min() + z_range * 0.5
    near_half_mask = in_frustum & (z < near_threshold)

    all_indices = torch.arange(gs_xyz.shape[0], device=device)
    near_half_indices = all_indices[near_half_mask]
    frustum_indices = all_indices[in_frustum]
    
    return near_half_indices
