from functools import partial
import numpy as np
import torch
import torch.nn as nn
from typing import Set
import spconv
from torch.cuda.amp import autocast

if float(spconv.__version__[2:]) >= 2.2:
    spconv.constants.SPCONV_USE_DIRECT_TABLE = False

try:
    import spconv.pytorch as spconv
except:
    import spconv as spconv

try:
    import torch_scatter
except Exception as e:
    # Incase someone doesn't want to use dynamic pillar vfe and hasn't installed torch_scatter
    pass

# from navsim.agents.transfuser.vis_utils import vis_tensor, vis_bev_map


def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
    """
    Finds all spconv keys that need to have weight's transposed
    """
    found_keys: Set[str] = set()
    for name, child in model.named_children():
        new_prefix = f"{prefix}.{name}" if prefix != "" else name

        if isinstance(child, spconv.conv.SparseConvolution):
            new_prefix = f"{new_prefix}.weight"
            found_keys.add(new_prefix)

        found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))

    return found_keys


def replace_feature(out, new_features):
    if "replace_feature" in out.__dir__():
        # spconv 2.x behaviour
        return out.replace_feature(new_features)
    else:
        out.features = new_features
        return out


def post_act_block(
    in_channels,
    out_channels,
    kernel_size,
    indice_key=None,
    stride=1,
    padding=0,
    conv_type="subm",
    norm_fn=None,
):

    if conv_type == "subm":
        conv = spconv.SubMConv3d(
            in_channels, out_channels, kernel_size, bias=False, indice_key=indice_key
        )
    elif conv_type == "spconv":
        conv = spconv.SparseConv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=False,
            indice_key=indice_key,
        )
    elif conv_type == "inverseconv":
        conv = spconv.SparseInverseConv3d(
            in_channels, out_channels, kernel_size, indice_key=indice_key, bias=False
        )
    else:
        raise NotImplementedError

    m = spconv.SparseSequential(
        conv,
        norm_fn(out_channels),
        nn.ReLU(),
    )

    return m


class SparseBasicBlock(spconv.SparseModule):
    expansion = 1

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        bias=None,
        norm_fn=None,
        downsample=None,
        indice_key=None,
    ):
        super(SparseBasicBlock, self).__init__()

        assert norm_fn is not None
        if bias is None:
            bias = norm_fn is not None
        self.conv1 = spconv.SubMConv3d(
            inplanes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=bias,
            indice_key=indice_key,
        )
        self.bn1 = norm_fn(planes)
        self.relu = nn.ReLU()
        self.conv2 = spconv.SubMConv3d(
            planes,
            planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=bias,
            indice_key=indice_key,
        )
        self.bn2 = norm_fn(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = replace_feature(out, self.bn1(out.features))
        out = replace_feature(out, self.relu(out.features))

        out = self.conv2(out)
        out = replace_feature(out, self.bn2(out.features))

        if self.downsample is not None:
            identity = self.downsample(x)

        out = replace_feature(out, out.features + identity.features)
        out = replace_feature(out, self.relu(out.features))

        return out


class FeatureInfo:
    def __init__(self):
        self.info = []

    def add_info(self, name, num_chs, reduction, index):
        self.info.append(
            {"module": name, "num_chs": num_chs, "reduction": reduction, "index": index}
        )  # 'num_chs': 128, 'reduction': 8, 'module': 'layer2', 'index': 2


class VoxelBackBone8x(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        grid_size = (
            model_cfg.point_cloud_range[3:6] - model_cfg.point_cloud_range[0:3]
        ) / model_cfg.voxel_size
        self.model_cfg = model_cfg
        self.point_preprocessor = DynamicVoxelVFE(model_cfg, grid_size)

        self.sparse_shape = grid_size[::-1]

        layer_nums = [1, 2, 3, 5, 5]
        layer_strides = [1, 1, 2, 2, 2]
        num_filters = [64, 64, 128, 128, 256]

        num_levels = len(layer_nums)
        c_in_list = [self.point_preprocessor.num_filters[-1], *num_filters[:-1]]
        self.blocks = nn.ModuleDict()
        self.return_layers = {}
        self.feature_info = FeatureInfo()
        for idx in range(num_levels):
            cur_layers = [
                nn.ZeroPad2d(1),
                nn.Conv2d(
                    c_in_list[idx],
                    num_filters[idx],
                    kernel_size=3,
                    stride=layer_strides[idx],
                    padding=0,
                    bias=False,
                ),
                nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
                nn.ReLU(),
            ]
            for k in range(layer_nums[idx]):
                cur_layers.extend(
                    [
                        nn.Conv2d(
                            num_filters[idx],
                            num_filters[idx],
                            kernel_size=3,
                            padding=1,
                            bias=False,
                        ),
                        nn.BatchNorm2d(num_filters[idx], eps=1e-3, momentum=0.01),
                        nn.ReLU(),
                    ]
                )
            self.blocks.update({f"layer{idx}": nn.Sequential(*cur_layers)})
            self.return_layers.update({f"layer{idx}": idx})
            self.feature_info.add_info(
                f"layer{idx}",
                num_filters[idx],
                reduction=layer_strides[idx],
                index=idx + 1,
            )

    def preprocess(self, points, batch_size):
        voxel_features, voxel_coords = self.point_preprocessor(points)

        bev_features = spconv.SparseConvTensor(
            features=voxel_features,
            indices=voxel_coords.int(),
            spatial_shape=self.sparse_shape,
            batch_size=batch_size,
        ).dense()
        B, C, H, W, L = bev_features.size()
        bev_features = bev_features.view(B, C * H, W, L)

        return bev_features


class PointPreprocess(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        grid_size = (
            model_cfg.point_cloud_range[3:6] - model_cfg.point_cloud_range[0:3]
        ) / model_cfg.voxel_size
        self.model_cfg = model_cfg
        self.point_preprocessor = DynamicVoxelVFE(model_cfg, grid_size)
        self.sparse_shape = grid_size[::-1]

    def forward(self, points, batch_size):
        # new_points = torch.cat(
        #     [points[..., [0]], points[..., [2]], points[..., [1]], points[..., [3]]],
        #     dim=-1,
        # )
        voxel_features, voxel_coords = self.point_preprocessor(points)

        bev_features = spconv.SparseConvTensor(
            features=voxel_features,
            indices=voxel_coords.int(),
            spatial_shape=self.sparse_shape,
            batch_size=batch_size,
        ).dense()
        B, C, H, L, W = bev_features.size()
        bev_features = bev_features.view(B, C * H, L, W)

        # vis_bev_features = bev_features.clone().sum(dim=1).detach().cpu().numpy()
        # vis_bev_targets = bev_label.clone().cpu().numpy()
        # vis_bev_ele = bev_ele.clone().cpu().numpy()
        # for i in range(B):
        #     vis_bev_map(vis_bev_features[i], vis_bev_targets[i], vis_bev_ele[i])

        return bev_features


class VoxelResBackBone8x(nn.Module):
    def __init__(self, model_cfg, **kwargs):
        super().__init__()
        use_bias = kwargs.get("use_bias", None)
        norm_fn = partial(nn.BatchNorm1d, eps=1e-3, momentum=0.01)
        grid_size = (
            model_cfg.point_cloud_range[3:6] - model_cfg.point_cloud_range[0:3]
        ) / model_cfg.voxel_size

        self.point_preprocessor = DynamicVoxelVFE(model_cfg, grid_size)

        self.sparse_shape = grid_size[::-1] + [1, 0, 0]

        self.conv_input = spconv.SparseSequential(
            spconv.SubMConv3d(16, 16, 3, padding=1, bias=False, indice_key="subm1"),
            norm_fn(16),
            nn.ReLU(),
        )
        block = post_act_block

        self.conv1 = spconv.SparseSequential(
            SparseBasicBlock(16, 16, bias=use_bias, norm_fn=norm_fn, indice_key="res1"),
            SparseBasicBlock(16, 16, bias=use_bias, norm_fn=norm_fn, indice_key="res1"),
        )

        self.conv2 = spconv.SparseSequential(
            # [1600, 1408, 41] <- [800, 704, 21]
            block(
                16,
                32,
                3,
                norm_fn=norm_fn,
                stride=2,
                padding=1,
                indice_key="spconv2",
                conv_type="spconv",
            ),
            SparseBasicBlock(32, 32, bias=use_bias, norm_fn=norm_fn, indice_key="res2"),
            SparseBasicBlock(32, 32, bias=use_bias, norm_fn=norm_fn, indice_key="res2"),
        )

        self.conv3 = spconv.SparseSequential(
            # [800, 704, 21] <- [400, 352, 11]
            block(
                32,
                64,
                3,
                norm_fn=norm_fn,
                stride=2,
                padding=1,
                indice_key="spconv3",
                conv_type="spconv",
            ),
            SparseBasicBlock(64, 64, bias=use_bias, norm_fn=norm_fn, indice_key="res3"),
            SparseBasicBlock(64, 64, bias=use_bias, norm_fn=norm_fn, indice_key="res3"),
        )

        self.conv4 = spconv.SparseSequential(
            # [400, 352, 11] <- [200, 176, 5]
            block(
                64,
                128,
                3,
                norm_fn=norm_fn,
                stride=2,
                padding=(0, 1, 1),
                indice_key="spconv4",
                conv_type="spconv",
            ),
            SparseBasicBlock(
                128, 128, bias=use_bias, norm_fn=norm_fn, indice_key="res4"
            ),
            SparseBasicBlock(
                128, 128, bias=use_bias, norm_fn=norm_fn, indice_key="res4"
            ),
        )

        last_pad = 0
        last_pad = kwargs.get("last_pad", last_pad)
        self.conv_out = spconv.SparseSequential(
            # [200, 150, 5] -> [200, 150, 2]
            spconv.SparseConv3d(
                128,
                128,
                (3, 1, 1),
                stride=(2, 1, 1),
                padding=last_pad,
                bias=False,
                indice_key="spconv_down2",
            ),
            norm_fn(128),
            nn.ReLU(),
        )
        self.num_bev_features = 256

    def forward(self, points, batch_size):
        """
        Args:
            batch_dict:
                batch_size: int
                vfe_features: (num_voxels, C)
                voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx]
        Returns:
            batch_dict:
                encoded_spconv_tensor: sparse tensor
        """

        if self.training == False:
            with autocast(enabled=True, dtype=torch.float32):
                voxel_features, voxel_coords = self.point_preprocessor(points)

                input_sp_tensor = spconv.SparseConvTensor(
                    features=voxel_features.float(),
                    indices=voxel_coords.int(),
                    spatial_shape=self.sparse_shape,
                    batch_size=batch_size,
                )
                x = self.conv_input(input_sp_tensor)

                x_conv1 = self.conv1(x)
                x_conv2 = self.conv2(x_conv1)
                x_conv3 = self.conv3(x_conv2)
                x_conv4 = self.conv4(x_conv3)

                # for detection head
                # [200, 176, 5] -> [200, 176, 2]
                sparse_out = self.conv_out(x_conv4)

                bev_features = sparse_out.dense()
                N, C, D, H, W = bev_features.shape
                bev_features = bev_features.view(N, C * D, H, W)
        else:
            voxel_features, voxel_coords = self.point_preprocessor(points)

            input_sp_tensor = spconv.SparseConvTensor(
                features=voxel_features,
                indices=voxel_coords.int(),
                spatial_shape=self.sparse_shape,
                batch_size=batch_size,
            )
            x = self.conv_input(input_sp_tensor)

            x_conv1 = self.conv1(x)
            x_conv2 = self.conv2(x_conv1)
            x_conv3 = self.conv3(x_conv2)
            x_conv4 = self.conv4(x_conv3)

            # for detection head
            # [200, 176, 5] -> [200, 176, 2]
            sparse_out = self.conv_out(x_conv4)

            bev_features = sparse_out.dense()
            N, C, D, H, W = bev_features.shape
            bev_features = bev_features.view(N, C * D, H, W)

        return bev_features


class PFNLayerV2(nn.Module):
    def __init__(self, in_channels, out_channels, use_norm=True, last_layer=False):
        super().__init__()

        self.last_vfe = last_layer
        self.use_norm = use_norm
        if not self.last_vfe:
            out_channels = out_channels // 2

        if self.use_norm:
            self.linear = nn.Linear(in_channels, out_channels, bias=False)
            self.norm = nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01)
        else:
            self.linear = nn.Linear(in_channels, out_channels, bias=True)

        self.relu = nn.ReLU()

    def forward(self, inputs, unq_inv):

        x = self.linear(inputs)
        x = self.norm(x) if self.use_norm else x
        x = self.relu(x)
        x_max = torch_scatter.scatter_max(x, unq_inv, dim=0)[0]

        if self.last_vfe:
            return x_max
        else:
            x_concatenated = torch.cat([x, x_max[unq_inv, :]], dim=1)
            return x_concatenated


class DynamicVoxelVFE(nn.Module):
    def __init__(
        self,
        model_cfg,
        grid_size,
        use_norm=True,
        with_distance=False,
        use_absolute_xyz=False,
        **kwargs,
    ):
        super().__init__()
        self.model_cfg = model_cfg
        self.use_norm = use_norm
        self.with_distance = with_distance
        self.use_absolute_xyz = use_absolute_xyz
        num_point_features = self.model_cfg.num_point_features
        voxel_size = self.model_cfg.voxel_size
        point_cloud_range = self.model_cfg.point_cloud_range

        num_point_features += 6 if self.use_absolute_xyz else 3
        if self.with_distance:
            num_point_features += 1

        self.num_filters = self.model_cfg.num_filters
        assert len(self.num_filters) > 0
        num_filters = [num_point_features] + list(self.num_filters)

        pfn_layers = []
        for i in range(len(num_filters) - 1):
            in_filters = num_filters[i]
            out_filters = num_filters[i + 1]
            pfn_layers.append(
                PFNLayerV2(
                    in_filters,
                    out_filters,
                    self.use_norm,
                    last_layer=(i >= len(num_filters) - 2),
                )
            )
        self.pfn_layers = nn.ModuleList(pfn_layers)

        self.voxel_x = voxel_size[0]
        self.voxel_y = voxel_size[1]
        self.voxel_z = voxel_size[2]
        self.x_offset = self.voxel_x / 2 + point_cloud_range[0]
        self.y_offset = self.voxel_y / 2 + point_cloud_range[1]
        self.z_offset = self.voxel_z / 2 + point_cloud_range[2]

        self.scale_xyz = grid_size[0] * grid_size[1] * grid_size[2]
        self.scale_yz = grid_size[1] * grid_size[2]
        self.scale_z = grid_size[2]

        self.grid_size = torch.tensor(grid_size)
        self.voxel_size = torch.tensor(voxel_size)
        self.point_cloud_range = torch.tensor(point_cloud_range)

    def get_output_feature_dim(self):
        return self.num_filters[-1]

    def forward(self, points, **kwargs):
        # input size (batch_idx, x, y, z)
        self.grid_size = self.grid_size.type_as(points)
        self.voxel_size = self.voxel_size.type_as(points)
        self.point_cloud_range = self.point_cloud_range.type_as(points)

        points_coords = torch.floor(
            (points[:, [1, 2, 3]] - self.point_cloud_range[[0, 1, 2]])
            / self.voxel_size[[0, 1, 2]]
        ).int()
        mask = ((points_coords >= 0) & (points_coords < self.grid_size[[0, 1, 2]])).all(
            dim=1
        )
        points = points[mask]
        points_coords = points_coords[mask]
        points_xyz = points[:, [1, 2, 3]].contiguous()

        merge_coords = (
            points[:, 0].int() * self.scale_xyz
            + points_coords[:, 0] * self.scale_yz
            + points_coords[:, 1] * self.scale_z
            + points_coords[:, 2]
        )

        unq_coords, unq_inv, unq_cnt = torch.unique(
            merge_coords, return_inverse=True, return_counts=True, dim=0
        )

        points_mean = torch_scatter.scatter_mean(points_xyz, unq_inv, dim=0)
        f_cluster = points_xyz - points_mean[unq_inv, :]

        f_center = torch.zeros_like(points_xyz)
        f_center[:, 0] = points_xyz[:, 0] - (
            points_coords[:, 0].to(points_xyz.dtype) * self.voxel_x + self.x_offset
        )
        f_center[:, 1] = points_xyz[:, 1] - (
            points_coords[:, 1].to(points_xyz.dtype) * self.voxel_y + self.y_offset
        )
        # f_center[:, 2] = points_xyz[:, 2] - self.z_offset
        f_center[:, 2] = points_xyz[:, 2] - (
            points_coords[:, 2].to(points_xyz.dtype) * self.voxel_z + self.z_offset
        )

        if self.use_absolute_xyz:
            features = [points[:, 1:], f_cluster, f_center]
        else:
            features = [points[:, 4:], f_cluster, f_center]

        if self.with_distance:
            points_dist = torch.norm(points[:, 1:4], 2, dim=1, keepdim=True)
            features.append(points_dist)
        features = torch.cat(features, dim=-1)

        for pfn in self.pfn_layers:
            features = pfn(features, unq_inv)

        # generate voxel coordinates
        unq_coords = unq_coords.int()
        voxel_coords = torch.stack(
            (
                unq_coords // self.scale_xyz,
                (unq_coords % self.scale_xyz) // self.scale_yz,
                (unq_coords % self.scale_yz) // self.scale_z,
                unq_coords % self.scale_z,
            ),
            dim=1,
        )
        voxel_coords = voxel_coords[:, [0, 3, 2, 1]]

        return features, voxel_coords
