import math
import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import MinkowskiEngine as ME
from MinkowskiEngine.MinkowskiKernelGenerator import KernelGenerator

from models.common import MinkowskiLayerNorm
from models.sparse_ops import (
    direct_dot_product_cuda,
    direct_dot_product_shared_cuda,
    dot_product_cuda,
    dot_product_with_key_cuda,
    dot_product_sample_shared_cuda,
    dot_product_sample_cuda,
    dot_product_sample_with_key_cuda,
    dot_product_intra_cuda,
    dot_product_intra_inter_cuda,
    scalar_attention_cuda,
    dot_product_key_cuda,
    add_sum_squares_cuda,
)


class TransformerLayerBase(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        out_channels = in_channels if out_channels is None else out_channels
        assert out_channels % num_heads == 0
        super(TransformerLayerBase, self).__init__()
        
        self.out_channels = out_channels
        self.attn_channels = out_channels // num_heads
        self.num_heads = num_heads
        
        self.to_query = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, self.out_channels, kernel_size=1, stride=stride, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature()
        )
        self.to_value = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, self.out_channels, kernel_size=1, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature()
        )
        self.to_out = nn.Linear(out_channels, out_channels, bias=bias)
        
        # just for infomation
        if kernel_size == 3 and stride == 2:
            logging.info("Recommend to use kernel size 5 instead of 3 for stride 2.")
        self.kernel_size = kernel_size
        self.kernel_generator = KernelGenerator(kernel_size=kernel_size, stride=stride, dimension=dimension)
        self.kernel_volume = self.kernel_generator.kernel_volume
    
    @torch.no_grad()
    def key_query_indices_from_kernel_map(self, kernel_map):
        kq_indices = []
        for _, in_out in kernel_map.items():
            kq_indices.append(in_out)
        kq_indices = torch.cat(kq_indices, -1)
        
        return kq_indices
    
    @torch.no_grad()
    def key_query_map_from_kernel_map(self, kernel_map):
        kq_map = []
        for kernel_idx, in_out in kernel_map.items():
            in_out[0] = in_out[0] * self.kernel_volume + kernel_idx
            kq_map.append(in_out)
        kq_map = torch.cat(kq_map, -1)
        
        return kq_map
    
    @torch.no_grad()
    def key_query_indices_from_key_query_map(self, kq_map):
        kq_indices = kq_map.clone()
        kq_indices[0] = kq_indices[0] // self.kernel_volume

        return kq_indices
    
    @torch.no_grad()
    def get_kernel_map_and_out_key(self, stensor):
        cm = stensor.coordinate_manager
        in_key = stensor.coordinate_key
        out_key = cm.stride(in_key, self.kernel_generator.kernel_stride)
        region_type, region_offset, _ = self.kernel_generator.get_kernel(stensor.tensor_stride, False)
        kernel_map = cm.kernel_map(
            in_key,
            out_key,
            self.kernel_generator.kernel_stride,
            self.kernel_generator.kernel_size,
            self.kernel_generator.kernel_dilation,
            region_type=region_type,
            region_offset=region_offset,
        )
        
        return kernel_map, out_key


class PointTransformerLayer(TransformerLayerBase):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(PointTransformerLayer, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.out_channels, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels, self.out_channels)
        )
        
    @torch.no_grad()
    def get_relative_position(self, points, kq_indices):
        kq_indices = kq_indices.long()
        rel_pos = points[kq_indices[1]] - points[kq_indices[0]] # query - key
        
        return rel_pos
        
    def forward(self, stensor, points):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
        kq_indices = self.key_query_indices_from_kernel_map(kernel_map)
        
        # relative positional encodings
        rel_pos = self.get_relative_position(points, kq_indices)
        rel_pos_enc = self.rel_pos_mlp(rel_pos)
        rel_pos_enc = rel_pos_enc.view(-1, self.num_heads, self.attn_channels).contiguous()
        
        # dot-product similarity
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, p=2, dim=-1)
        attn = direct_dot_product_cuda(norm_q, norm_pos_enc, attn, kq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager)
        
        
class EfficientPointTransformerLayer(TransformerLayerBase):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayer, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.inter_pos_enc = nn.Parameter(torch.FloatTensor(self.kernel_volume, self.num_heads, self.attn_channels))
        self.intra_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, in_channels, bias=False),
            ME.MinkowskiBatchNorm(in_channels),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels, in_channels)
        )
        nn.init.normal_(self.inter_pos_enc, 0, 1)
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value with intra-voxel relative positional encodings
        intra_pos_enc = self.intra_pos_mlp(points)
        stensor = stensor + intra_pos_enc
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        attn = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(self.inter_pos_enc, p=2, dim=-1)
        attn = dot_product_cuda(norm_q, norm_pos_enc, attn, kq_map)
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map


class EfficientPointTransformerLayerInterOnly(TransformerLayerBase):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerInterOnly, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.inter_pos_enc = nn.Parameter(torch.FloatTensor(self.kernel_volume, self.num_heads, self.attn_channels))
        nn.init.normal_(self.inter_pos_enc, 0, 1)
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value with intra-voxel relative positional encodings
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        attn = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(self.inter_pos_enc, p=2, dim=-1)
        attn = dot_product_cuda(norm_q, norm_pos_enc, attn, kq_map)
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map


class EfficientPointTransformerLayerWithKey(EfficientPointTransformerLayer):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        assert stride == 1
        super(EfficientPointTransformerLayerWithKey, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        delattr(self, "to_query")
        delattr(self, "to_value")
        self.to_qkv = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, 3 * self.out_channels, kernel_size=1, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature()
        )
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value with intra-voxel relative positional encodings
        intra_pos_enc = self.intra_pos_mlp(points)
        stensor = stensor + intra_pos_enc
        q, k, v = self.to_qkv(stensor).split(self.out_channels, dim=-1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        N = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        M = kq_map.shape[1]
        attn = q.new(M, self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        attn = dot_product_with_key_cuda(norm_q, k, self.inter_pos_enc, attn, kq_map)
        ss_kp = q.new(M, self.num_heads).zero_()
        ss_kp = dot_product_key_cuda(k, self.inter_pos_enc, ss_kp, kq_map[0])
        ss_kp = 2 * ss_kp
        ss_k = k.square().sum(dim=-1, keepdim=False) # (N, H)
        ss_p = self.inter_pos_enc.square().sum(dim=-1, keepdim=False) # (K, H)
        ss_kp = add_sum_squares_cuda(ss_k, ss_p, ss_kp, kq_map[0])
        norm_kp = torch.sqrt(ss_kp)
        attn = attn / norm_kp
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(N, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map


class EfficientPointTransformerLayerWithKeyApprox(EfficientPointTransformerLayerWithKey):
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value with intra-voxel relative positional encodings
        intra_pos_enc = self.intra_pos_mlp(points)
        stensor = stensor + intra_pos_enc
        q, k, v = self.to_qkv(stensor).split(self.out_channels, dim=-1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        N = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        M = kq_map.shape[1]
        attn = q.new(M, self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_k = F.normalize(k, p=2, dim=-1)
        norm_pos_enc = F.normalize(self.inter_pos_enc, p=2, dim=-1)
        attn = dot_product_with_key_cuda(norm_q, norm_k, norm_pos_enc, attn, kq_map)
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(N, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map


class EfficientPointTransformerLayerWithKeySoftmax(EfficientPointTransformerLayerWithKey):
        
    @torch.no_grad()
    def query_neighbor_indices_from_key_query_map(self, kq_map):
        qn_indices = kq_map.clone()
        qn_indices[0] = kq_map[1]
        qn_indices[1] = kq_map[0] % self.kernel_volume
        
        return qn_indices
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        cm = stensor.coordinate_manager
        
        # query and value with intra-voxel relative positional encodings
        intra_pos_enc = self.intra_pos_mlp(points)
        stensor = stensor + intra_pos_enc
        q, k, v = self.to_qkv(stensor).split(self.out_channels, dim=-1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        M = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        else:
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        
        # dot-product similarity
        sim = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, dim=-1)
        norm_k = F.normalize(k, dim=-1)
        sim = dot_product_with_key_cuda(norm_q, norm_k, self.inter_pos_enc, sim, kq_map)
        
        # attention weights (softmax)
        sim = sim / math.sqrt(self.attn_channels)
        qn_indices = self.query_neighbor_indices_from_key_query_map(kq_map)
        qn_keys = qn_indices[0] * self.kernel_volume + qn_indices[1]
        _, indices = torch.sort(qn_keys)
        sorted_kq_indices = torch.index_select(kq_indices, 1, indices)
        sorted_qn_indices = torch.index_select(qn_indices, 1, indices)
        COO = sorted_qn_indices.long()
        size = torch.Size([M, self.kernel_volume, self.num_heads])
        sp = torch.sparse.FloatTensor(COO, sim, size)
        sp = torch.sparse.softmax(sp, dim=1).coalesce()
        attn = sp.values()
        
        assert torch.allclose(sorted_qn_indices, sp.indices().int())
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(M, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, sorted_kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=cm), kq_map


class EfficientPointTransformerLayerWithKeyScaled(EfficientPointTransformerLayerWithKeySoftmax):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerWithKeyScaled, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.scales = nn.Parameter(torch.FloatTensor(1, self.num_heads))
        nn.init.normal_(self.scales, 0, 1)
        
    @torch.no_grad()
    def query_neighbor_indices_from_key_query_map(self, kq_map):
        qn_indices = kq_map.clone()
        qn_indices[0] = kq_map[1]
        qn_indices[1] = kq_map[0] % self.kernel_volume
        
        return qn_indices
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        cm = stensor.coordinate_manager
        
        # query and value with intra-voxel relative positional encodings
        intra_pos_enc = self.intra_pos_mlp(points)
        stensor = stensor + intra_pos_enc
        q, k, v = self.to_qkv(stensor).split(self.out_channels, dim=-1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        M = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        else:
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        
        # dot-product similarity
        sim = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, dim=-1)
        norm_k = F.normalize(k, dim=-1)
        norm_pos_enc = F.normalize(self.inter_pos_enc, dim=-1)
        sim = dot_product_with_key_cuda(norm_q, norm_k, norm_pos_enc, sim, kq_map)
        sim = sim * self.scales
        
        # aggregation & projection
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        out_F = stensor._F.new(M, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(sim, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=cm), kq_map

    
class EfficientPointTransformerLayerLinear(TransformerLayerBase):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        assert kernel_size % 2 == 1
        super(EfficientPointTransformerLayerLinear, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.to_query = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, self.out_channels, kernel_size=1, stride=stride, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature(),
            nn.LayerNorm(self.out_channels)
        ) # overriding
        self.to_pos_enc = nn.Linear(dimension, self.out_channels, bias=False)
        self.inter_pos = nn.Parameter(self.enumerate_inter_pos(), requires_grad=False)
        
    def enumerate_inter_pos(self):
        # note that MinkowskiEngine uses key - query but we need query - key
        inter_pos = []
        pos_max = self.kernel_size // 2
        for pos_z in reversed(range(-pos_max, pos_max + 1)):
            for pos_y in reversed(range(-pos_max, pos_max + 1)):
                for pos_x in reversed(range(-pos_max, pos_max + 1)):
                    inter_pos.append([pos_x, pos_y, pos_z])

        return torch.FloatTensor(inter_pos)
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        
        # similarity
        attn = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        intra_pos_enc = self.to_pos_enc(points).view(-1, self.num_heads, self.attn_channels).contiguous()
        inter_pos_enc = self.to_pos_enc(self.inter_pos).view(-1, self.num_heads, self.attn_channels).contiguous()
        attn = dot_product_intra_inter_cuda(q, intra_pos_enc, inter_pos_enc, attn, kq_map)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map
    
    
class EfficientPointTransformerLayerLinearSigmoid(EfficientPointTransformerLayerLinear):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerLinearSigmoid, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.to_query = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, self.out_channels, kernel_size=1, stride=stride, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature()
        ) # overriding
        
    def forward(self, stensor, points, kq_map=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        kq_indices = self.key_query_indices_from_key_query_map(kq_map)
        
        # similarity
        attn = stensor._F.new(kq_map.shape[1], self.num_heads).zero_()
        intra_pos_enc = self.to_pos_enc(points).view(-1, self.num_heads, self.attn_channels).contiguous()
        inter_pos_enc = self.to_pos_enc(self.inter_pos).view(-1, self.num_heads, self.attn_channels).contiguous()
        attn = dot_product_intra_inter_cuda(q, intra_pos_enc, inter_pos_enc, attn, kq_map)
        attn = torch.sigmoid(attn)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_map
    
    
class EfficientPointTransformerLayerLipschitzShared(TransformerLayerBase):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        assert kernel_size % 2 == 1
        super(EfficientPointTransformerLayerLipschitzShared, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.granularity = granularity
        self.num_samples = int(granularity ** dimension)
        self.rel_pos_bound = self.kernel_size // 2 + 1 # if kernel_size = 3, bound = 2.
        self.rel_pos_mlp = nn.Sequential(
            spectral_norm(nn.Linear(dimension, dimension)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(dimension, self.attn_channels)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.attn_channels, self.attn_channels)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.attn_channels, self.attn_channels)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.attn_channels, self.attn_channels)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.attn_channels, self.attn_channels)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.attn_channels, self.attn_channels))
        )
        self.register_buffer("sampled_rel_pos", self.sample_relative_position())
        
    @torch.no_grad()
    def sample_relative_position(self):
        grid_length = 2 * self.rel_pos_bound / self.granularity
        half = grid_length / 2
        pos_1d = torch.linspace(
            -self.rel_pos_bound + half,
            self.rel_pos_bound - half,
            self.granularity,
            dtype=torch.float32
        )
        grid_x, grid_y, grid_z = torch.meshgrid([pos_1d, pos_1d, pos_1d])
        sampled_pos = torch.cat([grid_z.reshape(-1, 1), grid_y.reshape(-1, 1), grid_x.reshape(-1, 1)], dim=1)
        
        return sampled_pos
    
    def forward_sampled_rel_pos(self):
        if self.training:
            for m in self.rel_pos_mlp:
                m.training = False if m.training else False
        rel_pos_enc = self.rel_pos_mlp(self.sampled_rel_pos)
        
        return rel_pos_enc

    @torch.no_grad()
    def get_relative_position_from_kernel_map(self, points, kernel_map):
        kq_indices = []
        for _, in_out in kernel_map.items():
            kq_indices.append(in_out)
        kq_indices = torch.cat(kq_indices, -1)
        rel_pos = points[kq_indices[1].long()] - points[kq_indices[0].long()] # where, points = 3D points / stride
        
        return rel_pos, kq_indices
    
    @torch.no_grad()
    def get_sample_indices(self, rel_pos):
        # rel_pos in (-bound, +bound)
        rel_pos += self.rel_pos_bound # (0, 2 * bound)
        rel_pos *= self.granularity / (2 * self.rel_pos_bound) # (0, granuality)
        rel_pos = torch.floor(rel_pos).int() # quantization: (0, .., granuality - 1)
        sample_indices = rel_pos[:, 0] + self.granularity * rel_pos[:, 1] + self.granularity**2 * rel_pos[:, 2]

        return sample_indices
        
    def forward(self, stensor, points, kq_indices=None, sq_indices=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # relative positional encodings
        if kq_indices is None or sq_indices is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            rel_pos, kq_indices = self.get_relative_position_from_kernel_map(points, kernel_map)
            s_indices = self.get_sample_indices(rel_pos)
            sq_indices = torch.vstack([s_indices, kq_indices[1]])
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        rel_pos_enc = self.forward_sampled_rel_pos()
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, p=2, dim=-1)
        attn = dot_product_sample_shared_cuda(norm_q, norm_pos_enc, attn, sq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, sq_indices
    

class EfficientPointTransformerLayerLipschitz(EfficientPointTransformerLayerLipschitzShared):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerLipschitz, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            spectral_norm(nn.Linear(dimension, dimension)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(dimension, self.out_channels // 4)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.out_channels // 4, self.out_channels // 2)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.out_channels // 2, self.out_channels // 2)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.out_channels // 2, self.out_channels // 2)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.out_channels // 2, self.out_channels // 2)),
            nn.LeakyReLU(0.1, inplace=True),
            spectral_norm(nn.Linear(self.out_channels // 2, self.out_channels))
        ) # overriding
        
    def forward(self, stensor, points, kq_indices=None, sq_indices=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # relative positional encodings
        if kq_indices is None or sq_indices is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            rel_pos, kq_indices = self.get_relative_position_from_kernel_map(points, kernel_map)
            s_indices = self.get_sample_indices(rel_pos)
            sq_indices = torch.vstack([s_indices, kq_indices[1]])
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity
        rel_pos_enc = self.forward_sampled_rel_pos()
        rel_pos_enc = rel_pos_enc.view(-1, self.num_heads, self.attn_channels).contiguous()
        
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, p=2, dim=-1)
        attn = dot_product_sample_cuda(norm_q, norm_pos_enc, attn, sq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, sq_indices
    

class EfficientPointTransformerLayerMappingShared(EfficientPointTransformerLayerLipschitzShared):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerMappingShared, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels)
        ) # overriding
        
    def forward_sampled_rel_pos(self):
        return self.rel_pos_mlp(self.sampled_rel_pos) # overriding

    
class EfficientPointTransformerLayerMapping(EfficientPointTransformerLayerLipschitz):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerMapping, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.out_channels // 4, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 4, self.out_channels // 2, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 2, self.out_channels // 2, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 2, self.out_channels // 2, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 2, self.out_channels // 2, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 2, self.out_channels)
        ) # overriding
        
    def forward_sampled_rel_pos(self):
        return self.rel_pos_mlp(self.sampled_rel_pos) # overriding
        
        
class EfficientPointTransformerLayerMappingLighter(EfficientPointTransformerLayerMapping):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerMappingLighter, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.out_channels, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels, self.out_channels)
        ) # overriding
        
        
class EfficientPointTransformerLayerMappingFiner(EfficientPointTransformerLayerMapping):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        granularity = 4 * (kernel_size//2 + 1) + 1
        super(EfficientPointTransformerLayerMappingFiner, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        
        
class EfficientPointTransformerLayerLipschitzFiner(EfficientPointTransformerLayerLipschitz):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        granularity = 4 * (kernel_size//2 + 1) + 1
        super(EfficientPointTransformerLayerLipschitzFiner, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        
        
class EfficientPointTransformerLayerMHSA(EfficientPointTransformerLayerMappingFiner):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerMHSA, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        delattr(self, "to_value")
        self.to_key_value = nn.Sequential(
            ME.MinkowskiConvolution(in_channels, 2 * self.out_channels, kernel_size=1, bias=bias, dimension=dimension),
            ME.MinkowskiToFeature()
        )
        
    def forward(self, stensor, points, kq_indices=None, sq_indices=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        kv = self.to_key_value(stensor)
        k, v = kv.split(self.out_channels, dim=1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)

        # relative positional encodings
        if kq_indices is None or sq_indices is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            rel_pos, kq_indices = self.get_relative_position_from_kernel_map(points, kernel_map)
            s_indices = self.get_sample_indices(rel_pos)
            skq_indices = torch.vstack([s_indices, kq_indices])
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
            skq_indices = torch.vstack([sq_indices[0], kq_indices])
        
        # dot-product similarity
        rel_pos_enc = self.forward_sampled_rel_pos()
        rel_pos_enc = rel_pos_enc.view(-1, self.num_heads, self.attn_channels).contiguous()
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        attn = dot_product_sample_with_key_cuda(q, k, rel_pos_enc, attn, skq_indices)
        attn = attn / math.sqrt(self.attn_channels)
        attn = torch.sparse_coo_tensor(
            kq_indices,
            attn,
            (kq_indices[0].max() + 1, num_queries, self.num_heads)
        )
        attn = torch.sparse.softmax(attn, dim=0)
        attn = attn.coalesce()
        kq_indices_ = attn.indices().int()
        attn = attn.values()
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices_)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, sq_indices
    
    
class EfficientPointTransformerLayerMappingFinerWithKey(EfficientPointTransformerLayerMHSA):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(EfficientPointTransformerLayerMappingFinerWithKey, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            granularity,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.out_channels // 4, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 4),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 4, self.out_channels // 2, bias=False),
            ME.MinkowskiBatchNorm(self.out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.out_channels // 2, self.out_channels)
        ) # overriding
    
    def forward(self, stensor, points, kq_indices=None, sq_indices=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        kv = self.to_key_value(stensor)
        k, v = kv.split(self.out_channels, dim=1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)

        # relative positional encodings
        if kq_indices is None or sq_indices is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            rel_pos, kq_indices = self.get_relative_position_from_kernel_map(points, kernel_map)
            s_indices = self.get_sample_indices(rel_pos)
            skq_indices = torch.vstack([s_indices, kq_indices])
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
            skq_indices = torch.vstack([sq_indices[0], kq_indices])
        
        # dot-product similarity
        rel_pos_enc = self.forward_sampled_rel_pos()
        rel_pos_enc = rel_pos_enc.view(-1, self.num_heads, self.attn_channels).contiguous()
        norm_q = F.normalize(q, dim=-1)
        norm_k = F.normalize(k, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, dim=-1)
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        attn = dot_product_sample_with_key_cuda(norm_q, norm_k, norm_pos_enc, attn, skq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, sq_indices
    
    
class EfficientPointTransformerLayerToken(EfficientPointTransformerLayerMappingFiner):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        granularity=7,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        TransformerLayerBase.__init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        
        self.granularity = granularity
        self.num_samples = int(granularity ** dimension)
        self.rel_pos_bound = self.kernel_size // 2 + 1 # if kernel_size = 3, bound = 2.
        self.rel_pos_enc = nn.Parameter(torch.FloatTensor(self.num_samples, self.num_heads, self.attn_channels))
        nn.init.normal_(self.rel_pos_enc, 0, 1)
        
    def forward(self, stensor, points, kq_indices=None, sq_indices=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # relative positional encodings
        if kq_indices is None or sq_indices is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            rel_pos, kq_indices = self.get_relative_position_from_kernel_map(points, kernel_map)
            s_indices = self.get_sample_indices(rel_pos)
            sq_indices = torch.vstack([s_indices, kq_indices[1]])
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # dot-product similarity        
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(self.rel_pos_enc, p=2, dim=-1)
        attn = dot_product_sample_cuda(norm_q, norm_pos_enc, attn, sq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, sq_indices
    

class PointTransformerLayerShared(PointTransformerLayer):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        super(PointTransformerLayerShared, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            num_heads,
            bias,
            dimension
        )
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels)
        )
        
    def forward(self, stensor, points, kq_indices=None, rel_pos=None):
        assert len(stensor) == len(points)
        
        # query and value
        q = self.to_query(stensor)
        v = self.to_value(stensor)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        num_queries = len(q)
        
        # kernel map
        if kq_indices is None or rel_pos is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_indices = self.key_query_indices_from_kernel_map(kernel_map)
            rel_pos = self.get_relative_position(points, kq_indices)
        else:
            cm = stensor.coordinate_manager
            out_key = cm.stride(stensor.coordinate_key, self.kernel_generator.kernel_stride)
        
        # relative positional encodings
        rel_pos_enc = self.rel_pos_mlp(rel_pos)
        
        # dot-product similarity
        attn = stensor._F.new(kq_indices.shape[1], self.num_heads).zero_()
        norm_q = F.normalize(q, p=2, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, p=2, dim=-1)
        attn = direct_dot_product_shared_cuda(norm_q, norm_pos_enc, attn, kq_indices)
        
        # aggregation & projection
        out_F = stensor._F.new(num_queries, self.num_heads, self.attn_channels).zero_()
        out_F = scalar_attention_cuda(attn, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=stensor.coordinate_manager), kq_indices, rel_pos