import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
from .pointnet2_utils import *
from spikingjelly.clock_driven.neuron import MultiStepLIFNode, ParametricLIFNode, MultiStepParametricLIFNode
import spikingjelly.clock_driven.surrogate as surrogate
import math
from dataclasses import dataclass

def spike_rate(inp):
    nonzero_indices = torch.nonzero(inp)
    nonzero_count = nonzero_indices.size(0)
    spike_rate = torch.tensor(nonzero_count / inp.numel()).item()
    return spike_rate


class ReLUX(nn.Module):
    def __init__(self, thre=4):
        super(ReLUX, self).__init__()
        self.thre = thre

    def forward(self, input):
        return torch.clamp(input, 0, self.thre)

relu4 = ReLUX(thre=4)


class multispike(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, lens=4):
        ctx.save_for_backward(input)
        ctx.lens = lens
        return torch.floor(relu4(input) + 0.5)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        temp1 = 0 < input
        temp2 = input < ctx.lens
        return grad_input * temp1.float() * temp2.float(), None


class Multispike(nn.Module):
    def __init__(self, lens=4, spike=multispike):
        super().__init__()
        self.lens = lens
        self.spike = spike

    def forward(self, inputs):
        return self.spike.apply(inputs)


class PosE_Geo(nn.Module):
    def __init__(self, in_dim, out_dim, alpha, beta):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.alpha, self.beta = alpha, beta

    def forward(self, knn_xyz, knn_x):
        T = knn_x.shape[0]
        B, _, G, K = knn_xyz.shape
        feat_dim = self.out_dim // (self.in_dim * 2)

        time_range = torch.arange(1, T+1).float().cuda().reshape(T, 1, 1, 1, 1, 1)
        feat_range = torch.arange(feat_dim).float().cuda()
        dim_embed = torch.pow(self.alpha, feat_range / feat_dim)
        div_embed = torch.div(self.beta * knn_xyz.unsqueeze(-1), dim_embed).unsqueeze(0).repeat(T, 1, 1, 1, 1, 1) / time_range

        sin_embed = torch.sin(div_embed)
        cos_embed = torch.cos(div_embed)
        position_embed = torch.cat([sin_embed, cos_embed], -1)
        # print(position_embed.shape)
        # position_embed = position_embed
        position_embed = position_embed.permute(0, 1, 2, 5, 3, 4).contiguous()
        position_embed = position_embed.view(T, B, self.out_dim, G, K)
        # position_embed = torch.heaviside(position_embed - 0.7, torch.tensor([1.0]).cuda())
        # Weigh
        return position_embed


class PosE_Geo_Global(nn.Module):
    def __init__(self, in_dim, out_dim, alpha, beta):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.alpha, self.beta = alpha, beta

    def forward(self, xyz, x):
        T = x.shape[0]
        B, _, N = xyz.shape
        feat_dim = self.out_dim // (self.in_dim * 2)

        time_range = torch.arange(1, T+1).float().cuda().reshape(T, 1, 1, 1, 1)
        feat_range = torch.arange(feat_dim).float().cuda()
        dim_embed = torch.pow(self.alpha, feat_range / feat_dim)
        div_embed = torch.div(self.beta * xyz.unsqueeze(-1), dim_embed)

        sin_embed = torch.sin(div_embed)
        cos_embed = torch.cos(div_embed)
        position_embed = torch.cat([sin_embed, cos_embed], -1)

        position_embed = position_embed.unsqueeze(0).repeat(T, 1, 1, 1, 1) / time_range
        position_embed = position_embed.permute(0, 1, 2, 4, 3).contiguous()
        position_embed = position_embed.view(T, B, self.out_dim, N)
        # position_embed = torch.heaviside(position_embed - 0.7, torch.tensor([1.0]).cuda())
        return position_embed



class FPS_kNN(nn.Module):
    def __init__(self, group_num, k_neighbors):
        super().__init__()
        self.group_num = group_num
        self.k_neighbors = k_neighbors

    def forward(self, xyz, x):
        B, N, _ = xyz.shape

        # FPS
        fps_idx = farthest_point_sample(xyz, self.group_num).long()
        lc_xyz = index_points(xyz, fps_idx)
        lc_x = index_points(x, fps_idx)

        # kNN
        knn_idx = knn_point(self.k_neighbors, xyz, lc_xyz)
        knn_xyz = index_points(xyz, knn_idx)
        knn_x = index_points(x, knn_idx)

        return lc_xyz, lc_x, knn_xyz, knn_x


class Local_op(nn.Module):
    def __init__(self, in_channels, out_channels, alpha, beta):
        super(Local_op, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.lif1 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        # self.rpe_conv = nn.Conv2d(3, in_channels, kernel_size=1, bias=False)
        # self.rpe_bn = nn.BatchNorm2d(in_channels)
        self.geo_extract = PosE_Geo(3, in_channels, alpha=1000, beta=100)

    def forward(self, knn_xyz, knn_x):
        knn_x = knn_x.permute(0, 1, 4, 2, 3).contiguous()  # T B C G K
        T, B, C, G, K = knn_x.shape
        pe = self.geo_extract(knn_xyz, knn_x)
        knn_x = pe + knn_x
        knn_x = self.lif1(knn_x)
        knn_x = knn_x.permute(0, 1, 3, 4, 2).contiguous()  # T B G K C
        t, b, n, s, d = knn_x.size()  # torch.Size([32, 512, 32, 6])
        x = knn_x.permute(0, 1, 2, 4, 3).reshape(t, -1, d, s)
        time, batch_size, _, N = x.size()
        x = self.bn1(self.conv1(x.flatten(0, 1))).reshape(time, batch_size, -1, N)   # B, D, N
        x = x.max(-1)[0] + x.mean(-1)
        x = x.reshape(t, b, n, -1).permute(0, 1, 3, 2)
        return x


# Local Geometry Aggregation
class LGA(nn.Module):
    def __init__(self, out_dim, alpha, beta, block_num, dim_expansion, type, stage):
        super().__init__()
        self.type = type
        if dim_expansion == 1:
            expand = 2
        elif dim_expansion == 2:
            expand = 1
        self.local_op = Local_op(out_dim * expand, out_dim, alpha, beta)
        self.transformer = []
        for i in range(block_num):
            if stage < 2:
                self.transformer.append(Local_Transformer(out_dim, share_planes=8, nsample=16))
            else:
                self.transformer.append(Global_Transformer(out_dim))
        self.transformer = nn.Sequential(*self.transformer)

    def forward(self, lc_xyz, lc_x, knn_xyz, knn_x):
        # Normalization
        if self.type == 'mn40':
            mean_xyz = lc_xyz.unsqueeze(dim=-2)
            std_xyz = torch.std(knn_xyz - mean_xyz)
            knn_xyz = (knn_xyz - mean_xyz) / (std_xyz + 1e-5)

        elif self.type == 'scan':
            knn_xyz = knn_xyz.permute(0, 3, 1, 2).contiguous()
            knn_xyz -= lc_xyz.permute(0, 2, 1).unsqueeze(-1).contiguous()
            knn_xyz /= torch.abs(knn_xyz).max(dim=-1, keepdim=True)[0]
            knn_xyz = knn_xyz.permute(0, 2, 3, 1).contiguous()

        # Feature Expansion
        T, B, G, K, C = knn_x.shape
        knn_x = torch.cat([knn_x, lc_x.reshape(T, B, G, 1, -1).repeat(1, 1, 1, K, 1)], dim=-1)
        knn_xyz = knn_xyz.permute(0, 3, 1, 2).contiguous()  # B C G K
        # Local Feature
        x = self.local_op(knn_xyz, knn_x)
        # Global Feature
        for layer in self.transformer:
            x = layer(x, lc_xyz)
        return x


# Pooling
class Pooling(nn.Module):
    def __init__(self, out_dim):
        super().__init__()

    def forward(self, knn_x_w):
        # Feature Aggregation (Pooling)
        lc_x = knn_x_w.max(-1)[0] + knn_x_w.mean(-1)
        return lc_x


class Linear0Layer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, bias=True):
        super(Linear0Layer, self).__init__()
        self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, bias=bias)
        self.bn = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        T, B, C, N = x.shape
        x = self.conv(x.flatten(0, 1))
        x = self.bn(x).reshape(T, B, -1, N)
        return x


class Linear1Layer(nn.Module):
    def __init__(self, in_channels, kernel_size=1, groups=1, bias=True):
        super(Linear1Layer, self).__init__()

        # self.act = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=int(in_channels / 4),
                      kernel_size=kernel_size, groups=groups, bias=bias)
        self.bn1 = nn.BatchNorm1d(int(in_channels / 4))
        self.act1 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')

        self.conv2 = nn.Conv1d(in_channels=int(in_channels / 4), out_channels=in_channels,
                      kernel_size=kernel_size, bias=bias)
        self.bn2 = nn.BatchNorm1d(in_channels)
        self.act2 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')

    def forward(self, x):
        T, B, C, N = x.shape
        x = self.act1(x)
        x = self.conv1(x.flatten(0, 1))
        x = self.bn1(x).reshape(T, B, -1, N)
        x = self.act2(x)
        x = self.conv2(x.flatten(0, 1))
        x = self.bn2(x).reshape(T, B, -1, N)
        return x


class Local_Attention(nn.Module):
    def __init__(self, channels, share_planes=8, nsample=16):
        super(Local_Attention, self).__init__()

        self.nsample = nsample
        self.share_planes = share_planes
        self.mid_planes = channels
        self.out_planes = channels
        self.start_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.q_conv = nn.Conv1d(self.out_planes, self.mid_planes, 1, bias=False)
        self.k_conv = nn.Conv1d(self.out_planes, self.mid_planes, 1, bias=False)
        # self.q_conv.weight = self.k_conv.weight
        # self.q_conv.bias = self.k_conv.bias
        self.q_bn = nn.BatchNorm1d(self.mid_planes)
        self.k_bn = nn.BatchNorm1d(self.mid_planes)
        self.q_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.k_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.v_conv = nn.Conv1d(channels, channels, 1)
        self.v_bn = nn.BatchNorm1d(channels)
        self.v_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.geo_extract_global = PosE_Geo_Global(3, self.out_planes, alpha=1000, beta=100)
        '''
        self.pe_conv = nn.Conv1d(3, self.out_planes, 1)
        self.pe_bn = nn.BatchNorm1d(self.out_planes)
        self.pe_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(),
                                                 detach_reset=True, backend='cupy', v_threshold=0.1)
        '''
        self.attn_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                   backend='cupy')
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)

    def forward(self, x, xyz):
        T, B, _, N = x.shape
        # pe = xyz.permute(0, 2, 1).unsqueeze(0).repeat(T, 1, 1, 1)
        # pe = self.pe_conv(pe.flatten(0, 1))
        # pe = self.pe_bn(pe).reshape(T, B, -1, N)
        pe = self.geo_extract_global(xyz.permute(0, 2, 1), x)
        x = self.start_lif(x + pe)

        x_q = self.q_conv(x.flatten(0, 1))
        x_q = self.q_bn(x_q).reshape(T, B, -1, N)
        x_q = self.q_lif(x_q)
        # print("x_q", spike_rate(x_q))

        x_k = self.k_conv(x.flatten(0, 1))
        x_k = self.k_bn(x_k).reshape(T, B, -1, N)
        x_k = self.k_lif(x_k)
        # print("x_k", spike_rate(x_k))

        sample_points, x_k = sample_and_group_attn(N, self.nsample, xyz, attn_k=x_k, attn_v=x_k)

        x_v = self.v_conv(x.flatten(0, 1))
        x_v = self.v_bn(x_v).reshape(T, B, -1, N)
        x_v = self.v_lif(x_v)

        # T, B, C, N, K = pe.shape
        w = x_k - x_q.unsqueeze(-1)  # + pe
        # w = self.pe_lif(w)  # T B C N K
        # print("w", spike_rate(w))
        w = self.attn_lif(w.sum(-1))
        # print("w2", spike_rate(w))
        x = (x_v * w).view(T, B, -1, N)
        # print("local", spike_rate(x))
        x = self.after_norm(self.trans_conv(x.flatten(0, 1))).reshape(T, B, -1, N)
        return x


class Global_Attention(nn.Module):
    def __init__(self, channels):
        super(Global_Attention, self).__init__()
        self.start_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        '''
        self.pe_conv = nn.Conv1d(3, channels, 1)
        self.pe_bn = nn.BatchNorm1d(channels)
        self.pe_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(),
                                                      detach_reset=True, backend='cupy')
        '''
        self.geo_extract_global = PosE_Geo_Global(3, channels, alpha=1000, beta=100)
        self.q_conv = nn.Conv1d(channels, channels, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels, 1, bias=False)
        self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight
        self.q_conv.bias = self.k_conv.bias
        self.q_bn = nn.BatchNorm1d(channels)
        self.k_bn = nn.BatchNorm1d(channels)
        self.v_bn = nn.BatchNorm1d(channels)
        self.q_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.k_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.v_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.attn_lif = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                   backend='cupy', v_threshold=0.1)
        self.act = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                backend='cupy')
        self.num_heads = 6
        self.scale = channels // self.num_heads

    def forward(self, x, xyz):
        T, B, C, N = x.shape
        # pe = xyz.permute(0, 2, 1).unsqueeze(0).repeat(T, 1, 1, 1)
        # pe = self.pe_conv(pe.flatten(0, 1))
        # pe = self.pe_bn(pe).reshape(T, B, -1, N)
        # pe = self.pe_lif(pe)  # T B C N K
        pe = self.geo_extract_global(xyz.permute(0, 2, 1), x)
        x = self.start_lif(x + pe)  # + pe
        # t, b, n, c
        x_q = self.q_conv(x.flatten(0, 1))
        x_q = self.q_bn(x_q).reshape(T, B, -1, N).reshape(T, B, self.num_heads, C // self.num_heads, N)
        x_q = self.q_lif(x_q)  # .permute(0, 1, 3, 2)

        x_k = self.k_conv(x.flatten(0, 1))
        x_k = self.k_bn(x_k).reshape(T, B, -1, N).reshape(T, B, self.num_heads, C // self.num_heads, N)
        x_k = self.k_lif(x_k)

        x_v = self.v_conv(x.flatten(0, 1))
        x_v = self.v_bn(x_v).reshape(T, B, -1, N).reshape(T, B, self.num_heads, C // self.num_heads, N)
        x_v = self.v_lif(x_v)

        k = torch.sum(x_k, dim=3, keepdim=True) / self.scale   # T B head 1 N
        q = torch.sum(x_q, dim=3, keepdim=True) / self.scale  # T B head 1 N

        tvd = torch.abs(q - k.transpose(-2, -1))
        attnmap = self.attn_lif(tvd)
        # print(spike_rate(attnmap))
        # t, b, c, n
        x_r = self.act(x_v @ attnmap).reshape(T, B, -1, N)
        # print(spike_rate(x_r))
        x_r = self.after_norm(self.trans_conv(x_r.flatten(0, 1))).reshape(T, B, -1, N)
        return x_r


class Global_Transformer(nn.Module):
    def __init__(self, channels):
        super(Global_Transformer, self).__init__()
        self.attn = Global_Attention(channels)
        self.mlp = Linear1Layer(channels)

    def forward(self, x, xyz):
        x = self.attn(x, xyz) + x
        x = self.mlp(x) + x
        return x


class Local_Transformer(nn.Module):
    def __init__(self, channels, share_planes=8, nsample=16):
        super(Local_Transformer, self).__init__()
        self.attn = Local_Attention(channels, share_planes=share_planes, nsample=nsample)
        self.mlp = Linear1Layer(channels)

    def forward(self, x, xyz):
        x = self.attn(x, xyz) + x
        x = self.mlp(x) + x
        return x


class EncP(nn.Module):
    def __init__(self, in_channels, input_points, num_stages, embed_dim, k_neighbors, alpha, beta, LGA_block,
                 dim_expansion, type):
        super().__init__()
        self.input_points = input_points
        self.num_stages = num_stages
        self.embed_dim = embed_dim
        self.alpha, self.beta = alpha, beta

        # Raw-point Embedding
        self.raw_point_embed = Linear0Layer(in_channels, self.embed_dim, bias=False)
        self.FPS_kNN_list = nn.ModuleList()  # FPS, kNN
        self.LGA_list = nn.ModuleList()  # Local Geometry Aggregation

        out_dim = self.embed_dim
        group_num = self.input_points

        # Multi-stage Hierarchy
        for i in range(self.num_stages):
            out_dim = out_dim * dim_expansion[i]
            group_num = group_num // 2
            self.FPS_kNN_list.append(FPS_kNN(group_num, k_neighbors))
            self.LGA_list.append(LGA(out_dim, self.alpha, self.beta, LGA_block[i], dim_expansion[i], type, i))

    def forward(self, xyz, x):

        # Raw-point Embedding
        x = self.raw_point_embed(x)
        # Multi-stage Hierarchy
        for i in range(self.num_stages):
            # FPS, kNN
            xyz, lc_x, knn_xyz, knn_x = self.FPS_kNN_list[i](xyz, x.permute(0, 1, 3, 2).contiguous())
            # Local Geometry Aggregation
            x = self.LGA_list[i](xyz, lc_x, knn_xyz, knn_x)
        return x


# Parametric Network for ModelNet40
class Point_PN_mn40(nn.Module):
    def __init__(self, in_channels=3, class_num=40, input_points=1024, num_stages=4, embed_dim=48, k_neighbors=32,
                 beta=100, alpha=1000, LGA_block=[1, 1, 1, 1], dim_expansion=[2, 2, 2, 1], type='mn40'):
        super().__init__()
        # Parametric Encoder
        self.EncP = EncP(in_channels, input_points, num_stages, embed_dim, k_neighbors, alpha, beta, LGA_block,
                         dim_expansion, type)
        self.out_channel = embed_dim
        for i in dim_expansion:
            self.out_channel *= i
        self.conv_fuse = nn.Conv1d(self.out_channel, self.out_channel, kernel_size=1)

        self.fc1 = nn.Linear(self.out_channel, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, class_num)
        self.dropout = nn.Dropout(p=0.4)
        self.bn_fuse = nn.BatchNorm1d(self.out_channel)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.lif1 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.lif2 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.lif_fuse = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.step = 1

    def forward(self, x):
        xyz = x.permute(0, 2, 1).contiguous()
        x = x.repeat(self.step, 1, 1, 1)
        # xyz: point coordinates
        # x: point features
        # Parametric Encoder
        x = self.EncP(xyz, x)
        T, B, C, N = x.shape
        x = self.lif_fuse(x)
        x = self.bn_fuse(self.conv_fuse(x.flatten(0, 1))).reshape(T, B, -1, N)
        x = x.max(-1)[0] + x.mean(-1)
        x = self.bn1(self.fc1(self.lif1(x).flatten(0, 1))).reshape(T, B, -1)
        x = self.bn2(self.dropout(self.fc2(self.lif2(x).flatten(0, 1)))).reshape(T, B, -1)
        x = self.fc3(x)
        x = x.mean(0)
        return x


# Parametric Network for ScanObjectNN
class Point_PN_scan(nn.Module):
    def __init__(self, in_channels=4, class_num=15, input_points=1024, num_stages=4, embed_dim=48, k_neighbors=40,
                 beta=100, alpha=1000, LGA_block=[2, 1, 1, 1], dim_expansion=[2, 2, 2, 1], type='scan'):
        super().__init__()
        # Parametric Encoder
        self.EncP = EncP(in_channels, input_points, num_stages, embed_dim, k_neighbors, alpha, beta, LGA_block,
                         dim_expansion, type)
        self.out_channel = embed_dim
        for i in dim_expansion:
            self.out_channel *= i
        self.conv_fuse = nn.Conv1d(self.out_channel, self.out_channel, kernel_size=1)
        self.bn_fuse = nn.BatchNorm1d(self.out_channel)
        self.lif_fuse = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                                   backend='cupy')

        self.fc1 = nn.Linear(self.out_channel, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, class_num)
        self.dropout = nn.Dropout(p=0.4)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.lif1 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.lif2 = MultiStepParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True,
                                               backend='cupy')
        self.step = 1

    def forward(self, x, xyz):
        x = x.repeat(self.step, 1, 1, 1)
        # xyz: point coordinates
        # x: point features
        # Parametric Encoder
        x = self.EncP(xyz, x)
        T, B, C, N = x.shape
        x = self.lif_fuse(x)
        x = self.bn_fuse(self.conv_fuse(x.flatten(0, 1))).reshape(T, B, -1, N)
        x = x.max(-1)[0] + x.mean(-1)
        x = self.bn1(self.fc1(self.lif1(x).flatten(0, 1))).reshape(T, B, -1)
        x = self.bn2(self.dropout(self.fc2(self.lif2(x).flatten(0, 1)))).reshape(T, B, -1)
        x = self.fc3(x)
        x = x.mean([0])
        return x


if __name__ == '__main__':
    model = Point_PN_mn40().cuda()
    input = torch.randn(5, 3, 1024).cuda()
    output = model(input)