from typing import List, Tuple

import torch
import numpy as np
import spconv.pytorch as spconv
import absl.flags as flags
from epic_ops.ball_query import ball_query
from epic_ops.ccl import connected_components_labeling
from epic_ops.nms import nms
from epic_ops.reduce import segmented_reduce
from epic_ops.voxelize import voxelize

device = 'cuda'
FLAGS = flags.FLAGS
def segmented_voxelize(
    pt_xyz: torch.Tensor,
    pt_features: torch.Tensor,
    segment_offsets: torch.Tensor,
    segment_indices: torch.Tensor,
    num_points_per_segment: torch.Tensor,
    score_fullscale: float,
    score_scale: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    segment_offsets_begin = segment_offsets[:-1]
    segment_offsets_end = segment_offsets[1:]

    segment_coords_mean = segmented_reduce(
        pt_xyz, segment_offsets_begin, segment_offsets_end, mode="sum"
    ) / num_points_per_segment[:, None]

    centered_points = pt_xyz - segment_coords_mean[segment_indices]

    segment_coords_min = segmented_reduce(
        centered_points, segment_offsets_begin, segment_offsets_end, mode="min"
    )
    segment_coords_max = segmented_reduce(
        centered_points, segment_offsets_begin, segment_offsets_end, mode="max"
    )

    segment_scales = 1. / (
        (segment_coords_max - segment_coords_min) / score_fullscale
    ).max(-1)[0] - 0.01
    segment_scales = torch.clamp(segment_scales, min=None, max=score_scale)

    min_xyz = segment_coords_min * segment_scales[..., None]
    max_xyz = segment_coords_max * segment_scales[..., None]

    segment_scales = segment_scales[segment_indices]
    scaled_points = centered_points * segment_scales[..., None]

    range_xyz = max_xyz - min_xyz
    offsets = -min_xyz + torch.clamp(
        score_fullscale - range_xyz - 0.001, min=0
    ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device) + torch.clamp(
        score_fullscale - range_xyz + 0.001, max=0
    ) * torch.rand(3, dtype=min_xyz.dtype, device=min_xyz.device)
    scaled_points += offsets[segment_indices]
    score_fullscale = float(score_fullscale)
    voxel_features, voxel_coords, voxel_batch_indices, pc_voxel_id = voxelize(
        scaled_points,
        pt_features,
        batch_offsets=segment_offsets.long(),
        voxel_size=torch.as_tensor([1., 1., 1.], device=scaled_points.device),
        points_range_min=torch.as_tensor([0., 0., 0.], device=scaled_points.device),
        points_range_max=torch.as_tensor([score_fullscale, score_fullscale, score_fullscale], device=scaled_points.device),
        reduction="mean",
    )
    voxel_coords = torch.cat([voxel_batch_indices[:, None], voxel_coords], dim=1)

    return voxel_features, voxel_coords, pc_voxel_id

def apply_voxelization(pc,voxel_size):
    num_points = pc.shape[0]
    pt_xyz = pc[:, :3]
    points_range_min = pt_xyz.min(0)[0] - 1e-4
    points_range_max = pt_xyz.max(0)[0] + 1e-4
    voxel_features, voxel_coords, _, pc_voxel_id = voxelize(
        pt_xyz, pc,
        batch_offsets=torch.as_tensor([0, num_points], dtype=torch.int64, device = pt_xyz.device),
        voxel_size=torch.as_tensor(voxel_size, device = pt_xyz.device),
        points_range_min=torch.as_tensor(points_range_min, device = pt_xyz.device),
        points_range_max=torch.as_tensor(points_range_max, device = pt_xyz.device),
        reduction="mean",
    )
    assert (pc_voxel_id >= 0).all()

    voxel_coords_range = (voxel_coords.max(0)[0] + 1).clamp(min=128, max=None)

    voxel_features = voxel_features
    voxel_coords = voxel_coords
    voxel_coords_range = voxel_coords_range.tolist()
    pc_voxel_id = pc_voxel_id

    return torch.tensor(voxel_features),torch.tensor(voxel_coords),torch.tensor(voxel_coords_range),torch.tensor(pc_voxel_id)

def get_voxel_feat_per_batch(feat_pts_batch,voxel_size,num_points_per_batch=None,batch_offsets=None):
    #batch_size = len(num_points_per_batch)
    batch_size = len(feat_pts_batch)
    voxel_feat_batch = None
    voxel_coords_batch = None
    voxel_coords_range_batch = None
    voxel_batch_indices = None
    pc_voxel_id_list = []
    feat_pts_list = []
    num_voxel_offset = 0
    for i in range(batch_size):
        #feat_pts = feat_pts_batch[batch_offsets[i]:(batch_offsets[i+1])]
        feat_pts = feat_pts_batch[i]
        voxel_feat,voxel_coords,voxel_coords_range,pc_voxel_id = apply_voxelization(feat_pts,voxel_size)
        voxel_indices = torch.full((voxel_coords.shape[0],),i,dtype=torch.int32, device=torch.device('cuda'))
        if(voxel_feat_batch is None):
            voxel_feat_batch = voxel_feat
            voxel_coords_batch = voxel_coords
            voxel_coords_range_batch = voxel_coords_range
            voxel_batch_indices = voxel_indices
        else:
            voxel_feat_batch = torch.cat([voxel_feat_batch,voxel_feat],dim=0)
            voxel_coords_batch = torch.cat([voxel_coords_batch,voxel_coords],dim=0)
            voxel_coords_range_batch = torch.cat([voxel_coords_range_batch,voxel_coords_range],dim=0)
            voxel_batch_indices = torch.cat([voxel_batch_indices,voxel_indices],dim=0)
        #pc_voxel_id[pc_voxel_id >= 0] += num_voxel_offset
        pc_voxel_id += num_voxel_offset
        pc_voxel_id_list.append(pc_voxel_id)
        feat_pts_list.append(feat_pts)
        num_voxel_offset += voxel_coords.shape[0]
    voxel_coords_range_batch = voxel_coords_range_batch.reshape(-1,3)
    voxel_coords_batch = torch.cat([
        voxel_batch_indices[:, None], voxel_coords_batch
    ], dim=-1)
    voxel_coords_range_batch = np.max(np.array(voxel_coords_range_batch),axis=0)
    pc_voxel_id_batch = torch.cat(pc_voxel_id_list, dim=0)
    voxel_tensor = spconv.SparseConvTensor(
        voxel_feat_batch, voxel_coords_batch,
        spatial_shape=voxel_coords_range_batch.tolist(),
        batch_size=batch_size,
    )
    return voxel_tensor,pc_voxel_id_batch

def cluster_proposals(
    pt_xyz: torch.Tensor,
    batch_indices: torch.Tensor,
    batch_offsets: torch.Tensor,
    sem_preds: torch.Tensor,
    ball_query_radius: float,
    max_num_points_per_query: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    device = pt_xyz.device
    index_dtype = batch_indices.dtype

    clustered_indices, num_points_per_query = ball_query(
        pt_xyz,
        pt_xyz,
        batch_indices,
        batch_offsets,
        ball_query_radius,
        max_num_points_per_query,
        point_labels=sem_preds,
        query_labels=sem_preds,
    )

    ccl_indices_begin = torch.arange(
        pt_xyz.shape[0], dtype=index_dtype, device=device
    ) * max_num_points_per_query
    ccl_indices_end = ccl_indices_begin + num_points_per_query
    ccl_indices = torch.stack([ccl_indices_begin, ccl_indices_end], dim=1)
    cc_labels = connected_components_labeling(
        ccl_indices.view(-1), clustered_indices.view(-1), compacted=False
    )

    sorted_cc_labels, sorted_indices = torch.sort(cc_labels)
    return sorted_cc_labels, sorted_indices

def do_clusting(pt_xyz,batch_indices,sem_preds,offset_preds):
    _, batch_indices_compact, num_points_per_batch = torch.unique_consecutive(
        batch_indices, return_inverse=True, return_counts=True
    )
    batch_indices_compact = batch_indices_compact.int()
    batch_offsets = torch.zeros(
        (num_points_per_batch.shape[0] + 1,), dtype=torch.int32, device=device
    )
    batch_offsets[1:] = num_points_per_batch.cumsum(0)
    sorted_cc_labels, sorted_indices = cluster_proposals(
        pt_xyz, batch_indices_compact, batch_offsets, sem_preds,
        FLAGS.ball_query_radius, FLAGS.max_num_points_per_query,
    )
    sorted_cc_labels_shift, sorted_indices_shift = cluster_proposals(
        pt_xyz + offset_preds, batch_indices_compact, batch_offsets, sem_preds,
        FLAGS.ball_query_radius, FLAGS.max_num_points_per_query_shift,
    )
    sorted_cc_labels = torch.cat([
        sorted_cc_labels,
        sorted_cc_labels_shift + sorted_cc_labels.shape[0],
    ], dim=0)
    sorted_indices = torch.cat([sorted_indices, sorted_indices_shift], dim=0)
    # compact the proposal ids
    _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
        sorted_cc_labels, return_inverse=True, return_counts=True
    )

    # remove small proposals
    valid_proposal_mask = (
            num_points_per_proposal >= FLAGS.min_num_points_per_proposal
    )
    # proposal to point
    valid_point_mask = valid_proposal_mask[proposal_indices]

    sorted_indices = sorted_indices[valid_point_mask]

    # re-compact the proposal ids
    proposal_indices = proposal_indices[valid_point_mask]
    _, proposal_indices, num_points_per_proposal = torch.unique_consecutive(
        proposal_indices, return_inverse=True, return_counts=True
    )
    num_proposals = num_points_per_proposal.shape[0]

    # get proposal batch offsets
    proposal_offsets = torch.zeros(
        num_proposals + 1, dtype=torch.int32, device=device
    )
    proposal_offsets[1:] = num_points_per_proposal.cumsum(0)
    return sorted_indices,proposal_offsets, proposal_indices,num_points_per_proposal

#def get_voxel_info_per_batch():
