from copy import deepcopy
import json

import numpy as np
import open3d as o3d
import torch
import torch.nn.functional as F
import torch.nn as nn
import sys
import os

from transformers import (
    RobertaModel, RobertaTokenizerFast, CLIPTextModel, CLIPTokenizerFast
)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

from src.scannet_classes import  SCANNET_OBJECTS
from .backbone_module import Pointnet2Backbone
from .modules import (
    ClassPredictHead, PointsObjClsModule, FPSModule, GeneralSamplingModule,
    PredictHead, ClsAgnosticPredictHead, PositionEmbeddingLearned
)
from .encoder_decoder_layers import (
    BiEncoder, BiEncoderLayer, BiDecoderLayer
)
from . import utils
from sunrgbd.sunrgbd_utils import extract_pc_in_box3d
from utils import pc_util
import ipdb
st = ipdb.set_trace


OBJ_CONCEPTS = {obj: o for o, obj in enumerate(SCANNET_OBJECTS)}


class GroupFreeModulator(nn.Module):
    """
    A Group-Free model for 3D language grounding.

    Args:
        num_class (int): number of semantics classes to predict
        num_heading_bin (int): number of heading (angle) bins used
        num_size_cluster (int): number of size "classes"
        mean_size_arr (ndarray): mean size for each size class
        input_feature_dim (int): feat_dim of pointcloud (without xyz)
        width (int): PointNet++ backbone width ratio
        num_proposal (int): Number of proposals generated
        sampling (str): initial object candidate sampling method
    """

    def __init__(self, num_class=485, num_heading_bin=1, num_size_cluster=485,
                 mean_size_arr=np.ones((485, 3)), input_feature_dim=0, width=2,
                 bn_momentum=0.1, sync_bn=False, num_proposal=512,
                 sampling='kps', dropout=0.1, activation="relu",
                 nhead=8, num_decoder_layers=12, dim_feedforward=2048,
                 self_position_embedding='loc_learned',
                 size_cls_agnostic=False, text_encoder_type='roberta',
                 d_model=288, contrastive_align_loss=False,
                 contrastive_hungarian=False, sa_lang=True, sa_vis=True,
                 use_gt_box=False, cross_attend=True,
                 gt_with_bbox_loss=False, gt_with_bbox_sampling=False,
                 use_gt_class=False, num_obj_classes=485,
                 train_viewpoint_module=False,
                 train_viewpoint_prototype=False, teacher_forcing=False,
                 butd=False, use_class_for_butd=True, use_glove=False,
                 use_butd_in_encoder=True, use_butd_enc_attn=False):
        """Initialize layers."""
        super().__init__()

        self.num_class = num_class
        self.num_heading_bin = num_heading_bin
        self.num_size_cluster = num_size_cluster
        self.mean_size_arr = mean_size_arr
        assert (mean_size_arr.shape[0] == self.num_size_cluster)
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.bn_momentum = bn_momentum
        self.sync_bn = sync_bn
        self.width = width
        self.nhead = nhead
        self.sampling = sampling
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward
        self.self_position_embedding = self_position_embedding
        self.size_cls_agnostic = size_cls_agnostic
        self.d_model = d_model
        self.cross_attend = cross_attend
        self.contrastive_hungarian = contrastive_hungarian
        self.use_gt_box = use_gt_box
        self.gt_with_bbox_loss = gt_with_bbox_loss
        self.gt_with_bbox_sampling = gt_with_bbox_sampling
        self.use_gt_class = use_gt_class
        self.train_viewpoint_module = train_viewpoint_module
        self.train_viewpoint_prototype = train_viewpoint_prototype
        self.teacher_forcing = teacher_forcing
        self.butd = butd
        self.use_class_for_butd = use_class_for_butd
        self.use_glove = use_glove
        self.use_butd_in_encoder = use_butd_in_encoder
        self.use_butd_enc_attn = use_butd_enc_attn

        # Backbone point feature learning
        self.backbone_net = Pointnet2Backbone(
            input_feature_dim=self.input_feature_dim,
            width=self.width
        )
        if input_feature_dim == 3:
            self.backbone_net.load_state_dict(torch.load(
                "./dataset/language_grounding/gf_detector_l6o256.pth"
            ), strict=False)
        elif input_feature_dim == 131:
            self.backbone_net.load_state_dict(torch.load(
                "./dataset/language_grounding/imp_checkpoints/gf_2dfeat.pth"
            ), strict=False)

        # Text Encoder
        if text_encoder_type == 'roberta':
            t_type = "roberta-base"
            self.tokenizer = RobertaTokenizerFast.from_pretrained(t_type)
            self.text_encoder = RobertaModel.from_pretrained(t_type)
        elif text_encoder_type == 'clip':
            t_type = "openai/clip-vit-base-patch32"
            self.tokenizer = CLIPTokenizerFast.from_pretrained(t_type)
            self.text_encoder = CLIPTextModel.from_pretrained(t_type)
        else:
            raise NotImplementedError

        if self.use_gt_class:
            self.class_embeddings = nn.Embedding(num_obj_classes, 32)
            d_model_query = d_model + 32
        else:
            d_model_query = d_model

        self.text_projector = nn.Sequential(
            nn.Linear(self.text_encoder.config.hidden_size, d_model),
            nn.LayerNorm(d_model, eps=1e-12),
            nn.Dropout(0.1)
        )

        if self.train_viewpoint_module:
            if self.train_viewpoint_prototype:
                self.view_prototypes = nn.Embedding(len(OBJ_CONCEPTS.keys()), 32)
                self.pano_view_encoder = ViewEncoder()
                self.side_view_encoder = deepcopy(self.pano_view_encoder)
            else:
                self.viewpoint_embed = nn.Embedding(1, self.d_model)
                # Positional embeddings
                if self_position_embedding == 'xyz_learned':
                    self.decoder_posembed = PositionEmbeddingLearned(3, d_model)
                elif self_position_embedding == 'loc_learned':
                    self.decoder_posembed = PositionEmbeddingLearned(6, d_model)
                else:
                    self.decoder_posembed = None

                # predict euler angle
                self.viewpoint_predict_head = nn.Linear(d_model, 3)

        if self.butd:
            if not self.use_butd_enc_attn and self.use_butd_in_encoder:
                self.points_box_features = PositionEmbeddingLearned(7, 64)
                self.point_feature_proj = nn.Sequential(
                    nn.Linear(d_model + 64, d_model),
                    nn.LayerNorm(d_model, eps=1e-12),
                    nn.Dropout(0.1)
                )
            if self.use_class_for_butd:
                if self.use_glove:
                    with open('scannet_glove.json') as fid:
                        w2v = torch.as_tensor(json.load(fid))
                    self.class_embeddings = nn.Embedding(num_obj_classes, 300)
                    self.class_embeddings.weight.data.copy_(w2v)
                    self.class_embeddings.requires_grad = False
                    self.emb_projector = nn.Sequential(
                        nn.Linear(300, 32),
                        nn.LayerNorm(32, eps=1e-12),
                        nn.Dropout(0.1)
                    )
                else:
                    self.class_embeddings = nn.Embedding(num_obj_classes, 32)
                self.box_embeddings = PositionEmbeddingLearned(6, d_model-32)
            else:
                self.box_embeddings = PositionEmbeddingLearned(6, d_model)

        # Cross-encoder
        self.pos_embed = PositionEmbeddingLearned(3, d_model)
        bi_layer = BiEncoderLayer(
            d_model, dropout, activation, nhead, dim_feedforward,
            self_attend_lang=sa_lang, self_attend_vis=sa_vis,
            use_butd_enc_attn=self.use_butd_enc_attn
        )
        self.cross_encoder = BiEncoder(bi_layer, 3)

        # Query sampling method
        if not self.use_gt_box:
            if self.sampling == 'fps':
                self.fps_module = FPSModule(num_proposal)
            elif self.sampling == 'kps':
                self.points_obj_cls = PointsObjClsModule(d_model)
                self.gsample_module = GeneralSamplingModule()
            else:
                raise NotImplementedError

        # Proposal (layer for size and center)
        if self.use_gt_box and not self.gt_with_bbox_loss:
            self.proposal_head = ClassPredictHead(num_class, d_model_query)
        elif self.size_cls_agnostic:
            self.proposal_head = ClsAgnosticPredictHead(
                num_class, num_heading_bin, num_proposal, d_model_query,
                objectness=False, heading=False,
                compute_sem_scores=not contrastive_hungarian
            )
        else:
            self.proposal_head = PredictHead(
                num_class, num_heading_bin, num_size_cluster,
                mean_size_arr, num_proposal, d_model_query
            )

        # Transformer Decoder Projection
        self.decoder_query_proj = nn.Conv1d(d_model_query, d_model, kernel_size=1)

        # Transformer decoder layers
        self.decoder = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            self.decoder.append(BiDecoderLayer(
                d_model, nhead, dim_feedforward, dropout, activation,
                self_position_embedding if (not train_viewpoint_module or train_viewpoint_prototype)
                                        else None, butd=self.butd
            ))

        # Prediction Head
        self.prediction_heads = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            if self.use_gt_box and not self.gt_with_bbox_loss:
                self.prediction_heads.append(ClassPredictHead(
                    num_class, d_model
                ))
            elif self.size_cls_agnostic:
                self.prediction_heads.append(ClsAgnosticPredictHead(
                    num_class, num_heading_bin, num_proposal, d_model,
                    objectness=False, heading=False,
                    compute_sem_scores=not contrastive_hungarian
                ))
            else:
                self.prediction_heads.append(PredictHead(
                    num_class, num_heading_bin, num_size_cluster,
                    mean_size_arr, num_proposal, d_model
                ))

        self.contrastive_align_loss = contrastive_align_loss
        if contrastive_align_loss:
            self.contrastive_align_projection_image = nn.Linear(d_model, 64)
            self.contrastive_align_projection_text = nn.Linear(d_model, 64)

        # Init
        # self.init_weights()
        self.init_bn_momentum()
        if self.sync_bn:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    def _run_backbones(self, inputs):
        """Run visual and text backbones."""
        # Visual encoder
        end_points = self.backbone_net(inputs['point_clouds'], end_points={})
        end_points['seed_inds'] = end_points['fp2_inds']
        end_points['seed_xyz'] = end_points['fp2_xyz']
        end_points['seed_features'] = end_points['fp2_features']
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        end_points['tokenized'] = tokenized
        return end_points

    def _generate_queries(self, xyz, features, end_points, inputs=None):

        # load features for classes
        if self.use_gt_class:
            cls_labels = inputs['all_classes']
            query_class_features = self.class_embeddings(cls_labels) # B, 132, 32/300
            if self.use_glove:
                query_class_features = self.emb_projector(query_class_features)

        # load precomputed features corresponding object centers
        if self.use_gt_box and self.gt_with_bbox_sampling:
            features = features.transpose(1, 2).contiguous()
            center_xyz = inputs['all_bboxes'][:, :, :3]
            size = inputs['all_bboxes'][:, :, 3:]
            min_ = center_xyz - size / 2
            max_ = center_xyz + size / 2
            is_neighbor = torch.stack([
                (min_[k].unsqueeze(0) <= xyz[k].unsqueeze(1))
                & (xyz[k].unsqueeze(1) <= max_[k].unsqueeze(0))
                for k in range(len(xyz))
            ]).all(-1).transpose(1, 2)  # (B, 132, len(xyz))
            sampled_features = torch.stack([
                torch.stack([
                    features[k][neighbors].mean(0) if neighbors.any()
                    else torch.zeros(features.size(-1)).to(features.device)
                    for neighbors in is_neighbor[k]
                ])
                for k in range(len(xyz))
            ])  # (B, 132, F)

            if self.use_gt_class:
                sampled_features = torch.cat(
                    (sampled_features, query_class_features), dim=-1)

            sampled_features = sampled_features.transpose(1, 2)  # B, 288, 132
            end_points['query_mask'] = ~inputs['all_bbox_label_mask']
            end_points['query_points_xyz'] = center_xyz  # (B, V, 3)
            end_points['query_points_feature'] = sampled_features  # (B, F, V)
            end_points['query_points_sample_inds'] = None  # (B, V)
        elif self.use_gt_box:
            center_xyz = inputs['all_bboxes'][:, :, :3]
            radius = 1.2  # hyperparam
            n_sample = 8
            sampled_features, _ = utils.pc_feature_interpolation(
                xyz.transpose(1, 2), features, center_xyz, radius=radius,
                nsample=n_sample, use_losses=False
            )
            if self.use_gt_class:
                sampled_features = torch.cat(
                    (sampled_features, query_class_features), dim=-1)

            sampled_features = sampled_features.transpose(1, 2)  # B, 288, 132
            end_points['query_mask'] = ~inputs['all_bbox_label_mask']
            end_points['query_points_xyz'] = center_xyz  # (B, V, 3)
            end_points['query_points_feature'] = sampled_features  # (B, F, V)
            end_points['query_points_sample_inds'] = None  # (B, V)
        else:
            if self.sampling == 'fps':
                xyz, features, sample_inds = self.fps_module(xyz, features)
            elif self.sampling == 'kps':
                points_obj_cls_logits = self.points_obj_cls(features)
                end_points['seeds_obj_cls_logits'] = points_obj_cls_logits
                sample_inds = torch.topk(
                    torch.sigmoid(points_obj_cls_logits).squeeze(1),
                    self.num_proposal
                )[1].int()
                xyz, features, sample_inds = self.gsample_module(
                    xyz, features, sample_inds
                )
            else:
                raise NotImplementedError
            end_points['query_points_xyz'] = xyz  # (B, V, 3)
            end_points['query_points_feature'] = features  # (B, F, V)
            end_points['query_points_sample_inds'] = sample_inds  # (B, V)
        return end_points

    def forward(self, inputs):
        """
        Forward pass.

        Args:
            inputs: dict
                {point_clouds, text}
                point_clouds (tensor): (B, Npoint, 3 + input_channels)
                text (list): ['text0', 'text1', ...], len(text) = B

        Returns:
            end_points: dict
        """
        # Run backbones
        end_points = self._run_backbones(inputs)
        points_inds = end_points['fp2_inds']
        points_xyz = end_points['fp2_xyz']  # (B, points, 3)
        points_features = end_points['fp2_features']  # (B, F, points)
        text_feats = end_points['text_feats']  # (B, L, F)
        text_padding_mask = end_points['text_attention_mask']  # (B, L)

        if self.butd:
            if self.use_butd_in_encoder and not self.use_butd_enc_attn:
                # features concatenated with points
                points_to_boxes = torch.gather(
                    inputs['points_to_boxes'], 1,
                    points_inds[..., None].repeat(1, 1, 7).long()
                )
                detected_points_box_features = self.points_box_features(points_to_boxes)
                points_features = torch.cat([
                    points_features,
                    detected_points_box_features,
                ], 1)
                points_features = self.point_feature_proj(
                    points_features.transpose(1, 2)).transpose(1, 2)

            # features which queries would attend to
            detected_mask = ~inputs['all_detected_bbox_label_mask']
            if self.use_class_for_butd:
                detected_class_ids = inputs['all_detected_class_ids']
                detected_class_features = self.class_embeddings(detected_class_ids)  # B, Q, 32/300
                if self.use_glove:
                    detected_class_features = self.emb_projector(detected_class_features)
                detected_class_features = detected_class_features.transpose(1, 2)
            detected_boxes = inputs['all_detected_boxes']
            detected_box_features = self.box_embeddings(detected_boxes)
            if self.use_class_for_butd:
                detected_feats = torch.cat([detected_box_features, detected_class_features], 1)
            else:
                detected_feats = detected_box_features

        # Cross-encoder
        if self.cross_attend:
            points_features, text_feats = self.cross_encoder(
                vis_feats=points_features.transpose(1, 2).contiguous(),
                pos_feats=self.pos_embed(points_xyz).transpose(1, 2).contiguous(),
                padding_mask=torch.zeros(
                    len(points_xyz), points_xyz.size(1)
                ).to(points_xyz.device).bool(),
                text_feats=text_feats,
                text_padding_mask=text_padding_mask,
                end_points=end_points,
                detected_feats=detected_feats.transpose(1, 2).contiguous() if self.use_butd_enc_attn else None,
                detected_mask=detected_mask if self.use_butd_enc_attn else None,
            )
            points_features = points_features.transpose(1, 2)
            points_features = points_features.contiguous()  # (B, F, points)
        end_points["text_memory"] = text_feats
        end_points['seed_features'] = points_features
        if self.contrastive_align_loss:
            proj_tokens = F.normalize(
                self.contrastive_align_projection_text(text_feats), p=2, dim=-1
            )
            end_points['proj_tokens'] = proj_tokens

        # Query Points Generation
        end_points = self._generate_queries(
            points_xyz, points_features, end_points, inputs
        )
        cluster_feature = end_points['query_points_feature']  # (B, F, V)
        cluster_xyz = end_points['query_points_xyz']  # (B, V, 3)

        # Transformer Decoder and Prediction
        query = self.decoder_query_proj(cluster_feature)
        query = query.transpose(1, 2).contiguous()  # (B, V, F)
        if self.contrastive_align_loss:
            end_points['proposal_proj_queries'] = F.normalize(
                self.contrastive_align_projection_image(query), p=2, dim=-1
            )

        # Proposals (one for each query)
        if self.use_gt_box:
            if not self.gt_with_bbox_loss:
                self.proposal_head(cluster_feature, end_points, prefix='proposal_')
            else:
                self.proposal_head(
                    cluster_feature,
                    base_xyz=cluster_xyz,
                    end_points=end_points,
                    prefix='proposal_'
                )
            base_xyz = inputs['all_bboxes'][:, :, :3]
            base_size = inputs['all_bboxes'][:, :, 3:]
            query_mask = end_points['query_mask']
        else:
            proposal_center, proposal_size = self.proposal_head(
                cluster_feature,
                base_xyz=cluster_xyz,
                end_points=end_points,
                prefix='proposal_'
            )
            base_xyz = proposal_center.detach().clone()  # (B, V, 3)
            base_size = proposal_size.detach().clone()  # (B, V, 3)
            query_mask = None

        for i in range(self.num_decoder_layers):
            prefix = 'last_' if i == self.num_decoder_layers-1 else f'{i}head_'

            # Position Embedding for Self-Attention
            if self.self_position_embedding == 'none':
                query_pos = None
            elif self.self_position_embedding == 'xyz_learned':
                query_pos = base_xyz
            elif self.self_position_embedding == 'loc_learned':
                query_pos = torch.cat([base_xyz, base_size], -1)
            else:
                raise NotImplementedError

            if self.train_viewpoint_module and query_pos is not None and not self.train_viewpoint_prototype:
                # add query pos now only rather than in decoder
                query_pos = self.decoder_posembed(query_pos)
                query = query + query_pos.transpose(1, 2)
                query_pos = None

                # add viewpoint query
                bs = query.shape[0]
                viewpoint_query = self.viewpoint_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
                query = torch.cat([query, viewpoint_query], 1)

            # Transformer Decoder Layer
            query = self.decoder[i](
                query, points_features.transpose(1, 2).contiguous(),
                text_feats, query_pos,
                query_mask,
                text_padding_mask,
                detected_feats=detected_feats.transpose(1, 2).contiguous() if self.butd else None,
                detected_mask=detected_mask if self.butd else None
            )  # (B, V, F)

            if self.train_viewpoint_module and not self.train_viewpoint_prototype:
                viewpoint_query = query[:, -1][:, None]
                query = query[:, :-1]
                pred_viewpoint = self.viewpoint_predict_head(viewpoint_query)
                end_points[f'{prefix}pred_viewpoint'] = pred_viewpoint.squeeze(1)

            if self.contrastive_align_loss:
                end_points[f'{prefix}proj_queries'] = F.normalize(
                    self.contrastive_align_projection_image(query), p=2, dim=-1
                )

            # Prediction
            if self.use_gt_box and not self.gt_with_bbox_loss:
                self.prediction_heads[i](
                    query.transpose(1, 2).contiguous(),
                    end_points, prefix
                )
            elif self.use_gt_box:
                self.prediction_heads[i](
                    query.transpose(1, 2).contiguous(),  # (B, F, V)
                    base_xyz=cluster_xyz,
                    end_points=end_points,
                    prefix=prefix
                )
            else:
                base_xyz, base_size = self.prediction_heads[i](
                    query.transpose(1, 2).contiguous(),  # (B, F, V)
                    base_xyz=cluster_xyz,
                    end_points=end_points,
                    prefix=prefix
                )

                if self.train_viewpoint_prototype:
                    # find the most confident anchor box
                    if self.teacher_forcing: #and np.random.rand()>0.5 and inputs["train"]:
                        gt_center = inputs['center_label'][:, :, 0:3]  # (B, K, 3)
                        gt_size = inputs['size_gts']  # (B, K2,3)
                        anchor_boxes = torch.cat([gt_center, gt_size], dim=-1)[:, 0]
                        anchor_centers = anchor_boxes[:, :3]
                        anchor_boxes = pc_util.box2points(anchor_boxes.cpu().numpy())
                    else:
                        sem_scores = end_points[f'{prefix}sem_cls_scores'].softmax(-1)
                        pred_bbox_ = torch.cat([base_xyz, base_size], dim=-1)
                        positive_map = torch.clone(inputs['positive_map'])  # (B, K, 256)
                        positive_map[positive_map > 0] = 1
                        # 1st element in positive map here is expected to be anchor
                        positive_map = positive_map[:, :1]
                        scores = (
                            sem_scores.unsqueeze(1)
                            * positive_map.unsqueeze(2)
                        ).sum(-1)
                        top = torch.topk(scores, 1)[1].to(torch.int64)
                        anchor_boxes = torch.gather(pred_bbox_, 1, top.repeat(1, 1, 6))
                        anchor_centers = anchor_boxes.squeeze(1)[:, :3]
                        anchor_boxes = pc_util.box2points(
                            anchor_boxes.squeeze(1).detach().cpu().numpy()
                        )

                    obj_concept = inputs['target_name']
                    view_concept = ['front' if 'front' in sent else 'back' for sent in inputs['text']]
                    pred_viewpoint = torch.zeros(len(obj_concept), 12)
                    for bid in range(len(obj_concept)):
                        obj_pc = extract_pc_in_box3d(points_xyz[bid].detach().cpu().numpy(), anchor_boxes[bid])[0]
                        if len(obj_pc) != 0:
                            obj_pc = obj_pc - anchor_centers[bid].detach().cpu().numpy()
                            scores = self._viewpoint(
                                obj_concept[bid],
                                view_concept[bid],
                                torch.from_numpy(obj_pc).to(base_xyz.device)
                            )
                            pred_viewpoint[bid] = scores

                    end_points[f'{prefix}pred_viewpoint'] = pred_viewpoint.cuda()

                base_xyz = base_xyz.detach().clone()
                base_size = base_size.detach().clone()

        return end_points

    def init_weights(self):
        """Initialize transformer with xavier."""
        for m in self.decoder.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)

    def init_bn_momentum(self):
        """Initialize batch-norm momentum."""
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.momentum = self.bn_momentum

    def _viewpoint(self, obj_concept, view_concept, pc):
        """Rotate the scene based on a single's object view."""
        if pc.shape[0] == 0:
            return torch.zeros(12).to(pc.device)
        pcd = o3d.geometry.PointCloud()
        # Rotate along z
        pc_views = [pc_util.rot_z(pc, theta) for theta in torch.linspace(-180, 180, 12)]
        # Views
        pano_views = torch.cat([
            pc_util.create_depth_image(pc_util.cam0(view, pcd).unsqueeze(0).float(), 84)
            for view in pc_views
        ])
        pano_views = self.pano_view_encoder(pano_views)
        side_views = torch.cat([
            pc_util.create_depth_image(
                pc_util.cam1(view, pcd).unsqueeze(0).float(), 84)
            for view in pc_views
        ])
        side_views = self.side_view_encoder(side_views)
        views = torch.cat((pano_views, side_views), 1)
        # Similarities to prototype
        prototype = self.view_prototypes(torch.as_tensor(
            OBJ_CONCEPTS[obj_concept]
        ).long().to(pc.device))
        cos = nn.CosineSimilarity(dim=1)
        scores = 10 * cos(views, prototype.unsqueeze(0))

        return scores


class ViewEncoder(nn.Module):
    """Convolutional view encoder."""

    def __init__(self):
        """Initialize layers."""
        super().__init__()
        self.conv_encoder = nn.Sequential(
            nn.Conv2d(1, 32, 5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 16, 5, padding=2),
            nn.ReLU()
        )
        self.fc_encoder = nn.Linear(16, 16)

    def forward(self, img):
        """Forward pass for image tensor (b, 1, 84, 84)."""
        img = self.conv_encoder(img)
        return self.fc_encoder(img.mean(2).mean(2))


if __name__ == "__main__":
    net = GroupFreeModulator().cuda()
    from time import time
    t = time()
    out = net({
        'point_clouds': torch.rand(2, 3500, 3).cuda(),
        'text': ['this is not a boy', 'but this is']
    })
    print(time() - t)
    print('')
    for key, value in out.items():
        if 'last' in key:
            print(key, value.shape)
    print(out['last_center'][0, 0])
    print(out['last_pred_size'][0, 0])
