from functools import partial
import numpy as np
import torch
import torch.nn as nn
from typing import Set
import spconv
from torch.cuda.amp import autocast

if float(spconv.__version__[2:]) >= 2.2:
    spconv.constants.SPCONV_USE_DIRECT_TABLE = False

try:
    import spconv.pytorch as spconv
except:
    import spconv as spconv

try:
    import torch_scatter
except Exception as e:
    # Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
    pass

from .vis_utils import vis_bev_map

# from navsim.agents.transfuser.vis_utils import vis_tensor, vis_bev_map


class PointPreprocess(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        pc_range = np.array(model_cfg.point_cloud_range)
        voxel_size = np.array(model_cfg.voxel_size)
        grid_size = (pc_range[3:6] - pc_range[0:3]) / voxel_size
        self.model_cfg = model_cfg
        self.point_preprocessor = DynamicVoxelVFE(model_cfg, grid_size)
        self.sparse_shape = grid_size[::-1]

    def forward(self, points, batch_size, bev_semantic_label):
        # new_points = torch.cat(
        #     [points[..., [0]], points[..., [2]], points[..., [1]], points[..., [3]]],
        #     dim=-1,
        # )
        voxel_features, voxel_coords = self.point_preprocessor(points)

        bev_features = spconv.SparseConvTensor(
            features=voxel_features,
            indices=voxel_coords.int(),
            spatial_shape=self.sparse_shape,
            batch_size=batch_size,
        ).dense()
        B, C, H, L, W = bev_features.size()
        bev_features = bev_features.view(B, C * H, L, W)

        # vis_bev_features = bev_features.clone().sum(dim=1).detach().cpu().numpy()
        # vis_bev_targets = bev_semantic_label.clone().cpu().numpy()
        # # vis_bev_ele = bev_ele.clone().cpu().numpy()
        # for i in range(B):
        #     vis_bev_map(vis_bev_features[i], vis_bev_targets[i])

        return bev_features


class PFNLayerV2(nn.Module):
    def __init__(self, in_channels, out_channels, use_norm=True, last_layer=False):
        super().__init__()

        self.last_vfe = last_layer
        self.use_norm = use_norm
        if not self.last_vfe:
            out_channels = out_channels // 2

        if self.use_norm:
            self.linear = nn.Linear(in_channels, out_channels, bias=False)
            self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
        else:
            self.linear = nn.Linear(in_channels, out_channels, bias=True)

        self.relu = nn.ReLU()

    def forward(self, inputs, unq_inv):

        x = self.linear(inputs)
        x = self.norm(x) if self.use_norm else x
        x = self.relu(x)
        x_max = torch_scatter.scatter_max(x, unq_inv, dim=0)[0]

        if self.last_vfe:
            return x_max
        else:
            x_concatenated = torch.cat([x, x_max[unq_inv, :]], dim=1)
            return x_concatenated


class DynamicVoxelVFE(nn.Module):
    def __init__(
        self,
        model_cfg,
        grid_size,
        use_norm=True,
        with_distance=False,
        use_absolute_xyz=False,
        **kwargs,
    ):
        super().__init__()
        self.model_cfg = model_cfg
        self.use_norm = use_norm
        self.with_distance = with_distance
        self.use_absolute_xyz = use_absolute_xyz
        num_point_features = self.model_cfg.num_point_features
        voxel_size = self.model_cfg.voxel_size
        point_cloud_range = self.model_cfg.point_cloud_range

        num_point_features += 6 if self.use_absolute_xyz else 3
        if self.with_distance:
            num_point_features += 1

        self.num_filters = self.model_cfg.num_filters
        assert len(self.num_filters) > 0
        num_filters = [num_point_features] + list(self.num_filters)

        pfn_layers = []
        for i in range(len(num_filters) - 1):
            in_filters = num_filters[i]
            out_filters = num_filters[i + 1]
            pfn_layers.append(
                PFNLayerV2(
                    in_filters,
                    out_filters,
                    self.use_norm,
                    last_layer=(i >= len(num_filters) - 2),
                )
            )
        self.pfn_layers = nn.ModuleList(pfn_layers)

        self.voxel_x = voxel_size[0]
        self.voxel_y = voxel_size[1]
        self.voxel_z = voxel_size[2]
        self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
        self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
        self.z_offset = self.voxel_z / 2 + point_cloud_range[2]

        self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
        self.scale_yz = grid_size[1] * grid_size[2]
        self.scale_z = grid_size[2]

        self.grid_size = torch.tensor(grid_size)
        self.voxel_size = torch.tensor(voxel_size)
        self.point_cloud_range = torch.tensor(point_cloud_range)

    def get_output_feature_dim(self):
        return self.num_filters[-1]

    def forward(self, points, **kwargs):
        # input size (batch_idx, x, y, z)
        self.grid_size = self.grid_size.type_as(points)
        self.voxel_size = self.voxel_size.type_as(points)
        self.point_cloud_range = self.point_cloud_range.type_as(points)

        points_coords = torch.floor(
            (points[:, [1, 2, 3]] - self.point_cloud_range[[0, 1, 2]])
            / self.voxel_size[[0, 1, 2]]
        ).int()
        mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0, 1, 2]])).all(
            dim=1
        )
        points = points[mask]
        points_coords = points_coords[mask]
        points_xyz = points[:, [1, 2, 3]].contiguous()

        merge_coords = (
            points[:, 0].int() * self.scale_xyz
            + points_coords[:, 0] * self.scale_yz
            + points_coords[:, 1] * self.scale_z
            + points_coords[:, 2]
        )

        unq_coords, unq_inv, unq_cnt = torch.unique(
            merge_coords, return_inverse=True, return_counts=True, dim=0
        )

        points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
        f_cluster = points_xyz - points_mean[unq_inv, :]

        f_center = torch.zeros_like(points_xyz)
        f_center[:, 0] = points_xyz[:, 0] - (
            points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset
        )
        f_center[:, 1] = points_xyz[:, 1] - (
            points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset
        )
        # f_center[:, 2] = points_xyz[:, 2] - self.z_offset
        f_center[:, 2] = points_xyz[:, 2] - (
            points_coords[:, 2].to(points_xyz.dtype) * self.voxel_z + self.z_offset
        )

        if self.use_absolute_xyz:
            features = [points[:, 1:], f_cluster, f_center]
        else:
            features = [points[:, 4:], f_cluster, f_center]

        if self.with_distance:
            points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
            features.append(points_dist)
        features = torch.cat(features, dim=-1)

        for pfn in self.pfn_layers:
            features = pfn(features, unq_inv)

        # generate voxel coordinates
        unq_coords = unq_coords.int()
        voxel_coords = torch.stack(
            (
                unq_coords // self.scale_xyz,
                (unq_coords % self.scale_xyz) // self.scale_yz,
                (unq_coords % self.scale_yz) // self.scale_z,
                unq_coords % self.scale_z,
            ),
            dim=1,
        )
        voxel_coords = voxel_coords[:, [0, 3, 2, 1]]

        return features, voxel_coords


class DynamicVoxelVFE_sim(nn.Module):
    def __init__(
        self,
        model_cfg,
        grid_size,
        max_hit=5,
        **kwargs,
    ):
        super().__init__()
        self.model_cfg = model_cfg
        self.max_hit = max_hit
        voxel_size = self.model_cfg.voxel_size
        point_cloud_range = self.model_cfg.point_cloud_range

        self.voxel_x = voxel_size[0]
        self.voxel_y = voxel_size[1]
        self.voxel_z = voxel_size[2]
        self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
        self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
        self.z_offset = self.voxel_z / 2 + point_cloud_range[2]

        self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
        self.scale_yz = grid_size[1] * grid_size[2]
        self.scale_z = grid_size[2]

        self.grid_size = torch.tensor(grid_size)
        self.voxel_size = torch.tensor(voxel_size)
        self.point_cloud_range = torch.tensor(point_cloud_range)

    def forward(self, points, **kwargs):
        # input size (batch_idx, x, y, z)
        self.grid_size = self.grid_size.type_as(points)
        self.voxel_size = self.voxel_size.type_as(points)
        self.point_cloud_range = self.point_cloud_range.type_as(points)

        points_coords = torch.floor(
            (points[:, [1, 2, 3]] - self.point_cloud_range[[0, 1, 2]])
            / self.voxel_size[[0, 1, 2]]
        ).int()
        mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0, 1, 2]])).all(
            dim=1
        )
        points = points[mask]
        points_coords = points_coords[mask]

        merge_coords = (
            points[:, 0].int() * self.scale_xyz
            + points_coords[:, 0] * self.scale_yz
            + points_coords[:, 1] * self.scale_z
            + points_coords[:, 2]
        )

        unq_coords, unq_inv, unq_cnt = torch.unique(
            merge_coords, return_inverse=True, return_counts=True, dim=0
        )

        unq_coords = unq_coords.int()
        voxel_coords = torch.stack(
            (
                unq_coords // self.scale_xyz,
                (unq_coords % self.scale_xyz) // self.scale_yz,
                (unq_coords % self.scale_yz) // self.scale_z,
                unq_coords % self.scale_z,
            ),
            dim=1,
        )
        voxel_coords = voxel_coords[:, [0, 3, 2, 1]]

        features = unq_cnt.float().clamp(max=self.max_hit) / self.max_hit
        return features.unsqueeze(1), voxel_coords
