import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import MinkowskiEngine as ME
from MinkowskiEngine.MinkowskiKernelGenerator import KernelGenerator

from models.common import MinkowskiLayerNorm
from models.sparse_ops import (
    direct_dot_product_shared_with_key_cuda,
    scalar_attention_cuda,
)


class EfficientPointTransformerLayerV2Shared(nn.Module):
    
    def __init__(
        self,
        in_channels,
        out_channels=None,
        kernel_size=3,
        stride=1,
        num_heads=8,
        bias=True,
        dimension=3
    ):
        out_channels = in_channels if out_channels is None else out_channels
        assert out_channels % num_heads == 0
        assert stride == 1
        super(EfficientPointTransformerLayerV2Shared, self).__init__()
        
        self.out_channels = out_channels
        self.attn_channels = out_channels // num_heads
        self.num_heads = num_heads
        
        self.to_qkv = nn.Sequential(
            ME.MinkowskiConvolution(
                in_channels, 3 * out_channels, kernel_size=1, bias=bias, dimension=dimension
            ),
            ME.MinkowskiToFeature()
        )
        self.to_out = nn.Linear(out_channels, out_channels, bias=bias)
        self.rel_pos_mlp = nn.Sequential(
            nn.Linear(dimension, dimension, bias=False),
            ME.MinkowskiBatchNorm(dimension),
            nn.ReLU(inplace=True),
            nn.Linear(dimension, self.attn_channels, bias=False),
            ME.MinkowskiBatchNorm(self.attn_channels),
            nn.ReLU(inplace=True),
            nn.Linear(self.attn_channels, self.attn_channels)
        )
        
        # just for infomation
        if kernel_size == 3 and stride == 2:
            logging.info("Recommend to use kernel size 5 instead of 3 for stride 2.")
        self.kernel_size = kernel_size
        self.kernel_generator = KernelGenerator(kernel_size=kernel_size, stride=stride, dimension=dimension)
        self.kernel_volume = self.kernel_generator.kernel_volume
        
    @torch.no_grad()
    def key_query_indices_from_kernel_map(self, kernel_map):
        kq_indices = []
        for _, in_out in kernel_map.items():
            kq_indices.append(in_out)
        kq_indices = torch.cat(kq_indices, -1)
        
        return kq_indices
    
    @torch.no_grad()
    def key_query_map_from_kernel_map(self, kernel_map):
        kq_map = []
        for kernel_idx, in_out in kernel_map.items():
            in_out[0] = in_out[0] * self.kernel_volume + kernel_idx
            kq_map.append(in_out)
        kq_map = torch.cat(kq_map, -1)
        
        return kq_map
    
    @torch.no_grad()
    def key_query_indices_from_key_query_map(self, kq_map):
        kq_indices = kq_map.clone()
        kq_indices[0] = kq_indices[0] // self.kernel_volume

        return kq_indices
    
    @torch.no_grad()
    def query_neighbor_indices_from_key_query_map(self, kq_map):
        qn_indices = kq_map.clone()
        qn_indices[0] = kq_map[1]
        qn_indices[1] = kq_map[0] % self.kernel_volume
        
        return qn_indices
    
    @torch.no_grad()
    def get_kernel_map_and_out_key(self, stensor):
        cm = stensor.coordinate_manager
        in_key = stensor.coordinate_key
        out_key = cm.stride(in_key, self.kernel_generator.kernel_stride)
        region_type, region_offset, _ = self.kernel_generator.get_kernel(stensor.tensor_stride, False)
        kernel_map = cm.kernel_map(
            in_key,
            out_key,
            self.kernel_generator.kernel_stride,
            self.kernel_generator.kernel_size,
            self.kernel_generator.kernel_dilation,
            region_type=region_type,
            region_offset=region_offset,
        )
        
        return kernel_map, out_key
    
    @torch.no_grad()
    def get_relative_position(self, points, kq_indices):
        kq_indices = kq_indices.long()
        rel_pos = points[kq_indices[1]] - points[kq_indices[0]] # query - key
        
        return rel_pos
    
    def forward(self, stensor, points, kq_map=None, rel_pos=None):
        assert len(stensor) == len(points)
        cm = stensor.coordinate_manager
        
        # query, key, and value
        q, k, v = self.to_qkv(stensor).split(self.out_channels, dim=-1)
        q = q.view(-1, self.num_heads, self.attn_channels).contiguous()
        k = k.view(-1, self.num_heads, self.attn_channels).contiguous()
        v = v.view(-1, self.num_heads, self.attn_channels).contiguous()
        
        # kernel_map
        if kq_map is None:
            kernel_map, out_key = self.get_kernel_map_and_out_key(stensor)
            kq_map = self.key_query_map_from_kernel_map(kernel_map)
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
            rel_pos = self.get_relative_position(points, kq_indices)
        else:
            assert rel_pos is not None
            out_key = cm.stride(
                stensor.coordinate_key,
                self.kernel_generator.kernel_stride
            )
            kq_indices = self.key_query_indices_from_key_query_map(kq_map)
            
        # relative positional encodings
        rel_pos_enc = self.rel_pos_mlp(rel_pos)
        
        # dot-product similarity
        N = kq_indices.shape[1]
        M = len(q)
        sim = q.new(N, self.num_heads).zero_()
        norm_q = F.normalize(q, dim=-1)
        norm_k = F.normalize(k, dim=-1)
        norm_pos_enc = F.normalize(rel_pos_enc, dim=-1)
        sim = direct_dot_product_shared_with_key_cuda(norm_q, norm_k, norm_pos_enc, sim, kq_indices)
        
        # aggregation & projection
        out_F = v.new(M, self.num_heads, self.attn_channels).zero_() 
        out_F = scalar_attention_cuda(sim, v, out_F, kq_indices)
        out_F = self.to_out(out_F.view(-1, self.out_channels).contiguous())
        
        return (
            ME.SparseTensor(out_F, coordinate_map_key=out_key, coordinate_manager=cm),
            kq_map,
            rel_pos
        )