"""Official implementation of PointNext
PointNeXt: Revisiting PointNet++ with Improved Training and Scaling Strategies
https://arxiv.org/abs/2206.04670
Guocheng Qian, Yuchen Li, Houwen Peng, Jinjie Mai, Hasan Abed Al Kader Hammoud, Mohamed Elhoseiny, Bernard Ghanem
"""
from typing import List, Type
import logging
import torch
import torch.nn as nn
from ..build import MODELS
from ..layers import create_convblock1d, create_convblock2d, create_act, CHANNEL_MAP, \
    create_grouper, furthest_point_sample, random_sample, three_interpolation, get_aggregation_feautres
from torch.autograd import Variable
import torch.nn.functional as F


def get_reduction_fn(reduction):
    reduction = 'mean' if reduction.lower() == 'avg' else reduction
    assert reduction in ['sum', 'max', 'mean']
    if reduction == 'max':
        pool = lambda x: torch.max(x, dim=-1, keepdim=False)[0]
    elif reduction == 'mean':
        pool = lambda x: torch.mean(x, dim=-1, keepdim=False)
    elif reduction == 'sum':
        pool = lambda x: torch.sum(x, dim=-1, keepdim=False)
    return pool


class LocalAggregation(nn.Module):
    """Local aggregation layer for a set 
    Set abstraction layer abstracts features from a larger set to a smaller set
    Local aggregation layer aggregates features from the same set
    """

    def __init__(self,
                 channels: List[int],
                 norm_args={'norm': 'bn1d'},
                 act_args={'act': 'relu'},
                 group_args={'NAME': 'ballquery', 'radius': 0.1, 'nsample': 16},
                 conv_args=None,
                 feature_type='dp_fj',
                 reduction='max',
                 last_act=True,
                 **kwargs
                 ):
        super().__init__()
        if kwargs:
            logging.warning(f"kwargs: {kwargs} are not used in {__class__.__name__}")
        channels[0] = CHANNEL_MAP[feature_type](channels[0])
        convs = []
        for i in range(len(channels) - 1):  # #layers in each blocks
            convs.append(create_convblock2d(channels[i], channels[i + 1],
                                            norm_args=norm_args,
                                            act_args=None if i == (
                                                    len(channels) - 2) and not last_act else act_args,
                                            **conv_args)
                         )
        self.convs = nn.Sequential(*convs)
        self.grouper = create_grouper(group_args)
        self.reduction = reduction.lower()
        self.pool = get_reduction_fn(self.reduction)
        self.feature_type = feature_type

    def forward(self, pf) -> torch.Tensor:
        # p: position, f: feature
        p, f = pf
        # neighborhood_features
        dp, fj = self.grouper(p, p, f)
        fj = get_aggregation_feautres(p, dp, f, fj, self.feature_type)
        f = self.pool(self.convs(fj))
        """ DEBUG neighbor numbers. 
        if f.shape[-1] != 1:
            query_xyz, support_xyz = p, p
            radius = self.grouper.radius
            dist = torch.cdist(query_xyz.cpu(), support_xyz.cpu())
            points = len(dist[dist < radius]) / (dist.shape[0] * dist.shape[1])
            logging.info(
                f'query size: {query_xyz.shape}, support size: {support_xyz.shape}, radius: {radius}, num_neighbors: {points}')
        DEBUG end """
        return f


class SetAbstraction(nn.Module):
    """The modified set abstraction module in PointNet++ with residual connection support
    """

    def __init__(self,
                 in_channels, out_channels,
                 layers=1,
                 stride=1,
                 group_args={'NAME': 'ballquery',
                             'radius': 0.1, 'nsample': 16},
                 norm_args={'norm': 'bn1d'},
                 act_args={'act': 'relu'},
                 conv_args=None,
                 sampler='fps',
                 feature_type='dp_fj',
                 use_res=False,
                 is_head=False,
                 **kwargs, 
                 ):
        super().__init__()
        self.stride = stride
        self.is_head = is_head
        self.all_aggr = not is_head and stride == 1
        self.use_res = use_res and not self.all_aggr and not self.is_head
        self.feature_type = feature_type
        self.opt_num =8
        self.teacher_length = 49
        self.kd_GS_thetas = nn.Parameter(torch.ones(self.teacher_length)*(1/self.teacher_length))
        self.GS_thetas = nn.Parameter(torch.Tensor([1.0 / self.opt_num for i in range(self.opt_num)]))
        # self.GS_thetas = torch.ones(1)
        # self.register_buffer('GS_thetas_index', self.GS_thetas)

        mid_channel = out_channels // 2 if stride > 1 else out_channels
        channels = [in_channels] + [mid_channel] * \
                   (layers - 1) + [out_channels]
        channels[0] = in_channels if is_head else CHANNEL_MAP[feature_type](channels[0])

        if self.use_res:
            self.skipconv = create_convblock1d(
                in_channels, channels[-1], norm_args=None, act_args=None) if in_channels != channels[
                -1] else nn.Identity()
            self.act = create_act(act_args)

        # actually, one can use local aggregation layer to replace the following
        create_conv = create_convblock1d if is_head else create_convblock2d
        convs = []
        for i in range(len(channels) - 1):
            convs.append(create_conv(channels[i], channels[i + 1],
                                     norm_args=norm_args if not is_head else None,
                                     act_args=None if i == len(channels) - 2
                                                      and (self.use_res or is_head) else act_args,
                                     **conv_args)
                         )
        self.convs = nn.Sequential(*convs)
        if not is_head:
            if self.all_aggr:
                group_args.nsample = None
                group_args.radius = None
            self.grouper = create_grouper(group_args)
            self.pool = lambda x: torch.max(x, dim=-1, keepdim=False)[0]
            if sampler.lower() == 'fps':
                self.sample_fn = furthest_point_sample
            elif sampler.lower() == 'random':
                self.sample_fn = random_sample

    def masking(self, x, temperature):
        # N, channel_in, H, W = list(x.shape)
        B, channel_in, Number= list(x.shape)
        soft_mask_variables = nn.functional.gumbel_softmax(self.GS_thetas, temperature)
        # mask = torch.zeros(N, self.largest_out_channel, self.out_H, self.out_W)
        mask = torch.zeros(B, channel_in, Number)
        mask = Variable(mask, requires_grad=False)
        # if self.use_gpu:
            # mask = mask.cuda(self.gpu_number)
        mask = mask.cuda()

        effective_output_channel = 0

        for i in range(self.opt_num):
            if i == 0:
                # mask_i = torch.ones(N, self.largest_out_channel, self.out_H, self.out_W)
                mask_i = torch.ones(B, channel_in, Number)
            else:
                # mask_i_one = torch.ones(N, self.largest_out_channel - i * self.option_step, self.out_H, self.out_W)
                # mask_i_zero = torch.zeros(N, i * self.option_step, self.out_H, self.out_W)
                # mask_i = torch.cat((mask_i_one, mask_i_zero), dim=1)
                mask_i_one = torch.ones(B, channel_in - i * self.opt_num, Number)
                mask_i_zero = torch.zeros(B, i * self.opt_num, Number)
                mask_i = torch.cat((mask_i_one, mask_i_zero), dim=1)

            mask_i = Variable(mask_i, requires_grad=False)
            # if self.use_gpu:
            mask_i = mask_i.cuda()
            mask = mask + soft_mask_variables[i] * mask_i

            weighted_out_channel_i = soft_mask_variables[i] * (channel_in - i * self.opt_num)
            effective_output_channel = effective_output_channel + weighted_out_channel_i
        
        # #########train#####
        # index = int(self.GS_thetas_index)
        # if index==0:
        #    mask = torch.ones(B, channel_in, Number)
        # else:
        #    mask_i_one = torch.ones(B, channel_in - index * 4, Number)
        #    mask_i_zero = torch.zeros(B, index * 4, Number)
        #    mask = torch.cat((mask_i_one, mask_i_zero), dim=1)
        # mask = mask.cuda()
        # effective_output_channel = channel_in - index * 4
        # #########train#####

        x = x.mul(mask)

        return x, effective_output_channel
    
    def calculate_irf_block_cost(self, effective_input_channel, effective_output_channel):

        pw_cost =  effective_input_channel * effective_output_channel 

        # dw_kernel_size = self.block.dw.conv.kernel_size[0]
        dw_cost =  effective_output_channel * effective_output_channel 

        pwl_cost =  effective_output_channel * effective_output_channel 
        cost = pw_cost + dw_cost + pwl_cost

        return cost

    def forward(self, pf):
        p, f, all_output,cost_accumulate,kl_accumulate,temperature,effective_input_channel= pf
        if self.is_head:
            f = self.convs(f)  # (n, c)
            # kl_loss = 0
            effective_output_channel=f.size(1)
        else:
            if not self.all_aggr:
                idx = self.sample_fn(p, p.shape[1] // self.stride).long()
                new_p = torch.gather(p, 1, idx.unsqueeze(-1).expand(-1, -1, 3))
            else:
                new_p = p
            """ DEBUG neighbor numbers. 
            query_xyz, support_xyz = new_p, p
            radius = self.grouper.radius
            dist = torch.cdist(query_xyz.cpu(), support_xyz.cpu())
            points = len(dist[dist < radius]) / (dist.shape[0] * dist.shape[1])
            logging.info(f'query size: {query_xyz.shape}, support size: {support_xyz.shape}, radius: {radius}, num_neighbors: {points}')
            DEBUG end """
            if self.use_res or 'df' in self.feature_type:
                fi = torch.gather(
                    f, -1, idx.unsqueeze(1).expand(-1, f.shape[1], -1))
                if self.use_res:
                    identity = self.skipconv(fi)
            else:
                fi = None
            dp, fj = self.grouper(new_p, p, f)
            fj = get_aggregation_feautres(new_p, dp, fi, fj, feature_type=self.feature_type)
            f = self.pool(self.convs(fj)) 
            if all_output != None:
                soft_mask_variables = nn.functional.gumbel_softmax(self.kd_GS_thetas, temperature)
                
                # student = F.interpolate(student[None,:], size =(14,14), mode='bilinear')
                # ##########segmentation########
                # student = ((f)**2).sum(1) # f([32, 64, 1024]) student([32, 1024])
                # student=upsample_linear(student, all_output.size(1)) # all_output.size(1):N  student([2, 6000])-> student([1, 24000]) all_output([13, 24000])
                # student = student.reshape(-1)/torch.norm(student) #student([1, 24000])->([24000])
                # all_output = all_output.reshape(all_output.shape[0],-1) #all_output([13, 24000])
                # all_out_norm = torch.norm(all_output, dim=-1) #torch.Size([13])
                # all_output = all_output/all_out_norm[:,None] #all_output([13, 24000])/all_out_norm[:,None]([13, 1])
                # if all_output.size(1) == student.size(0):
                #     kl_loss = torch.norm(all_output - student[None,:]) * soft_mask_variables[:,None]
                #     kl_loss = kl_loss.mean().abs()
                # else:
                #     kl_loss = 0
                # kl_accumulate = kl_accumulate+kl_loss
                # ##########segmentation########
                ##########classification########
                student = f.reshape(f.shape[0],-1) #student([32, 65536])
                student_norm = torch.norm(student, dim=-1) # student_norm([32])
                student = student/student_norm[:,None]
                all_output = all_output.reshape(all_output.shape[0],-1)
                all_out_norm = torch.norm(all_output, dim=-1) #all_output([32, 65536]) 
                all_output = all_output/all_out_norm[:,None] # all_output ([32, 65536])
                kl_loss = torch.norm(all_output - student) * soft_mask_variables[:,None]
                kl_loss = kl_loss.mean().abs()
                kl_accumulate = kl_accumulate+kl_loss
                ##########classification########                    
            else:
                kl_loss = None
            f, effective_output_channel = self.masking(f, temperature)
            cost = self.calculate_irf_block_cost(effective_input_channel, effective_output_channel)
            cost_accumulate = cost + cost_accumulate
            if self.use_res:
                f = self.act(f + identity)
            p = new_p
            
        return p, f,all_output, cost_accumulate,kl_accumulate,temperature,effective_output_channel


class FeaturePropogation(nn.Module):
    """The Feature Propogation module in PointNet++
    """

    def __init__(self, mlp,
                 upsample=True,
                 norm_args={'norm': 'bn1d'},
                 act_args={'act': 'relu'}
                 ):
        """
        Args:
            mlp: [current_channels, next_channels, next_channels]
            out_channels:
            norm_args:
            act_args:
        """
        super().__init__()
        if not upsample:
            self.linear2 = nn.Sequential(
                nn.Linear(mlp[0], mlp[1]), nn.ReLU(inplace=True))
            mlp[1] *= 2
            linear1 = []
            for i in range(1, len(mlp) - 1):
                linear1.append(create_convblock1d(mlp[i], mlp[i + 1],
                                                  norm_args=norm_args, act_args=act_args
                                                  ))
            self.linear1 = nn.Sequential(*linear1)
        else:
            convs = []
            for i in range(len(mlp) - 1):
                convs.append(create_convblock1d(mlp[i], mlp[i + 1],
                                                norm_args=norm_args, act_args=act_args
                                                ))
            self.convs = nn.Sequential(*convs)

        self.pool = lambda x: torch.mean(x, dim=-1, keepdim=False)

    def forward(self, pf1, pf2=None):
        # pfb1 is with the same size of upsampled points
        if pf2 is None:
            _, f = pf1  # (B, N, 3), (B, C, N)
            f_global = self.pool(f)
            f = torch.cat(
                (f, self.linear2(f_global).unsqueeze(-1).expand(-1, -1, f.shape[-1])), dim=1)
            f = self.linear1(f)
        else:
            p1, f1 = pf1
            p2, f2 = pf2
            if f1 is not None:
                f = self.convs(
                    torch.cat((f1, three_interpolation(p1, p2, f2)), dim=1))
            else:
                f = self.convs(three_interpolation(p1, p2, f2))
        return f


class InvResMLP(nn.Module):
    def __init__(self,
                 in_channels,
                 norm_args=None,
                 act_args=None,
                 aggr_args={'feature_type': 'dp_fj', "reduction": 'max'},
                 group_args={'NAME': 'ballquery'},
                 conv_args=None,
                 expansion=1,
                 use_res=True,
                 num_posconvs=2,
                 less_act=False,
                 **kwargs
                 ):
        super().__init__()
        self.use_res = use_res
        mid_channels = int(in_channels * expansion)
        self.convs = LocalAggregation([in_channels, in_channels],
                                      norm_args=norm_args, act_args=act_args if num_posconvs > 0 else None,
                                      group_args=group_args, conv_args=conv_args,
                                      **aggr_args, **kwargs)
        if num_posconvs < 1:
            channels = []
        elif num_posconvs == 1:
            channels = [in_channels, in_channels]
        else:
            channels = [in_channels, mid_channels, in_channels]
        pwconv = []
        # point wise after depth wise conv (without last layer)
        for i in range(len(channels) - 1):
            pwconv.append(create_convblock1d(channels[i], channels[i + 1],
                                             norm_args=norm_args,
                                             act_args=act_args if
                                             (i != len(channels) - 2) and not less_act else None,
                                             **conv_args)
                          )
        self.pwconv = nn.Sequential(*pwconv)
        self.act = create_act(act_args)
        self.teacher_length = 49
        self.kd_GS_thetas = nn.Parameter(torch.ones(self.teacher_length)*(1/self.teacher_length))
        self.GS_thetas = nn.Parameter(torch.Tensor([1.0 / 4 for i in range(4)]))
        # self.GS_thetas = torch.ones(1)
        # self.register_buffer('GS_thetas_index', self.GS_thetas)

    def masking(self, x, temperature):
        # N, channel_in, H, W = list(x.shape)
        B, channel_in, Number= list(x.shape)
        soft_mask_variables = nn.functional.gumbel_softmax(self.GS_thetas, temperature)
        # mask = torch.zeros(N, self.largest_out_channel, self.out_H, self.out_W)
        mask = torch.zeros(B, channel_in, Number)
        mask = Variable(mask, requires_grad=False)
        # if self.use_gpu:
            # mask = mask.cuda(self.gpu_number)
        mask = mask.cuda()

        effective_output_channel = 0

        for i in range(4):
            if i == 0:
                # mask_i = torch.ones(N, self.largest_out_channel, self.out_H, self.out_W)
                mask_i = torch.ones(B, channel_in, Number)
            else:
                # mask_i_one = torch.ones(N, self.largest_out_channel - i * self.option_step, self.out_H, self.out_W)
                # mask_i_zero = torch.zeros(N, i * self.option_step, self.out_H, self.out_W)
                # mask_i = torch.cat((mask_i_one, mask_i_zero), dim=1)
                mask_i_one = torch.ones(B, channel_in - i * 4, Number)
                mask_i_zero = torch.zeros(B, i * 4, Number)
                mask_i = torch.cat((mask_i_one, mask_i_zero), dim=1)

            mask_i = Variable(mask_i, requires_grad=False)
            # if self.use_gpu:
            mask_i = mask_i.cuda()
            mask = mask + soft_mask_variables[i] * mask_i

            weighted_out_channel_i = soft_mask_variables[i] * (channel_in - i * 4)
            effective_output_channel = effective_output_channel + weighted_out_channel_i

        # #########train#####
        # index = int(self.GS_thetas_index)
        # if index==0:
        #    mask = torch.ones(B, channel_in, Number)
        # else:
        #    mask_i_one = torch.ones(B, channel_in - index * 4, Number)
        #    mask_i_zero = torch.zeros(B, index * 4, Number)
        #    mask = torch.cat((mask_i_one, mask_i_zero), dim=1)
        # mask = mask.cuda()
        # effective_output_channel = channel_in - index * 4
        # #########train#####
         
        x = x.mul(mask)

        return x, effective_output_channel
    
    def calculate_irf_block_cost(self, effective_input_channel, effective_output_channel):

        pw_cost =  effective_input_channel * effective_output_channel 

        # dw_kernel_size = self.block.dw.conv.kernel_size[0]
        dw_cost =  effective_output_channel * effective_output_channel 

        pwl_cost =  effective_output_channel * effective_output_channel 
        cost = pw_cost + dw_cost + pwl_cost

        return cost        

    def forward(self, pf):
        p, f, all_output,cost_accumulate,kl_accumulate,temperature,effective_input_channel = pf
        identity = f
        f = self.convs([p, f])
        f = self.pwconv(f)
        soft_mask_variables = nn.functional.gumbel_softmax(self.kd_GS_thetas, temperature)
        student = ((f)**2).sum(1)
        if all_output != None:
            student=upsample_linear(student, all_output.size(1))
            student = student.reshape(-1)/torch.norm(student)           
            all_output = all_output.reshape(all_output.shape[0],-1)
            all_out_norm = torch.norm(all_output, dim=-1)
            all_output = all_output/all_out_norm[:,None]
            if all_output.size(1) == student.size(0): 
                kl_loss = torch.norm(all_output - student[None,:]) * soft_mask_variables[:,None]
                kl_loss = kl_loss.mean().abs()
                kl_accumulate = kl_accumulate + kl_loss
            else:
                kl_loss = kl_accumulate
        else:
            kl_loss =  None
        f, effective_output_channel = self.masking(f, temperature)
        cost = self.calculate_irf_block_cost(effective_input_channel, effective_output_channel)
        cost_accumulate = cost + cost_accumulate
        if f.shape[-1] == identity.shape[-1] and self.use_res:
            f += identity
        f = self.act(f)
        return [p, f,all_output,cost_accumulate,kl_accumulate,temperature,effective_output_channel]


class ResBlock(nn.Module):
    def __init__(self,
                 in_channels,
                 norm_args=None,
                 act_args=None,
                 aggr_args={'feature_type': 'dp_fj', "reduction": 'max'},
                 group_args={'NAME': 'ballquery'},
                 conv_args=None,
                 expansion=1,
                 use_res=True,
                 **kwargs
                 ):
        super().__init__()
        self.use_res = use_res
        mid_channels = in_channels * expansion
        self.convs = LocalAggregation([in_channels, in_channels, mid_channels, in_channels],
                                      norm_args=norm_args, act_args=None,
                                      group_args=group_args, conv_args=conv_args,
                                      **aggr_args, **kwargs)
        self.act = create_act(act_args)

    def forward(self, pf):
        p, f = pf
        identity = f
        f = self.convs([p, f])
        if f.shape[-1] == identity.shape[-1] and self.use_res:
            f += identity
        f = self.act(f)
        return [p, f]


@MODELS.register_module()
class PointNextEncoder(nn.Module):
    r"""The Encoder for PointNext 
    `"PointNeXt: Revisiting PointNet++ with Improved Training and Scaling Strategies".
    <https://arxiv.org/abs/2206.04670>`_.
    .. note::
        For an example of using :obj:`PointNextEncoder`, see
        `examples/segmentation/main.py <https://github.com/guochengqian/PointNeXt/blob/master/cfgs/s3dis/README.md>`_.
    Args:
        in_channels (int, optional): input channels . Defaults to 4.
        width (int, optional): width of network, the output mlp of the stem MLP. Defaults to 32.
        blocks (List[int], optional): # of blocks per stage (including the SA block). Defaults to [1, 4, 7, 4, 4].
        strides (List[int], optional): the downsampling ratio of each stage. Defaults to [4, 4, 4, 4].
        block (strorType[InvResMLP], optional): the block to use for depth scaling. Defaults to 'InvResMLP'.
        nsample (intorList[int], optional): the number of neighbors to query for each block. Defaults to 32.
        radius (floatorList[float], optional): the initial radius. Defaults to 0.1.
        aggr_args (_type_, optional): the args for local aggregataion. Defaults to {'feature_type': 'dp_fj', "reduction": 'max'}.
        group_args (_type_, optional): the args for grouping. Defaults to {'NAME': 'ballquery'}.
        norm_args (_type_, optional): the args for normalization layer. Defaults to {'norm': 'bn'}.
        act_args (_type_, optional): the args for activation layer. Defaults to {'act': 'relu'}.
        expansion (int, optional): the expansion ratio of the InvResMLP block. Defaults to 4.
        sa_layers (int, optional): the number of MLP layers to use in the SA block. Defaults to 1.
        sa_use_res (bool, optional): wheter to use residual connection in SA block. Set to True only for PointNeXt-S. 
    """

    def __init__(self,
                 in_channels: int = 4,
                 width: int = 32,
                 blocks: List[int] = [1, 4, 7, 4, 4],
                 strides: List[int] = [4, 4, 4, 4],
                 block: str or Type[InvResMLP] = 'InvResMLP',
                 nsample: int or List[int] = 32,
                 radius: float or List[float] = 0.1,
                 aggr_args: dict = {'feature_type': 'dp_fj', "reduction": 'max'},
                 group_args: dict = {'NAME': 'ballquery'},
                 sa_layers: int = 1,
                 sa_use_res: bool = False,
                 temperature: float = 5.0,
                 **kwargs
                 ):
        super().__init__()
        if isinstance(block, str):
            block = eval(block)
        self.blocks = blocks
        self.strides = strides
        self.in_channels = in_channels
        self.aggr_args = aggr_args
        self.norm_args = kwargs.get('norm_args', {'norm': 'bn'}) 
        self.act_args = kwargs.get('act_args', {'act': 'relu'}) 
        self.conv_args = kwargs.get('conv_args', None)
        self.sampler = kwargs.get('sampler', 'fps')
        self.expansion = kwargs.get('expansion', 4)
        self.sa_layers = sa_layers
        self.sa_use_res = sa_use_res
        self.use_res = kwargs.get('use_res', True)
        # self.temperature = temperature
        radius_scaling = kwargs.get('radius_scaling', 2)
        nsample_scaling = kwargs.get('nsample_scaling', 1)

        self.radii = self._to_full_list(radius, radius_scaling)
        self.nsample = self._to_full_list(nsample, nsample_scaling)
        # self.teacher_length = 49
        # self.kd_GS_thetas = nn.Parameter(torch.ones(self.teacher_length)*(1/self.teacher_length))
        logging.info(f'radius: {self.radii},\n nsample: {self.nsample}')

        # double width after downsampling.
        channels = []
        for stride in strides:
            if stride != 1:
                width *= 2
            channels.append(width)
        encoder = []
        for i in range(len(blocks)):
            group_args.radius = self.radii[i]
            group_args.nsample = self.nsample[i]
            encoder.append(self._make_enc(
                block, channels[i], blocks[i], stride=strides[i], group_args=group_args,
                is_head=i == 0 and strides[i] == 1
            ))
        self.encoder = nn.Sequential(*encoder)
        self.out_channels = channels[-1]
        self.channel_list = channels

    def _to_full_list(self, param, param_scaling=1):
        # param can be: radius, nsample
        param_list = []
        if isinstance(param, List):
            # make param a full list
            for i, value in enumerate(param):
                value = [value] if not isinstance(value, List) else value
                if len(value) != self.blocks[i]:
                    value += [value[-1]] * (self.blocks[i] - len(value))
                param_list.append(value)
        else:  # radius is a scalar (in this case, only initial raidus is provide), then create a list (radius for each block)
            for i, stride in enumerate(self.strides):
                if stride == 1:
                    param_list.append([param] * self.blocks[i])
                else:
                    param_list.append(
                        [param] + [param * param_scaling] * (self.blocks[i] - 1))
                    param *= param_scaling
        return param_list

    def _make_enc(self, block, channels, blocks, stride, group_args, is_head=False):
        layers = []
        radii = group_args.radius
        nsample = group_args.nsample
        group_args.radius = radii[0]
        group_args.nsample = nsample[0]
        layers.append(SetAbstraction(self.in_channels, channels,
                                     self.sa_layers if not is_head else 1, stride,
                                     group_args=group_args,
                                     sampler=self.sampler,
                                     norm_args=self.norm_args, act_args=self.act_args, conv_args=self.conv_args,
                                     is_head=is_head, use_res=self.sa_use_res, **self.aggr_args 
                                     ))
        self.in_channels = channels
        for i in range(1, blocks):
            group_args.radius = radii[i]
            group_args.nsample = nsample[i]
            layers.append(block(self.in_channels,
                                aggr_args=self.aggr_args,
                                norm_args=self.norm_args, act_args=self.act_args, group_args=group_args,
                                conv_args=self.conv_args, expansion=self.expansion,
                                use_res=self.use_res
                                ))
        return nn.Sequential(*layers)

    def forward_cls_feat(self, p0,all_output, cost_accumulate,kl_accumulate,temperature, f0=None):
        if hasattr(p0, 'keys'):
            p0, f0 = p0['pos'], p0.get('x', None)
        if f0 is None:
            f0 = p0.clone().transpose(1, 2).contiguous()
        p, f = [p0], [f0]
        effective_channel = 3
        effective_channels_cou=0
        for i in range(0, len(self.encoder)):
            p0, f0,output_layer,cost_accumulate,kl_accumulate,temperature,effective_channel = \
                self.encoder[i]([p0, f0,all_output[i],cost_accumulate,kl_accumulate,temperature,effective_channel])
            # p.append(p0)
            # f.append(f0)
            effective_channels_cou=effective_channel+effective_channels_cou
        return f0.squeeze(-1),cost_accumulate,kl_accumulate,effective_channels_cou 

    def forward_seg_feat(self, p0, all_output, cost_accumulate,kl_accumulate,temperature,f0=None):
        if hasattr(p0, 'keys'):
            p0, f0 = p0['pos'], p0.get('x', None)
        if f0 is None:
            f0 = p0.clone().transpose(1, 2).contiguous()
        p, f = [p0], [f0]
        # all_output_all= [all_output]
        # kl_loss_all = []
        effective_channel = 4
        for i in range(0, len(self.encoder)):
            _p, _f,all_output,cost_accumulate,kl_accumulate,temperature,effective_channel = \
                self.encoder[i]([p[-1], f[-1],all_output,cost_accumulate,kl_accumulate,temperature,effective_channel])
            p.append(_p)
            f.append(_f)
            # all_output_all.append(_all_output)
            # kl_loss_all.append(_kl_loss_all)
        return p, f,cost_accumulate, kl_accumulate, effective_channel

    def forward(self, p0, f0=None, all_output=None, temperature=None):
        # kl_accumulate = Variable(torch.Tensor([[0.0]]), requires_grad=True)
        # cost_accumulate = Variable(torch.Tensor([[0.0]]), requires_grad=True)
        return self.forward_seg_feat(p0, f0, all_output)


@MODELS.register_module()
class PointNextDecoder(nn.Module):
    def __init__(self,
                 encoder_channel_list: List[int],
                 decoder_layers: int = 2,
                 decoder_stages: int = 4, 
                 **kwargs
                 ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.in_channels = encoder_channel_list[-1]
        skip_channels = encoder_channel_list[:-1]
        if len(skip_channels) < decoder_stages:
            skip_channels.insert(0, kwargs.get('in_channels', 3))
        # the output channel after interpolation
        fp_channels = encoder_channel_list[:decoder_stages]

        n_decoder_stages = len(fp_channels)
        decoder = [[] for _ in range(n_decoder_stages)]
        for i in range(-1, -n_decoder_stages - 1, -1):
            decoder[i] = self._make_dec(
                skip_channels[i], fp_channels[i])
        self.decoder = nn.Sequential(*decoder)
        self.out_channels = fp_channels[-n_decoder_stages]

    def _make_dec(self, skip_channels, fp_channels):
        layers = []
        mlp = [skip_channels + self.in_channels] + \
              [fp_channels] * self.decoder_layers
        layers.append(FeaturePropogation(mlp))
        self.in_channels = fp_channels
        return nn.Sequential(*layers)

    def forward(self, p, f):
        for i in range(-1, -len(self.decoder) - 1, -1):
            f[i - 1] = self.decoder[i][1:](
                [p[i], self.decoder[i][0]([p[i - 1], f[i - 1]], [p[i], f[i]])])[1]
        return f[-len(self.decoder) - 1]


@MODELS.register_module()
class PointNextPartDecoder(nn.Module):
    def __init__(self,
                 encoder_channel_list: List[int],
                 decoder_layers: int = 2,
                 decoder_blocks: List[int] = [1, 1, 1, 1],
                 decoder_strides: List[int] = [4, 4, 4, 4],
                 act_args: str = 'relu',
                 cls_map='pointnet2',
                 num_classes: int = 16,
                 cls2partembed=None,
                 **kwargs
                 ):
        super().__init__()
        self.decoder_layers = decoder_layers
        self.in_channels = encoder_channel_list[-1]
        skip_channels = encoder_channel_list[:-1]
        fp_channels = encoder_channel_list[:-1]
        
        # the following is for decoder blocks
        self.conv_args = kwargs.get('conv_args', None)
        radius_scaling = kwargs.get('radius_scaling', 2)
        nsample_scaling = kwargs.get('nsample_scaling', 1)
        block = kwargs.get('block', 'InvResMLP')
        if isinstance(block, str):
            block = eval(block)
        self.blocks = decoder_blocks
        self.strides = decoder_strides
        self.norm_args = kwargs.get('norm_args', {'norm': 'bn'}) 
        self.act_args = kwargs.get('act_args', {'act': 'relu'}) 
        self.expansion = kwargs.get('expansion', 4)
        radius = kwargs.get('radius', 0.1)
        nsample = kwargs.get('nsample', 16)
        self.radii = self._to_full_list(radius, radius_scaling)
        self.nsample = self._to_full_list(nsample, nsample_scaling)
        self.cls_map = cls_map
        self.num_classes = num_classes
        self.use_res = kwargs.get('use_res', True)
        group_args = kwargs.get('group_args', {'NAME': 'ballquery'})
        self.aggr_args = kwargs.get('aggr_args', 
                                    {'feature_type': 'dp_fj', "reduction": 'max'}
                                    )  
        if self.cls_map == 'curvenet':
            # global features
            self.global_conv2 = nn.Sequential(
                create_convblock1d(fp_channels[-1] * 2, 128,
                                   norm_args=None,
                                   act_args=act_args))
            self.global_conv1 = nn.Sequential(
                create_convblock1d(fp_channels[-2] * 2, 64,
                                   norm_args=None,
                                   act_args=act_args))
            skip_channels[0] += 64 + 128 + 16  # shape categories labels
        elif self.cls_map == 'pointnet2':
            self.convc = nn.Sequential(create_convblock1d(16, 64,
                                                          norm_args=None,
                                                          act_args=act_args))
            skip_channels[0] += 64  # shape categories labels

        elif self.cls_map == 'pointnext':
            self.global_conv2 = nn.Sequential(
                create_convblock1d(fp_channels[-1] * 2, 128,
                                   norm_args=None,
                                   act_args=act_args))
            self.global_conv1 = nn.Sequential(
                create_convblock1d(fp_channels[-2] * 2, 64,
                                   norm_args=None,
                                   act_args=act_args))
            skip_channels[0] += 64 + 128 + 50  # shape categories labels
            self.cls2partembed = cls2partembed
        elif self.cls_map == 'pointnext1':
            self.convc = nn.Sequential(create_convblock1d(50, 64,
                                                          norm_args=None,
                                                          act_args=act_args))
            skip_channels[0] += 64  # shape categories labels
            self.cls2partembed = cls2partembed

        n_decoder_stages = len(fp_channels)
        decoder = [[] for _ in range(n_decoder_stages)]
        for i in range(-1, -n_decoder_stages - 1, -1):
            group_args.radius = self.radii[i]
            group_args.nsample = self.nsample[i]
            decoder[i] = self._make_dec(
                skip_channels[i], fp_channels[i], group_args=group_args, block=block, blocks=self.blocks[i])

        self.decoder = nn.Sequential(*decoder)
        self.out_channels = fp_channels[-n_decoder_stages]

    def _make_dec(self, skip_channels, fp_channels, group_args=None, block=None, blocks=1):
        layers = []
        radii = group_args.radius
        nsample = group_args.nsample
        mlp = [skip_channels + self.in_channels] + \
              [fp_channels] * self.decoder_layers
        layers.append(FeaturePropogation(mlp, act_args=self.act_args))
        self.in_channels = fp_channels
        for i in range(1, blocks):
            group_args.radius = radii[i]
            group_args.nsample = nsample[i]
            layers.append(block(self.in_channels,
                                aggr_args=self.aggr_args,
                                norm_args=self.norm_args, act_args=self.act_args, group_args=group_args,
                                conv_args=self.conv_args, expansion=self.expansion,
                                use_res=self.use_res
                                ))
        return nn.Sequential(*layers)

    def _to_full_list(self, param, param_scaling=1):
        # param can be: radius, nsample
        param_list = []
        if isinstance(param, List):
            # make param a full list
            for i, value in enumerate(param):
                value = [value] if not isinstance(value, List) else value
                if len(value) != self.blocks[i]:
                    value += [value[-1]] * (self.blocks[i] - len(value))
                param_list.append(value)
        else:  # radius is a scalar (in this case, only initial raidus is provide), then create a list (radius for each block)
            for i, stride in enumerate(self.strides):
                if stride == 1:
                    param_list.append([param] * self.blocks[i])
                else:
                    param_list.append(
                        [param] + [param * param_scaling] * (self.blocks[i] - 1))
                    param *= param_scaling
        return param_list

    def forward(self, p, f, cls_label):
        B, N = p[0].shape[0:2]
        if self.cls_map == 'curvenet':
            emb1 = self.global_conv1(f[-2])
            emb1 = emb1.max(dim=-1, keepdim=True)[0]  # bs, 64, 1
            emb2 = self.global_conv2(f[-1])
            emb2 = emb2.max(dim=-1, keepdim=True)[0]  # bs, 128, 1
            cls_one_hot = torch.zeros((B, self.num_classes), device=p[0].device)
            cls_one_hot = cls_one_hot.scatter_(1, cls_label, 1).unsqueeze(-1)
            cls_one_hot = torch.cat((emb1, emb2, cls_one_hot), dim=1)
            cls_one_hot = cls_one_hot.expand(-1, -1, N)
        elif self.cls_map == 'pointnet2':
            cls_one_hot = torch.zeros((B, self.num_classes), device=p[0].device)
            cls_one_hot = cls_one_hot.scatter_(1, cls_label, 1).unsqueeze(-1).repeat(1, 1, N)
            cls_one_hot = self.convc(cls_one_hot)
        elif self.cls_map == 'pointnext':
            emb1 = self.global_conv1(f[-2])
            emb1 = emb1.max(dim=-1, keepdim=True)[0]  # bs, 64, 1
            emb2 = self.global_conv2(f[-1])
            emb2 = emb2.max(dim=-1, keepdim=True)[0]  # bs, 128, 1
            self.cls2partembed = self.cls2partembed.to(p[0].device)
            cls_one_hot = self.cls2partembed[cls_label.squeeze()].unsqueeze(-1)
            cls_one_hot = torch.cat((emb1, emb2, cls_one_hot), dim=1)
            cls_one_hot = cls_one_hot.expand(-1, -1, N)
        elif self.cls_map == 'pointnext1':
            self.cls2partembed = self.cls2partembed.to(p[0].device)
            cls_one_hot = self.cls2partembed[cls_label.squeeze()].unsqueeze(-1).expand(-1, -1, N)
            cls_one_hot = self.convc(cls_one_hot)

        for i in range(-1, -len(self.decoder), -1):
            f[i - 1] = self.decoder[i][1:](
                [p[i-1], self.decoder[i][0]([p[i - 1], f[i - 1]], [p[i], f[i]])])[1]

        # TODO: study where to add this ? 
        f[-len(self.decoder) - 1] = self.decoder[0][1:](
            [p[1], self.decoder[0][0]([p[1], torch.cat([cls_one_hot, f[1]], 1)], [p[2], f[2]])])[1]

        return f[-len(self.decoder) - 1]

def pad_tensor(tensor, new_size):
    # 计算需要填充的点数
    padding_size = new_size - tensor.size(1)
    if padding_size > 0:
        # 添加零填充
        padding = torch.zeros((tensor.size(0), padding_size), device=tensor.device)
        return torch.cat((tensor, padding), dim=1)
    else:
        return tensor


def upsample_linear(point_cloud, new_size):
    """
    使用线性插值对点云进行上采样。
    
    参数:
    - point_cloud: PyTorch张量，形状为(C, N)，表示原始点云数据。
    - scale_factor: 整数，表示上采样的倍数。
    
    返回:
    - upsampled_point_cloud: PyTorch张量，形状为(C, xN)，表示上采样后的点云数据。
    """
    # 确定新的点数
    scale_factor = new_size // (point_cloud.size(1)*point_cloud.size(0))
    
    # 初始化上采样后的点云
    upsampled_point_cloud = torch.zeros((1, new_size), device=point_cloud.device)
    
    # 遍历每个通道
    # for c in range(1):
    #     # 将原始点云的每个通道视为一维信号
    #     signal = point_cloud[c].unsqueeze(0)
        
    # 使用线性插值进行上采样
    upsampled_signal = F.interpolate(point_cloud.contiguous().view(1, -1).unsqueeze(1), scale_factor=scale_factor, mode='linear', align_corners=False)
    upsampled_signal = upsampled_signal.squeeze(1)
    upsampled_signal = pad_tensor(upsampled_signal, new_size)
    # 将上采样后的信号赋值到新张量中
    upsampled_point_cloud = upsampled_signal
    
    
    return upsampled_point_cloud