
import numpy as np
import torch
import torch.nn as nn
import sys
import os
from transformers import (
    RobertaModel, RobertaTokenizerFast, CLIPTextModel, CLIPTokenizerFast
)

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = os.path.dirname(BASE_DIR)
sys.path.append(BASE_DIR)

from .backbone_module import Pointnet2Backbone
from .modules import (
    PointsObjClsModule, FPSModule, GeneralSamplingModule,
    PredictHead, ClsAgnosticPredictHead, PositionEmbeddingLearned
)
from .encoder_decoder_layers import (
    BiEncoder, BiEncoderLayer, BiDecoderLayer
)


class GroupFreeLangDetector(nn.Module):
    """
    A Group-Free model for 3D language grounding.

    Args:
        num_class (int): n
            Number of semantics classes to predict over -- size of softmax classifier
        num_heading_bin: int
        num_size_cluster: int
        input_feature_dim: (default: 0)
            Input dim in the feature descriptor for each point.  If the point cloud is Nx9, this
            value should be 6 as in an Nx9 point cloud, 3 of the channels are xyz, and 6 are feature descriptors
        width: (default: 1)
            PointNet backbone width ratio
        num_proposal: int (default: 128)
            Number of proposals/detections generated from the network. Each proposal is a 3D OBB with a semantic class.
        sampling: (default: kps)
            Initial object candidate sampling method
    """

    def __init__(self, num_class=485, num_heading_bin=1, num_size_cluster=485,
                 mean_size_arr=np.ones((485, 3)), input_feature_dim=0, width=2,
                 bn_momentum=0.1, sync_bn=False, num_proposal=512,
                 sampling='kps', dropout=0.1, activation="relu",
                 nhead=8, num_decoder_layers=12, dim_feedforward=2048,
                 self_position_embedding='loc_learned',
                 size_cls_agnostic=False, text_encoder_type='roberta',
                 d_model=288, cross_position_embedding=None):
        super().__init__()

        self.num_class = num_class
        self.num_heading_bin = num_heading_bin
        self.num_size_cluster = num_size_cluster
        self.mean_size_arr = mean_size_arr
        assert (mean_size_arr.shape[0] == self.num_size_cluster)
        self.input_feature_dim = input_feature_dim
        self.num_proposal = num_proposal
        self.bn_momentum = bn_momentum
        self.sync_bn = sync_bn
        self.width = width
        self.nhead = nhead
        self.sampling = sampling
        self.num_decoder_layers = num_decoder_layers
        self.dim_feedforward = dim_feedforward
        self.self_position_embedding = self_position_embedding
        self.size_cls_agnostic = size_cls_agnostic

        # Backbone point feature learning
        self.backbone_net = Pointnet2Backbone(
            input_feature_dim=self.input_feature_dim,
            width=self.width
        )
        self.backbone_net.load_state_dict(torch.load(
            "./dataset/language_grounding/checkpoints/group_free/ckpt_epoch_300.pth"
        ), strict=False)

        # Text Encoder
        if text_encoder_type == 'roberta':
            t_type = "roberta-base"
            self.tokenizer = RobertaTokenizerFast.from_pretrained(t_type)
            self.text_encoder = RobertaModel.from_pretrained(t_type)
        elif text_encoder_type == 'clip':
            t_type = "openai/clip-vit-base-patch32"
            self.tokenizer = CLIPTokenizerFast.from_pretrained(t_type)
            self.text_encoder = CLIPTextModel.from_pretrained(t_type)
        else:
            raise NotImplementedError
        self.text_projector = nn.Sequential(
            nn.Linear(self.text_encoder.config.hidden_size, d_model),
            nn.LayerNorm(d_model, eps=1e-12),
            nn.Dropout(0.1)
        )

        # Cross-encoder
        self.pos_embed = PositionEmbeddingLearned(3, d_model)
        bi_layer = BiEncoderLayer(
            d_model, dropout, activation, nhead, dim_feedforward,
            self_attend=False
        )
        self.cross_encoder = BiEncoder(bi_layer, 3)

        # Query sampling method
        if self.sampling == 'fps':
            self.fps_module = FPSModule(num_proposal)
        elif self.sampling == 'kps':
            self.points_obj_cls = PointsObjClsModule(d_model)
            self.gsample_module = GeneralSamplingModule()
        else:
            raise NotImplementedError

        # Proposal (layer for size and center)
        if self.size_cls_agnostic:
            self.proposal_head = ClsAgnosticPredictHead(
                num_class, num_heading_bin, num_proposal, d_model
            )
        else:
            self.proposal_head = PredictHead(
                num_class, num_heading_bin, num_size_cluster,
                mean_size_arr, num_proposal, d_model
            )

        # Transformer Decoder Projection
        self.decoder_key_proj = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.decoder_query_proj = nn.Conv1d(d_model, d_model, kernel_size=1)

        # Transformer decoder layers
        self.decoder = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            self.decoder.append(BiDecoderLayer(
                d_model, nhead, dim_feedforward, dropout, activation,
                self_position_embedding
            ))

        # Prediction Head
        self.prediction_heads = nn.ModuleList()
        for i in range(self.num_decoder_layers):
            if self.size_cls_agnostic:
                self.prediction_heads.append(ClsAgnosticPredictHead(
                    num_class, num_heading_bin, num_proposal, d_model
                ))
            else:
                self.prediction_heads.append(PredictHead(
                    num_class, num_heading_bin, num_size_cluster,
                    mean_size_arr, num_proposal, d_model
                ))

        # Init
        # self.init_weights()
        self.init_bn_momentum()
        if self.sync_bn:
            nn.SyncBatchNorm.convert_sync_batchnorm(self)

    def _run_backbones(self, inputs):
        """Run visual and text backbones."""
        # Visual encoder
        end_points = self.backbone_net(inputs['point_clouds'], end_points={})
        end_points['seed_inds'] = end_points['fp2_inds']
        end_points['seed_xyz'] = end_points['fp2_xyz']
        end_points['seed_features'] = end_points['fp2_features']
        # Text encoder
        tokenized = self.tokenizer.batch_encode_plus(
            inputs['text'], padding="longest", return_tensors="pt"
        ).to(inputs['point_clouds'].device)
        encoded_text = self.text_encoder(**tokenized)
        text_feats = self.text_projector(encoded_text.last_hidden_state)
        # Invert attention mask that we get from huggingface
        # because its the opposite in pytorch transformer
        text_attention_mask = tokenized.attention_mask.ne(1).bool()
        end_points['text_feats'] = text_feats
        end_points['text_attention_mask'] = text_attention_mask
        return end_points

    def _generate_queries(self, xyz, features, end_points):
        if self.sampling == 'fps':
            xyz, features, sample_inds = self.fps_module(xyz, features)
        elif self.sampling == 'kps':
            points_obj_cls_logits = self.points_obj_cls(features)
            end_points['seeds_obj_cls_logits'] = points_obj_cls_logits
            sample_inds = torch.topk(
                torch.sigmoid(points_obj_cls_logits).squeeze(1),
                self.num_proposal
            )[1].int()
            xyz, features, sample_inds = self.gsample_module(
                xyz, features, sample_inds
            )
        else:
            raise NotImplementedError
        end_points['query_points_xyz'] = xyz  # (B, V, 3)
        end_points['query_points_feature'] = features  # (B, F, V)
        end_points['query_points_sample_inds'] = sample_inds  # (B, V)
        return end_points

    def forward(self, inputs):
        """
        Forward pass.

        Args:
            inputs: dict
                {point_clouds, text}

                point_clouds: Variable(torch.cuda.FloatTensor)
                    (B, N, 3 + input_channels) tensor
                    Point cloud to run predicts on
                    Each point in the point-cloud MUST
                    be formated as (x, y, z, features...)

        Returns:
            end_points: dict
        """
        # Run backbones
        end_points = self._run_backbones(inputs)
        points_xyz = end_points['fp2_xyz']  # (B, points, 3)
        points_features = end_points['fp2_features']  # (B, F, points)
        text_feats = end_points['text_feats']  # (B, L, F)
        text_padding_mask = end_points['text_attention_mask']  # (B, L)

        # Cross-encoder
        points_features, text_feats = self.cross_encoder(
            vis_feats=points_features.transpose(1, 2).contiguous(),
            pos_feats=self.pos_embed(points_xyz).transpose(1, 2).contiguous(),
            padding_mask=torch.zeros(
                len(points_xyz), points_xyz.size(1)
            ).to(points_xyz.device).bool(),
            text_feats=text_feats,
            text_padding_mask=text_padding_mask
        )
        points_features = points_features.transpose(1, 2)
        points_features = points_features.contiguous()  # (B, F, points)
        end_points['seed_features'] = points_features

        # Query Points Generation
        end_points = self._generate_queries(
            points_xyz, points_features, end_points
        )
        cluster_feature = end_points['query_points_feature']  # (B, F, V)
        cluster_xyz = end_points['query_points_xyz']  # (B, V, 3)

        # Proposals (one for each query)
        proposal_center, proposal_size = self.proposal_head(
            cluster_feature,
            base_xyz=cluster_xyz,
            end_points=end_points,
            prefix='proposal_'
        )
        base_xyz = proposal_center.detach().clone()  # (B, V, 3)
        base_size = proposal_size.detach().clone()  # (B, V, 3)

        # Transformer Decoder and Prediction
        query = self.decoder_query_proj(cluster_feature)
        query = query.transpose(1, 2).contiguous()  # (B, V, F)

        for i in range(self.num_decoder_layers):
            prefix = 'last_' if i == self.num_decoder_layers-1 else f'{i}head_'

            # Position Embedding for Self-Attention
            if self.self_position_embedding == 'none':
                query_pos = None
            elif self.self_position_embedding == 'xyz_learned':
                query_pos = base_xyz
            elif self.self_position_embedding == 'loc_learned':
                query_pos = torch.cat([base_xyz, base_size], -1)
            else:
                raise NotImplementedError

            # Transformer Decoder Layer
            query = self.decoder[i](
                query, points_features.transpose(1, 2).contiguous(),
                text_feats, query_pos,
                torch.zeros(
                    len(points_xyz), points_xyz.size(1)
                ).to(points_xyz.device).bool(),
                text_padding_mask
            )  # (B, V, F)

            # Prediction
            base_xyz, base_size = self.prediction_heads[i](
                query.transpose(1, 2).contiguous(),  # (B, F, V)
                base_xyz=cluster_xyz,
                end_points=end_points,
                prefix=prefix
            )
            base_xyz = base_xyz.detach().clone()
            base_size = base_size.detach().clone()

        return end_points

    def init_weights(self):
        """Initialize transformer with xavier."""
        for m in self.decoder.parameters():
            if m.dim() > 1:
                nn.init.xavier_uniform_(m)

    def init_bn_momentum(self):
        """Initialize batch-norm momentum."""
        for m in self.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
                m.momentum = self.bn_momentum


if __name__ == "__main__":
    net = GroupFreeLangDetector().cuda()
    from time import time
    t = time()
    out = net({
        'point_clouds': torch.rand(2, 3500, 3).cuda(),
        'text': ['this is not a boy', 'but this is']
    })
    print(time() - t)
    print('')
    for key, value in out.items():
        if 'last' in key:
            print(key, value.shape)
    print(out['last_center'][0, 0])
    print(out['last_pred_size'][0, 0])
