import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
from transformers import (
    RobertaModel, RobertaTokenizerFast, CLIPTextModel, CLIPTokenizerFast
)

from .backbone_module import Pointnet2Backbone
from .modules import ClassPredictHead, PositionEmbeddingLearned

from .encoder_decoder_layers import (
    PosTransformerEncoderLayer, BiDecoderLayer,
    BiEncoder, BiEncoderLayer
)

import ipdb
st = ipdb.set_trace


class GroupFreeGTModulator2(nn.Module):
    """
    A Group-Free model for 3D language grounding.

    Args:
        num_class (int): number of semantics classes to predict
        num_heading_bin (int): number of heading (angle) bins used
        num_size_cluster (int): number of size "classes"
        mean_size_arr (ndarray): mean size for each size class
        input_feature_dim (int): feat_dim of pointcloud (without xyz)
        width (int): PointNet++ backbone width ratio
        num_proposal (int): Number of proposals generated
        sampling (str): initial object candidate sampling method
    """

    def __init__(self, num_class=485, num_heading_bin=1, num_size_cluster=485,
                 mean_size_arr=np.ones((485, 3)), input_feature_dim=0, width=2,
                 bn_momentum=0.1, sync_bn=False, num_proposal=512,
                 sampling='kps', dropout=0.1, activation="relu",
                 nhead=8, num_decoder_layers=12, dim_feedforward=2048,
                 self_position_embedding='loc_learned',
                 size_cls_agnostic=False, text_encoder_type='roberta',
                 d_model=288, contrastive_align_loss=False,
                 contrastive_hungarian=False, sa_lang=True, sa_vis=True,
                 use_gt_box=False, cross_attend=True,
                 gt_with_bbox_loss=False, gt_with_bbox_sampling=False, 
                 use_gt_class=False, num_obj_classes=485, freeze_text_encoder=False):
        """Initialize layers."""
        super().__init__()

        self.num_class = num_class
        self.num_heading_bin = num_heading_bin
        self.num_size_cluster = num_size_cluster
        self.mean_size_arr = mean_size_arr
        assert (mean_size_arr.shape[0] == self.num_size_cluster)
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.bn_momentum = bn_momentum
        self.sync_bn = sync_bn
        self.width = width
        self.nhead = nhead
        self.sampling = sampling
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward
        self.self_position_embedding = self_position_embedding
        self.size_cls_agnostic = size_cls_agnostic
        self.d_model = d_model
        self.cross_attend = cross_attend
        self.contrastive_hungarian = contrastive_hungarian
        self.use_gt_box = use_gt_box
        self.gt_with_bbox_loss = gt_with_bbox_loss
        self.gt_with_bbox_sampling = gt_with_bbox_sampling
        self.use_gt_class = use_gt_class

        # Text Encoder
        if text_encoder_type == 'roberta':
            t_type = "roberta-base"
            self.tokenizer = RobertaTokenizerFast.from_pretrained(t_type)
            self.text_encoder = RobertaModel.from_pretrained(t_type)
        elif text_encoder_type == 'clip':
            t_type = "openai/clip-vit-base-patch32"
            self.tokenizer = CLIPTokenizerFast.from_pretrained(t_type)
            self.text_encoder = CLIPTextModel.from_pretrained(t_type)
        else:
            raise NotImplementedError

        if freeze_text_encoder:
            print("Freezing text encoder")
            for p in self.text_encoder.parameters():
                p.requires_grad = False

        if self.use_gt_class or self.use_detected_boxes:
            self.class_embeddings = nn.Embedding(num_obj_classes, 32)
            d_model_query = d_model + 32
        else:
            d_model_query = d_model

        self.text_projector = nn.Sequential(
            nn.Linear(self.text_encoder.config.hidden_size, d_model),
            nn.LayerNorm(d_model, eps=1e-12),
            nn.Dropout(0.1)
        )
        self.posembed = PositionEmbeddingLearned(6, d_model)

        self.contrastive_align_loss = contrastive_align_loss
        if contrastive_align_loss:
            self.contrastive_align_projection_image = nn.Linear(d_model, 64)
            self.contrastive_align_projection_text = nn.Linear(d_model, 64)

        # Proposal (layer for size and center)
        self.proposal_head = ClassPredictHead(num_class, d_model_query)

        # Transformer Decoder Projection
        self.decoder_query_proj = nn.Conv1d(d_model_query, d_model, kernel_size=1)

        # Transformer decoder layers
        self.decoder = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.decoder.append(DecoderLayerL(
                d_model, nhead, dim_feedforward, dropout, activation,
                self_position_embedding
            ))

        # Prediction Head
        self.prediction_heads = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.prediction_heads.append(ClassPredictHead(
                num_class, d_model
            ))

        # Init
        # self.init_weights()
        self.init_bn_momentum()
        if self.sync_bn:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    def _run_backbones(self, inputs):
        """Run visual and text backbones."""
        end_points = {}
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        end_points['tokenized'] = tokenized
        return end_points

    def forward(self, inputs):
        """
        Forward pass.

        Args:
            inputs: dict
                {point_clouds, text}
                point_clouds (tensor): (B, Npoint, 3 + input_channels)
                text (list): ['text0', 'text1', ...], len(text) = B

        Returns:
            end_points: dict
        """
        # Run backbones
        end_points = self._run_backbones(inputs)
        text_feats = end_points['text_feats']  # (B, L, F)
        text_padding_mask = end_points['text_attention_mask']  # (B, L)
        end_points["text_memory"] = text_feats
        if self.contrastive_align_loss:
            proj_tokens = F.normalize(
                self.contrastive_align_projection_text(text_feats), p=2, dim=-1
            )
            end_points['proj_tokens'] = proj_tokens

        # Query Points Generation
        cls_labels = inputs['all_classes']
        query_class_features = self.class_embeddings(cls_labels) # B, 132, 32
        base_xyz = inputs['all_bboxes'][:, :, :3]
        base_size = inputs['all_bboxes'][:, :, 3:]
        cluster_feature = torch.cat((
            self.posembed(torch.cat([base_xyz, base_size], -1)),
            query_class_features.transpose(1, 2)
        ), 1)
        # Transformer Decoder and Prediction
        query = self.decoder_query_proj(cluster_feature)
        query = query.transpose(1, 2).contiguous()  # (B, V, F)
        if self.contrastive_align_loss:
            end_points['proposal_proj_queries'] = F.normalize(
                self.contrastive_align_projection_image(query), p=2, dim=-1
            )

        # Proposals (one for each query)
        self.proposal_head(cluster_feature, end_points, prefix='proposal_')
        query_mask = ~inputs['all_bbox_label_mask']
        end_points['query_mask'] = query_mask

        for i in range(self.num_decoder_layers):
            prefix = 'last_' if i == self.num_decoder_layers-1 else f'{i}head_'

            # Position Embedding for Self-Attention
            if self.self_position_embedding == 'none':
                query_pos = None
            elif self.self_position_embedding == 'xyz_learned':
                query_pos = base_xyz
            elif self.self_position_embedding == 'loc_learned':
                query_pos = torch.cat([base_xyz, base_size], -1)
            else:
                raise NotImplementedError

            # Transformer Decoder Layer
            query = self.decoder[i](
                query,
                text_feats, query_pos,
                query_mask,
                text_padding_mask
            )  # (B, V, F)

            if self.contrastive_align_loss:
                end_points[f'{prefix}proj_queries'] = F.normalize(
                    self.contrastive_align_projection_image(query), p=2, dim=-1
                )

            # Prediction
            self.prediction_heads[i](
                query.transpose(1, 2).contiguous(), 
                end_points, prefix
            )

        return end_points

    def init_weights(self):
        """Initialize transformer with xavier."""
        for m in self.decoder.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)

    def init_bn_momentum(self):
        """Initialize batch-norm momentum."""
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.momentum = self.bn_momentum


class DecoderLayerL(nn.Module):
    """Self->cross_l layer for proposals."""

    def __init__(self, d_model, n_heads, dim_feedforward=2048, dropout=0.1,
                 activation="relu",
                 self_position_embedding='loc_learned',
                 use_oriented_boxes=False):
        """Initialize layers, d_model is the encoder dimension."""
        super().__init__()

        # Self attention
        self.self_attention = PosTransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation
        )

        # Cross attention in language
        self.cross_l = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout
        )
        self.dropout_l = nn.Dropout(dropout)
        self.norm_l = nn.LayerNorm(d_model)
        self.ffn_l = nn.Sequential(
            nn.Linear(d_model, 1024),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(1024, d_model),
            nn.Dropout(dropout)
        )
        self.norm_l2 = nn.LayerNorm(d_model)

        # Positional embeddings
        if self_position_embedding == 'xyz_learned':
            self.self_posembed = PositionEmbeddingLearned(3, d_model)
        elif self_position_embedding == 'loc_learned':
            if use_oriented_boxes:
                self.self_posembed = PositionEmbeddingLearned(9, d_model)
            else:
                self.self_posembed = PositionEmbeddingLearned(6, d_model)
        else:
            self.self_posembed = None

    def forward(self, query, lang_feats, query_pos,
                padding_mask, text_key_padding_mask):
        """
        Forward pass.

        Args:
            query: (B, N, F)
            lang_feats: (B, L, F)
            query_pos: (B, N, 3or6)
            padding_mask: (B, N) (for query)
            text_key_padding_mask: (B, L)

        Returns:
            query: (B, N, F)
        """
        # NxCxP to PxNxC
        if self.self_posembed is not None:
            query_pos = self.self_posembed(query_pos)
            query_pos = query_pos.transpose(1, 2).contiguous()
        else:
            query_pos = torch.zeros_like(query, device=query.device)

        # Self attention
        query = self.self_attention(
            query.transpose(0, 1),
            query_pos.transpose(0, 1),
            src_key_padding_mask=padding_mask
        ).transpose(0, 1)

        # Cross attend to language
        query2 = self.cross_l(
            query=(query + query_pos).transpose(0, 1),
            key=lang_feats.transpose(0, 1),
            value=lang_feats.transpose(0, 1),
            attn_mask=None,
            key_padding_mask=text_key_padding_mask  # (B, L)
        )[0].transpose(0, 1)
        query = query + self.dropout_l(query2)
        query = self.norm_l(query)
        query = self.norm_l2(query + self.ffn_l(query))

        return query


class GroupFreeGTModulator(nn.Module):
    """
    A Group-Free model for 3D language grounding.

    Args:
        num_class (int): number of semantics classes to predict
        num_heading_bin (int): number of heading (angle) bins used
        num_size_cluster (int): number of size "classes"
        mean_size_arr (ndarray): mean size for each size class
        input_feature_dim (int): feat_dim of pointcloud (without xyz)
        width (int): PointNet++ backbone width ratio
        num_proposal (int): Number of proposals generated
        sampling (str): initial object candidate sampling method
    """

    def __init__(self, num_class=485, num_heading_bin=1, num_size_cluster=485,
                 mean_size_arr=np.ones((485, 3)), input_feature_dim=0, width=2,
                 bn_momentum=0.1, sync_bn=False, num_proposal=512,
                 sampling='kps', dropout=0.1, activation="relu",
                 nhead=8, num_decoder_layers=12, dim_feedforward=2048,
                 self_position_embedding='loc_learned',
                 size_cls_agnostic=False, text_encoder_type='roberta',
                 d_model=288, contrastive_align_loss=False,
                 contrastive_hungarian=False, sa_lang=True, sa_vis=True,
                 use_gt_box=False, cross_attend=True,
                 gt_with_bbox_loss=False, gt_with_bbox_sampling=False,
                 use_gt_class=False, num_obj_classes=485,
                 freeze_text_encoder=False, use_logits=False,
                 use_oriented_boxes=False, use_detected_boxes=False):
        """Initialize layers."""
        super().__init__()

        self.num_class = num_class
        self.num_heading_bin = num_heading_bin
        self.num_size_cluster = num_size_cluster
        self.mean_size_arr = mean_size_arr
        assert (mean_size_arr.shape[0] == self.num_size_cluster)
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.bn_momentum = bn_momentum
        self.sync_bn = sync_bn
        self.width = width
        self.nhead = nhead
        self.sampling = sampling
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward
        self.self_position_embedding = self_position_embedding
        self.size_cls_agnostic = size_cls_agnostic
        self.d_model = d_model
        self.cross_attend = cross_attend
        self.contrastive_hungarian = contrastive_hungarian
        self.use_gt_box = use_gt_box
        self.use_detected_boxes = use_detected_boxes
        self.gt_with_bbox_loss = gt_with_bbox_loss
        self.gt_with_bbox_sampling = gt_with_bbox_sampling
        self.use_gt_class = use_gt_class
        self.use_logits = use_logits
        self.use_oriented_bboxes = use_oriented_boxes

        # Text Encoder
        if text_encoder_type == 'roberta':
            t_type = "roberta-base"
            self.tokenizer = RobertaTokenizerFast.from_pretrained(t_type)
            self.text_encoder = RobertaModel.from_pretrained(t_type)
        elif text_encoder_type == 'clip':
            t_type = "openai/clip-vit-base-patch32"
            self.tokenizer = CLIPTokenizerFast.from_pretrained(t_type)
            self.text_encoder = CLIPTextModel.from_pretrained(t_type)
        else:
            raise NotImplementedError

        if freeze_text_encoder:
            print("Freezing text encoder")
            for p in self.text_encoder.parameters():
                p.requires_grad = False

        if self.use_gt_class or self.use_detected_boxes:
            if use_logits:
                self.class_embeddings = nn.Linear(num_obj_classes, 32, bias=False)
            else:
                self.class_embeddings = nn.Embedding(num_obj_classes, 32)
            d_model_query = d_model + 32
        else:
            d_model_query = d_model

        self.text_projector = nn.Sequential(
            nn.Linear(self.text_encoder.config.hidden_size, d_model),
            nn.LayerNorm(d_model, eps=1e-12),
            nn.Dropout(0.1)
        )

        if self.use_oriented_bboxes:
            self.posembed = PositionEmbeddingLearned(9, d_model)
        else:
            self.posembed = PositionEmbeddingLearned(6, d_model)

        # Transformer Decoder Projection
        self.decoder_query_proj = nn.Conv1d(d_model_query, d_model, kernel_size=1)

        # Transformer decoder layers
        self.decoder = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.decoder.append(DecoderLayerL(
                d_model, nhead, dim_feedforward, dropout, activation,
                self_position_embedding,
                use_oriented_boxes=self.use_oriented_bboxes
            ))

        # Prediction Head
        self.prediction_heads = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.prediction_heads.append(ClassPredictHead(
                num_class, d_model
            ))
        self.contrastive_align_projection_image = nn.Linear(d_model, 64)
        self.contrastive_align_projection_text = nn.Linear(d_model, 64)

        # Init
        # self.init_weights()
        self.init_bn_momentum()
        if self.sync_bn:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    def _run_backbones(self, inputs):
        """Run visual and text backbones."""
        end_points = {}
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        end_points['tokenized'] = tokenized
        return end_points

    def forward(self, inputs):
        """
        Forward pass.

        Args:
            inputs: dict
                {point_clouds, text}
                point_clouds (tensor): (B, Npoint, 3 + input_channels)
                text (list): ['text0', 'text1', ...], len(text) = B

        Returns:
            end_points: dict
        """
        # Run backbones
        end_points = self._run_backbones(inputs)
        text_feats = end_points['text_feats']  # (B, L, F)
        text_padding_mask = end_points['text_attention_mask']  # (B, L)
        end_points["text_memory"] = text_feats
        # max_pooled_text = text_feats.max(1)[0].unsqueeze(1)  # (B, 1, F)
        proj_tokens = F.normalize(
            self.contrastive_align_projection_text(text_feats), p=2, dim=-1
        )
        end_points['proj_tokens'] = proj_tokens

        # Query Points Generation
        if not self.use_logits:
            cls_labels = inputs['all_classes']
            query_class_features = self.class_embeddings(cls_labels)  # B, 132, 32
        else:
            cls_labels = inputs['all_logits']
            query_class_features = self.class_embeddings(cls_labels)  # B, 132, 32
        base_box = inputs['all_bboxes']
        base_xyz = inputs['all_bboxes'][:, :, :3]
        cluster_feature = torch.cat((
            self.posembed(base_box),
            query_class_features.transpose(1, 2)
        ), 1)
        # Transformer Decoder and Prediction
        query = self.decoder_query_proj(cluster_feature)
        query = query.transpose(1, 2).contiguous()  # (B, V, F)

        # Proposals (one for each query)
        query_mask = ~inputs['all_bbox_label_mask']
        end_points['query_mask'] = query_mask

        for i in range(self.num_decoder_layers):
            prefix = 'last_' if i == self.num_decoder_layers-1 else f'{i}head_'

            # Position Embedding for Self-Attention
            if self.self_position_embedding == 'none':
                query_pos = None
            elif self.self_position_embedding == 'xyz_learned':
                query_pos = base_xyz
            elif self.self_position_embedding == 'loc_learned':
                query_pos = base_box
            else:
                raise NotImplementedError

            # Transformer Decoder Layer
            query = self.decoder[i](
                query,
                text_feats, query_pos,
                query_mask,
                text_padding_mask
            )  # (B, V, F)

            # Prediction
            self.prediction_heads[i](
                query.transpose(1, 2).contiguous(),
                end_points, prefix
            )
            end_points[f'{prefix}proj_queries'] = F.normalize(
                self.contrastive_align_projection_image(query), p=2, dim=-1
            )
            # '''

        return end_points

    def init_weights(self):
        """Initialize transformer with xavier."""
        for m in self.decoder.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)

    def init_bn_momentum(self):
        """Initialize batch-norm momentum."""
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.momentum = self.bn_momentum


class GroupFreeGTModulatorVis(GroupFreeGTModulator):
    """
    A Group-Free model for 3D language grounding.

    Args:
        num_class (int): number of semantics classes to predict
        num_heading_bin (int): number of heading (angle) bins used
        num_size_cluster (int): number of size "classes"
        mean_size_arr (ndarray): mean size for each size class
        input_feature_dim (int): feat_dim of pointcloud (without xyz)
        width (int): PointNet++ backbone width ratio
        num_proposal (int): Number of proposals generated
        sampling (str): initial object candidate sampling method
    """

    def __init__(self, num_class=485, num_heading_bin=1, num_size_cluster=485,
                 mean_size_arr=np.ones((485, 3)), input_feature_dim=0, width=2,
                 bn_momentum=0.1, sync_bn=False, num_proposal=512,
                 sampling='kps', dropout=0.1, activation="relu",
                 nhead=8, num_decoder_layers=12, dim_feedforward=2048,
                 self_position_embedding='loc_learned',
                 size_cls_agnostic=False, text_encoder_type='roberta',
                 d_model=288, contrastive_align_loss=False,
                 contrastive_hungarian=False, sa_lang=True, sa_vis=True,
                 use_gt_box=False, cross_attend=True,
                 gt_with_bbox_loss=False, gt_with_bbox_sampling=False, 
                 use_gt_class=False, num_obj_classes=485,
                 freeze_text_encoder=False):
        """Initialize layers."""
        super().__init__(
            num_class, num_heading_bin, num_size_cluster,
            mean_size_arr, input_feature_dim, width,
            bn_momentum, sync_bn, num_proposal,
            sampling, dropout, activation,
            nhead, num_decoder_layers, dim_feedforward,
            self_position_embedding,
            size_cls_agnostic, text_encoder_type,
            d_model, contrastive_align_loss,
            contrastive_hungarian, sa_lang, sa_vis,
            use_gt_box, cross_attend,
            gt_with_bbox_loss, gt_with_bbox_sampling,
            use_gt_class, num_obj_classes, freeze_text_encoder
        )

        # Backbone point feature learning
        self.backbone_net = Pointnet2Backbone(
            input_feature_dim=self.input_feature_dim,
            width=self.width
        )
        self.backbone_net.load_state_dict(torch.load(
            "./dataset/language_grounding/gf_detector_l6o256.pth"
        ), strict=False)
        # print("Freezing p++ encoder")
        # for p in self.backbone_net.parameters():
        #    p.requires_grad = False

        # Cross-encoder
        self.pos_embed = PositionEmbeddingLearned(3, d_model)
        bi_layer = BiEncoderLayer(
            d_model, dropout, activation, nhead, dim_feedforward,
            self_attend_lang=sa_lang, self_attend_vis=sa_vis
        )
        self.cross_encoder = BiEncoder(bi_layer, 3)

        # Transformer decoder layers
        self.decoder = nn.ModuleList()
        for _ in range(self.num_decoder_layers):
            self.decoder.append(BiDecoderLayer(
                d_model, nhead, dim_feedforward, dropout, activation,
                self_position_embedding
            ))

        # Init
        # self.init_weights()
        self.init_bn_momentum()
        if self.sync_bn:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    def _run_backbones(self, inputs):
        """Run visual and text backbones."""
        # Visual encoder
        end_points = self.backbone_net(inputs['point_clouds'], end_points={})
        end_points['seed_inds'] = end_points['fp2_inds']
        end_points['seed_xyz'] = end_points['fp2_xyz']
        end_points['seed_features'] = end_points['fp2_features']
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        end_points['tokenized'] = tokenized
        return end_points

    def forward(self, inputs):
        """
        Forward pass.

        Args:
            inputs: dict
                {point_clouds, text}
                point_clouds (tensor): (B, Npoint, 3 + input_channels)
                text (list): ['text0', 'text1', ...], len(text) = B

        Returns:
            end_points: dict
        """
        # Run backbones
        end_points = self._run_backbones(inputs)
        points_xyz = end_points['fp2_xyz']  # (B, points, 3)
        points_features = end_points['fp2_features']  # (B, F, points)
        text_feats = end_points['text_feats']  # (B, L, F)
        text_padding_mask = end_points['text_attention_mask']  # (B, L)

        # Cross-encoder
        if self.cross_attend:
            points_features, text_feats = self.cross_encoder(
                vis_feats=points_features.transpose(1, 2).contiguous(),
                pos_feats=self.pos_embed(points_xyz).transpose(1, 2).contiguous(),
                padding_mask=torch.zeros(
                    len(points_xyz), points_xyz.size(1)
                ).to(points_xyz.device).bool(),
                text_feats=text_feats,
                text_padding_mask=text_padding_mask,
                end_points=end_points
            )
            points_features = points_features.transpose(1, 2)
            points_features = points_features.contiguous()  # (B, F, points)
        end_points["text_memory"] = text_feats
        end_points['seed_features'] = points_features
        proj_tokens = F.normalize(
            self.contrastive_align_projection_text(text_feats), p=2, dim=-1
        )
        end_points['proj_tokens'] = proj_tokens

        # Query Points Generation
        cls_labels = inputs['all_classes']
        query_class_features = self.class_embeddings(cls_labels)  # B, 132, 32
        base_xyz = inputs['all_bboxes'][:, :, :3]
        base_size = inputs['all_bboxes'][:, :, 3:]
        cluster_feature = torch.cat((
            self.posembed(torch.cat([base_xyz, base_size], -1)),
            query_class_features.transpose(1, 2)
        ), 1)
        # Transformer Decoder and Prediction
        query = self.decoder_query_proj(cluster_feature)
        query = query.transpose(1, 2).contiguous()  # (B, V, F)

        # Proposals (one for each query)
        query_mask = ~inputs['all_bbox_label_mask']
        end_points['query_mask'] = query_mask

        for i in range(self.num_decoder_layers):
            prefix = 'last_' if i == self.num_decoder_layers-1 else f'{i}head_'

            # Position Embedding for Self-Attention
            if self.self_position_embedding == 'none':
                query_pos = None
            elif self.self_position_embedding == 'xyz_learned':
                query_pos = base_xyz
            elif self.self_position_embedding == 'loc_learned':
                query_pos = torch.cat([base_xyz, base_size], -1)
            else:
                raise NotImplementedError

            # Transformer Decoder Layer
            query = self.decoder[i](
                query, points_features.transpose(1, 2).contiguous(),
                text_feats, query_pos,
                query_mask,
                text_padding_mask
            )  # (B, V, F)

            # Prediction
            '''
            scores = F.cosine_similarity(
                query,
                max_pooled_text.repeat(1, query.size(1), 1),
                dim=2
            )  # (B, V)
            '''
            '''
            scores = self.prediction_heads[i](
                query.transpose(1, 2).contiguous(),
                end_points, prefix
            ).squeeze(-1)
            scores[~inputs['all_bbox_label_mask']] = -1000
            end_points[f'{prefix}sem_cls_scores'] = 10 * scores
            '''
            self.prediction_heads[i](
                query.transpose(1, 2).contiguous(),
                end_points, prefix
            )
            end_points[f'{prefix}proj_queries'] = F.normalize(
                self.contrastive_align_projection_image(query), p=2, dim=-1
            )
            # '''

        return end_points
