from typing import Set

import spconv
if float(spconv.__version__[2:]) >= 2.2:
    spconv.constants.SPCONV_USE_DIRECT_TABLE = False
    
try:
    import spconv.pytorch as spconv
except:
    import spconv as spconv
import torch
import torch.nn as nn
from opencood.models.sub_modules.sst_ops import scatter_v2
from scipy.sparse.csgraph import connected_components
try:
    from torchex import connected_components as cc_gpu
except ImportError:
    cc_gpu = None


def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]:
    """
    Finds all spconv keys that need to have weight's transposed
    """
    found_keys: Set[str] = set()
    for name, child in model.named_children():
        new_prefix = f"{prefix}.{name}" if prefix != "" else name

        if isinstance(child, spconv.conv.SparseConvolution):
            new_prefix = f"{new_prefix}.weight"
            found_keys.add(new_prefix)

        found_keys.update(find_all_spconv_keys(child, prefix=new_prefix))

    return found_keys

def replace_feature(out, new_features):
    if "replace_feature" in out.__dir__():
        # spconv 2.x behaviour
        return out.replace_feature(new_features)
    else:
        out.features = new_features
        return out
    
def voxel2point(voxel_features, voxel_num_points):
    """
    Args:
        voxel_features (torch.Tensor[N_v, Max_p, 4]): sampled point features inside voxel
        voxel_num_points (torch.Tensor[N_v]): vaild point num (< Max_p) inside voxel
        
    Returns:
        N_p = sum(voxel_num_points)
        point_feats (torch.Tensor[N_p, 4]): point-wise origin feature
        voxel2point_inds (torch.Tensor[N_p]): the index of voxel that point belong to
    """
    
    # create a mask tensor for vaild point
    voxel_num = voxel_features.shape[0]
    max_points_num = voxel_features.shape[1]
    mask = torch.arange(max_points_num).unsqueeze(0).repeat(
        voxel_num, 1).to(voxel_features.device) < voxel_num_points.unsqueeze(1)
    mask = mask.view(-1)
    # get vaild point features
    point_feats = voxel_features.view(-1, voxel_features.shape[-1])[mask, :]    # [points_num, 4]
    
    # get voxel2point_inds
    # get the row indices of the valid elements
    row_indices = torch.arange(voxel_num).unsqueeze(1).repeat(1, max_points_num).view(-1)[mask]
    # compute the column indices of the valid elements
    col_indices = torch.arange(max_points_num).repeat(voxel_num)[mask]
    # point i in voxel index of voxel2point_inds[i]
    voxel2point_inds = row_indices * max_points_num + col_indices   # [points_num]
    voxel2point_inds = voxel2point_inds // max_points_num
    
    return point_feats, voxel2point_inds

def filter_almost_empty(coors, min_points):
    new_coors, unq_inv, unq_cnt = torch.unique(coors, return_inverse=True, return_counts=True, dim=0)
    cnt_per_point = unq_cnt[unq_inv]
    valid_mask = cnt_per_point >= min_points
    return valid_mask

def find_connected_componets_gpu(points, batch_idx, dist):

    assert len(points) > 0
    assert cc_gpu is not None
    components_inds = cc_gpu(points, batch_idx, dist, 100, 2, False)
    assert len(torch.unique(components_inds)) == components_inds.max().item() + 1
    return components_inds

def find_connected_componets(points, batch_idx, dist):

    device = points.device
    bsz = batch_idx.max().item() + 1
    base = 0
    components_inds = torch.zeros_like(batch_idx) - 1

    for i in range(bsz):
        batch_mask = batch_idx == i
        if batch_mask.any():
            this_points = points[batch_mask]
            dist_mat = this_points[:, None, :2] - this_points[None, :, :2] # only care about xy
            dist_mat = (dist_mat ** 2).sum(2) ** 0.5
            adj_mat = dist_mat < dist
            adj_mat = adj_mat.cpu().numpy()
            c_inds = connected_components(adj_mat, directed=False)[1]
            c_inds = torch.from_numpy(c_inds).to(device).int() + base
            base = c_inds.max().item() + 1
            components_inds[batch_mask] = c_inds

    assert len(torch.unique(components_inds)) == components_inds.max().item() + 1

    return components_inds

def find_connected_componets_single_batch(points, batch_idx, dist):

    device = points.device

    this_points = points
    dist_mat = this_points[:, None, :2] - this_points[None, :, :2] # only care about xy
    dist_mat = (dist_mat ** 2).sum(2) ** 0.5
    # dist_mat = torch.cdist(this_points[:, :2], this_points[:, :2], p=2)
    adj_mat = dist_mat < dist
    adj_mat = adj_mat.cpu().numpy()
    c_inds = connected_components(adj_mat, directed=False)[1]
    c_inds = torch.from_numpy(c_inds).to(device).int()
    return c_inds

class ClusterAssigner(nn.Module):
    ''' Generating cluster centers for each class and assign each point to cluster centers
    '''

    def __init__(
        self,
        cluster_voxel_size,
        min_points,
        point_cloud_range,
        connected_dist,
        gpu_clustering=(False, False),
    ):
        super().__init__()
        self.cluster_voxel_size = cluster_voxel_size
        self.min_points = min_points
        self.connected_dist = connected_dist
        self.point_cloud_range = point_cloud_range
        self.gpu_clustering = gpu_clustering

    @torch.no_grad()
    def forward(self, points, batch_idx):
        batch_idx = batch_idx.int()
        cluster_vsize = self.cluster_voxel_size
        voxel_size = torch.tensor(cluster_vsize, device=points.device)
        pc_range = torch.tensor(self.point_cloud_range, device=points.device)
        coors = torch.div(points[:, :3] - pc_range[None, :3], voxel_size[None, :], rounding_mode='floor').int()
        # coors = coors[:, [2, 1, 0]] # to zyx order
        coors = torch.cat([batch_idx[:, None], coors], dim=1)

        valid_mask = filter_almost_empty(coors, min_points=self.min_points)
        if not valid_mask.any():
            valid_mask = ~valid_mask
            # return coors.new_zeros((3,0)), valid_mask

        points = points[valid_mask]
        batch_idx = batch_idx[valid_mask]
        coors = coors[valid_mask]
        # elif len(points) 

        sampled_centers, voxel_coors, inv_inds = scatter_v2(points, coors, mode='avg', return_inv=True)
        dist = self.connected_dist
        if self.training:
            cluster_inds = find_connected_componets(sampled_centers, voxel_coors[:, 0], dist)
        else:
            if self.gpu_clustering[1]:
                cluster_inds = find_connected_componets_gpu(sampled_centers, voxel_coors[:, 0], dist)
            else:
                cluster_inds = find_connected_componets_single_batch(sampled_centers, voxel_coors[:, 0], dist)
        assert len(cluster_inds) == len(sampled_centers)

        cluster_inds_per_point = cluster_inds[inv_inds]
        cluster_inds_per_point = torch.stack([batch_idx, cluster_inds_per_point], 1)
        return cluster_inds_per_point, valid_mask