from torch import nn
import torch
from opencood.utils.sparse_utils import scatter_v2

class Voxel2PointScatterNeck(nn.Module):
    """
    A memory-efficient voxel2point with torch_scatter
    """

    def __init__(
        self,
        point_cloud_range=None,
        voxel_size=None,
        with_xyz=True,
        normalize_local_xyz=False,
        ):
        super().__init__()
        self.point_cloud_range = point_cloud_range
        self.voxel_size = voxel_size
        self.with_xyz = with_xyz
        self.normalize_local_xyz = normalize_local_xyz

    def forward(self, batch_dict):
        """Forward function.

        Args:
            points (torch.Tensor): of shape (N, C_point).
            pts_coors (torch.Tensor): of shape (N, 4).
            voxel_feats (torch.Tensor): of shape (M, C_feature), should be padded and reordered.
            voxel2point_inds: (N,)

        Returns:
            torch.Tensor: of shape (N, C_feature+C_point).
        """
        
        points = batch_dict['points']
        pts_coors = batch_dict['point_coords']
        voxel_feats = batch_dict['voxel_features'].features
        voxel2point_inds = batch_dict['voxel2point_inds']
        
        assert points.size(0) == pts_coors.size(0) == voxel2point_inds.size(-1)
        dtype = voxel_feats.dtype
        device = voxel_feats.device
        pts_feats = voxel_feats[voxel2point_inds] # voxel_feats must be the output of torch_scatter, voxel2point_inds is the input of torch_scatter
        pts_mask = ~((pts_feats == -1).all(1)) # some dropped voxels are padded
        if self.with_xyz:
            pts_feats = pts_feats[pts_mask]
            pts_coors = pts_coors[pts_mask]
            points = points[pts_mask]
            voxel_size = torch.tensor(self.voxel_size, dtype=dtype, device=device).reshape(1,3)
            pc_min_range = torch.tensor(self.point_cloud_range[:3], dtype=dtype, device=device).reshape(1,3)
            voxel_center_each_pts = (pts_coors[:, [3,2,1]].to(dtype).to(device) + 0.5) * voxel_size + pc_min_range# x y z order
            local_xyz = points[:, :3] - voxel_center_each_pts
            if self.normalize_local_xyz:
                local_xyz = local_xyz / (voxel_size / 2)

            if self.training and not self.normalize_local_xyz:
                assert (local_xyz.abs() < voxel_size / 2 + 1e-3).all(), 'Holds in training. However, in test, this is not always True because of lack of point range clip'
            results = torch.cat([pts_feats, local_xyz], 1)
        else:
            results = pts_feats[pts_mask]
        
        batch_dict['point_feats'] = results
        batch_dict['point_mask'] = pts_mask
        return batch_dict
    
    
class VoteSegHead(nn.Module):

    def __init__(self,
                 in_channel,
                 num_classes,
                 hidden_dim,
                 dropout_ratio=0.5):
        super(VoteSegHead, self).__init__()
        
        self.pre_seg_conv = nn.Sequential(
            nn.Linear(in_channel, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim, eps=1e-3, momentum=0.01),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim, bias=False)
        )
        if dropout_ratio > 0:
            self.dropout = nn.Dropout(dropout_ratio)
        else:
            self.dropout = None
        self.num_classes = num_classes + 1  # +1 for background class
        self.conv_seg = nn.Linear(hidden_dim, self.num_classes)
        self.voting = nn.Linear(hidden_dim, self.num_classes * 3)
    
    def cls_seg(self, feats):
        """Classify each points."""
        if self.dropout is not None:
            feat = self.dropout(feats)
        output = self.conv_seg(feats)
        return output
        
    def forward(self, batch_dict):
        hidden_feats = self.pre_seg_conv(batch_dict['point_feats'])     #[N, C_h]
        logits = self.cls_seg(hidden_feats)                             #[N, 1]
        vote_preds = self.voting(hidden_feats)                          #[N, 3]
        batch_dict.update({
            'class_logits': logits,     #[N, 2]
            'vote_preds': vote_preds,   #[N, 3]
            'vote_offsets': self.decode_vote_targets(vote_preds)    #[N, 3]
        })
        return batch_dict
    
    def get_vote_target(self, inbox_inds, points, bboxes):

        bg_mask = inbox_inds < 0
        if self.train_cfg.get('centroid_offset', False):
            centroid, _, inv = scatter_v2(points, inbox_inds, mode='avg', return_inv=True)
            center_per_point = centroid[inv]
        else:
            center_per_point = bboxes.gravity_center[inbox_inds]
        delta = center_per_point.to(points.device) - points
        delta[bg_mask] = 0
        target = self.encode_vote_targets(delta)
        vote_mask = ~bg_mask
        return target, vote_mask
    
    def encode_vote_targets(self, delta):
        return torch.sign(delta) * (delta.abs() ** 0.5) 
    
    def decode_vote_targets(self, preds):
        return preds * preds.abs()