import itertools
import logging as log
from typing import Optional, Union, List, Dict, Sequence, Iterable, Collection, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from loguru import logger

def get_normalized_directions(directions):
    """SH encoding must be in the range [0, 1]

    Args:
        directions: batch of directions
    """
    return (directions + 1.0) / 2.0


def normalize_aabb(pts, aabb):
    return (pts - aabb[0]) * (2.0 / (aabb[1] - aabb[0])) - 1.0

def grid_sample_wrapper(grid: torch.Tensor, coords: torch.Tensor, align_corners: bool = True) -> torch.Tensor:
    grid_dim = coords.shape[-1]

    if grid.dim() == grid_dim + 1:
        # no batch dimension present, need to add it
        grid = grid.unsqueeze(0)
    if coords.dim() == 2:
        coords = coords.unsqueeze(0)

    if grid_dim == 2 or grid_dim == 3:
        grid_sampler = F.grid_sample
    else:
        raise NotImplementedError(f"Grid-sample was called with {grid_dim}D data but is only "
                                  f"implemented for 2 and 3D data.")

    coords = coords.view([coords.shape[0]] + [1] * (grid_dim - 1) + list(coords.shape[1:])) # B, 1, N, 2
    B, feature_dim = grid.shape[:2]
    n = coords.shape[-2]
    #! F.grid_sampler 坐标映射逻辑
    #* grid = [B, feature_dim, *res] = [1, 16, 301, 50] & align_corners == True
    #* H = 301: -1 -> 0, 1 -> 300; W = 50: -1 -> 0, 1 -> 49
    #* 坐标映射公式：index = (normalized_coord + 1) * (size - 1) / 2
    #* ts=150, t=150/300=0.5, normalized_t=0.5*2-1=0, index=(0+1)*(301-1)/2=150 ✅
    interp = grid_sampler(
        grid,  # [B, feature_dim, reso, ...]
        coords,  # [B, 1, ..., n, grid_dim]
        align_corners=align_corners,
        mode='bilinear', padding_mode='border')
    interp = interp.view(B, feature_dim, n).transpose(-1, -2)  # [B, n, feature_dim]
    interp = interp.squeeze()  # [B?, n, feature_dim?]

    if torch.isnan(interp).any():
        logger.warning(f"interp feature {interp.shape} has {torch.isnan(interp).any(dim=1).sum()} nan")
    return interp

def init_grid_param(
        grid_nd: int,
        in_dim: int,
        out_dim: int,
        reso: Sequence[int],
        a: float = 0.1,
        b: float = 0.5):
    assert in_dim == len(reso), "Resolution must have same number of elements as input-dimension"
    # has_time_planes = in_dim == 4
    has_time_planes = in_dim >= 4
    assert grid_nd <= in_dim
    coo_combs = list(itertools.combinations(range(in_dim), grid_nd))
    if in_dim == 5:
        coo_combs = coo_combs[:-1]
    grid_coefs = nn.ParameterList()
    for ci, coo_comb in enumerate(coo_combs):
        new_grid_coef = nn.Parameter(torch.empty(
            [1, out_dim] + [reso[cc] for cc in coo_comb[::-1]]
        ))
        if has_time_planes and (3 in coo_comb or 4 in coo_comb):  # Initialize time planes to 1，time planes 初始化为 1 为了放大动态维度的响应
            nn.init.ones_(new_grid_coef)
        else:
            nn.init.uniform_(new_grid_coef, a=a, b=b)
        # nn.init.uniform_(new_grid_coef, a=-0.01, b=0.01)

        grid_coefs.append(new_grid_coef)

    if in_dim == 5:
        assert len(grid_coefs) == 9

    return grid_coefs


def interpolate_ms_features(pts: torch.Tensor,
                            ms_grids: Collection[Iterable[nn.Module]],
                            grid_dimensions: int,
                            concat_features: bool,
                            num_levels: Optional[int],
                            ) -> torch.Tensor:
    coo_combs = list(itertools.combinations(
        range(pts.shape[-1]), grid_dimensions)
    )
    if num_levels is None:
        num_levels = len(ms_grids)
    multi_scale_interp = [] if concat_features else 0.
    grid: nn.ParameterList
    for scale_id,  grid in enumerate(ms_grids[:num_levels]):
        interp_space = 1.
        # for ci, coo_comb in enumerate(coo_combs):
        for ci in range(len(grid)):
            coo_comb = coo_combs[ci]
            # interpolate in plane
            feature_dim = grid[ci].shape[1]  # shape of grid[ci]: 1, out_dim, *reso
            # print("grid shape and coord combination:", grid[ci].shape, coo_comb)
            interp_out_plane = (
                grid_sample_wrapper(grid[ci], pts[..., coo_comb])
                .view(-1, feature_dim)
            )
            # compute product over planes
            interp_space = interp_space * interp_out_plane

        # combine over scales
        if concat_features:
            multi_scale_interp.append(interp_space)
        else:
            multi_scale_interp = multi_scale_interp + interp_space

    if concat_features:
        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
    return multi_scale_interp

#! Concatenate t1t2, instead of Product
def interpolate_ms_features_v2(pts: torch.Tensor,
                            ms_grids: Collection[Iterable[nn.Module]],
                            grid_dimensions: int,
                            concat_features: bool,
                            num_levels: Optional[int],
                            ) -> torch.Tensor:
    coo_combs = list(itertools.combinations(
        range(pts.shape[-1]), grid_dimensions)
    )
    if num_levels is None:
        num_levels = len(ms_grids)
    multi_scale_interp = [] if concat_features else 0.
    grid: nn.ParameterList
    for scale_id,  grid in enumerate(ms_grids[:num_levels]):
        interp_space_t1 = 1.
        interp_space_t2 = 1.
        spatiotemporal_grids_t1 = [2,5,7]
        spatiotemporal_grids_t2 = [3,6,8]
        spatio_grids = [0,1,4]
        # for ci, coo_comb in enumerate(coo_combs):
        for ci in range(len(grid)):
            coo_comb = coo_combs[ci]
            # interpolate in plane
            feature_dim = grid[ci].shape[1]  # shape of grid[ci]: 1, out_dim, *reso
            # print("grid shape and coord combination:", grid[ci].shape, coo_comb)
            interp_out_plane = (
                grid_sample_wrapper(grid[ci], pts[..., coo_comb])
                .view(-1, feature_dim)
            )
            # compute product over planes
            if ci in spatio_grids:
                interp_space_t1 = interp_space_t1 * interp_out_plane
                interp_space_t2 = interp_space_t2 * interp_out_plane
            elif ci in spatiotemporal_grids_t1:
                interp_space_t1 = interp_space_t1 * interp_out_plane
            elif ci in spatiotemporal_grids_t2:
                interp_space_t2 = interp_space_t2 * interp_out_plane
        # combine over time axis
        interp_space = torch.cat([interp_space_t1, interp_space_t2], dim=-1)


        # combine over scales
        if concat_features:
            multi_scale_interp.append(interp_space)
        else:
            multi_scale_interp = multi_scale_interp + interp_space

    if concat_features:
        multi_scale_interp = torch.cat(multi_scale_interp, dim=-1)
    return multi_scale_interp


class HexPlaneField(nn.Module):
    def __init__(
        self,
        bounds,
        planeconfig,
        multires,
        concat_time_axis = False
    ) -> None:
        super().__init__()
        aabb = torch.tensor([[bounds,bounds,bounds],
                             [-bounds,-bounds,-bounds]])
        self.aabb = nn.Parameter(aabb, requires_grad=False)
        self.grid_config =  [planeconfig]
        self.multiscale_res_multipliers = multires
        self.concat_features = True
        self.concat_time_axis = concat_time_axis

        # 1. Init planes
        self.grids = nn.ModuleList()
        self.feat_dim = 0
        for res in self.multiscale_res_multipliers:
            # initialize coordinate grid
            config = self.grid_config[0].copy()
            # Resolution fix: multi-res only on spatial planes
            config["resolution"] = [
                r * res for r in config["resolution"][:3]
            ] + config["resolution"][3:]
            gp = init_grid_param(
                grid_nd=config["grid_dimensions"],
                in_dim=config["input_coordinate_dim"],
                out_dim=config["output_coordinate_dim"],
                reso=config["resolution"],
            )
            # shape[1] is out-dim - Concatenate over feature len for each scale
            if self.concat_features:
                self.feat_dim += gp[-1].shape[1]
            else:
                self.feat_dim = gp[-1].shape[1]

            self.grids.append(gp)
        #! Concatenate t1t2, instead of Product
        if self.concat_time_axis:
            self.feat_dim *= 2
        # print(f"Initialized model grids: {self.grids}")
        # print("feature_dim:",self.feat_dim)
    @property
    def get_aabb(self):
        return self.aabb[0], self.aabb[1]
    def set_aabb(self,xyz_max, xyz_min, margin_ratio = 0.1):
        # aabb = torch.tensor([
        #     xyz_max,
        #     xyz_min
        # ],dtype=torch.float32).cuda()

        center = (xyz_min + xyz_max) / 2
        half_range = (xyz_max - xyz_min) / 2
        margin = half_range * margin_ratio
        new_half_range = half_range + margin
        aabb = torch.tensor([
            center + new_half_range,
            center - new_half_range
        ],dtype=torch.float32).cuda()

        self.aabb = nn.Parameter(aabb,requires_grad=False)
        print(f"Voxel Plane: set aabb={self.aabb}, margin={margin}")

    def get_density(self, pts: torch.Tensor, timestamps: Optional[torch.Tensor] = None):
        """Computes and returns the densities.
        timestamps: [-1,1]
        """
        # breakpoint()
        pts = normalize_aabb(pts, self.aabb) # pts -> [-1,1]
        # if torch.isnan(pts).any() or torch.isinf(pts).any():
        #     print("coords has NaN or Inf!", pts.min(), pts.max())
        # print("coords stats:", pts.min().item(), pts.max().item())

        #* 检查坐标越界[-1,1] *#
        over = (torch.abs(pts) > 1.0).any(dim=1).sum().item()
        total = pts.shape[0]
        if over / total > 0.1:
            logger.warning(f"invalid points {over}/{total}, min_pts={pts.min().item()} max_pts={pts.max().item()}" )
        if torch.isnan(pts).any():
            logger.warning(f"pts {pts.shape} has {torch.isnan(pts).any(dim=1).sum()} nan, min_pts={pts.min().item()} max_pts={pts.max().item()}")
        if torch.isinf(pts).any():
            logger.warning(f"pts {pts.shape} has {torch.isinf(pts).any(dim=1).sum()} inf, min_pts={pts.min().item()} max_pts={pts.max().item()}")
        #* 检查时间坐标越界[-1,1] *#
        over = (torch.abs(timestamps) > 1.0).any(dim=1).sum().item()
        total = timestamps.shape[0]
        if over / total > 0.1:
            logger.warning(f"invalid timestamps {over}/{total}")
        if torch.isnan(timestamps).any():
            logger.warning(f"timestamps {timestamps.shape} has {torch.isnan(timestamps).any(dim=1).sum()} nan, min_pts={timestamps.min().item()} max_pts={timestamps.max().item()}")
        if torch.isinf(timestamps).any():
            logger.warning(f"timestamps {timestamps.shape} has {torch.isinf(timestamps).any(dim=1).sum()} inf, min_pts={timestamps.min().item()} max_pts={timestamps.max().item()}")
        #* 检查grid plane has nan *#
        for i, p in enumerate(self.grids[0]):
            if torch.isnan(p).any():
                logger.warning(f"grid plane #{i} has {torch.isnan(p).sum()} nan")

        pts = torch.clamp(pts, -1.0 + 1e-6, 1.0 - 1e-6)
        # timestamps = torch.clamp(timestamps, -1.0, 0.9999)
        pts = torch.cat((pts, timestamps), dim=-1)  # [n_rays, n_samples, 4]


        pts = pts.reshape(-1, pts.shape[-1])
        if not self.concat_time_axis:
            features = interpolate_ms_features(
                pts, ms_grids=self.grids,  # noqa
                grid_dimensions=self.grid_config[0]["grid_dimensions"],
                concat_features=self.concat_features, num_levels=None)
        else:
            features = interpolate_ms_features_v2(
                pts, ms_grids=self.grids,  # noqa
                grid_dimensions=self.grid_config[0]["grid_dimensions"],
                concat_features=self.concat_features, num_levels=None)

        if len(features) < 1:
            features = torch.zeros((0, 1)).to(features.device)

        if torch.isnan(features).any():
            logger.warning(f"interpolate_ms_features {features.shape} has {torch.isnan(features).sum()} nan")

        return features

    def forward(self,
                pts: torch.Tensor,
                timestamps: Optional[torch.Tensor] = None):

        features = self.get_density(pts, timestamps)

        return features
