import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from timm.models.layers import DropPath, trunc_normal_
import numpy as np
from .build import MODELS
from utils import misc
from utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
from utils.logger import *
import random
from knn_cuda import KNN
from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2
#from models.GPT import GPT_extractor, GPT_generator
import math
from models.z_order import *

class Encoder_large(nn.Module):  # Embedding module
    def __init__(self, encoder_channel):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 256, 1),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Conv1d(256, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, 1024, 1)
        )
        self.second_conv = nn.Sequential(
            nn.Conv1d(2048, 2048, 1),
            nn.BatchNorm1d(2048),
            nn.ReLU(inplace=True),
            nn.Conv1d(2048, self.encoder_channel, 1)
        )

    def forward(self, point_groups):
        '''
            point_groups : B G N 3
            -----------------
            feature_global : B G C
        '''
        bs, g, n, _ = point_groups.shape
        point_groups = point_groups.reshape(bs * g, n, 3)
        # encoder
        feature = self.first_conv(point_groups.transpose(2, 1))  # BG 256 n
        feature_global = torch.max(feature, dim=2, keepdim=True)[0]  # BG 256 1
        feature = torch.cat(
            [feature_global.expand(-1, -1, n), feature], dim=1)  # BG 512 n
        feature = self.second_conv(feature)  # BG 1024 n
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]  # BG 1024
        return feature_global.reshape(bs, g, self.encoder_channel)

class Encoder_small(nn.Module):  # Embedding module
    def __init__(self, encoder_channel):
        super().__init__()
        self.encoder_channel = encoder_channel
        self.first_conv = nn.Sequential(
            nn.Conv1d(3, 128, 1),
            nn.BatchNorm1d(128),
            nn.ReLU(inplace=True),
            nn.Conv1d(128, 256, 1)
        )
        self.second_conv = nn.Sequential(
            nn.Conv1d(512, 512, 1),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Conv1d(512, self.encoder_channel, 1)
        )

    def forward(self, point_groups):
        '''
            point_groups : B G N 3
            -----------------
            feature_global : B G C
        '''
        bs, g, n, _ = point_groups.shape
        point_groups = point_groups.reshape(bs * g, n, 3)
        # encoder
        feature = self.first_conv(point_groups.transpose(2, 1))
        feature_global = torch.max(feature, dim=2, keepdim=True)[0]
        feature = torch.cat(
            [feature_global.expand(-1, -1, n), feature], dim=1)
        feature = self.second_conv(feature)
        feature_global = torch.max(feature, dim=2, keepdim=False)[0]
        return feature_global.reshape(bs, g, self.encoder_channel)


class Group(nn.Module):
    def __init__(self, num_group, group_size):
        super().__init__()
        self.num_group = num_group
        self.group_size = group_size
        self.knn = KNN(k=self.group_size, transpose_mode=True)
        self.knn_2 = KNN(k=1, transpose_mode=True)

    def simplied_morton_sorting(self, xyz, center):
        '''
        Simplifying the Morton code sorting to iterate and set the nearest patch to the last patch as the next patch, we found this to be more efficient.
        '''
        batch_size, num_points, _ = xyz.shape
        distances_batch = torch.cdist(center, center)
        distances_batch[:, torch.eye(self.num_group).bool()] = float("inf")
        idx_base = torch.arange(
            0, batch_size, device=xyz.device) * self.num_group
        sorted_indices_list = []
        sorted_indices_list.append(idx_base)
        distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
            1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
        distances_batch[idx_base] = float("inf")
        distances_batch = distances_batch.view(
            batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
        for i in range(self.num_group - 1):
            distances_batch = distances_batch.view(
                batch_size * self.num_group, self.num_group)
            distances_to_last_batch = distances_batch[sorted_indices_list[-1]]
            closest_point_idx = torch.argmin(distances_to_last_batch, dim=-1)
            closest_point_idx = closest_point_idx + idx_base
            sorted_indices_list.append(closest_point_idx)
            distances_batch = distances_batch.view(batch_size, self.num_group, self.num_group).transpose(
                1, 2).contiguous().view(batch_size * self.num_group, self.num_group)
            distances_batch[closest_point_idx] = float("inf")
            distances_batch = distances_batch.view(
                batch_size, self.num_group, self.num_group).transpose(1, 2).contiguous()
        sorted_indices = torch.stack(sorted_indices_list, dim=-1)
        sorted_indices = sorted_indices.view(-1)
        return sorted_indices

    def morton_sorting(self, xyz, center):
        batch_size, num_points, _ = xyz.shape
        all_indices = []
        for index in range(batch_size):
            points = center[index]
            z = get_z_values(points.cpu().numpy())
            idxs = np.zeros((self.num_group), dtype=np.int32)
            temp = np.arange(self.num_group)
            z_ind = np.argsort(z[temp])
            idxs = temp[z_ind]
            all_indices.append(idxs)
        all_indices = torch.tensor(all_indices, device=xyz.device)

        idx_base = torch.arange(
            0, batch_size, device=xyz.device).view(-1, 1) * self.num_group
        sorted_indices = all_indices + idx_base
        sorted_indices = sorted_indices.view(-1)

    def forward(self, xyz):
        '''
            input: B N 3
            ---------------------------
            output: B G M 3
            center : B G 3
        '''
        batch_size, num_points, _ = xyz.shape
        # fps the centers out
        center = misc.fps(xyz, self.num_group)  # B G 3
        # knn to get the neighborhood
        _, idx = self.knn(xyz, center)  # B G M
        assert idx.size(1) == self.num_group
        assert idx.size(2) == self.group_size
        idx_base = torch.arange(
            0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
        idx = idx + idx_base
        idx = idx.view(-1)
        neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
        neighborhood = neighborhood.view(
            batch_size, self.num_group, self.group_size, 3).contiguous()
        # normalize
        neighborhood = neighborhood - center.unsqueeze(2)

        # can utilize morton_sorting by choosing morton_sorting function
        sorted_indices = self.simplied_morton_sorting(xyz, center)

        neighborhood = neighborhood.view(
            batch_size * self.num_group, self.group_size, 3)[sorted_indices, :, :]
        neighborhood = neighborhood.view(
            batch_size, self.num_group, self.group_size, 3).contiguous()
        center = center.view(
            batch_size * self.num_group, 3)[sorted_indices, :]
        center = center.view(
            batch_size, self.num_group, 3).contiguous()

        return neighborhood, center


class Block(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(Block, self).__init__()
        self.ln_1 = nn.LayerNorm(embed_dim)
        self.ln_2 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
        )
        self.num_heads = num_heads
        self.embed_dim = embed_dim

    def score_assignment_step(self, attn, v):
        """
        Token Score Assignment Step.
        :param attn: attention matrix
        :param v: values
        :return: sorted significance scores and their corresponding indices
        """
        v_norm = torch.linalg.norm(v, ord=2, dim=2)  # value norm of size [B x T]
        significance_score = attn[:, :, 0] # attention weights of CLS token of size [B x T]
        #print(significance_score.shape)
        #print(v_norm.shape)
        significance_score = significance_score * v_norm  # [B x T]
        significance_score = significance_score[:, 2:]  # [B x T-2]

        return significance_score

    def sampling(self, significance_score, mask_ratio):
        """
        Sample tokens based on their significance scores.
        """
        B, K = significance_score.shape
        
        #mask_num = int( (1-mask_ratio)*K )
        mask_num = int( mask_ratio*K )

        temperature = 0.5
        probabilities = torch.softmax((significance_score/ temperature), dim=1)
        beta = torch.rand(B, K).to(significance_score.device)
        r = -torch.log(-torch.log(beta))
        probabilities = torch.log(probabilities)+r

        sorted_patch, sorted_indices = torch.sort(probabilities, descending=False, dim=1)

        #tokens_to_pick_ind = sorted_indices[:, mask_num:] #[B, M]
        tokens_to_pick_ind = sorted_indices[:, :mask_num]
        mask = torch.zeros_like(significance_score, dtype=torch.bool)
        mask[torch.arange(B).view(-1, 1), tokens_to_pick_ind] = True

        return mask

    def forward(self, x, attn_mask, mask_ratio, last_block):
        G, B, _ = x.shape
        x = self.ln_1(x)
        # a, _ = self.attn(x, x, x, attn_mask=attn_mask, need_weights=False)
        a, attn_weights = self.attn(x, x, x, attn_mask=attn_mask, need_weights=True)
        if last_block:
            v = x.permute(1, 0, 2)
            score = self.score_assignment_step(attn_weights, v)
            mask = self.sampling(score, mask_ratio)
        else:
            mask = None
            score = None
        x = x + a
        m = self.mlp(self.ln_2(x))
        x = x + m
        return x, mask, score


class GPT_extractor(nn.Module):
    def __init__(
        self, embed_dim, num_heads, num_layers, num_classes, trans_dim, group_size, pretrained=False
    ):
        super(GPT_extractor, self).__init__()

        self.embed_dim = embed_dim
        self.trans_dim = trans_dim
        self.group_size = group_size
        self.depth = num_layers

        # start of sequence token
        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(embed_dim, num_heads))

        self.ln_f = nn.LayerNorm(embed_dim)
        # prediction head
        self.increase_dim = nn.Sequential(
            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
        )

        if pretrained == False:
            self.cls_head_finetune = nn.Sequential(
                nn.Linear(self.trans_dim * 2, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(256, 256),
                nn.BatchNorm1d(256),
                nn.ReLU(inplace=True),
                nn.Dropout(0.5),
                nn.Linear(256, num_classes)
            )

            self.cls_norm = nn.LayerNorm(self.trans_dim)

    def forward(self, h, pos, attn_mask, mask_ratio, classify=False):
        """
        Expect input as shape [sequence len, batch]
        If classify, return classification logits
        """
        batch, length, C = h.shape

        h = h.transpose(0, 1)
        pos = pos.transpose(0, 1)

        # prepend sos token
        sos = torch.ones(1, batch, self.embed_dim, device=h.device) * self.sos
        if not classify:
            h = torch.cat([sos, h[:-1, :, :]], axis=0)
        else:
            h = torch.cat([sos, h], axis=0)

        # transformer
        last_block = False
        for idx, layer in enumerate(self.layers):
            if idx == self.depth-1 :
                last_block = True
            h = h + pos
            h, mask, score = layer(h, attn_mask, mask_ratio, last_block)

        h = self.ln_f(h)

        encoded_points = h.transpose(0, 1)
        if not classify:
            return encoded_points

        h = h.transpose(0, 1)
        h = self.cls_norm(h)
        concat_f = torch.cat([h[:, 1], h[:, 2:].max(1)[0]], dim=-1)
        ret = self.cls_head_finetune(concat_f)
        return concat_f, encoded_points, ret, mask, score


class GPT_generator(nn.Module):
    def __init__(
        self, embed_dim, num_heads, num_layers, trans_dim, group_size
    ):
        super(GPT_generator, self).__init__()

        self.embed_dim = embed_dim
        self.trans_dim = trans_dim
        self.group_size = group_size

        # start of sequence token
        self.sos = torch.nn.Parameter(torch.zeros(embed_dim))
        nn.init.normal_(self.sos)

        self.layers = nn.ModuleList()
        for _ in range(num_layers):
            self.layers.append(Block(embed_dim, num_heads))

        self.ln_f = nn.LayerNorm(embed_dim)
        self.increase_dim = nn.Sequential(
            nn.Conv1d(self.trans_dim, 3*(self.group_size), 1)
        )

    def forward(self, h, pos, attn_mask):
        """
        Expect input as shape [sequence len, batch]
        If classify, return classification logits
        """
        batch, length, C = h.shape

        h = h.transpose(0, 1)
        pos = pos.transpose(0, 1)

        # transformer
        last_block = False
        mask_ratio = 0
        for layer in self.layers:
            h = h + pos
            h, mask, _ = layer(h, attn_mask, mask_ratio, last_block)

        h = self.ln_f(h)

        rebuild_points = self.increase_dim(h.transpose(1, 2)).transpose(
            1, 2).transpose(0, 1).reshape(batch * length, -1, 3)

        return rebuild_points


class PositionEmbeddingCoordsSine(nn.Module):
    """Similar to transformer's position encoding, but generalizes it to
    arbitrary dimensions and continuous coordinates.

    Args:
        n_dim: Number of input dimensions, e.g. 2 for image coordinates.
        d_model: Number of dimensions to encode into
        temperature:
        scale:
    """

    def __init__(self, n_dim: int = 1, d_model: int = 256, temperature=10000, scale=None):
        super().__init__()

        self.n_dim = n_dim
        self.num_pos_feats = d_model // n_dim // 2 * 2
        self.temperature = temperature
        self.padding = d_model - self.num_pos_feats * self.n_dim

        if scale is None:
            scale = 1.0
        self.scale = scale * 2 * math.pi

    def forward(self, xyz: torch.Tensor) -> torch.Tensor:
        """
        Args:
            xyz: Point positions (*, d_in)

        Returns:
            pos_emb (*, d_out)
        """
        assert xyz.shape[-1] == self.n_dim

        dim_t = torch.arange(self.num_pos_feats,
                             dtype=torch.float32, device=xyz.device)
        dim_t = self.temperature ** (2 * torch.div(dim_t,
                                     2, rounding_mode='trunc') / self.num_pos_feats)

        xyz = xyz * self.scale
        pos_divided = xyz.unsqueeze(-1) / dim_t
        pos_sin = pos_divided[..., 0::2].sin()
        pos_cos = pos_divided[..., 1::2].cos()
        pos_emb = torch.stack([pos_sin, pos_cos], dim=-
                              1).reshape(*xyz.shape[:-1], -1)

        # Pad unused dimensions with zeros
        pos_emb = F.pad(pos_emb, (0, self.padding))
        return pos_emb


@MODELS.register_module()
class PointGPT(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__()
        self.config = config

        self.trans_dim = config.trans_dim
        self.depth = config.depth
        self.decoder_depth = config.decoder_depth
        self.drop_path_rate = config.drop_path_rate
        self.cls_dim = config.cls_dim
        self.num_heads = config.num_heads

        self.group_size = config.group_size
        self.num_group = config.num_group
        self.encoder_dims = config.encoder_dims

        self.group_divider = Group(
            num_group=self.num_group, group_size=self.group_size)

        assert self.encoder_dims in [384, 768, 1024]
        if self.encoder_dims == 384:
            self.encoder = Encoder_small(encoder_channel=self.encoder_dims)
        else:
            self.encoder = Encoder_large(encoder_channel=self.encoder_dims)

        self.pos_embed = PositionEmbeddingCoordsSine(3, self.encoder_dims, 1.0)

        self.blocks = GPT_extractor(
            embed_dim=self.encoder_dims,
            num_heads=self.num_heads,
            num_layers=self.depth,
            num_classes=config.cls_dim,
            trans_dim=self.trans_dim,
            group_size=self.group_size
        )

        self.generator_blocks = GPT_generator(
            embed_dim=self.encoder_dims,
            num_heads=self.num_heads,
            num_layers=self.decoder_depth,
            trans_dim=self.trans_dim,
            group_size=self.group_size
        )

        self.norm = nn.LayerNorm(self.trans_dim)

        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
        self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))

        self.mask_ratio = config.mask_ratio
        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))

        self.sos_pos = nn.Parameter(torch.zeros(1, 1, self.trans_dim))

        self.norm = nn.LayerNorm(self.trans_dim)

        self.build_loss_func()
        
        self.use_sample_contras_loss = config.use_sample_contras_loss
        self.use_class_contras_loss = config.use_class_contras_loss

        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.cls_pos, std=.02)
        trunc_normal_(self.mask_token, std=.02)

    def build_loss_func(self, loss_type='cdl12'):
        self.loss_ce = nn.CrossEntropyLoss()
        if loss_type == "cdl1":
            self.loss_func_p = ChamferDistanceL1().cuda()
        elif loss_type == 'cdl2':
            self.loss_func_p = ChamferDistanceL2().cuda()
        elif loss_type == 'cdl12':
            self.loss_func_p1 = ChamferDistanceL1().cuda()
            self.loss_func_p2 = ChamferDistanceL2().cuda()
        else:
            raise NotImplementedError

    def class_contras_loss(self, feat1, feat2, feature_class, gt, logit_scale=0.7):
        logits_S = (F.normalize(feat1, dim=1) @ F.normalize(feature_class, dim=1).T) / logit_scale
        logits_M = (F.normalize(feat2, dim=1) @ F.normalize(feature_class, dim=1).T) / logit_scale
        loss = (F.cross_entropy(logits_S, gt.long()) + F.cross_entropy(logits_M, gt.long()))/ 2
        return loss

    def contras_loss(self, feat1, feat2, gt, logit_scale=0.7):
        logits_S = (F.normalize(feat1, dim=1) @ F.normalize(feat1, dim=1).T) / logit_scale
        logits_M = (F.normalize(feat1, dim=1) @ F.normalize(feat2, dim=1).T) / logit_scale
        mask_SM = torch.eye(logits_S.shape[0], dtype=torch.bool).to(logits_S.device)
        logits = torch.where(mask_SM, logits_M, logits_S)

        labels = torch.arange(logits.shape[0]).to(logits.device)

        gt = gt.view(-1, 1)  # (batch_size, 1)
        mask = (gt == gt.T).to(logits.device)  # (batch_size, batch_size)
        mask.fill_diagonal_(False)
        logits = logits.masked_fill(mask, float('-1e9'))

        
        loss = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2
        return loss
    
    def get_standard_loss_acc(self, ret, gt):
        loss = self.loss_ce(ret, gt.long())
        pred = ret.argmax(-1)
        acc = (pred == gt).sum() / float(gt.size(0))
        return loss, acc * 100    
        
    def get_loss_acc(self, feature_s, feature_m, feature_class, ret_s, ret_m, gt, use_sample_loss):
        
        pred_loss_s = self.loss_ce(ret_s, gt.long())
        loss = pred_loss_s
        
        if ret_m != None:
            pred_loss_m = self.loss_ce(ret_m, gt.long())
            loss = loss + pred_loss_m

        
        logit_scale_num = 0.5
        #if self.use_sample_contras_loss:
        if use_sample_loss:
            a = 0.6
            sample_loss = self.contras_loss(feature_s, feature_m, gt, logit_scale_num)
            loss = a * sample_loss + loss
        if self.use_class_contras_loss:
            b = 0.1
            class_loss = self.class_contras_loss(feature_s, feature_m, feature_class, gt, logit_scale_num)
            loss = b * class_loss + loss
         
        pred = ret_s.argmax(-1)
        acc = (pred == gt).sum() / float(gt.size(0))
        return loss, acc * 100

    def load_model_from_ckpt(self, bert_ckpt_path):
        if bert_ckpt_path is not None:
            ckpt = torch.load(bert_ckpt_path)
            base_ckpt = {k.replace("module.", ""): v for k,
                         v in ckpt['base_model'].items()}

            for k in list(base_ckpt.keys()):
                if k.startswith('GPT_Transformer'):
                    base_ckpt[k[len('GPT_Transformer.'):]] = base_ckpt[k]
                    del base_ckpt[k]
                elif k.startswith('base_model'):
                    base_ckpt[k[len('base_model.'):]] = base_ckpt[k]
                    del base_ckpt[k]
                    
                #if 'cls_head_finetune' in k:
                #    del base_ckpt[k]

            incompatible = self.load_state_dict(base_ckpt, strict=False)
        
            if incompatible.missing_keys:
                print_log('missing_keys', logger='Transformer')
                print_log(
                    get_missing_parameters_message(incompatible.missing_keys),
                    logger='Transformer'
                )
            if incompatible.unexpected_keys:
                print_log('unexpected_keys', logger='Transformer')
                print_log(
                    get_unexpected_parameters_message(
                        incompatible.unexpected_keys),
                    logger='Transformer'
                )
            """
            state_dict = ckpt['base_model']
            print("Model parameters in checkpoint:")
            for name, param in state_dict.items():
                print(name)
            """
            print_log(
                f'[Transformer] Successful Loading the ckpt from {bert_ckpt_path}', logger='Transformer')
        else:
            print_log('Training from scratch!!!', logger='Transformer')
            self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv1d):
            trunc_normal_(m.weight, std=.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def _mask_center_rand(self, center, noaug = False):
        '''
            center : B G 3
            --------------
            mask : B G (bool)
        '''
        B, G, _ = center.shape
        # skip the mask
        if noaug or self.mask_ratio == 0:
            return torch.zeros(center.shape[:2]).bool()

        self.num_mask = int(self.mask_ratio * G)

        overall_mask = np.zeros([B, G])
        for i in range(B):
            mask = np.hstack([
                np.zeros(G-self.num_mask),
                np.ones(self.num_mask),
            ])
            np.random.shuffle(mask)
            overall_mask[i, :] = mask
        overall_mask = torch.from_numpy(overall_mask).to(torch.bool)

        return overall_mask.to(center.device) # B G
    
    def forward(self, pts, mode='standard', mask=None, vis=False):

        neighborhood, center = self.group_divider(pts)
        group_input_tokens = self.encoder(neighborhood)  # B G N

        B, G, C = group_input_tokens.size()

        cls_tokens = self.cls_token.expand(B, -1, -1)
        cls_pos = self.cls_pos.expand(B, -1, -1)

        if mode == 'standard':
            pos = self.pos_embed(center)
            sos_pos = self.sos_pos.expand(B, -1, -1)
            pos = torch.cat([sos_pos, pos], dim=1)
            pos = torch.cat((cls_pos, pos), dim=1)
            
            relative_position = center[:, 1:, :] - center[:, :-1, :]
            relative_norm = torch.norm(relative_position, dim=-1, keepdim=True)
            relative_direction = relative_position / relative_norm
            position = torch.cat(
                [center[:, 0, :].unsqueeze(1), relative_direction], dim=1)
            pos_relative = self.pos_embed(position)

            x = torch.cat((cls_tokens, group_input_tokens), dim=1)

            attn_mask = torch.full(
                (G+2, G+2), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
            ).to(torch.bool)

            attn_mask = torch.triu(attn_mask, diagonal=1)

            # transformer
            concat_f, encoded_features, ret, mask, score = self.blocks(x, pos, attn_mask, self.mask_ratio, classify=True)
            
            encoded_features = torch.cat(
                [encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)

            attn_mask = torch.full(
                (G, G), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
            ).to(torch.bool)

            attn_mask = torch.triu(attn_mask, diagonal=1)

            generated_points = self.generator_blocks(
                encoded_features, pos_relative, attn_mask)

            neighborhood = neighborhood + center.unsqueeze(2)

            gt_points = neighborhood.reshape(
                B*(self.num_group), self.group_size, 3)

            loss1 = self.loss_func_p1(generated_points, gt_points)
            loss2 = self.loss_func_p2(generated_points, gt_points)
            
            if vis: #visualization
                """
                vis_points = neighborhood[~mask].reshape(B * (self.num_group - M), -1, 3)
                full_vis = vis_points + center[~mask].unsqueeze(1)
                full_rebuild = rebuild_points + center[mask].unsqueeze(1)
                full = torch.cat([full_vis, full_rebuild], dim=0)
                # full_points = torch.cat([rebuild_points,vis_points], dim=0)
                full_center = torch.cat([center[mask], center[~mask]], dim=0)
                # full = full_points + full_center.unsqueeze(1)
                ret2 = full_vis.reshape(-1, 3).unsqueeze(0)
                ret1 = full.reshape(-1, 3).unsqueeze(0)
                """
                #pc_points = neighborhood + center.unsqueeze(2)
                

                return neighborhood, center, score           
            
            #mask = self._mask_center_rand(center)
            return concat_f, ret, mask, loss1+loss2
        else:
            vis_tokens = group_input_tokens[~mask].reshape(B, -1, C)
            _,N,_ = vis_tokens.shape
            M = G-N
            mask_tokens = self.mask_token.expand(B, M, -1)
            full_tokens = torch.cat([vis_tokens, mask_tokens], dim=1)
            x = torch.cat((cls_tokens, full_tokens), dim=1)

            vis_pos = self.pos_embed(center[~mask]).reshape(B, -1, C)
            mask_pos = self.pos_embed(center[mask]).reshape(B, -1, C)
            full_pos = torch.cat((vis_pos, mask_pos), dim=1)
            sos_pos = self.sos_pos.expand(B, -1, -1)
            pos = torch.cat([sos_pos, full_pos], dim=1)
            pos = torch.cat((cls_pos, pos), dim=1)
            
            x = torch.cat((cls_tokens, full_tokens), dim=1) 

            attn_mask = torch.full(
                (G+2, G+2), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
            ).to(torch.bool)

            attn_mask = torch.triu(attn_mask, diagonal=1)

            # transformer
            concat_f, encoded_features, ret, mask = self.blocks(x, pos, attn_mask, self.mask_ratio, classify=True)
            
            """
            encoded_features = torch.cat(
                [encoded_features[:, 0, :].unsqueeze(1), encoded_features[:, 2:-1, :]], dim=1)

            attn_mask = torch.full(
                (G, G), -float("Inf"), device=group_input_tokens.device, dtype=group_input_tokens.dtype
            ).to(torch.bool)

            attn_mask = torch.triu(attn_mask, diagonal=1)

            generated_points = self.generator_blocks(
                encoded_features, pos_relative, attn_mask)

            neighborhood = neighborhood + center.unsqueeze(2)

            gt_points = neighborhood.reshape(
                B*(self.num_group), self.group_size, 3)

            loss1 = self.loss_func_p1(generated_points, gt_points)
            loss2 = self.loss_func_p2(generated_points, gt_points)
            """
            loss1 = None 
            
            return concat_f, ret, mask, loss1
